bert-pytorch版源码详细解读.pdf
bert-pytorch版源码详细解读前主要代码1.主函数class BertModel(nn.Module):def _init_(self,config:BertConfig):super(BertModel,self)._init_()self.embeddings=BERTEmbeddings(config)self.encoder=BERTEncoder(config)self.pooler=BERTPooler(config)def forward(self,input_ids,token_type_ids=None,attention_mask=None):if attention_mask is None:attention_mask=torch.ones_like(input_ids)if token_type_ids is None:token_type_ids=torch.zeros_like(input_ids)#attention_mask的维度应保持和多头的hidden_states致#!个感觉这extended_attention_mask 还应该扩展下,感觉这个维度不太对!extended_attention_mask=attention_mask.unsqueeze(1).unsqueeze(2)extended_attention_mask=extended_attention_mask.float()#mask部分token的权重直接给-10000,使其在self-att的时候基本不起作。extended_attention_mask=(1.0-extended_attention_mask)*-10000.0#根据input_ids,token_type_ids以及position_ids来确定初始embeddings embedding_output=self.embeddings(input_ids,token_type_ids)#核层,由以多层self_attention为主的神经络构成 all_encoder_layers=self.encoder(embedding_output,extended_attention_mask)#最后层隐藏层 sequence_output=all_encoder_layers-1#取出最后层隐藏层的cls的表征,经过络层(self.pooler)后得到pooled_output pooled_output=self.pooler(sequence_output)return all_encoder_layers,pooled_output致讲下吧:般必传的三个参数input_idx,token_type_ids,attention_mask。维度均为(batch_size,max_sent_length)input_idx就是每个token对应的idx,对应关系在预训练模型件集的vocab.txttoken_type_ids有两种取值(0对应sentenceA,1对应sentenceB)该tensor会在self.embeddings的时候和input_iput成的embedding相加成初始的embeddings。attention_mask有两种取值(1代表mask词,0代表mask掉的词)般来说在finetune阶段,我们会把padding部分都设成mask掉的词。其他基本也都注释了。2.BertEmbedding层class BERTEmbeddings(nn.Module):def _init_(self,config):super(BERTEmbeddings,self)._init_()self.word_embeddings=nn.Embedding(config.vocab_size,config.hidden_size)self.position_embeddings=nn.Embedding(config.max_position_embeddings,config.hidden_size)self.token_type_embeddings=nn.Embedding(config.type_vocab_size,config.hidden_size)self.LayerNorm=BERTLayerNorm(config)self.dropout=nn.Dropout(config.hidden_dropout_prob)def forward(self,input_ids,token_type_ids=None):#根据每个token的位置成position_ids,很直观 seq_length=input_ids.size(1)position_ids=torch.arange(seq_length,dtype=torch.long,device=input_ids.device)position_ids=position_ids.unsqueeze(0).expand_as(input_ids)if token_type_ids is None:token_type_ids=torch.zeros_like(input_ids)#这三个embeddings相信家可以参见下图就了然了 words_embeddings=self.word_embeddings(input_ids)position_embeddings=self.position_embeddings(position_ids)token_type_embeddings=self.token_type_embeddings(token_type_ids)embeddings=words_embeddings+position_embeddings+token_type_embeddings#最后过个layerNorm和dropout层 embeddings=self.LayerNorm(embeddings)embeddings=self.dropout(embeddings)return embeddings3.BertEnocder层class BERTEncoder(nn.Module):def _init_(self,config):super(BERTEncoder,self)._init_()layer=BERTLayer(config)self.layer=nn.ModuleList(copy.deepcopy(layer)for _ in range(config.num_hidden_layers)def forward(self,hidden_states,attention_mask):all_encoder_layers=for layer_module in self.layer:hidden_states=layer_module(hidden_states,attention_mask)all_encoder_layers.append(hidden_states)return all_encoder_layers class BERTLayer(nn.Module):def _init_(self,config):super(BERTLayer,self)._init_()self.attention=BERTAttention(config)self.intermediate=BERTIntermediate(config)self.output=BERTOutput(config)def forward(self,hidden_states,attention_mask):attention_output=self.attention(hidden_states,attention_mask)intermediate_output=self.intermediate(attention_output)layer_output=self.output(intermediate_output,attention_output)return layer_outputBertEncoder层实质上就是由多个(num_hidden_layers)BertLayer层堆叠成。BertLayer由attention,intermediate和output三部分组成,下分别来看。3.1BertTAttention重头戏开始!详见注释,看完你会发现很简单。class BERTAttention(nn.Module):def _init_(self,config):super(BERTAttention,self)._init_()self.self=BERTSelfAttention(config)self.output=BERTSelfOutput(config)def forward(self,input_tensor,attention_mask):self_output=self.self(input_tensor,attention_mask)attention_output=self.output(self_output,input_tensor)return attention_output class BERTSelfAttention(nn.Module):def _init_(self,config):super(BERTSelfAttention,self)._init_()if config.hidden_size%config.num_attention_heads!=0:raise ValueError(The hidden size(%d)is not a multiple of the number of attention heads(%d)%(config.hidden_size,config.num_attention_heads)self.num_attention_heads=config.num_attention_heads#多头self_attention self.attention_head_size=int(config.hidden_size/config.num_attention_heads)#每个头的维度,般是768/12=64 self.all_head_size=self.num_attention_heads*self.attention_head_size self.query=nn.Linear(config.hidden_size,self.all_head_size)self.key=nn.Linear(config.hidden_size,self.all_head_size)self.value=nn.Linear(config.hidden_size,self.all_head_size)self.dropout=nn.Dropout(config.attention_probs_dropout_prob)def transpose_for_scores(self,x):new_x_shape=x.size():-1+(self.num_attention_heads,self.attention_head_size)x=x.view(*new_x_shape)return x.permute(0,2,1,3)def forward(self,hidden_states,attention_mask):#经典成QKV#(batch_size,max_sen_length,hidden_size)-(batch_size,max_sen_length,hidden_size)#(8,512,768)-(8,512,768)mixed_query_layer=self.query(hidden_states)mixed_key_layer=self.key(hidden_states)mixed_value_layer=self.value(hidden_states)#改变维度,形成多头,记住是在成QKV之后才的事#(batch_size,max_sen_length,hidden_size)-(batch_size,num_attention_heads,max_sen_length,attention_head_size)#(8,512,768)-(8,12,512,64)query_layer=self.transpose_for_scores(mixed_query_layer)key_layer=self.transpose_for_scores(mixed_key_layer)value_layer=self.transpose_for_scores(mixed_value_layer)#QK tensor相乘,只对最后两维做矩阵乘法#(batch_size,num_attention_heads,max_sen_length,attention_head_size)*(batch_size,num_attention_heads,attention_head_size,max_sen_length)-(batch_size,num_attention_heads,max_sen_length,max_sen_length)#(8,12,512,64)*(8,12,64,512)-(8,12,512,512)attention_scores=torch.matmul(query_layer,key_layer.transpose(-1,-2)#除以维度的开,这是为了使QV的结果差变为1,使得sortmax后不会发梯度消失。attention_scores=attention_scores/math.sqrt(self.attention_head_size)#之前传的attention_mask在此刻发挥它的作了!把mask掉的词的“权重”变成-10000,softmax后就基本等于0。attention_scores=attention_scores+attention_mask#softmax加个dropout,这也没啥好说的 attention_probs=nn.Softmax(dim=-1)(attention_scores)attention_probs=self.dropout(attention_probs)#最后再和V相乘,此就完成了经典的softmax(QK/sqrt(dk)*V的操作!#(8,12,512,512)*(8,12,512,64)-(8,12,512,64)context_layer=torch.matmul(attention_probs,value_layer)#之后就是把维度进还原#(8,12,512,64)-(8,512,12,64)-(8,512,768)context_layer=context_layer.permute(0,2,1,3).contiguous()new_context_layer_shape=context_layer.size():-2+(self.all_head_size,)context_layer=context_layer.view(*new_context_layer_shape)return context_layerclass BERTSelfOutput(nn.Module):def _init_(self,config):super(BERTSelfOutput,self)._init_()self.dense=nn.Linear(config.hidden_size,config.hidden_size)self.LayerNorm=BERTLayerNorm(config)self.dropout=nn.Dropout(config.hidden_dropout_prob)def forward(self,hidden_states,input_tensor):#很平淡的全连接层加上dropout和LayerNorm hidden_states=self.dense(hidden_states)hidden_states=self.dropout(hidden_states)hidden_states=self.LayerNorm(hidden_states+input_tensor)return hidden_states3.2 BertIntermediate&BertOutputclass BERTIntermediate(nn.Module):def _init_(self,config):super(BERTIntermediate,self)._init_()#之前直不清楚这个intermediate_size是嘛的,原来是self_attention后还跟了BERTIntermediate和BERTOutput2个全连接层。self.dense=nn.Linear(config.hidden_size,config.intermediate_size)self.intermediate_act_fn=gelu def forward(self,hidden_states):hidden_states=self.dense(hidden_states)hidden_states=self.intermediate_act_fn(hidden_states)return hidden_states class BERTOutput(nn.Module):def _init_(self,config):super(BERTOutput,self)._init_()self.dense=nn.Linear(config.intermediate_size,config.hidden_size)self.LayerNorm=BERTLayerNorm(config)self.dropout=nn.Dropout(config.hidden_dropout_prob)def forward(self,hidden_states,input_tensor):hidden_states=self.dense(hidden_states)hidden_states=self.dropout(hidden_states)hidden_states=self.LayerNorm(hidden_states+input_tensor)return hidden_states!这个和我之前看的transformers的残差连接层差别还挺的,所以并不完全和transformers的encoder部分结构致。这之后就是主函数的步骤收尾作了,这也不再赘述。4.补充下补充下中途涉及到的相关类(LayerNorm)的代码4.1 BertLayerNormclass BERTLayerNorm(nn.Module):def _init_(self,config,variance_epsilon=1e-12):Construct a layernorm module in the TF style(epsilon inside the square root).super(BERTLayerNorm,self)._init_()self.gamma=nn.Parameter(torch.ones(config.hidden_size)self.beta=nn.Parameter(torch.zeros(config.hidden_size)self.variance_epsilon=variance_epsilon def forward(self,x):u=x.mean(-1,keepdim=True)s=(x-u).pow(2).mean(-1,keepdim=True)x=(x-u)/torch.sqrt(s+self.variance_epsilon)return self.gamma*x+self.beta1.batchNorm是对多个样本进标准化,layerNorm是对单样本标准化。2.BertLayerNorm除了标准化以外还加上了gamma和beta的变化。4.2 BertPoolerclass BERTPooler(nn.Module):def _init_(self,config):super(BERTPooler,self)._init_()self.dense=nn.Linear(config.hidden_size,config.hidden_size)self.activation=nn.Tanh()def forward(self,hidden_states):#取出cls后过个全连接层和激活函数。first_token_tensor=hidden_states:,0 pooled_output=self.dense(first_token_tensor)pooled_output=self.activation(pooled_output)return pooled_output上也提到了,BertPooler就是专门为cls设计的4.3 geludef gelu(x):Implementation of the gelu activation function.return x*0.5*(1.0+torch.erf(x/math.sqrt(2.0)4.4 transpose_for_scores def transpose_for_scores(self,x):new_x_shape=x.size():-1+(self.num_attention_heads,self.attention_head_size)x=x.view(*new_x_shape)return x.permute(0,2,1,3)总结