classBertLayer(nn.Module): def__init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 # self-attention self.attention = BertAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention # 增加了cross-attention if self.add_cross_attention: ifnot self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") # 发现cross-attention和self-attention使用的是同一个attention类说明进行了复用 self.crossattention = BertAttention(config, position_embedding_type="absolute") # layer中的全连接层 self.intermediate = BertIntermediate(config) self.output = BertOutput(config)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
if self.is_decoder and encoder_hidden_states isnotNone: ifnothasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" )
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value isnotNoneelseNone cross_attention_outputs = self.crossattention( attention_output, # self-attn的输出做cross-attn的Q输入 attention_mask, head_mask, encoder_hidden_states,# 来自encoder的隐藏层做K与V的输入 encoder_attention_mask, cross_attn_past_key_value, output_attentions, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
ifnot config.is_decoder: logger.warning("If you want to use `BertGenerationDecoder` as a standalone, add `is_decoder=True.`") # 复用了BertGenerationEncoder,通过config.is_decoder调整BertGenerationEncoder的行为 self.bert = BertGenerationEncoder(config) self.lm_head = BertGenerationOnlyLMHead(config)
>>> normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization. >>> Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google/bert_for_seq_generation_L-24_bbc_encoder and are newly initialized: ['lm_head.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias'] >>>You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. >>> torch.Size([1, 3, 50358])