bert-pytorch版源码详细解读.pdf
《bert-pytorch版源码详细解读.pdf》由会员分享,可在线阅读,更多相关《bert-pytorch版源码详细解读.pdf(6页珍藏版)》请在淘文阁 - 分享文档赚钱的网站上搜索。
1、bert-pytorch版源码详细解读前主要代码1.主函数class BertModel(nn.Module):def _init_(self,config:BertConfig):super(BertModel,self)._init_()self.embeddings=BERTEmbeddings(config)self.encoder=BERTEncoder(config)self.pooler=BERTPooler(config)def forward(self,input_ids,token_type_ids=None,attention_mask=None):if attentio
2、n_mask is None:attention_mask=torch.ones_like(input_ids)if token_type_ids is None:token_type_ids=torch.zeros_like(input_ids)#attention_mask的维度应保持和多头的hidden_states致#!个感觉这extended_attention_mask 还应该扩展下,感觉这个维度不太对!extended_attention_mask=attention_mask.unsqueeze(1).unsqueeze(2)extended_attention_mask=ex
3、tended_attention_mask.float()#mask部分token的权重直接给-10000,使其在self-att的时候基本不起作。extended_attention_mask=(1.0-extended_attention_mask)*-10000.0#根据input_ids,token_type_ids以及position_ids来确定初始embeddings embedding_output=self.embeddings(input_ids,token_type_ids)#核层,由以多层self_attention为主的神经络构成 all_encoder_layers
4、=self.encoder(embedding_output,extended_attention_mask)#最后层隐藏层 sequence_output=all_encoder_layers-1#取出最后层隐藏层的cls的表征,经过络层(self.pooler)后得到pooled_output pooled_output=self.pooler(sequence_output)return all_encoder_layers,pooled_output致讲下吧:般必传的三个参数input_idx,token_type_ids,attention_mask。维度均为(batch_size,
5、max_sent_length)input_idx就是每个token对应的idx,对应关系在预训练模型件集的vocab.txttoken_type_ids有两种取值(0对应sentenceA,1对应sentenceB)该tensor会在self.embeddings的时候和input_iput成的embedding相加成初始的embeddings。attention_mask有两种取值(1代表mask词,0代表mask掉的词)般来说在finetune阶段,我们会把padding部分都设成mask掉的词。其他基本也都注释了。2.BertEmbedding层class BERTEmbeddings
6、(nn.Module):def _init_(self,config):super(BERTEmbeddings,self)._init_()self.word_embeddings=nn.Embedding(config.vocab_size,config.hidden_size)self.position_embeddings=nn.Embedding(config.max_position_embeddings,config.hidden_size)self.token_type_embeddings=nn.Embedding(config.type_vocab_size,config.
7、hidden_size)self.LayerNorm=BERTLayerNorm(config)self.dropout=nn.Dropout(config.hidden_dropout_prob)def forward(self,input_ids,token_type_ids=None):#根据每个token的位置成position_ids,很直观 seq_length=input_ids.size(1)position_ids=torch.arange(seq_length,dtype=torch.long,device=input_ids.device)position_ids=pos
8、ition_ids.unsqueeze(0).expand_as(input_ids)if token_type_ids is None:token_type_ids=torch.zeros_like(input_ids)#这三个embeddings相信家可以参见下图就了然了 words_embeddings=self.word_embeddings(input_ids)position_embeddings=self.position_embeddings(position_ids)token_type_embeddings=self.token_type_embeddings(token_
9、type_ids)embeddings=words_embeddings+position_embeddings+token_type_embeddings#最后过个layerNorm和dropout层 embeddings=self.LayerNorm(embeddings)embeddings=self.dropout(embeddings)return embeddings3.BertEnocder层class BERTEncoder(nn.Module):def _init_(self,config):super(BERTEncoder,self)._init_()layer=BERT
10、Layer(config)self.layer=nn.ModuleList(copy.deepcopy(layer)for _ in range(config.num_hidden_layers)def forward(self,hidden_states,attention_mask):all_encoder_layers=for layer_module in self.layer:hidden_states=layer_module(hidden_states,attention_mask)all_encoder_layers.append(hidden_states)return al
11、l_encoder_layers class BERTLayer(nn.Module):def _init_(self,config):super(BERTLayer,self)._init_()self.attention=BERTAttention(config)self.intermediate=BERTIntermediate(config)self.output=BERTOutput(config)def forward(self,hidden_states,attention_mask):attention_output=self.attention(hidden_states,a
12、ttention_mask)intermediate_output=self.intermediate(attention_output)layer_output=self.output(intermediate_output,attention_output)return layer_outputBertEncoder层实质上就是由多个(num_hidden_layers)BertLayer层堆叠成。BertLayer由attention,intermediate和output三部分组成,下分别来看。3.1BertTAttention重头戏开始!详见注释,看完你会发现很简单。class BE
13、RTAttention(nn.Module):def _init_(self,config):super(BERTAttention,self)._init_()self.self=BERTSelfAttention(config)self.output=BERTSelfOutput(config)def forward(self,input_tensor,attention_mask):self_output=self.self(input_tensor,attention_mask)attention_output=self.output(self_output,input_tensor)
14、return attention_output class BERTSelfAttention(nn.Module):def _init_(self,config):super(BERTSelfAttention,self)._init_()if config.hidden_size%config.num_attention_heads!=0:raise ValueError(The hidden size(%d)is not a multiple of the number of attention heads(%d)%(config.hidden_size,config.num_atten
15、tion_heads)self.num_attention_heads=config.num_attention_heads#多头self_attention self.attention_head_size=int(config.hidden_size/config.num_attention_heads)#每个头的维度,般是768/12=64 self.all_head_size=self.num_attention_heads*self.attention_head_size self.query=nn.Linear(config.hidden_size,self.all_head_si
- 配套讲稿:
如PPT文件的首页显示word图标,表示该PPT已包含配套word讲稿。双击word图标可打开word文档。
- 特殊限制:
部分文档作品中含有的国旗、国徽等图片,仅作为作品整体效果示例展示,禁止商用。设计者仅对作品中独创性部分享有著作权。
- 关 键 词:
- bert pytorch 源码 详细 解读
限制150内