From 7553f59e5cb21ce5c008d8b19e61b62c6a6f2a80 Mon Sep 17 00:00:00 2001 From: abhi1nandy2 Date: Wed, 29 Sep 2021 18:54:14 +0200 Subject: [PATCH] condition for roberta, removing toekn_type_ids --- techqa_metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/techqa_metrics.py b/techqa_metrics.py index edcf2eb..ea4b6d0 100644 --- a/techqa_metrics.py +++ b/techqa_metrics.py @@ -163,6 +163,8 @@ def predict_output(device, eval_features: List[TechQaInputFeature], eval_dataset 'attention_mask': batch[1], 'token_type_ids': batch[2] } + if model_type in ["roberta"]: + del inputs["token_type_ids"] outputs = model(**inputs) feature_vector_indeces = batch[5] @@ -213,4 +215,4 @@ def _generate_predictions(nbest_spans_tracker: BestSpanTracker, corpus: Dict, prediction.start_offset:prediction.end_offset], 'start_offset': prediction.start_offset, 'end_offset': prediction.end_offset, - 'score': prediction.score} for prediction in top_spans] \ No newline at end of file + 'score': prediction.score} for prediction in top_spans]