forked from worldbank/WhatsApp-RAG-Example
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
97 lines (81 loc) · 3.35 KB
/
utils.py
File metadata and controls
97 lines (81 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Standard library import
import logging
from decouple import config
import os
from dotenv import load_dotenv
load_dotenv()
# Third-party imports
from twilio.rest import Client
from twilio.twiml.messaging_response import MessagingResponse
from fastapi.responses import PlainTextResponse
from langchain_openai import OpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_message_histories import ChatMessageHistory
# Local imports
from ensemble import ensemble_retriever_from_docs
from rag_chain import make_rag_chain, get_question
from local_loader import load_pdf_files
from basic_chain import basic_chain, get_model
from splitter import split_documents
from vector_store import create_vector_db
from memory import create_memory_chain
# Find your Account SID and Auth Token at twilio.com/console
# and set the environment variables. See http://twil.io/secure
account_sid = config("TWILIO_ACCOUNT_SID")
auth_token = config("TWILIO_AUTH_TOKEN")
openai_api_key = config("OPENAI_API_KEY")
twilio_number = config('TWILIO_NUMBER')
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_retriever(openai_api_key=None):
docs = load_pdf_files()
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, model="text-embedding-3-small")
return ensemble_retriever_from_docs(docs, embeddings=embeddings)
def get_chain(openai_api_key=None, huggingfacehub_api_token=None):
model = get_model("ChatGPT")
chat_memory = ChatMessageHistory()
ensemble_retriever = get_retriever(openai_api_key=openai_api_key)
output_parser = StrOutputParser()
rag_chain = make_rag_chain(model, ensemble_retriever)
chain = create_memory_chain(model, rag_chain, chat_memory) | output_parser
return chain
def run_rag_query(query):
"""Helper function to run RAG Query
"""
memory_chain = get_chain(openai_api_key=openai_api_key)
response = memory_chain.invoke(
{"question": query},
config={"configurable": {"session_id": "foo"}}
)
return response
def search_wikipedia(query):
"""Search Wikipedia through the LangChain API
and use the OpenAI LLM wrapper and retrieve
the agent result based on the received query
"""
from langchain_community.agent_toolkits.load_tools import load_tools
from langchain.agents import AgentExecutor, create_react_agent
prompt = hub_pull("hwchase17/react")
llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
tools = load_tools(["wikipedia"], llm=llm)
agent = create_react_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
output = agent_executor.invoke({"input": "{}".format(query)})
return output['output']
def main():
chain = get_chain(openai_api_key=openai_api_key)
questions = [
"Are there any disease outbreaks in Zambia?",
"When did Anthranx start in Zambia?"]
for q in questions:
print("\n--- QUESTION: ", q)
output = chain.invoke(q)
print('OUTPUT TYPE==>', type(output))
print("* RAG:\n", chain.invoke(q))
if __name__ == '__main__':
# this is to quite parallel tokenizers warning.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
main()