forked from SCE-Development/ChatSCE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretrieval.py
More file actions
19 lines (15 loc) · 778 Bytes
/
retrieval.py
File metadata and controls
19 lines (15 loc) · 778 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import pandas as pd
import numpy as np
import embeds
# Load the embedded data
embedded_df = pd.read_csv('embedded.csv')
embeddings = np.array(embedded_df['ada_embedding'].to_list())
# Function to find the closest embedding to a given query
def find_closest_embedding(embeddings, query):
# When given a query converts the query into an embedding
query_embedding = embeds.get_embedding(query, model='text-embedding-3-small')
# Calculates cosine similarity scores between query embedding and all embeddings
similarities = np.dot(embeddings, query_embedding) / (np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding))
# Find index of maximum similarity score
max_index = np.argmax(similarities)
return max_index, similarities[max_index]