diff --git a/Cargo.toml b/Cargo.toml index 5d4964c..7844cb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "openai-api-rs" -version = "0.1.1" +version = "0.1.2" edition = "2021" authors = ["Dongri Jin "] license = "MIT" diff --git a/README.md b/README.md index 09521b2..3d3a2b6 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Cargo.toml ```toml [dependencies] -openai-api-rs = "0.1" +openai-api-rs = "0.1.2" ``` ## Example: @@ -12,6 +12,30 @@ openai-api-rs = "0.1" export OPENAI_API_KEY={YOUR_API} ``` +### Chat +```rust +use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let req = ChatCompletionRequest { + model: chat_completion::GPT3_5_TURBO.to_string(), + messages: vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: String::from("NFTとは?"), + }], + }; + let result = client.chat_completion(req).await?; + println!("{:?}", result.choices[0].message.content); + + Ok(()) +} +``` + +### Completion ```rust use openai_api_rs::v1::completion::{self, CompletionRequest}; use openai_api_rs::v1::api::Client; diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs new file mode 100644 index 0000000..e4f965b --- /dev/null +++ b/examples/chat_completion.rs @@ -0,0 +1,21 @@ +use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let req = ChatCompletionRequest { + model: chat_completion::GPT3_5_TURBO.to_string(), + messages: vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: String::from("NFTとは?"), + }], + }; + let result = client.chat_completion(req).await?; + println!("{:?}", result.choices[0].message.content); + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example chat_completion diff --git a/examples/completion.rs b/examples/completion.rs index 4c95bdc..ecb20af 100644 --- a/examples/completion.rs +++ b/examples/completion.rs @@ -1,5 +1,5 @@ -use openai_api_rs::v1::completion::{self, CompletionRequest}; use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::completion::{self, CompletionRequest}; use std::env; #[tokio::main] @@ -22,11 +22,11 @@ async fn main() -> Result<(), Box> { best_of: None, logit_bias: None, user: None, - }; + }; let result = client.completion(req).await?; println!("{:}", result.choices[0].text); Ok(()) } -// cargo run --package openai-rs --example completion +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example completion diff --git a/src/v1/api.rs b/src/v1/api.rs index 1fb6b37..32a260e 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -1,32 +1,22 @@ - +use crate::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; use crate::v1::completion::{CompletionRequest, CompletionResponse}; use crate::v1::edit::{EditRequest, EditResponse}; -use crate::v1::image::{ - ImageGenerationRequest, - ImageGenerationResponse, - ImageEditRequest, - ImageEditResponse, - ImageVariationRequest, - ImageVariationResponse, -}; use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse}; use crate::v1::file::{ - FileListResponse, - FileUploadRequest, + FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveContentRequest, + FileRetrieveContentResponse, FileRetrieveRequest, FileRetrieveResponse, FileUploadRequest, FileUploadResponse, - FileDeleteRequest, - FileDeleteResponse, - FileRetrieveRequest, - FileRetrieveResponse, - FileRetrieveContentRequest, - FileRetrieveContentResponse, +}; +use crate::v1::image::{ + ImageEditRequest, ImageEditResponse, ImageGenerationRequest, ImageGenerationResponse, + ImageVariationRequest, ImageVariationResponse, }; use reqwest::Response; -const APU_URL_V1: &str = "https://api.openai.com/v1"; +const APU_URL_V1: &str = "https://api.openai.com/v1"; pub struct Client { - pub api_key: String, + pub api_key: String, } impl Client { @@ -34,25 +24,30 @@ impl Client { Self { api_key } } - pub async fn post(&self, path: &str, params: &T) -> Result> { + pub async fn post( + &self, + path: &str, + params: &T, + ) -> Result> { let client = reqwest::Client::new(); - let url = format!("{}{}", APU_URL_V1, path); + let url = format!("{APU_URL_V1}{path}"); let res = client .post(&url) .header(reqwest::header::CONTENT_TYPE, "application/json") - .header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key) + .header( + reqwest::header::AUTHORIZATION, + "Bearer ".to_owned() + &self.api_key, + ) .json(¶ms) .send() .await; match res { Ok(res) => match res.status().is_success() { true => Ok(res), - false => { - Err(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - format!("{}: {}", res.status(), res.text().await.unwrap()) - ))) - }, + false => Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{}: {}", res.status(), res.text().await.unwrap()), + ))), }, Err(e) => Err(Box::new(e)), } @@ -60,22 +55,23 @@ impl Client { pub async fn get(&self, path: &str) -> Result> { let client = reqwest::Client::new(); - let url = format!("{}{}", APU_URL_V1, path); + let url = format!("{APU_URL_V1}{path}"); let res = client .get(&url) .header(reqwest::header::CONTENT_TYPE, "application/json") - .header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key) + .header( + reqwest::header::AUTHORIZATION, + "Bearer ".to_owned() + &self.api_key, + ) .send() .await; match res { Ok(res) => match res.status().is_success() { true => Ok(res), - false => { - Err(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - format!("{}: {}", res.status(), res.text().await.unwrap()) - ))) - }, + false => Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{}: {}", res.status(), res.text().await.unwrap()), + ))), }, Err(e) => Err(Box::new(e)), } @@ -83,37 +79,39 @@ impl Client { pub async fn delete(&self, path: &str) -> Result> { let client = reqwest::Client::new(); - let url = format!("{}{}", APU_URL_V1, path); + let url = format!("{APU_URL_V1}{path}"); let res = client .delete(&url) .header(reqwest::header::CONTENT_TYPE, "application/json") - .header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key) + .header( + reqwest::header::AUTHORIZATION, + "Bearer ".to_owned() + &self.api_key, + ) .send() .await; match res { Ok(res) => match res.status().is_success() { true => Ok(res), - false => { - Err(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - format!("{}: {}", res.status(), res.text().await.unwrap()) - ))) - }, + false => Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{}: {}", res.status(), res.text().await.unwrap()), + ))), }, Err(e) => Err(Box::new(e)), } } - pub async fn completion(&self, req: CompletionRequest) -> Result> { + pub async fn completion( + &self, + req: CompletionRequest, + ) -> Result> { let res = self.post("/completions", &req).await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } @@ -122,63 +120,65 @@ impl Client { match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } - pub async fn image_generation(&self, req: ImageGenerationRequest) -> Result> { + pub async fn image_generation( + &self, + req: ImageGenerationRequest, + ) -> Result> { let res = self.post("/images/generations", &req).await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } - pub async fn image_edit(&self, req: ImageEditRequest) -> Result> { + pub async fn image_edit( + &self, + req: ImageEditRequest, + ) -> Result> { let res = self.post("/images/edits", &req).await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } - pub async fn image_variation(&self, req: ImageVariationRequest) -> Result> { + pub async fn image_variation( + &self, + req: ImageVariationRequest, + ) -> Result> { let res = self.post("/images/variations", &req).await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } - pub async fn embedding(&self, req: EmbeddingRequest) -> Result> { + pub async fn embedding( + &self, + req: EmbeddingRequest, + ) -> Result> { let res = self.post("/embeddings", &req).await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } @@ -187,64 +187,81 @@ impl Client { match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } - - pub async fn file_upload(&self, req: FileUploadRequest) -> Result> { + + pub async fn file_upload( + &self, + req: FileUploadRequest, + ) -> Result> { let res = self.post("/files", &req).await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } - - pub async fn file_delete(&self, req: FileDeleteRequest) -> Result> { + + pub async fn file_delete( + &self, + req: FileDeleteRequest, + ) -> Result> { let res = self.delete(&format!("{}/{}", "/files", req.file_id)).await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } - pub async fn file_retrieve(&self, req: FileRetrieveRequest) -> Result> { + pub async fn file_retrieve( + &self, + req: FileRetrieveRequest, + ) -> Result> { let res = self.get(&format!("{}/{}", "/files", req.file_id)).await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } - pub async fn file_retrieve_content(&self, req: FileRetrieveContentRequest) -> Result> { - let res = self.get(&format!("{}/{}/content", "/files", req.file_id)).await; + pub async fn file_retrieve_content( + &self, + req: FileRetrieveContentRequest, + ) -> Result> { + let res = self + .get(&format!("{}/{}/content", "/files", req.file_id)) + .await; match res { Ok(res) => { let r = res.json::().await?; - return Ok(r); - }, - Err(e) => { - return Err(e); - }, + Ok(r) + } + Err(e) => Err(e), } } + pub async fn chat_completion( + &self, + req: ChatCompletionRequest, + ) -> Result> { + let res = self.post("/chat/completions", &req).await; + match res { + Ok(res) => { + let r = res.json::().await?; + Ok(r) + } + Err(e) => Err(e), + } + } } diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs new file mode 100644 index 0000000..d3cfa8f --- /dev/null +++ b/src/v1/chat_completion.rs @@ -0,0 +1,43 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::common; + +pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo"; +pub const GPT3_5_TURBO_0301: &str = "gpt-3.5-turbo-0301"; + +#[derive(Debug, Serialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[allow(non_camel_case_types)] +pub enum MessageRole { + user, + system, + assistant, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatCompletionMessage { + pub role: MessageRole, + pub content: String, +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionChoice { + pub index: i64, + pub message: ChatCompletionMessage, + pub finish_reason: String, +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: common::Usage, +} diff --git a/src/v1/common.rs b/src/v1/common.rs index 7a87acc..05f500b 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -1,10 +1,8 @@ - -use serde::{Deserialize}; - +use serde::Deserialize; #[derive(Debug, Deserialize)] pub struct Usage { - pub prompt_tokens: i32, - pub completion_tokens: i32, - pub total_tokens: i32, + pub prompt_tokens: i32, + pub completion_tokens: i32, + pub total_tokens: i32, } diff --git a/src/v1/completion.rs b/src/v1/completion.rs index 592cdd1..4e22d9f 100644 --- a/src/v1/completion.rs +++ b/src/v1/completion.rs @@ -1,79 +1,79 @@ -use serde::{Serialize, Deserialize}; -use std::option::Option; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::option::Option; use crate::v1::common; -pub const GPT3_TEXT_DAVINCI_003: &str = "text-davinci-003"; -pub const GPT3_TEXT_DAVINCI_002: &str = "text-davinci-002"; -pub const GPT3_TEXT_CURIE_001: &str = "text-curie-001"; -pub const GPT3_TEXT_BABBAGE_001: &str = "text-babbage-001"; -pub const GPT3_TEXT_ADA_001: &str = "text-ada-001"; -pub const GPT3_TEXT_DAVINCI_001: &str = "text-davinci-001"; -pub const GPT3_DAVINCI_INSTRUCT_BETA: &str = "davinci-instruct-beta"; -pub const GPT3_DAVINCI: &str = "davinci"; -pub const GPT3_CURIE_INSTRUCT_BETA: &str = "curie-instruct-beta"; -pub const GPT3_CURIE: &str = "curie"; -pub const GPT3_ADA: &str = "ada"; -pub const GPT3_BABBAGE: &str = "babbage"; +pub const GPT3_TEXT_DAVINCI_003: &str = "text-davinci-003"; +pub const GPT3_TEXT_DAVINCI_002: &str = "text-davinci-002"; +pub const GPT3_TEXT_CURIE_001: &str = "text-curie-001"; +pub const GPT3_TEXT_BABBAGE_001: &str = "text-babbage-001"; +pub const GPT3_TEXT_ADA_001: &str = "text-ada-001"; +pub const GPT3_TEXT_DAVINCI_001: &str = "text-davinci-001"; +pub const GPT3_DAVINCI_INSTRUCT_BETA: &str = "davinci-instruct-beta"; +pub const GPT3_DAVINCI: &str = "davinci"; +pub const GPT3_CURIE_INSTRUCT_BETA: &str = "curie-instruct-beta"; +pub const GPT3_CURIE: &str = "curie"; +pub const GPT3_ADA: &str = "ada"; +pub const GPT3_BABBAGE: &str = "babbage"; #[derive(Debug, Serialize)] pub struct CompletionRequest { - pub model: String, + pub model: String, #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, + pub prompt: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub suffix: Option, + pub suffix: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, + pub max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, + pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, + pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, + pub n: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, + pub stream: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, + pub logprobs: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub echo: Option, + pub echo: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option>, + pub stop: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, + pub presence_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] pub frequency_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, + pub best_of: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, + pub logit_bias: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, + pub user: Option, } #[derive(Debug, Deserialize)] pub struct CompletionChoice { - pub text: String, - pub index: i64, + pub text: String, + pub index: i64, pub finish_reason: String, - pub logprobs: Option, + pub logprobs: Option, } #[derive(Debug, Deserialize)] pub struct LogprobResult { - pub tokens: Vec, - pub token_logprobs: Vec, - pub top_logprobs: Vec>, - pub text_offset: Vec, + pub tokens: Vec, + pub token_logprobs: Vec, + pub top_logprobs: Vec>, + pub text_offset: Vec, } #[derive(Debug, Deserialize)] pub struct CompletionResponse { - pub id: String, - pub object: String, + pub id: String, + pub object: String, pub created: i64, - pub model: String, + pub model: String, pub choices: Vec, - pub usage: common::Usage, + pub usage: common::Usage, } diff --git a/src/v1/edit.rs b/src/v1/edit.rs index c562c52..82870cf 100644 --- a/src/v1/edit.rs +++ b/src/v1/edit.rs @@ -1,4 +1,4 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; use std::option::Option; use crate::v1::common; @@ -18,15 +18,15 @@ pub struct EditRequest { } #[derive(Debug, Deserialize)] -pub struct EditChoice{ - pub text: String, - pub index: i32, +pub struct EditChoice { + pub text: String, + pub index: i32, } #[derive(Debug, Deserialize)] pub struct EditResponse { - pub object: String, + pub object: String, pub created: i64, - pub usage: common::Usage, + pub usage: common::Usage, pub choices: Vec, } diff --git a/src/v1/embedding.rs b/src/v1/embedding.rs index c49b06f..27d65f2 100644 --- a/src/v1/embedding.rs +++ b/src/v1/embedding.rs @@ -1,14 +1,14 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; use std::option::Option; use crate::v1::common; #[derive(Debug, Deserialize)] -pub struct EmbeddingData{ - pub object: String, - pub embedding: Vec, - pub index: i32, - pub usage: common::Usage, +pub struct EmbeddingData { + pub object: String, + pub embedding: Vec, + pub index: i32, + pub usage: common::Usage, } #[derive(Debug, Serialize)] @@ -19,9 +19,8 @@ pub struct EmbeddingRequest { pub user: Option, } - #[derive(Debug, Deserialize)] pub struct EmbeddingResponse { pub object: String, - pub data: Vec, + pub data: Vec, } diff --git a/src/v1/file.rs b/src/v1/file.rs index 9437700..5eff6cc 100644 --- a/src/v1/file.rs +++ b/src/v1/file.rs @@ -1,8 +1,8 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize)] -pub struct FileData{ - pub id: String, +pub struct FileData { + pub id: String, pub oejct: String, pub bytes: i32, pub created_at: i64, @@ -13,10 +13,9 @@ pub struct FileData{ #[derive(Debug, Deserialize)] pub struct FileListResponse { pub object: String, - pub data: Vec, + pub data: Vec, } - #[derive(Debug, Serialize)] pub struct FileUploadRequest { pub file: String, @@ -33,7 +32,6 @@ pub struct FileUploadResponse { pub purpose: String, } - #[derive(Debug, Serialize)] pub struct FileDeleteRequest { pub file_id: String, @@ -61,7 +59,6 @@ pub struct FileRetrieveResponse { pub purpose: String, } - #[derive(Debug, Serialize)] pub struct FileRetrieveContentRequest { pub file_id: String, diff --git a/src/v1/image.rs b/src/v1/image.rs index ac7cc4e..9e03a43 100644 --- a/src/v1/image.rs +++ b/src/v1/image.rs @@ -1,9 +1,9 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; use std::option::Option; #[derive(Debug, Deserialize)] -pub struct ImageData{ - pub url: String, +pub struct ImageData { + pub url: String, } #[derive(Debug, Serialize)] @@ -19,11 +19,10 @@ pub struct ImageGenerationRequest { pub user: Option, } - #[derive(Debug, Deserialize)] pub struct ImageGenerationResponse { pub created: i64, - pub data: Vec, + pub data: Vec, } #[derive(Debug, Serialize)] @@ -45,7 +44,7 @@ pub struct ImageEditRequest { #[derive(Debug, Deserialize)] pub struct ImageEditResponse { pub created: i64, - pub data: Vec, + pub data: Vec, } #[derive(Debug, Serialize)] @@ -64,5 +63,5 @@ pub struct ImageVariationRequest { #[derive(Debug, Deserialize)] pub struct ImageVariationResponse { pub created: i64, - pub data: Vec, + pub data: Vec, } diff --git a/src/v1/mod.rs b/src/v1/mod.rs index 0c0728d..8c84b8a 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -1,9 +1,10 @@ pub mod common; +pub mod chat_completion; pub mod completion; pub mod edit; -pub mod image; pub mod embedding; pub mod file; +pub mod image; pub mod api;