Skip to content

Commit

Permalink
Merge pull request #23 from tadascience/chat_ellipsis
Browse files Browse the repository at this point in the history
`chat(messages = )`
  • Loading branch information
romainfrancois authored Mar 30, 2024
2 parents 61f9ada + 43b4f3f commit 1fab858
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 30 deletions.
12 changes: 11 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method(print,chat_tibble)
S3method(as.data.frame,chat_response)
S3method(as_messages,character)
S3method(as_messages,list)
S3method(as_msg,character)
S3method(as_tibble,chat_response)
S3method(print,chat)
S3method(print,chat_tbl)
export(as_messages)
export(chat)
export(models)
export(stream)
Expand All @@ -10,6 +17,9 @@ import(rlang)
import(stringr)
import(tibble)
importFrom(jsonlite,fromJSON)
importFrom(purrr,list_flatten)
importFrom(purrr,map2)
importFrom(purrr,map_chr)
importFrom(purrr,map_dfr)
importFrom(purrr,pluck)
importFrom(tibble,as_tibble)
58 changes: 38 additions & 20 deletions R/chat.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#' Chat with the Mistral api
#'
#' @param text some text
#' @param messages Messages
#' @param model which model to use. See [models()] for more information about which models are available
#' @param dry_run if TRUE the request is not performed
#' @param ... ignored
#' @inheritParams httr2::req_perform
#' @inheritParams rlang::args_dots_empty
#' @inheritParams rlang::args_error_context
#'
#' @return A tibble with columns `role` and `content` with class `chat_tibble` or a request
#' if this is a `dry_run`
Expand All @@ -13,49 +13,67 @@
#' chat("Top 5 R packages", dry_run = TRUE)
#'
#' @export
chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) {
req <- req_chat(text, model, error_call = error_call, dry_run = dry_run)
chat <- function(messages, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) {
check_dots_empty()

req <- req_chat(messages, model = model, error_call = error_call, dry_run = dry_run)
if (is_true(dry_run)) {
return(req)
}
resp <- req_mistral_perform(req, error_call = error_call)
resp_chat(resp, error_call = error_call)
class(resp) <- c("chat", class(resp))
resp
}

#' @export
print.chat <- function(x, ...) {
writeLines(resp_body_json(x)$choices[[1]]$message$content)
invisible(x)
}

req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) {
req_chat <- function(messages, model = "mistral-tiny", stream = FALSE, dry_run = FALSE, ..., error_call = caller_env()) {
check_dots_empty()

if (!is_true(dry_run)) {
check_model(model, error_call = error_call)
}

messages <- as_messages(messages)

request(mistral_base_url) |>
req_url_path_append("v1", "chat", "completions") |>
authenticate() |>
req_body_json(
list(
model = model,
messages = list(
list(
role = "user",
content = text
)
),
messages = messages,
stream = is_true(stream)
)
)
}

resp_chat <- function(response, error_call = current_env()) {
data <- resp_body_json(response)
#' @export
as.data.frame.chat_response <- function(x, ...) {
req_messages <- x$request$body$data$messages
df_req <- map_dfr(req_messages, as.data.frame)

df_resp <- as.data.frame(
resp_body_json(x)$choices[[1]]$message[c("role", "content")]
)

tib <- map_dfr(data$choices, \(choice) {
as_tibble(choice$message)
})
rbind(df_req, df_resp)
}

class(tib) <- c("chat_tibble", class(tib))
#' @importFrom tibble as_tibble
#' @export
as_tibble.chat_response <- function(x, ...) {
tib <- as_tibble(as.data.frame(x, ...))
class(tib) <- c("chat_tbl", class(x))
tib
}

#' @export
print.chat_tibble <- function(x, ...) {
print.chat_tbl <- function(x, ...) {
n <- nrow(x)

for (i in seq_len(n)) {
Expand Down
70 changes: 70 additions & 0 deletions R/messages.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#' Convert object into a messages list
#'
#' @param messages object to convert to messages
#' @param ... ignored
#' @inheritParams rlang::args_error_context
#'
#' @examples
#' as_messages("hello")
#' as_messages(list("hello"))
#' as_messages(list(assistant = "hello", user = "hello"))
#'
#' @export
as_messages <- function(messages, ..., error_call = current_env()) {
UseMethod("as_messages")
}

#' @export
as_messages.character <- function(messages, ..., error_call = current_env()) {
check_dots_empty(call = error_call)
check_scalar_string(messages, error_call = error_call)
check_unnamed_string(messages, error_call = error_call)

list(
list(role = "user", content = messages)
)
}

#' @export
as_messages.list <- function(messages, ..., error_call = caller_env()) {
check_dots_empty()

out <- list_flatten(
map2(messages, names2(messages), as_msg, error_call = error_call)
)
names(out) <- NULL
out
}

as_msg <- function(x, name, error_call = caller_env()) {
UseMethod("as_msg")
}

#' @export
as_msg.character <- function(x, name, error_call = caller_env()) {
check_scalar_string(x, error_call = error_call)
role <- check_role(name, error_call = error_call)

list(
list(role = role, content = x)
)
}

check_role <- function(name = "", error_call = caller_env()) {
if (identical(name, "")) {
name <- "user"
}
name
}

check_scalar_string <- function(x, error_call = caller_env()) {
if (length(x) != 1L) {
cli_abort("{.arg x} must be a single string, not length {.code {length(x)}}. ", call = error_call)
}
}

check_unnamed_string <- function(x, error_call = caller_env()) {
if (!is.null(names(x))) {
cli_abort("{.arg x} must be unnamed", call = error_call)
}
}
5 changes: 3 additions & 2 deletions R/stream.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#' @inheritParams chat
#'
#' @export
stream <- function(text, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) {
stream <- function(messages, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) {
check_model(model, error_call = error_call)

req <- req_chat(text, model, stream = TRUE, error_call = error_call, dry_run = dry_run)
messages <- as_messages(messages)
req <- req_chat(messages, model, stream = TRUE, error_call = error_call, dry_run = dry_run)
if (is_true(dry_run)) {
return(req)
}
Expand Down
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @import httr2
#' @import tibble
#' @import stringr
#' @importFrom purrr map_dfr map_chr pluck
#' @importFrom purrr map_dfr map_chr pluck map2 list_flatten
NULL

mistral_base_url <- "https://api.mistral.ai"
Expand Down
27 changes: 27 additions & 0 deletions man/as_messages.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/stream.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1fab858

Please sign in to comment.