-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathopenAI_model.py
More file actions
40 lines (29 loc) · 1.21 KB
/
openAI_model.py
File metadata and controls
40 lines (29 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import logging
import model_pb2
from data_models import *
from models.basemodel import AIModel
logger = logging.getLogger("app")
# NOTE: This needs to be defined in the environment this model is running in.
model_definition = model_pb2.modelDefinition()
model_definition.needs_text = True
model_definition.needs_image = False
model_definition.can_text = True
model_definition.can_image = False
model_definition.modelID = "GPT4_turbo"
class OpenAIModel(AIModel):
def get_model_definition(self) -> model_pb2.modelDefinition:
return model_definition
def publish_metrics(self, metrics_json: str) -> None:
logger.info(metrics_json)
async def get_response(self, message: TaskInput) -> TaskOutput:
model = ChatOpenAI(model="gpt-4o")
AIresponse = model.invoke(message.model_dump()["messages"])
print(f"AIresponse: {AIresponse.content}")
taskResponse = TaskOutput()
taskResponse.text = AIresponse.content
return taskResponse
ai_model = AIModel()