-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding_demo.py
More file actions
39 lines (30 loc) · 1.98 KB
/
embedding_demo.py
File metadata and controls
39 lines (30 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from sentence_transformers import SentenceTransformer
import torch
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
corpus = [
"Machine learning is a field of study that gives computers the ability to learn without being explicitly programmed.",
"Deep learning is part of a broader family of machine learning methods based on artificial neural networks with representation learning.",
"Neural networks are computing systems vaguely inspired by the biological neural networks that constitute animal brains.",
"Mars rovers are robotic vehicles designed to travel on the surface of Mars to collect data and perform experiments.",
"The James Webb Space Telescope is the largest optical telescope in space, designed to conduct infrared astronomy.",
"SpaceX's Starship is designed to be a fully reusable transportation system capable of carrying humans to Mars and beyond.",
"Global warming is the long-term heating of Earth's climate system observed since the pre-industrial period due to human activities.",
"Renewable energy sources include solar, wind, hydro, and geothermal power that naturally replenish over time.",
"Carbon capture technologies aim to collect CO2 emissions before they enter the atmosphere and store them underground.",
]
corpus_embeddings = embedder.encode_document(corpus, convert_to_tensor=True)
while True:
query = input("Provide query: ")
query_embedding = embedder.encode_query(query, convert_to_tensor=True)
similarity_scores = embedder.similarity(query_embedding, corpus_embeddings)[0]
scores, indices = torch.topk(similarity_scores, k=1)
print("\nQuery:", query)
print("Most similar sentence in corpus:")
for score, idx in zip(scores, indices):
print(f"(Score: {score:.4f})", corpus[idx])
"""
Alt:
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=5)
hits = hits[0] #Get the hits for the first query
"""