From a09f67a5c47088c94fbaab21c9a11f674efb5c85 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Tue, 12 Mar 2024 11:20:13 +0100 Subject: [PATCH 1/4] rework chat to use ... --- NAMESPACE | 7 ++++++- R/chat.R | 46 ++++++++++++++++++++++++++++------------------ R/messages.R | 24 ++++++++++++++++++++++++ man/chat.Rd | 14 ++++---------- man/stream.Rd | 2 +- 5 files changed, 63 insertions(+), 30 deletions(-) create mode 100644 R/messages.R diff --git a/NAMESPACE b/NAMESPACE index 4b6c332..8f9cd00 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,10 @@ # Generated by roxygen2: do not edit by hand -S3method(print,chat_tibble) +S3method(as.data.frame,chat_response) +S3method(as_message,character) +S3method(as_tibble,chat_response) +S3method(print,chat) +S3method(print,chat_tbl) export(chat) export(models) export(stream) @@ -13,3 +17,4 @@ importFrom(jsonlite,fromJSON) importFrom(purrr,map_chr) importFrom(purrr,map_dfr) importFrom(purrr,pluck) +importFrom(tibble,as_tibble) diff --git a/R/chat.R b/R/chat.R index 7cbde0e..190ec1a 100644 --- a/R/chat.R +++ b/R/chat.R @@ -1,9 +1,9 @@ #' Chat with the Mistral api #' +#' @param ... messages #' @param text some text #' @param model which model to use. See [models()] for more information about which models are available #' @param dry_run if TRUE the request is not performed -#' @param ... ignored #' @inheritParams httr2::req_perform #' #' @return A tibble with columns `role` and `content` with class `chat_tibble` or a request @@ -13,49 +13,59 @@ #' chat("Top 5 R packages", dry_run = TRUE) #' #' @export -chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { - req <- req_chat(text, model, error_call = error_call, dry_run = dry_run) +chat <- function(..., model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) { + req <- req_chat(..., model = model, error_call = error_call, dry_run = dry_run) if (is_true(dry_run)) { return(req) } resp <- req_mistral_perform(req, error_call = error_call) - resp_chat(resp, error_call = error_call) + class(resp) <- c("chat", class(resp)) + resp } -req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) { +#' @export +print.chat <- function(x, ...) { + writeLines(resp_body_json(resp)$choices[[1]]$message$content) + invisible(x) +} + +req_chat <- function(..., model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) { if (!is_true(dry_run)) { check_model(model, error_call = error_call) } + + messages <- as_messages(...) + request(mistral_base_url) |> req_url_path_append("v1", "chat", "completions") |> authenticate() |> req_body_json( list( model = model, - messages = list( - list( - role = "user", - content = text - ) - ), + messages = messages, stream = is_true(stream) ) ) } -resp_chat <- function(response, error_call = current_env()) { - data <- resp_body_json(response) +#' @export +as.data.frame.chat_response <- function(x, ...) { + df_req <- map_dfr(resp$request$body$data$messages, as.data.frame) + df_resp <- as.data.frame(resp_body_json(resp)$choices[[1]]$message[c("role", "content")]) - tib <- map_dfr(data$choices, \(choice) { - as_tibble(choice$message) - }) + rbind(df_req, df_resp) +} - class(tib) <- c("chat_tibble", class(tib)) +#' @importFrom tibble as_tibble +#' @export +as_tibble.chat_response <- function(x, ...) { + tib <- as_tibble(as.data.frame(x, ...)) + class(tib) <- c("chat_tbl", class(x)) tib } #' @export -print.chat_tibble <- function(x, ...) { +print.chat_tbl <- function(x, ...) { n <- nrow(x) for (i in seq_len(n)) { diff --git a/R/messages.R b/R/messages.R new file mode 100644 index 0000000..50ac6b8 --- /dev/null +++ b/R/messages.R @@ -0,0 +1,24 @@ +as_message <- function(x, name) { + UseMethod("as_message") +} + +#' @export +as_message.character <- function(x, name) { + if (identical(name, "")) { + name <- "user" + } + list(role = name, content = x) +} + +as_messages <- function(...) { + x <- dots_list(..., .named = FALSE) + names <- names(x) + + messages <- list() + for (i in seq_len(length(x))) { + messages <- append(messages, as_message(x[[i]], name = names[i])) + } + + list(messages) +} + diff --git a/man/chat.Rd b/man/chat.Rd index 67f8483..e5a745a 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -4,27 +4,21 @@ \alias{chat} \title{Chat with the Mistral api} \usage{ -chat( - text = "What are the top 5 R packages ?", - model = "mistral-tiny", - dry_run = FALSE, - ..., - error_call = current_env() -) +chat(..., model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) } \arguments{ -\item{text}{some text} +\item{...}{messages} \item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available} \item{dry_run}{if TRUE the request is not performed} -\item{...}{ignored} - \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be mentioned in error messages as the source of the error. See the \code{call} argument of \code{\link[rlang:abort]{abort()}} for more information.} + +\item{text}{some text} } \value{ A tibble with columns \code{role} and \code{content} with class \code{chat_tibble} or a request diff --git a/man/stream.Rd b/man/stream.Rd index 7e7ae64..335ab4f 100644 --- a/man/stream.Rd +++ b/man/stream.Rd @@ -19,7 +19,7 @@ stream( \item{dry_run}{if TRUE the request is not performed} -\item{...}{ignored} +\item{...}{messages} \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be From d0215581f50f6c622e4069566f5d407abdc7bf22 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Sat, 30 Mar 2024 11:26:59 +0100 Subject: [PATCH 2/4] chat(messages = ) + as_messages() --- NAMESPACE | 7 +++++- R/chat.R | 12 +++++------ R/messages.R | 59 +++++++++++++++++++++++++++++++++++++++------------ R/zzz.R | 2 +- man/chat.Rd | 9 ++++++-- man/stream.Rd | 2 -- 6 files changed, 66 insertions(+), 25 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 8f9cd00..08ccb35 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,10 +1,13 @@ # Generated by roxygen2: do not edit by hand S3method(as.data.frame,chat_response) -S3method(as_message,character) +S3method(as_messages,character) +S3method(as_messages,list) +S3method(as_msg,character) S3method(as_tibble,chat_response) S3method(print,chat) S3method(print,chat_tbl) +export(as_messages) export(chat) export(models) export(stream) @@ -14,6 +17,8 @@ import(rlang) import(stringr) import(tibble) importFrom(jsonlite,fromJSON) +importFrom(purrr,list_flatten) +importFrom(purrr,map2) importFrom(purrr,map_chr) importFrom(purrr,map_dfr) importFrom(purrr,pluck) diff --git a/R/chat.R b/R/chat.R index 190ec1a..028cfed 100644 --- a/R/chat.R +++ b/R/chat.R @@ -1,6 +1,6 @@ #' Chat with the Mistral api #' -#' @param ... messages +#' @param messages Messages #' @param text some text #' @param model which model to use. See [models()] for more information about which models are available #' @param dry_run if TRUE the request is not performed @@ -13,8 +13,8 @@ #' chat("Top 5 R packages", dry_run = TRUE) #' #' @export -chat <- function(..., model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) { - req <- req_chat(..., model = model, error_call = error_call, dry_run = dry_run) +chat <- function(messages, model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) { + req <- req_chat(messages, model = model, error_call = error_call, dry_run = dry_run) if (is_true(dry_run)) { return(req) } @@ -25,16 +25,16 @@ chat <- function(..., model = "mistral-tiny", dry_run = FALSE, error_call = curr #' @export print.chat <- function(x, ...) { - writeLines(resp_body_json(resp)$choices[[1]]$message$content) + writeLines(resp_body_json(x)$choices[[1]]$message$content) invisible(x) } -req_chat <- function(..., model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) { +req_chat <- function(messages, model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) { if (!is_true(dry_run)) { check_model(model, error_call = error_call) } - messages <- as_messages(...) + messages <- as_messages(messages) request(mistral_base_url) |> req_url_path_append("v1", "chat", "completions") |> diff --git a/R/messages.R b/R/messages.R index 50ac6b8..92660f0 100644 --- a/R/messages.R +++ b/R/messages.R @@ -1,24 +1,57 @@ -as_message <- function(x, name) { - UseMethod("as_message") +#' @export +as_messages <- function(messages, ...) { + UseMethod("as_messages") +} + +#' @export +as_messages.character <- function(x, ..., error_call = current_env()) { + check_scalar_string(x, error_call = error_call) + check_unnamed_string(x, error_call = error_call) + + list( + list(role = "user", content = x) + ) +} + +#' @export +as_messages.list <- function(x, ..., error_call = caller_env()) { + check_dots_empty() + + bits <- map2(x, names2(x), as_msg, error_call = error_call) + out <- list_flatten(bits) + names(out) <- NULL + out +} + +as_msg <- function(x, name, error_call = caller_env()) { + UseMethod("as_msg") } #' @export -as_message.character <- function(x, name) { +as_msg.character <- function(x, name, error_call = caller_env()) { + check_scalar_string(x, error_call = error_call) + role <- check_role(name, error_call = error_call) + + list( + list(role = role, content = x) + ) +} + +check_role <- function(name = "", error_call = caller_env()) { if (identical(name, "")) { name <- "user" } - list(role = name, content = x) + name } -as_messages <- function(...) { - x <- dots_list(..., .named = FALSE) - names <- names(x) - - messages <- list() - for (i in seq_len(length(x))) { - messages <- append(messages, as_message(x[[i]], name = names[i])) +check_scalar_string <- function(x, error_call = caller_env()) { + if (length(x) != 1L) { + cli_abort("{.arg x} must be a single string, not length {.code {length(x)}}. ", call = error_call) } - - list(messages) } +check_unnamed_string <- function(x, error_call = caller_env()) { + if (!is.null(names(x))) { + cli_abort("{.arg x} must be unnamed", call = error_call) + } +} diff --git a/R/zzz.R b/R/zzz.R index 66e67f9..06c6070 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -3,7 +3,7 @@ #' @import httr2 #' @import tibble #' @import stringr -#' @importFrom purrr map_dfr map_chr pluck +#' @importFrom purrr map_dfr map_chr pluck map2 list_flatten NULL mistral_base_url <- "https://api.mistral.ai" diff --git a/man/chat.Rd b/man/chat.Rd index e5a745a..891f8fa 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -4,10 +4,15 @@ \alias{chat} \title{Chat with the Mistral api} \usage{ -chat(..., model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) +chat( + messages, + model = "mistral-tiny", + dry_run = FALSE, + error_call = current_env() +) } \arguments{ -\item{...}{messages} +\item{messages}{Messages} \item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available} diff --git a/man/stream.Rd b/man/stream.Rd index 335ab4f..3100fb7 100644 --- a/man/stream.Rd +++ b/man/stream.Rd @@ -19,8 +19,6 @@ stream( \item{dry_run}{if TRUE the request is not performed} -\item{...}{messages} - \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be mentioned in error messages as the source of the error. See the From e810ba7926b1a9e39cea4c2c7a1f7ab5d0bc46c9 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Sat, 30 Mar 2024 11:30:20 +0100 Subject: [PATCH 3/4] stream(messages = ) --- R/stream.R | 5 +++-- man/stream.Rd | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/R/stream.R b/R/stream.R index aaf793b..a5ed84b 100644 --- a/R/stream.R +++ b/R/stream.R @@ -3,10 +3,11 @@ #' @inheritParams chat #' #' @export -stream <- function(text, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { +stream <- function(messages, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { check_model(model, error_call = error_call) - req <- req_chat(text, model, stream = TRUE, error_call = error_call, dry_run = dry_run) + messages <- as_messages(messages) + req <- req_chat(messages, model, stream = TRUE, error_call = error_call, dry_run = dry_run) if (is_true(dry_run)) { return(req) } diff --git a/man/stream.Rd b/man/stream.Rd index 3100fb7..eeb3058 100644 --- a/man/stream.Rd +++ b/man/stream.Rd @@ -5,7 +5,7 @@ \title{stream} \usage{ stream( - text, + messages, model = "mistral-tiny", dry_run = FALSE, ..., @@ -13,7 +13,7 @@ stream( ) } \arguments{ -\item{text}{some text} +\item{messages}{Messages} \item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available} From 43b4f3f398bab80c3350b97e313de0cc1f6597d2 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Sat, 30 Mar 2024 11:55:07 +0100 Subject: [PATCH 4/4] check() --- R/chat.R | 20 ++++++++++++++------ R/messages.R | 29 +++++++++++++++++++++-------- man/as_messages.Rd | 27 +++++++++++++++++++++++++++ man/chat.Rd | 5 +++-- man/stream.Rd | 2 ++ 5 files changed, 67 insertions(+), 16 deletions(-) create mode 100644 man/as_messages.Rd diff --git a/R/chat.R b/R/chat.R index 028cfed..360e85e 100644 --- a/R/chat.R +++ b/R/chat.R @@ -1,10 +1,10 @@ #' Chat with the Mistral api #' #' @param messages Messages -#' @param text some text #' @param model which model to use. See [models()] for more information about which models are available #' @param dry_run if TRUE the request is not performed -#' @inheritParams httr2::req_perform +#' @inheritParams rlang::args_dots_empty +#' @inheritParams rlang::args_error_context #' #' @return A tibble with columns `role` and `content` with class `chat_tibble` or a request #' if this is a `dry_run` @@ -13,7 +13,9 @@ #' chat("Top 5 R packages", dry_run = TRUE) #' #' @export -chat <- function(messages, model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) { +chat <- function(messages, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { + check_dots_empty() + req <- req_chat(messages, model = model, error_call = error_call, dry_run = dry_run) if (is_true(dry_run)) { return(req) @@ -29,7 +31,9 @@ print.chat <- function(x, ...) { invisible(x) } -req_chat <- function(messages, model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) { +req_chat <- function(messages, model = "mistral-tiny", stream = FALSE, dry_run = FALSE, ..., error_call = caller_env()) { + check_dots_empty() + if (!is_true(dry_run)) { check_model(model, error_call = error_call) } @@ -50,8 +54,12 @@ req_chat <- function(messages, model = "mistral-tiny", stream = FALSE, dry_run = #' @export as.data.frame.chat_response <- function(x, ...) { - df_req <- map_dfr(resp$request$body$data$messages, as.data.frame) - df_resp <- as.data.frame(resp_body_json(resp)$choices[[1]]$message[c("role", "content")]) + req_messages <- x$request$body$data$messages + df_req <- map_dfr(req_messages, as.data.frame) + + df_resp <- as.data.frame( + resp_body_json(x)$choices[[1]]$message[c("role", "content")] + ) rbind(df_req, df_resp) } diff --git a/R/messages.R b/R/messages.R index 92660f0..d5501c1 100644 --- a/R/messages.R +++ b/R/messages.R @@ -1,24 +1,37 @@ +#' Convert object into a messages list +#' +#' @param messages object to convert to messages +#' @param ... ignored +#' @inheritParams rlang::args_error_context +#' +#' @examples +#' as_messages("hello") +#' as_messages(list("hello")) +#' as_messages(list(assistant = "hello", user = "hello")) +#' #' @export -as_messages <- function(messages, ...) { +as_messages <- function(messages, ..., error_call = current_env()) { UseMethod("as_messages") } #' @export -as_messages.character <- function(x, ..., error_call = current_env()) { - check_scalar_string(x, error_call = error_call) - check_unnamed_string(x, error_call = error_call) +as_messages.character <- function(messages, ..., error_call = current_env()) { + check_dots_empty(call = error_call) + check_scalar_string(messages, error_call = error_call) + check_unnamed_string(messages, error_call = error_call) list( - list(role = "user", content = x) + list(role = "user", content = messages) ) } #' @export -as_messages.list <- function(x, ..., error_call = caller_env()) { +as_messages.list <- function(messages, ..., error_call = caller_env()) { check_dots_empty() - bits <- map2(x, names2(x), as_msg, error_call = error_call) - out <- list_flatten(bits) + out <- list_flatten( + map2(messages, names2(messages), as_msg, error_call = error_call) + ) names(out) <- NULL out } diff --git a/man/as_messages.Rd b/man/as_messages.Rd new file mode 100644 index 0000000..7d01a9e --- /dev/null +++ b/man/as_messages.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/messages.R +\name{as_messages} +\alias{as_messages} +\title{Convert object into a messages list} +\usage{ +as_messages(messages, ..., error_call = current_env()) +} +\arguments{ +\item{messages}{object to convert to messages} + +\item{...}{ignored} + +\item{error_call}{The execution environment of a currently +running function, e.g. \code{caller_env()}. The function will be +mentioned in error messages as the source of the error. See the +\code{call} argument of \code{\link[rlang:abort]{abort()}} for more information.} +} +\description{ +Convert object into a messages list +} +\examples{ +as_messages("hello") +as_messages(list("hello")) +as_messages(list(assistant = "hello", user = "hello")) + +} diff --git a/man/chat.Rd b/man/chat.Rd index 891f8fa..f34a1d8 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -8,6 +8,7 @@ chat( messages, model = "mistral-tiny", dry_run = FALSE, + ..., error_call = current_env() ) } @@ -18,12 +19,12 @@ chat( \item{dry_run}{if TRUE the request is not performed} +\item{...}{These dots are for future extensions and must be empty.} + \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be mentioned in error messages as the source of the error. See the \code{call} argument of \code{\link[rlang:abort]{abort()}} for more information.} - -\item{text}{some text} } \value{ A tibble with columns \code{role} and \code{content} with class \code{chat_tibble} or a request diff --git a/man/stream.Rd b/man/stream.Rd index eeb3058..fee756c 100644 --- a/man/stream.Rd +++ b/man/stream.Rd @@ -19,6 +19,8 @@ stream( \item{dry_run}{if TRUE the request is not performed} +\item{...}{These dots are for future extensions and must be empty.} + \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be mentioned in error messages as the source of the error. See the