From 2fa13866a0a852163af3fe4b9b6d16b4055864d2 Mon Sep 17 00:00:00 2001 From: liushaokong Date: Tue, 26 Apr 2022 10:09:31 +0800 Subject: [PATCH] update transformer.py, add encdec_attention cache --- thumt/models/transformer.py | 43 ++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/thumt/models/transformer.py b/thumt/models/transformer.py index f9aa631..b9571be 100644 --- a/thumt/models/transformer.py +++ b/thumt/models/transformer.py @@ -15,7 +15,9 @@ class AttentionSubLayer(modules.Module): - def __init__(self, params, name="attention"): + def __init__(self, + params, + name="attention"): super(AttentionSubLayer, self).__init__(name=name) self.dropout = params.residual_dropout @@ -23,7 +25,11 @@ def __init__(self, params, name="attention"): with utils.scope(name): self.attention = modules.MultiHeadAttention( - params.hidden_size, params.num_heads, params.attention_dropout) + params.hidden_size, + params.num_heads, + params.attention_dropout, + ) + self.layer_norm = modules.LayerNorm(params.hidden_size) def forward(self, x, bias, memory=None, state=None): @@ -33,11 +39,11 @@ def forward(self, x, bias, memory=None, state=None): y = x if self.training or state is None: - y = self.attention(y, bias, memory, None) - else: + y = self.attention(y, bias, memory, kv=None) + else: kv = [state["k"], state["v"]] y, k, v = self.attention(y, bias, memory, kv) - state["k"], state["v"] = k, v + state["k"], state["v"] = k, v # update state here y = nn.functional.dropout(y, self.dropout, self.training) @@ -82,11 +88,11 @@ def __init__(self, params, name="layer"): super(TransformerEncoderLayer, self).__init__(name=name) with utils.scope(name): - self.self_attention = AttentionSubLayer(params) + self.self_attention = AttentionSubLayer(params) self.feed_forward = FFNSubLayer(params) def forward(self, x, bias): - x = self.self_attention(x, bias) + x = self.self_attention(x, bias) # memory=None, kv=None x = self.feed_forward(x) return x @@ -100,12 +106,15 @@ def __init__(self, params, name="layer"): self.self_attention = AttentionSubLayer(params, name="self_attention") self.encdec_attention = AttentionSubLayer(params, - name="encdec_attention") + name="encdec_attention") self.feed_forward = FFNSubLayer(params) - def __call__(self, x, attn_bias, encdec_bias, memory, state=None): - x = self.self_attention(x, attn_bias, state=state) - x = self.encdec_attention(x, encdec_bias, memory) + def __call__(self, x, attn_bias, encdec_bias, memory, + self_attention_state=None, + encdec_attention_state=None + ): + x = self.self_attention(x, attn_bias, state=self_attention_state) # memory=None + x = self.encdec_attention(x, encdec_bias, memory, state=encdec_attention_state) # add state, but not used in training x = self.feed_forward(x) return x @@ -157,7 +166,9 @@ def forward(self, x, attn_bias, encdec_bias, memory, state=None): for i, layer in enumerate(self.layers): if state is not None: x = layer(x, attn_bias, encdec_bias, memory, - state["decoder"]["layer_%d" % i]) + state["self_attention_kv"]["layer_%d" % i], # self_attention + state["encdec_attention_kv"]["layer_%d" % i] # encdec_attention + ) else: x = layer(x, attn_bias, encdec_bias, memory, None) @@ -320,13 +331,19 @@ def forward(self, features, labels, mode="train", level="sentence"): def empty_state(self, batch_size, device): state = { - "decoder": { + "self_attention_kv": { "layer_%d" % i: { "k": torch.zeros([batch_size, 0, self.hidden_size], device=device), "v": torch.zeros([batch_size, 0, self.hidden_size], device=device) } for i in range(self.num_decoder_layers) + }, + "encdec_attention_kv": { + "layer_%d" % i: { + "k": None, + "v": None + } for i in range(self.num_decoder_layers) } }