Skip to content

shape inconsistence between code.last and code_logits #16

@Wenjun-Peng

Description

@Wenjun-Peng

After add_last function, the length of code is 4

loop = args.max_length
config['save_path'] = args.save_path + f'-{loop}-fit'
config['code_length'] = loop + 1
config['prev_model'] = checkpoint
add_last(f'{checkpoint}.code', args.code_num, f'{checkpoint}.code.last')
config['prev_id'] = f'{checkpoint}.code.last'
config['epochs'] = 1000
config['loss_w'] = 3
checkpoint = train(config)
test_dr(config)

However, code_logits was truncated to 3

if self.code_length == 1:
    return_code_logits = None
else:
    return_code_logits = code_logits[:, :-1].contiguous()

So, I got error:

Traceback (most recent call last):
File "/home/pengwenjun.pwj/GenRet/run.py", line 1195, in
main()
File "/home/pengwenjun.pwj/GenRet/run.py", line 1188, in main
checkpoint = train(config)
File "/home/pengwenjun.pwj/GenRet/run.py", line 719, in train
losses = OurTrainer.train_step(model, batch, gathered=False)
File "/home/pengwenjun.pwj/GenRet/run.py", line 420, in train_step
query_code_loss = F.cross_entropy(query_outputs.code_logits.view(-1, code_number),
File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torch/nn/functional.py", line 3104, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (384) to match target batch_size (512).

How to solve it

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions