After add_last function, the length of code is 4
loop = args.max_length
config['save_path'] = args.save_path + f'-{loop}-fit'
config['code_length'] = loop + 1
config['prev_model'] = checkpoint
add_last(f'{checkpoint}.code', args.code_num, f'{checkpoint}.code.last')
config['prev_id'] = f'{checkpoint}.code.last'
config['epochs'] = 1000
config['loss_w'] = 3
checkpoint = train(config)
test_dr(config)
However, code_logits was truncated to 3
if self.code_length == 1:
return_code_logits = None
else:
return_code_logits = code_logits[:, :-1].contiguous()
So, I got error:
Traceback (most recent call last):
File "/home/pengwenjun.pwj/GenRet/run.py", line 1195, in
main()
File "/home/pengwenjun.pwj/GenRet/run.py", line 1188, in main
checkpoint = train(config)
File "/home/pengwenjun.pwj/GenRet/run.py", line 719, in train
losses = OurTrainer.train_step(model, batch, gathered=False)
File "/home/pengwenjun.pwj/GenRet/run.py", line 420, in train_step
query_code_loss = F.cross_entropy(query_outputs.code_logits.view(-1, code_number),
File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torch/nn/functional.py", line 3104, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (384) to match target batch_size (512).
How to solve it
After add_last function, the length of code is 4
However, code_logits was truncated to 3
So, I got error:
How to solve it