1+ import hashlib
12import json
23import re
34from typing import Iterable , List
5+
46import spacy
5- import hashlib
7+
68from tako_query_filter .keywords import keywords
9+ from tako_query_filter .whitelist import whitelist
710
811
912class TakoQueryFilter :
1013 def __init__ (
1114 self ,
1215 keyword_hashes : Iterable [str ] = keywords ,
16+ whitelist_hashes : Iterable [str ] = whitelist ,
1317 ):
14- self .nlp = spacy .load ("en_tako_query_filter " )
18+ self .nlp = spacy .load ("en_tako_query_analyzer " )
1519 self .keywords_hashes = set (keyword_hashes )
20+ self .whitelist_hashes = set (whitelist_hashes )
1621 self .keyword_match_score = 0.9
22+ self .whitelist_match_score = 0.8
1723
1824 @classmethod
1925 def load_with_keywords (
@@ -38,30 +44,42 @@ def predict(
3844 queries : List [str ],
3945 ) -> List [int ]:
4046 probs = self .predict_proba (queries )
41- predictions = [1 if p > 0.5 else 0 for p in probs ]
47+ predictions = [1 if p > 0.3 else 0 for p in probs ]
4248 return predictions
4349
4450 def predict_proba (
4551 self ,
4652 queries : List [str ],
4753 ) -> List [float ]:
48- preds = self .nlp .pipe (queries )
49-
50- probs = []
51- for pred in preds :
52- accept = pred .cats ["ACCEPT" ]
53- reject = pred .cats ["REJECT" ]
54- # Just to be safe, normalize the probabilities
55- probs .append (accept / (accept + reject ))
56-
57- # Check keywords
58- for i , query in enumerate (queries ):
59- split_query = self ._split_query (query )
60- split_hashes = {self ._hash_string (split ) for split in split_query }
61- if any (split_hash in self .keywords_hashes for split_hash in split_hashes ):
62- probs [i ] = self .keyword_match_score
63-
64- return probs
54+ with self .nlp .select_pipes (enable = ["tok2vec" , "ner" , "textcat_classify" ]):
55+ preds = self .nlp .pipe (queries )
56+
57+ probs = []
58+ for pred in preds :
59+ accept = pred .cats ["ACCEPT" ]
60+ reject = pred .cats ["REJECT" ]
61+ # Just to be safe, normalize the probabilities
62+ probs .append (accept / (accept + reject ))
63+
64+ # Check whitelist
65+ for i , query in enumerate (queries ):
66+ split_query = query .lower ().split ()
67+ if any (
68+ self ._hash_string (split ) in self .whitelist_hashes
69+ for split in split_query
70+ ):
71+ probs [i ] = self .whitelist_match_score
72+
73+ # Check keywords
74+ for i , query in enumerate (queries ):
75+ split_query = self ._split_query (query )
76+ split_hashes = {self ._hash_string (split ) for split in split_query }
77+ if any (
78+ split_hash in self .keywords_hashes for split_hash in split_hashes
79+ ):
80+ probs [i ] = self .keyword_match_score
81+
82+ return probs
6583
6684 def _split_query (self , query : str ) -> List [str ]:
6785 split_keywords = ["vs" , "vs." , "versus" , "or" , "and" ]
0 commit comments