Skip to content

Why using test_loader in train_fn() function #1

@Lucmon

Description

@Lucmon

Very nice work. But I am a little confused about some details.
In mnist_runner.py, the train_fn() function is like this:

def train_fn(state, params, timesteps):
        net = Net()
        copy_params(base_net, net)

        train_net(params, train_loader, net, timesteps, meta_train=True)

        avg_loss = test_net(test_loader, net, timesteps)
        compute = timesteps
        return avg_loss, compute

I think the objective of this function is to do a forward pass and calculate the loss on the training dataset. However, the parameter of test_net() is test_loader. Why not directly using the train_loader?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions