diff --git a/FlagEmbedding/abc/evaluation/evaluator.py b/FlagEmbedding/abc/evaluation/evaluator.py index 22044152..50762983 100644 --- a/FlagEmbedding/abc/evaluation/evaluator.py +++ b/FlagEmbedding/abc/evaluation/evaluator.py @@ -10,7 +10,7 @@ from .data_loader import AbsEvalDataLoader from .searcher import EvalRetriever, EvalReranker -from .utils import evaluate_metrics, evaluate_mrr +from .utils import evaluate_metrics, evaluate_mrr, evaluate_recall_cap logger = logging.getLogger(__name__) @@ -340,12 +340,18 @@ def compute_metrics( results=search_results, k_values=k_values, ) + recall_cap = evaluate_recall_cap( + qrels=qrels, + results=search_results, + k_values=k_values, + ) scores = { **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()}, + **{f"recall_cap_at_{k.split('@')[1]}": v for (k, v) in recall_cap.items()}, } return scores diff --git a/FlagEmbedding/abc/evaluation/utils.py b/FlagEmbedding/abc/evaluation/utils.py index 4508b481..9ed3545d 100644 --- a/FlagEmbedding/abc/evaluation/utils.py +++ b/FlagEmbedding/abc/evaluation/utils.py @@ -52,6 +52,45 @@ def evaluate_mrr( return mrr +# Modified from https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/retrieval/custom_metrics.py#L33 +def evaluate_recall_cap( + qrels: Dict[str, Dict[str, int]], + results: Dict[str, Dict[str, float]], + k_values: List[int] +) -> Tuple[Dict[str, float]]: + """Compute capped recall. + + Args: + qrels (Dict[str, Dict[str, int]]): Ground truth relevance. + results (Dict[str, Dict[str, float]]): Search results to evaluate. + k_values (List[int]): Cutoffs. + + Returns: + Tuple[Dict[str, float]]: Capped recall results at provided k values. + """ + capped_recall = {} + + for k in k_values: + capped_recall[f"R_cap@{k}"] = 0.0 + + k_max = max(k_values) + logging.info("\n") + + for query_id, doc_scores in results.items(): + top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] + query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0] + for k in k_values: + retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0] + denominator = min(len(query_relevant_docs), k) + capped_recall[f"R_cap@{k}"] += (len(retrieved_docs) / denominator) + + for k in k_values: + capped_recall[f"R_cap@{k}"] = round(capped_recall[f"R_cap@{k}"]/len(qrels), 5) + logging.info("R_cap@{}: {:.4f}".format(k, capped_recall[f"R_cap@{k}"])) + + return capped_recall + + # Modified from https://github.com/embeddings-benchmark/mteb/blob/18f730696451a5aaa026494cecf288fd5cde9fd0/mteb/evaluation/evaluators/RetrievalEvaluator.py#L501 def evaluate_metrics( qrels: Dict[str, Dict[str, int]],