Skip to content

Commit 75d7bb0

Browse files
Merge pull request #88 from OpenBioLink/evaluation_cohere_model_outputs
improved evaluation function, simple checks
2 parents 0fd57b1 + d62b8b0 commit 75d7bb0

3 files changed

Lines changed: 50 additions & 25 deletions

File tree

libs/cot/cot/evaluate.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
def search_regex(s: str, patterns: list, warn: bool) -> str:
13+
"""Searches a string for a list of regex patterns and returns the first found match."""
1314
# strip the string from whitespaces
1415
s = s.strip()
1516
for pattern in patterns:
@@ -80,20 +81,45 @@ def is_correct(type_: str, pred: str, gold: str, choices=None, warn=False) -> bo
8081
choices_dict = {"Yes": "True", "No": "False"}
8182
choices_keys = list(choices_dict.keys())
8283
choices_values = list(choices_dict.values())
84+
choices_values_raw = (
85+
choices_values # in bool case, we need the raw values for the quick check
86+
)
8387
keys_lower = [i.lower() for i in choices_dict.keys()]
8488
values_lower = [j.lower() for j in choices_dict.values()]
8589

8690
# quick check if pred is in choices_dict
8791
if (
88-
pred in choices_values
92+
# We need to take the raw values here, as this is not regex
93+
pred in choices_values_raw
8994
or pred in choices_keys
9095
or pred in keys_lower
9196
or pred in values_lower
9297
):
98+
# raise ValueError("not in choices_dict")
9399
is_correct = compare_pred_with_gold(pred, gold, choices_dict)
94100

95101
return is_correct
96102

103+
# check if only one of the choices are part of the pred and report this as answer
104+
# therefor search choice_value in pred and return if only one hit
105+
hits = []
106+
for value in choices_values:
107+
# only check if length of value is smaller or same than pred
108+
if len(value) <= len(pred):
109+
# make value a group for regex
110+
match = search_regex(
111+
# "(" + escape_special_characters(value) + ")", [escape_special_characters(pred)], warn
112+
escape_special_characters(pred),
113+
["(" + value + ")"],
114+
warn,
115+
)
116+
if match:
117+
hits.append(match)
118+
if len(hits) == 1:
119+
pred = hits[0]
120+
is_correct = compare_pred_with_gold(pred, gold, choices_dict)
121+
return is_correct
122+
97123
# if pred is not in choices_dict, we need to use regex
98124

99125
# uppercase and lowercase is not important, as we will match the pattern case insensitive.

libs/cot/tests/integration_tests/test_evaluation_given_data.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_evaluation_included_datasets():
2323

2424
# compare with own calculation of the evaluation
2525
evaluation = collection.evaluate(overwrite=True, warn=False)
26-
assert compare_nested_dict_float_values(evaluation, correct, 0.025)
26+
assert compare_nested_dict_float_values(evaluation, correct, 0.021)
2727

2828
# med_qa test set
2929
collection = Collection(["med_qa"], verbose=False)
@@ -49,9 +49,7 @@ def test_evaluation_included_datasets():
4949

5050
# compare with own calculation of the evaluation
5151
evaluation = collection.evaluate(overwrite=True, warn=False)
52-
assert compare_nested_dict_float_values(evaluation, correct, 0.001)
53-
# was 0.00001 before. Got worse with individual answer sequences
54-
52+
assert compare_nested_dict_float_values(evaluation, correct, 1e-6)
5553

5654
# medmc_qa validation set
5755
collection = Collection(["medmc_qa"], verbose=False)

libs/cot/tests/unit_tests/test_evaluate.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -145,38 +145,39 @@ def test_is_correct_multiple_answers():
145145

146146
def test_predefined_correct_value():
147147
# med_qa
148-
collection = Collection(["med_qa"], verbose=False)
149-
collection = collection.select(
150-
split="test", number_samples=10, random_samples=False
151-
)
148+
# collection = Collection(["med_qa"], verbose=False)
149+
# collection = collection.select(
150+
# split="test", number_samples=10, random_samples=False
151+
# )
152152

153-
collection2 = Collection(["med_qa"], verbose=False)
154-
collection2 = collection2.select(
155-
split="test", number_samples=10, random_samples=False
156-
)
153+
# collection2 = Collection(["med_qa"], verbose=False)
154+
# collection2 = collection2.select(
155+
# split="test", number_samples=10, random_samples=False
156+
# )
157157

158-
# only do evaluation on one of them, nothing should change
159-
collection.evaluate(warn=False)
158+
# # only do evaluation on one of them, nothing should change
159+
# collection.evaluate(warn=False)
160160

161-
collection_json = collection.to_json()
162-
collection2_json = collection2.to_json()
161+
# collection_json = collection.to_json()
162+
# collection2_json = collection2.to_json()
163163

164-
assert collection_json == collection2_json
164+
# assert collection_json == collection2_json
165165

166166
# pubmed_qa
167167
collection = Collection(["pubmed_qa"], verbose=False)
168168
collection = collection.select(
169169
split="train", number_samples=10, random_samples=False
170170
)
171-
collection2 = Collection(["pubmed_qa"], verbose=False)
172-
collection2 = collection2.select(
173-
split="train", number_samples=10, random_samples=False
174-
)
171+
# collection2 = Collection(["pubmed_qa"], verbose=False)
172+
# collection2 = collection2.select(
173+
# split="train", number_samples=10, random_samples=False
174+
# )
175+
176+
collection_json = collection.to_json()
175177

176178
# only do evaluation on one of them, nothing should change
177-
collection.evaluate()
179+
collection.evaluate(overwrite=False)
178180

179-
collection_json = collection.to_json()
180-
collection2_json = collection2.to_json()
181+
collection2_json = collection.to_json()
181182

182183
assert collection_json == collection2_json

0 commit comments

Comments
 (0)