Skip to content

Commit

Permalink
expand documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Feb 24, 2024
1 parent fdac41a commit 9e96ae3
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ makedocs(;
canonical="https://tylerjthomas9.github.io/GoogleGenAI.jl",
assets=String[],
),
pages=["Home" => "index.md"],
pages=["Home" => "index.md", "API" => "api.md"],
warnonly=true,
)

Expand Down
12 changes: 12 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CurrentModule = GoogleGenAI
```

# API
```
GoogleProvider
generate_content
count_tokens
embed_content
list_models
```
93 changes: 77 additions & 16 deletions src/GoogleGenAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@ using HTTP

abstract type AbstractGoogleProvider end

"""
Base.@kwdef struct GoogleProvider <: AbstractGoogleProvider
api_key::String = ""
base_url::String = "https://generativelanguage.googleapis.com"
api_version::String = "v1beta"
end
A configuration object used to set up and authenticate requests to the Google Generative Language API.
# Fields
- `api_key::String`: Your Google API key.
- `base_url::String`: The base URL for the Google Generative Language API. The default is set to `"https://generativelanguage.googleapis.com"`.
- `api_version::String`: The version of the API you wish to access. The default is set to `"v1beta"`.
"""
Base.@kwdef struct GoogleProvider <: AbstractGoogleProvider
api_key::String = ""
base_url::String = "https://generativelanguage.googleapis.com"
Expand All @@ -20,6 +34,7 @@ end

struct GoogleEmbeddingResponse
values::Vector{Float64}
response_status::Int
end

#TODO: Add support for exception
Expand Down Expand Up @@ -65,25 +80,29 @@ function _parse_response(response::HTTP.Messages.Response)
end

"""
generate_content(provider::GoogleProvider, model_name::String, input::String; kwargs...)
generate_content(api_key::String, model_name::String, input::String; kwargs...)
generate_content(provider::GooglePAbstractGoogleProviderrovider, model_name::String, input::String; kwargs...) -> GoogleTextResponse
generate_content(api_key::String, model_name::String, input::String; kwargs...) -> GoogleTextResponse
Generate text using the specified model.
# Arguments
- `provider::GoogleProvider`: The provider to use for the request.
- `model_name::String`: The model to use for the request.
- `input::String`: The input prompt to use for the request.
- `provider::AbstractGoogleProvider`: The provider instance containing API key and base URL information.
- `api_key::String`: Your Google API key as a string.
- `model_name::String`: The name of the model to use for generating content.
- `input::String`: The input prompt based on which the text is generated.
# Keyword Arguments
- `temperature::Float64`: The temperature for randomness in generation.
- `candidate_count::Int`: The number of candidates to consider. (Only one can be specified right now)
- `max_output_tokens::Int`: The maximum number of output tokens.
- `stop_sequences::Vector{String}`: Stop sequences to halt text generation.
- `safety_settings::Vector{Dict}`: Safety settings for generated text.
- `temperature::Float64` (optional): Controls the randomness in the generation process. Higher values result in more random outputs. Typically ranges between 0 and 1.
- `candidate_count::Int` (optional): The number of generation candidates to consider. Currently, only one candidate can be specified.
- `max_output_tokens::Int` (optional): The maximum number of tokens that the generated content should contain.
- `stop_sequences::Vector{String}` (optional): A list of sequences where the generation should stop. Useful for defining natural endpoints in generated content.
- `safety_settings::Vector{Dict}` (optional): Settings to control the safety aspects of the generated content, such as filtering out unsafe or inappropriate content.
# Returns
- `GoogleTextResponse`
"""
function generate_content(
provider::GoogleProvider, model_name::String, input::String; kwargs...
provider::AbstractGoogleProvider, model_name::String, input::String; kwargs...
)
endpoint = "models/$model_name:generateContent"

Expand All @@ -99,7 +118,6 @@ function generate_content(
else
safety_settings = nothing
end
println([Dict("parts" => [Dict("text" => input)])])
body = Dict(
"contents" => [Dict("parts" => [Dict("text" => input)])],
"generationConfig" => generation_config,
Expand All @@ -113,7 +131,22 @@ function generate_content(api_key::String, model_name::String, input::String; kw
return generate_content(GoogleProvider(; api_key), model_name, input; kwargs...)
end

function count_tokens(provider::GoogleProvider, model_name::String, input::String)
"""
count_tokens(provider::AbstractGoogleProvider, model_name::String, input::String) -> Int
count_tokens(api_key::String, model_name::String, input::String) -> Int
Calculate the number of tokens generated by the specified model for a given input string.
# Arguments
- `provider::AbstractGoogleProvider`: The provider instance containing API key and base URL information.
- `api_key::String`: Your Google API key as a string.
- `model_name::String`: The name of the model to use for generating content.
- `input::String`: The input prompt based on which the text is generated.
# Returns
- `Int`: The total number of tokens that the given input string would be broken into by the specified model's tokenizer.
"""
function count_tokens(provider::AbstractGoogleProvider, model_name::String, input::String)
endpoint = "models/$model_name:countTokens"
body = Dict("contents" => [Dict("parts" => [Dict("text" => input)])])
response = _request(provider, endpoint, :POST, body)
Expand All @@ -125,7 +158,22 @@ function count_tokens(api_key::String, model_name::String, input::String)
end

#TODO: Do we want an embeddings struct, or just the array of embeddings?
function embed_content(provider::GoogleProvider, model_name::String, input::String)
"""
embed_content(provider::AbstractGoogleProvider, model_name::String, input::String) -> GoogleEmbeddingResponse
embed_content(api_key::String, model_name::String, input::String) -> GoogleEmbeddingResponse
Generate an embedding for the given input text using the specified model.
# Arguments
- `provider::AbstractGoogleProvider`: The provider instance containing API key and base URL information.
- `api_key::String`: Your Google API key as a string.
- `model_name::String`: The name of the model to use for generating content.
- `input::String`: The input prompt based on which the text is generated.
# Returns
- `GoogleEmbeddingResponse`
"""
function embed_content(provider::AbstractGoogleProvider, model_name::String, input::String)
endpoint = "models/$model_name:embedContent"
body = Dict(
"model" => "models/$model_name",
Expand All @@ -135,13 +183,26 @@ function embed_content(provider::GoogleProvider, model_name::String, input::Stri
embedding_values = get(
get(JSON3.read(response.body), "embedding", Dict()), "values", Vector{Float64}()
)
return GoogleEmbeddingResponse(embedding_values)
return GoogleEmbeddingResponse(embedding_values, response.status)
end
function embed_content(api_key::String, model_name::String, input::String)
return embed_content(GoogleProvider(; api_key), model_name, input)
end

function list_models(provider::GoogleProvider)
"""
list_models(provider::AbstractGoogleProvider) -> Vector{Dict}
list_models(api_key::String) -> Vector{Dict}
Retrieve a list of available models along with their details from the Google AI API.
# Arguments
- `provider::AbstractGoogleProvider`: The provider instance containing API key and base URL information.
- `api_key::String`: Your Google API key as a string.
# Returns
- `Vector{Dict}`: A list of dictionaries, each containing details about an available model.
"""
function list_models(provider::AbstractGoogleProvider)
endpoint = "models"
response = _request(provider, endpoint, :GET, Dict())
parsed_response = JSON3.read(response.body)
Expand Down

0 comments on commit 9e96ae3

Please sign in to comment.