Skip to content

Commit fb44bcc

Browse files
authored
Update model and add whitelist (#7)
1 parent 51ee221 commit fb44bcc

6 files changed

Lines changed: 54 additions & 24 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.filter
2+
.venv
23
__pycache__
34
build/
45
dist/

demo.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
" \"election results\",\n",
3232
" \"presidential election polls 2024\",\n",
3333
" \"what books do you recommend\",\n",
34+
" \"sf weather tomorrow\"\n",
3435
"]\n",
3536
"\n",
3637
"preds = query_filter.predict(queries)\n",
@@ -41,7 +42,7 @@
4142
],
4243
"metadata": {
4344
"kernelspec": {
44-
"display_name": ".filter",
45+
"display_name": ".venv",
4546
"language": "python",
4647
"name": "python3"
4748
},

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ authors = [
33
{name = "Noah Jackson", email = "noah@trytako.com"},
44
]
55
dependencies = [
6-
"en-tako-query-filter @ https://huggingface.co/TakoData/en_tako_query_filter/resolve/0.0.1/en_tako_query_filter-any-py3-none-any.whl",
6+
"en-tako-query-analyzer @ https://huggingface.co/TakoData/en_tako_query_analyzer/resolve/0.0.4/en_tako_query_analyzer-any-py3-none-any.whl",
77
"ipykernel~=6.29.5",
88
"jupyter~=1.1.1",
99
"nbstripout~=0.7.1",
@@ -13,7 +13,7 @@ description = "Combines models to predict which queries Tako's API should handle
1313
name = "tako-query-filter"
1414
readme = "README.md"
1515
requires-python = ">=3.10.14,<3.13"
16-
version = "0.2.1"
16+
version = "0.3.0"
1717

1818
[tool.setuptools.packages.find]
1919
include = ["tako_query_filter*"]

src/tako_query_filter/filter.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
1+
import hashlib
12
import json
23
import re
34
from typing import Iterable, List
5+
46
import spacy
5-
import hashlib
7+
68
from tako_query_filter.keywords import keywords
9+
from tako_query_filter.whitelist import whitelist
710

811

912
class TakoQueryFilter:
1013
def __init__(
1114
self,
1215
keyword_hashes: Iterable[str] = keywords,
16+
whitelist_hashes: Iterable[str] = whitelist,
1317
):
14-
self.nlp = spacy.load("en_tako_query_filter")
18+
self.nlp = spacy.load("en_tako_query_analyzer")
1519
self.keywords_hashes = set(keyword_hashes)
20+
self.whitelist_hashes = set(whitelist_hashes)
1621
self.keyword_match_score = 0.9
22+
self.whitelist_match_score = 0.8
1723

1824
@classmethod
1925
def load_with_keywords(
@@ -38,30 +44,42 @@ def predict(
3844
queries: List[str],
3945
) -> List[int]:
4046
probs = self.predict_proba(queries)
41-
predictions = [1 if p > 0.5 else 0 for p in probs]
47+
predictions = [1 if p > 0.3 else 0 for p in probs]
4248
return predictions
4349

4450
def predict_proba(
4551
self,
4652
queries: List[str],
4753
) -> List[float]:
48-
preds = self.nlp.pipe(queries)
49-
50-
probs = []
51-
for pred in preds:
52-
accept = pred.cats["ACCEPT"]
53-
reject = pred.cats["REJECT"]
54-
# Just to be safe, normalize the probabilities
55-
probs.append(accept / (accept + reject))
56-
57-
# Check keywords
58-
for i, query in enumerate(queries):
59-
split_query = self._split_query(query)
60-
split_hashes = {self._hash_string(split) for split in split_query}
61-
if any(split_hash in self.keywords_hashes for split_hash in split_hashes):
62-
probs[i] = self.keyword_match_score
63-
64-
return probs
54+
with self.nlp.select_pipes(enable=["tok2vec", "ner", "textcat_classify"]):
55+
preds = self.nlp.pipe(queries)
56+
57+
probs = []
58+
for pred in preds:
59+
accept = pred.cats["ACCEPT"]
60+
reject = pred.cats["REJECT"]
61+
# Just to be safe, normalize the probabilities
62+
probs.append(accept / (accept + reject))
63+
64+
# Check whitelist
65+
for i, query in enumerate(queries):
66+
split_query = query.lower().split()
67+
if any(
68+
self._hash_string(split) in self.whitelist_hashes
69+
for split in split_query
70+
):
71+
probs[i] = self.whitelist_match_score
72+
73+
# Check keywords
74+
for i, query in enumerate(queries):
75+
split_query = self._split_query(query)
76+
split_hashes = {self._hash_string(split) for split in split_query}
77+
if any(
78+
split_hash in self.keywords_hashes for split_hash in split_hashes
79+
):
80+
probs[i] = self.keyword_match_score
81+
82+
return probs
6583

6684
def _split_query(self, query: str) -> List[str]:
6785
split_keywords = ["vs", "vs.", "versus", "or", "and"]

src/tako_query_filter/keywords.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457286,5 +457286,5 @@
457286457286
"67dccfe9332640119fc7644463f5a40b",
457287457287
"44ef39be655ad5f53c05885d8959dbf3",
457288457288
"eaafc99a7ad0fa65d19d5cc7f76403bf",
457289-
"2092310ac8727c5cfedb4c2fcecac7f4"
457289+
"2092310ac8727c5cfedb4c2fcecac7f4",
457290457290
]

src/tako_query_filter/whitelist.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
whitelist = [
2+
"aab92e69374e4c7b8c6741fe02e574b9",
3+
"23678db5efde9ab46bce8c23a6d91b50",
4+
"533c5ba8368075db8f6ef201546bd71a",
5+
"f3639baeb4530db03ef930eb16073f61",
6+
"2b93fbdf27d43547bec8794054c28e00",
7+
"7e25b972e192b01004b62346ee9975a5",
8+
"3811727de5b0ddf6ae30defe2ca4d2c2",
9+
"559608508b42a01c1068fae4fcdc2aef",
10+
]

0 commit comments

Comments
 (0)