diff --git a/Cargo.toml b/Cargo.toml index 9523548..1f1d3f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,13 +11,24 @@ keywords = ["ai", "machine-learning", "openai", "library"] [dependencies] serde_json = "1.0.94" derive_builder = "0.20.0" -reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"], optional = true } -serde = { version = "1.0.157", features = ["derive"] } +reqwest = { version = "0.12", default-features = false, features = [ + "json", + "stream", + "multipart", +], optional = true } +serde = { version = "^1.0", features = ["derive"] } reqwest-eventsource = "0.6" -tokio = { version = "1.26.0", features = ["full"] } -anyhow = "1.0.70" +tokio = { version = "1.0", features = ["full"] } +anyhow = "1.0" futures-util = "0.3.28" bytes = "1.4.0" +schemars = "0.8" +either = { version = "1.8.1", features = ["serde"] } +serde-double-tag = "0.0.4" +log = "0.4" +strum = { version = "0.26", features = ["derive"] } +strum_macros = "0.26" +once_cell = "^1" [dev-dependencies] dotenvy = "0.15.7" diff --git a/src/assistants/assistants.rs b/src/assistants/assistants.rs new file mode 100644 index 0000000..e5a4e6d --- /dev/null +++ b/src/assistants/assistants.rs @@ -0,0 +1,144 @@ +use std::collections::HashMap; + +use schemars::schema::RootSchema; +use serde::{Deserialize, Serialize}; + +use crate::{ + client::{Empty, OpenAiClient}, + ApiResponseOrError, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Assistant { + pub id: String, + pub object: String, + pub created_at: u32, + /// The name of the assistant. The maximum length is 256 characters. + pub name: Option, + /// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them. + pub model: String, + /// The system instructions that the assistant uses. The maximum length is 256,000 characters. + pub instructions: Option, + pub tools: Vec, + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. + pub tool_resources: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, + /// The default model to use for this assistant. + pub response_format: Option, +} + +#[derive(Debug, Clone, serde_double_tag::Deserialize, serde_double_tag::Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum Tool { + CodeInterpreter, + Function(Function), + FileSearch(FileSearch), +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Function { + pub name: String, + pub description: String, + pub parameters: RootSchema, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FunctionParameters { + pub title: String, + pub description: String, + #[serde(rename = "type")] + pub type_: String, + pub required: Vec, + pub properties: HashMap, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FunctionProperty { + pub description: String, + #[serde(rename = "type")] + pub type_: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct FileSearch { + pub max_num_results: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct ToolResources { + pub code_interpreter: Option, + pub file_search: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CodeInterpreterResources { + /// A list of file IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool. + pub file_ids: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FileSearchResources { + /// The ID of the vector store attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + pub vector_store_ids: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ResponseFormat { + Auto, +} + +#[derive(Serialize, Default, Debug, Clone)] +pub struct CreateAssistantRequest { + /// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them. + pub model: String, + + /// The name of the assistant. The maximum length is 256 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// The description of the assistant. The maximum length is 256 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The system instructions that the assistant uses. The maximum length is 256,000 characters. + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + /// A set of tools that the assistant can use. + pub tools: Vec, + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_resources: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + /// The default model to use for this assistant. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, +} + +impl OpenAiClient { + pub async fn create_assistant( + &self, + request: CreateAssistantRequest, + ) -> ApiResponseOrError { + self.post("assistants", Some(request)).await + } + + pub async fn get_assistant(&self, assistant_id: &str) -> ApiResponseOrError { + self.get(format!("assistants/{}", assistant_id)).await + } + + pub async fn delete_assistant(&self, assistant_id: &str) -> ApiResponseOrError { + self.delete(format!("assistants/{}", assistant_id)).await + } + + pub async fn update_assistant( + &self, + assistant_id: &str, + request: CreateAssistantRequest, + ) -> ApiResponseOrError { + self.post(format!("assistants/{}", assistant_id), Some(request)) + .await + } +} diff --git a/src/assistants/files.rs b/src/assistants/files.rs new file mode 100644 index 0000000..6c95111 --- /dev/null +++ b/src/assistants/files.rs @@ -0,0 +1,49 @@ +use crate::{client::OpenAiClient, ApiResponseOrError}; +use reqwest::{ + multipart::{Form, Part}, + Body, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct File { + pub id: String, + pub object: String, + pub created_at: u32, + pub bytes: u32, + pub filename: String, + pub purpose: FilePurpose, +} + +#[derive(Debug, Serialize, Deserialize, Clone, strum_macros::Display)] +#[strum(serialize_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum FilePurpose { + Assistants, + AssistantsOutput, + Batch, + BatchOutput, + FineTune, + FineTuneResults, + Vision, +} + +impl OpenAiClient { + pub async fn upload_file>( + &self, + filename: &str, + mime_type: &str, + bytes: B, + purpose: FilePurpose, + ) -> ApiResponseOrError { + let file_part = Part::stream(bytes) + .file_name(filename.to_string()) + .mime_str(mime_type)?; + + let form = Form::new() + .part("file", file_part) + .text("purpose", purpose.to_string()); + + self.post_multipart("files", form).await + } +} diff --git a/src/assistants/messages.rs b/src/assistants/messages.rs new file mode 100644 index 0000000..4c001f5 --- /dev/null +++ b/src/assistants/messages.rs @@ -0,0 +1,116 @@ +use crate::{assistants::Tool, client::OpenAiClient, ApiResponseOrError}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Message { + pub id: String, + pub object: String, + pub created_at: u32, + /// The thread ID that this message belongs to. + pub thread_id: String, + /// The status of the message, which can be either in_progress, incomplete, or completed. + pub status: Option, + /// On an incomplete message, details about why the message is incomplete. + pub incomplete_details: Option, + /// The Unix timestamp (in seconds) for when the message was completed. + pub completed_at: Option, + /// The Unix timestamp (in seconds) for when the message was marked as incomplete. + pub incomplete_at: Option, + /// The entity that produced the message. One of user or assistant + pub role: Role, + /// The content of the message. + pub content: Vec, + /// The assistant that produced the message. + pub assistant_id: Option, + /// The ID of the run associated with the creation of this message. Value is null when messages are created manually using the create message or create thread endpoints. + pub run_id: Option, + /// A list of files attached to the message. + pub attachments: Option>, + /// A set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum Status { + InProgress, + Incomplete, + Completed, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct IncompleteDetails { + pub reason: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum Role { + User, + Assistant, +} + +#[derive(Debug, serde_double_tag::Serialize, serde_double_tag::Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum Content { + Text(Text), + ImageFile(ImageFile), + ImageUrl(ImageUrl), + Refusal(Refusal), +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct Text { + pub value: String, + pub annotations: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Annotation { + #[serde(rename = "type")] + pub kind: String, + pub text: String, + pub start_index: u32, + pub end_index: u32, + pub file_citation: FileCitation, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FileCitation { + pub file_id: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageFile { + pub file_id: String, + pub detail: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ImageUrl { + pub image_url: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Refusal { + pub refusal: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Attachment { + pub file_id: String, + pub tools: Tool, +} + +impl OpenAiClient { + pub async fn list_messages( + &self, + thread_id: &str, + after_id: Option, + ) -> ApiResponseOrError> { + self.list(format!("threads/{thread_id}/messages"), after_id) + .await + } +} diff --git a/src/assistants/mod.rs b/src/assistants/mod.rs new file mode 100644 index 0000000..651476d --- /dev/null +++ b/src/assistants/mod.rs @@ -0,0 +1,8 @@ +pub mod assistants; +pub use assistants::*; + +pub mod files; +pub mod messages; +pub mod runs; +pub mod threads; +pub mod vector_stores; diff --git a/src/assistants/runs.rs b/src/assistants/runs.rs new file mode 100644 index 0000000..0386dd4 --- /dev/null +++ b/src/assistants/runs.rs @@ -0,0 +1,289 @@ +use derive_builder::Builder; +use either::Either; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::{assistants::Tool, chat::ToolCall, client::OpenAiClient, ApiResponseOrError}; + +use super::{ + messages::{Attachment, IncompleteDetails, Role}, + ResponseFormat, ToolResources, +}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Run { + pub id: String, + pub object: String, + pub created_at: u32, + /// The ID of the assistant used for this run. + pub assistant_id: String, + /// The ID of the thread associated with this run. + pub thread_id: String, + /// The status of the run. + pub status: Status, + /// Details on the action required to continue the run. Will be null if no action is required. + pub required_action: Option, + + /// The last error that occurred during this run. + pub last_error: Option, + + /// The time at which the run will expire. + pub expires_at: Option, + /// The time at which the run was started. + pub started_at: Option, + /// The time at which the run was completed. + pub completed_at: Option, + /// The time at which the run was cancelled. + pub cancelled_at: Option, + /// The time at which the run was failed. + pub failed_at: Option, + /// The time at which the run was incomplete. + pub incomplete_details: Option, + + /// The model used for this run. + pub model: String, + + /// The instructions given to the assistant. + pub instructions: String, + + /// The tools used for this run. + pub tools: Vec, + + /// The usage of the run. + pub usage: Option, + + /// The truncation strategy used for this run. + pub truncation_strategy: Option, + + /// Whether to run tool calls in parallel. + pub parallel_tool_calls: bool, + + /// The tool choice used for this run. + pub tool_choice: ToolChoice, + + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum Status { + Queued, + InProgress, + RequiresAction, + Cancelling, + Cancelled, + Failed, + Completed, + Incomplete, + Expired, +} + +impl Status { + pub fn is_terminal(&self) -> bool { + !matches!(self, Status::InProgress | Status::Queued) + } +} + +#[derive(Debug, serde_double_tag::Deserialize, serde_double_tag::Serialize, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum RequiredAction { + SubmitToolOutputs { tool_calls: Vec }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct LastError { + pub code: String, + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct TruncationStrategy { + #[serde(rename = "type")] + pub kind: String, + pub last_messages: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(transparent)] +pub struct ToolChoice { + #[serde(with = "either::serde_untagged")] + pub inner: Either, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceStrategy { + None, + Auto, + Required, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceFunction { + FileSearch, + Function { name: String }, +} + +#[derive(Serialize, Builder, Debug, Clone, Default)] +#[builder(pattern = "owned")] +#[builder(name = "CreateThreadRunBuilder")] +#[builder(setter(strip_option, into))] +pub struct CreateThreadRunRequest { + /// ID of the assistant to use. + pub assistant_id: String, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub model: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tools: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tool_resources: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub metadata: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub parallel_tool_calls: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub response_format: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tool_choice: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub max_completion_tokens: Option, + + /// the thread to create + pub thread: CreateThreadRequest, +} + +#[derive(Serialize, Builder, Debug, Clone, Default)] +#[builder(pattern = "owned")] +#[builder(name = "CreateThreadBuilder")] +#[builder(setter(strip_option, into))] +pub struct CreateThreadRequest { + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tool_resources: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub metadata: Option>, +} + +#[derive(Serialize, Builder, Debug, Clone)] +#[builder(pattern = "owned")] +#[builder(name = "CreateThreadMessageBuilder")] +#[builder(setter(strip_option, into))] +pub struct CreateThreadMessageRequest { + pub role: Role, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub attachments: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SubmitToolOutputsRequest { + pub tool_outputs: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ToolOutput { + pub tool_call_id: String, + pub output: String, +} + +#[derive(Serialize, Builder, Debug, Clone, Default)] +#[builder(pattern = "owned")] +#[builder(name = "CreateRunBuilder")] +#[builder(setter(strip_option, into))] +pub struct CreateRunRequest { + pub assistant_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub additional_messages: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default)] + pub max_completion_tokens: Option, +} + +impl OpenAiClient { + pub async fn create_thread_run( + &self, + request: CreateThreadRunRequest, + ) -> ApiResponseOrError { + self.post(format!("threads/runs"), Some(request)).await + } + + pub async fn create_run( + &self, + thread_id: &str, + request: CreateRunRequest, + ) -> ApiResponseOrError { + self.post(format!("threads/{thread_id}/runs"), Some(request)) + .await + } + + pub async fn poll_run(&self, mut run: Run) -> ApiResponseOrError { + while !run.status.is_terminal() { + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + run = self + .get_run(run.thread_id.as_str(), run.id.as_str()) + .await?; + } + Ok(run) + } + + pub async fn get_run(&self, thread_id: &str, run_id: &str) -> ApiResponseOrError { + self.get(format!("threads/{thread_id}/runs/{run_id}")).await + } + + pub async fn submit_tool_outputs_and_poll( + &self, + run: Run, + request: SubmitToolOutputsRequest, + ) -> ApiResponseOrError { + let run: Run = self + .post( + format!( + "threads/{}/runs/{}/submit_tool_outputs", + run.thread_id, run.id + ), + Some(request), + ) + .await?; + + self.poll_run(run).await + } +} diff --git a/src/assistants/threads.rs b/src/assistants/threads.rs new file mode 100644 index 0000000..f4e8d6f --- /dev/null +++ b/src/assistants/threads.rs @@ -0,0 +1,15 @@ +use serde::Deserialize; +use std::collections::HashMap; + +use crate::assistants::ToolResources; + +#[derive(Debug, Deserialize, Clone)] +pub struct Thread { + pub id: String, + pub object: String, + pub created_at: u32, + /// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs. + pub tool_resources: Option, + /// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + pub metadata: Option>, +} diff --git a/src/assistants/vector_stores.rs b/src/assistants/vector_stores.rs new file mode 100644 index 0000000..c822cd0 --- /dev/null +++ b/src/assistants/vector_stores.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; + +use crate::{client::OpenAiClient, ApiResponseOrError}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct VectorStore { + pub id: String, + pub object: String, + pub created_at: u32, + pub name: String, + pub usage_bytes: u32, + pub file_counts: FileCounts, + pub status: VectorStoreStatus, + pub expires_after: Option, + pub expires_at: Option, + pub last_active_at: Option, + pub metadata: Option>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FileCounts { + pub in_progress: u32, + pub completed: u32, + pub failed: u32, + pub cancelled: u32, + pub total: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, strum_macros::Display)] +#[strum(serialize_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum VectorStoreStatus { + Expired, + InProgress, + Completed, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ExpiresAfter { + pub anchor: String, + pub days: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct CreateVectorStoreRequest { + pub name: String, + pub file_ids: Option>, + pub metadata: Option>, + pub expires_after: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct VectorStoreFile { + pub id: String, + pub object: String, + pub created_at: u32, + pub file_id: String, + pub vector_store_id: String, + pub usage_bytes: u32, + pub status: VectorStoreFileStatus, + pub last_error: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, strum_macros::Display)] +#[strum(serialize_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum VectorStoreFileStatus { + InProgress, + Completed, + Cancelled, + Failed, +} + +impl OpenAiClient { + pub async fn create_vector_store( + &self, + params: CreateVectorStoreRequest, + ) -> ApiResponseOrError { + self.post("vector_stores", params).await + } + + pub async fn attach_file_to_vector_store( + &self, + vector_store_id: &str, + file_id: &str, + ) -> ApiResponseOrError { + self.post( + &format!("vector_stores/{}/files", vector_store_id), + json!({ file_id: file_id }), + ) + .await + } +} diff --git a/src/chat.rs b/src/chat.rs index 270bdae..25047aa 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,7 +1,7 @@ //! Given a chat conversation, the model will return a chat completion response. use super::{openai_post, ApiResponseOrError, Credentials, Usage}; -use crate::openai_request_stream; +use crate::{client::OpenAiClient, openai_request_stream}; use derive_builder::Builder; use futures_util::StreamExt; use reqwest::Method; @@ -166,7 +166,7 @@ pub enum ChatCompletionMessageRole { Tool, } -#[derive(Serialize, Builder, Debug, Clone)] +#[derive(Serialize, Builder, Debug, Clone, Default)] #[builder(derive(Clone, Debug, PartialEq))] #[builder(pattern = "owned")] #[builder(name = "ChatCompletionBuilder")] @@ -174,62 +174,62 @@ pub enum ChatCompletionMessageRole { pub struct ChatCompletionRequest { /// ID of the model to use. Currently, only `gpt-3.5-turbo`, `gpt-3.5-turbo-0301` and `gpt-4` /// are supported. - model: String, + pub model: String, /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction). - messages: Vec, + pub messages: Vec, /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. /// /// We generally recommend altering this or `top_p` but not both. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, + pub temperature: Option, /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. /// /// We generally recommend altering this or `temperature` but not both. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - top_p: Option, + pub top_p: Option, /// How many chat completion choices to generate for each input message. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - n: Option, + pub n: Option, #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, + pub stream: Option, /// Up to 4 sequences where the API will stop generating further tokens. #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] - stop: Vec, + pub stop: Vec, /// This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - seed: Option, + pub seed: Option, /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens). #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, + pub max_tokens: Option, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. /// /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - presence_penalty: Option, + pub presence_penalty: Option, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. /// /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - frequency_penalty: Option, + pub frequency_penalty: Option, /// Modify the likelihood of specified tokens appearing in the completion. /// /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - logit_bias: Option>, + pub logit_bias: Option>, /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). #[builder(default)] #[serde(skip_serializing_if = "String::is_empty")] - user: String, + pub user: String, /// Describe functions that ChatGPT can call /// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt. /// For example, you can define a function called "get_weather" that returns the weather in a given city @@ -238,7 +238,7 @@ pub struct ChatCompletionRequest { /// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling) #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] - functions: Vec, + pub functions: Vec, /// A string or object of the function to call /// /// Controls how the model responds to function calls @@ -250,17 +250,17 @@ pub struct ChatCompletionRequest { /// "none" is the default when no functions are present. "auto" is the default if functions are present. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, + pub function_call: Option, /// An object specifying the format that the model must output. Compatible with GPT-4 Turbo and all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. /// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. /// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length. #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] - response_format: Option, + pub response_format: Option, /// The credentials to use for this request. #[serde(skip_serializing)] #[builder(default)] - credentials: Option, + pub credentials: Option, } #[derive(Serialize, Debug, Clone, Eq, PartialEq)] @@ -302,6 +302,15 @@ impl ChatCompletion { } } +impl OpenAiClient { + pub async fn create_chat_completion( + &self, + request: ChatCompletionRequest, + ) -> ApiResponseOrError { + self.post("chat/completions", request).await + } +} + impl ChatCompletionDelta { pub async fn create( request: ChatCompletionRequest, diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..bcf2ba9 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,209 @@ +use std::str::FromStr; + +use crate::{ApiResponseOrError, Credentials, OpenAiError, DEFAULT_CREDENTIALS}; +use anyhow::Result; +use reqwest::{ + header::{HeaderName, HeaderValue, AUTHORIZATION}, + multipart::Form, + Client, Method, RequestBuilder, Response, +}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +#[derive(Clone)] +pub struct OpenAiClient { + credentials: Credentials, + client: Client, +} + +impl std::fmt::Debug for OpenAiClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "OpenAiClient") + } +} + +#[derive(Debug, Clone, Deserialize)] +struct OpenAiErrorWrapper { + error: OpenAiError, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Empty {} + +enum RequestBody { + Json(S), + Multipart(Form), + None, +} + +impl From> for RequestBody { + fn from(value: Option) -> Self { + match value { + Some(value) => RequestBody::Json(value), + None => RequestBody::None, + } + } +} + +impl OpenAiClient { + pub fn default() -> Result { + Self::new(DEFAULT_CREDENTIALS.read().unwrap().clone()) + } + + pub fn new(credentials: Credentials) -> Result { + let client = Client::builder() + .default_headers( + [ + ( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", credentials.api_key))?, + ), + ( + HeaderName::from_str("OpenAI-Beta")?, + HeaderValue::from_str("assistants=v2")?, + ), + ] + .into_iter() + .collect(), + ) + .build()?; + + Ok(Self { + credentials, + client, + }) + } + + fn request_builder(&self, method: Method, route: R) -> RequestBuilder + where + R: Into, + { + let url = format!("{}{}", self.credentials.base_url, route.into()); + log::debug!("OpenAI Request[{}] {}", method.to_string(), url); + + self.client.request(method.clone(), url.clone()) + } + + async fn request_inner( + &self, + method: Method, + route: R, + body: RequestBody, + ) -> Result + where + R: Into, + S: Serialize, + { + let mut request = self.request_builder(method.clone(), route); + + match body { + RequestBody::Json(body) => request = request.json(&body), + RequestBody::Multipart(body) => request = request.multipart(body), + RequestBody::None => (), + } + + let response = request.send().await?; + + log::debug!( + "OpenAI Response[{}] {} {}", + method.to_string(), + response.status().as_str(), + response.url() + ); + Ok(response) + } + + async fn request( + &self, + method: Method, + route: R, + body: B, + ) -> ApiResponseOrError + where + R: Into, + B: Into>, + S: Serialize, + T: DeserializeOwned, + { + let response = self.request_inner(method, route, body.into()).await?; + let api_response = if response.status().is_success() { + response.json::().await? + } else { + let result = response.text().await?; + if let Ok(api_response) = serde_json::from_str::(&result) { + return Err(api_response.error); + } else { + return Err(OpenAiError::new(result, "unknown".to_string())); + } + }; + + Ok(api_response) + } + pub async fn get(&self, route: R) -> ApiResponseOrError + where + R: Into, + T: DeserializeOwned, + { + self.request::<_, (), R, T>(Method::GET, route, None).await + } + + pub async fn post(&self, route: R, body: S) -> ApiResponseOrError + where + R: Into, + S: Serialize, + T: DeserializeOwned, + { + self.request(Method::POST, route, Some(body)).await + } + + pub async fn post_multipart(&self, route: R, form: Form) -> ApiResponseOrError + where + R: Into, + T: DeserializeOwned, + { + self.request::<_, (), R, T>(Method::POST, route, RequestBody::Multipart(form)) + .await + } + + pub async fn delete(&self, route: R) -> ApiResponseOrError + where + R: Into, + { + self.request::<_, (), R, Empty>(Method::DELETE, route, None) + .await + } + + pub async fn list(&self, route: R, after: Option) -> ApiResponseOrError> + where + R: Into, + T: DeserializeOwned + std::fmt::Debug, + { + let mut route = if let Some(after) = after { + format!("{}?order=asc&after={after}", route.into()) + } else { + format!("{}?order=asc", route.into()) + }; + + let mut has_more = true; + let mut data = Vec::new(); + + while has_more { + let list: List = self.get(&route).await?; + data.extend(list.data); + has_more = list.has_more; + route = format!( + "{route}?order=asc&after={}", + list.last_id.unwrap_or_default() + ); + } + + Ok(data) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct List { + pub first_id: Option, + pub last_id: Option, + pub data: Vec, + pub has_more: bool, +} diff --git a/src/files.rs b/src/files.rs index 7c167fa..5b1c6b7 100644 --- a/src/files.rs +++ b/src/files.rs @@ -351,7 +351,7 @@ mod tests { assert_eq!(openapi_err.error_type, "io"); assert_eq!( openapi_err.message, - "No such file or directory (os error 2)" + Some("No such file or directory (os error 2)".to_string()) ) } diff --git a/src/lib.rs b/src/lib.rs index fd24555..b03db5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,15 @@ +use once_cell::sync::Lazy; use reqwest::multipart::Form; use reqwest::{header::AUTHORIZATION, Client, Method, RequestBuilder, Response}; use reqwest_eventsource::{CannotCloneRequestError, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::env; use std::env::VarError; -use std::sync::{LazyLock, RwLock}; +use std::sync::RwLock; +pub mod assistants; pub mod chat; +pub mod client; pub mod completions; pub mod edits; pub mod embeddings; @@ -14,10 +17,10 @@ pub mod files; pub mod models; pub mod moderations; -pub static DEFAULT_BASE_URL: LazyLock = - LazyLock::new(|| String::from("https://api.openai.com/v1/")); -static DEFAULT_CREDENTIALS: LazyLock> = - LazyLock::new(|| RwLock::new(Credentials::from_env())); +pub static DEFAULT_BASE_URL: Lazy = + Lazy::new(|| String::from("https://api.openai.com/v1/")); +static DEFAULT_CREDENTIALS: Lazy> = + Lazy::new(|| RwLock::new(Credentials::from_env())); /// Holds the API key and base URL for an OpenAI-compatible API. #[derive(Debug, Clone, Eq, PartialEq)] @@ -61,7 +64,7 @@ impl Credentials { #[derive(Deserialize, Debug, Clone, Eq, PartialEq)] pub struct OpenAiError { - pub message: String, + pub message: Option, #[serde(rename = "type")] pub error_type: String, pub param: Option, @@ -71,7 +74,7 @@ pub struct OpenAiError { impl OpenAiError { fn new(message: String, error_type: String) -> OpenAiError { OpenAiError { - message, + message: Some(message), error_type, param: None, code: None, @@ -81,7 +84,12 @@ impl OpenAiError { impl std::fmt::Display for OpenAiError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(&self.message) + f.write_str( + &self + .message + .as_ref() + .unwrap_or(&"empty error message".to_string()), + ) } } @@ -105,7 +113,7 @@ pub type ApiResponseOrError = Result; impl From for OpenAiError { fn from(value: reqwest::Error) -> Self { - OpenAiError::new(value.to_string(), "reqwest".to_string()) + OpenAiError::new(format!("{:?}", value), "reqwest".to_string()) } } @@ -150,6 +158,7 @@ where let mut request = client.request(method, format!("{}{route}", credentials.base_url)); request = builder(request); let response = request + .header("OpenAI-Beta", "assistants=v2") .header(AUTHORIZATION, format!("Bearer {}", credentials.api_key)) .send() .await?;