diff --git a/NAMESPACE b/NAMESPACE index 4b6c332..08ccb35 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,13 @@ # Generated by roxygen2: do not edit by hand -S3method(print,chat_tibble) +S3method(as.data.frame,chat_response) +S3method(as_messages,character) +S3method(as_messages,list) +S3method(as_msg,character) +S3method(as_tibble,chat_response) +S3method(print,chat) +S3method(print,chat_tbl) +export(as_messages) export(chat) export(models) export(stream) @@ -10,6 +17,9 @@ import(rlang) import(stringr) import(tibble) importFrom(jsonlite,fromJSON) +importFrom(purrr,list_flatten) +importFrom(purrr,map2) importFrom(purrr,map_chr) importFrom(purrr,map_dfr) importFrom(purrr,pluck) +importFrom(tibble,as_tibble) diff --git a/R/chat.R b/R/chat.R index 7cbde0e..360e85e 100644 --- a/R/chat.R +++ b/R/chat.R @@ -1,10 +1,10 @@ #' Chat with the Mistral api #' -#' @param text some text +#' @param messages Messages #' @param model which model to use. See [models()] for more information about which models are available #' @param dry_run if TRUE the request is not performed -#' @param ... ignored -#' @inheritParams httr2::req_perform +#' @inheritParams rlang::args_dots_empty +#' @inheritParams rlang::args_error_context #' #' @return A tibble with columns `role` and `content` with class `chat_tibble` or a request #' if this is a `dry_run` @@ -13,49 +13,67 @@ #' chat("Top 5 R packages", dry_run = TRUE) #' #' @export -chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { - req <- req_chat(text, model, error_call = error_call, dry_run = dry_run) +chat <- function(messages, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { + check_dots_empty() + + req <- req_chat(messages, model = model, error_call = error_call, dry_run = dry_run) if (is_true(dry_run)) { return(req) } resp <- req_mistral_perform(req, error_call = error_call) - resp_chat(resp, error_call = error_call) + class(resp) <- c("chat", class(resp)) + resp +} + +#' @export +print.chat <- function(x, ...) { + writeLines(resp_body_json(x)$choices[[1]]$message$content) + invisible(x) } -req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) { +req_chat <- function(messages, model = "mistral-tiny", stream = FALSE, dry_run = FALSE, ..., error_call = caller_env()) { + check_dots_empty() + if (!is_true(dry_run)) { check_model(model, error_call = error_call) } + + messages <- as_messages(messages) + request(mistral_base_url) |> req_url_path_append("v1", "chat", "completions") |> authenticate() |> req_body_json( list( model = model, - messages = list( - list( - role = "user", - content = text - ) - ), + messages = messages, stream = is_true(stream) ) ) } -resp_chat <- function(response, error_call = current_env()) { - data <- resp_body_json(response) +#' @export +as.data.frame.chat_response <- function(x, ...) { + req_messages <- x$request$body$data$messages + df_req <- map_dfr(req_messages, as.data.frame) + + df_resp <- as.data.frame( + resp_body_json(x)$choices[[1]]$message[c("role", "content")] + ) - tib <- map_dfr(data$choices, \(choice) { - as_tibble(choice$message) - }) + rbind(df_req, df_resp) +} - class(tib) <- c("chat_tibble", class(tib)) +#' @importFrom tibble as_tibble +#' @export +as_tibble.chat_response <- function(x, ...) { + tib <- as_tibble(as.data.frame(x, ...)) + class(tib) <- c("chat_tbl", class(x)) tib } #' @export -print.chat_tibble <- function(x, ...) { +print.chat_tbl <- function(x, ...) { n <- nrow(x) for (i in seq_len(n)) { diff --git a/R/messages.R b/R/messages.R new file mode 100644 index 0000000..d5501c1 --- /dev/null +++ b/R/messages.R @@ -0,0 +1,70 @@ +#' Convert object into a messages list +#' +#' @param messages object to convert to messages +#' @param ... ignored +#' @inheritParams rlang::args_error_context +#' +#' @examples +#' as_messages("hello") +#' as_messages(list("hello")) +#' as_messages(list(assistant = "hello", user = "hello")) +#' +#' @export +as_messages <- function(messages, ..., error_call = current_env()) { + UseMethod("as_messages") +} + +#' @export +as_messages.character <- function(messages, ..., error_call = current_env()) { + check_dots_empty(call = error_call) + check_scalar_string(messages, error_call = error_call) + check_unnamed_string(messages, error_call = error_call) + + list( + list(role = "user", content = messages) + ) +} + +#' @export +as_messages.list <- function(messages, ..., error_call = caller_env()) { + check_dots_empty() + + out <- list_flatten( + map2(messages, names2(messages), as_msg, error_call = error_call) + ) + names(out) <- NULL + out +} + +as_msg <- function(x, name, error_call = caller_env()) { + UseMethod("as_msg") +} + +#' @export +as_msg.character <- function(x, name, error_call = caller_env()) { + check_scalar_string(x, error_call = error_call) + role <- check_role(name, error_call = error_call) + + list( + list(role = role, content = x) + ) +} + +check_role <- function(name = "", error_call = caller_env()) { + if (identical(name, "")) { + name <- "user" + } + name +} + +check_scalar_string <- function(x, error_call = caller_env()) { + if (length(x) != 1L) { + cli_abort("{.arg x} must be a single string, not length {.code {length(x)}}. ", call = error_call) + } +} + +check_unnamed_string <- function(x, error_call = caller_env()) { + if (!is.null(names(x))) { + cli_abort("{.arg x} must be unnamed", call = error_call) + } +} diff --git a/R/stream.R b/R/stream.R index aaf793b..a5ed84b 100644 --- a/R/stream.R +++ b/R/stream.R @@ -3,10 +3,11 @@ #' @inheritParams chat #' #' @export -stream <- function(text, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { +stream <- function(messages, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { check_model(model, error_call = error_call) - req <- req_chat(text, model, stream = TRUE, error_call = error_call, dry_run = dry_run) + messages <- as_messages(messages) + req <- req_chat(messages, model, stream = TRUE, error_call = error_call, dry_run = dry_run) if (is_true(dry_run)) { return(req) } diff --git a/R/zzz.R b/R/zzz.R index 66e67f9..06c6070 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -3,7 +3,7 @@ #' @import httr2 #' @import tibble #' @import stringr -#' @importFrom purrr map_dfr map_chr pluck +#' @importFrom purrr map_dfr map_chr pluck map2 list_flatten NULL mistral_base_url <- "https://api.mistral.ai" diff --git a/man/as_messages.Rd b/man/as_messages.Rd new file mode 100644 index 0000000..7d01a9e --- /dev/null +++ b/man/as_messages.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/messages.R +\name{as_messages} +\alias{as_messages} +\title{Convert object into a messages list} +\usage{ +as_messages(messages, ..., error_call = current_env()) +} +\arguments{ +\item{messages}{object to convert to messages} + +\item{...}{ignored} + +\item{error_call}{The execution environment of a currently +running function, e.g. \code{caller_env()}. The function will be +mentioned in error messages as the source of the error. See the +\code{call} argument of \code{\link[rlang:abort]{abort()}} for more information.} +} +\description{ +Convert object into a messages list +} +\examples{ +as_messages("hello") +as_messages(list("hello")) +as_messages(list(assistant = "hello", user = "hello")) + +} diff --git a/man/chat.Rd b/man/chat.Rd index 67f8483..f34a1d8 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -5,7 +5,7 @@ \title{Chat with the Mistral api} \usage{ chat( - text = "What are the top 5 R packages ?", + messages, model = "mistral-tiny", dry_run = FALSE, ..., @@ -13,13 +13,13 @@ chat( ) } \arguments{ -\item{text}{some text} +\item{messages}{Messages} \item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available} \item{dry_run}{if TRUE the request is not performed} -\item{...}{ignored} +\item{...}{These dots are for future extensions and must be empty.} \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be diff --git a/man/stream.Rd b/man/stream.Rd index 7e7ae64..fee756c 100644 --- a/man/stream.Rd +++ b/man/stream.Rd @@ -5,7 +5,7 @@ \title{stream} \usage{ stream( - text, + messages, model = "mistral-tiny", dry_run = FALSE, ..., @@ -13,13 +13,13 @@ stream( ) } \arguments{ -\item{text}{some text} +\item{messages}{Messages} \item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available} \item{dry_run}{if TRUE the request is not performed} -\item{...}{ignored} +\item{...}{These dots are for future extensions and must be empty.} \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be