-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoder.py
More file actions
71 lines (61 loc) · 2.17 KB
/
decoder.py
File metadata and controls
71 lines (61 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
RNN-T Decoder (Prediction Network) — paper Section 3.2
The paper uses a single-LSTM-layer decoder.
It takes the previously predicted token and outputs a hidden state
that is combined with the encoder output in the Joint Network.
Architecture:
Embedding(vocab_size, d_model)
→ LSTM(d_model, decoder_dim, num_layers=1)
→ output: (batch, 1, decoder_dim)
"""
import torch
import torch.nn as nn
class RNNTDecoder(nn.Module):
"""
Single-layer LSTM prediction network for RNN-T.
Args:
vocab_size : number of tokens (characters + blank)
embed_dim : embedding dimension (= d_model for simplicity)
decoder_dim : LSTM hidden size (320 for S, 640 for M/L — Table 1)
num_layers : paper uses 1
"""
def __init__(
self,
vocab_size: int,
embed_dim: int = 144,
decoder_dim: int = 320,
num_layers: int = 1,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(
input_size = embed_dim,
hidden_size = decoder_dim,
num_layers = num_layers,
batch_first = True,
)
self.decoder_dim = decoder_dim
def forward(
self,
targets: torch.Tensor, # (B, U)
hidden: tuple | None = None,
) -> tuple[torch.Tensor, tuple]:
"""
Args:
targets : (B, U) — token indices (SOS-prepended during training)
hidden : optional LSTM hidden state for streaming
Returns:
out : (B, U, decoder_dim)
hidden : updated LSTM state
"""
emb = self.embedding(targets) # (B, U, embed_dim)
out, hidden = self.lstm(emb, hidden) # (B, U, decoder_dim)
return out, hidden
# sanity check
if __name__ == "__main__":
vocab = 30
decoder = RNNTDecoder(vocab_size=vocab, embed_dim=144, decoder_dim=320)
targets = torch.randint(0, vocab, (2, 5)) # batch=2, U=5 tokens
out, _ = decoder(targets)
print(f"Input : {targets.shape}")
print(f"Output : {out.shape}") # (2, 5, 320)