diff --git a/techqa_metrics.py b/techqa_metrics.py index edcf2eb..ea4b6d0 100644 --- a/techqa_metrics.py +++ b/techqa_metrics.py @@ -163,6 +163,8 @@ def predict_output(device, eval_features: List[TechQaInputFeature], eval_dataset 'attention_mask': batch[1], 'token_type_ids': batch[2] } + if model_type in ["roberta"]: + del inputs["token_type_ids"] outputs = model(**inputs) feature_vector_indeces = batch[5] @@ -213,4 +215,4 @@ def _generate_predictions(nbest_spans_tracker: BestSpanTracker, corpus: Dict, prediction.start_offset:prediction.end_offset], 'start_offset': prediction.start_offset, 'end_offset': prediction.end_offset, - 'score': prediction.score} for prediction in top_spans] \ No newline at end of file + 'score': prediction.score} for prediction in top_spans]