Skip to content

name, author and papar for the special UnregularizedRNN #437

@tesla-cat

Description

@tesla-cat
class UnregularizedRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, g, h, tau_over_dt=5):
        super(UnregularizedRNN, self).__init__()
        self.hidden_size = hidden_size
        self.tau_over_dt = tau_over_dt
        self.output_linear = nn.Linear(hidden_size, output_size)

        # Weight initialization
        self.J = nn.Parameter(torch.randn(hidden_size, hidden_size) * (g / torch.sqrt(torch.tensor(hidden_size, dtype=torch.float))))
        self.B = nn.Parameter(torch.randn(hidden_size, input_size) * (h / torch.sqrt(torch.tensor(input_size, dtype=torch.float))))
        self.bx = nn.Parameter(torch.zeros(hidden_size))

        # Nonlinearity
        self.nonlinearity = rectified_tanh

    def forward(self, input, hidden):

        # Calculate the visible firing rate from the hidden state.
        firing_rate_before = self.nonlinearity(hidden)

        # Update hidden state
        recurrent_drive = torch.matmul(self.J, firing_rate_before.transpose(0, 1))
        input_drive = torch.matmul(self.B, input.transpose(0, 1))
        total_drive = recurrent_drive + input_drive + self.bx.unsqueeze(1)
        total_drive = total_drive.transpose(0, 1)

        # Euler integration for continuous-time update
        hidden = hidden + (1 / self.tau_over_dt) * (-hidden + total_drive)

        # Calculate the new firing rate given the update.
        firing_rate = self.nonlinearity(hidden)

        # Project the firing rate linearly to form the output
        output = self.output_linear(firing_rate)

        return output, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)
  • or equivalently, following the math equation in the tutorial
class UnRegRNN(nn.Module):
    def __init__(s, I, H, O, g, h, tt=5):
        super().__init__()
        s.H, s.tt = H, tt
        s.out = nn.Linear(H, O)
        s.J = nn.Parameter(tc.randn(H, H) * (g / np.sqrt(H)))
        s.B = nn.Parameter(tc.randn(H, I) * (h / np.sqrt(I)))
        s.b = nn.Parameter(tc.zeros(H))
        s.act = rectified_tanh

    def init_h(s, B):
        return tc.zeros(B, s.H)

    def forward(s, x: tc.Tensor, h):
        h += (1 / s.tt) * (-h + mm(s.B, x) + mm(s.J, s.act(h)) + s.b)
        return s.out(s.act(h)), h

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions