Skip to content

Commit

Permalink
Remove extra_args arg from chat_perform (#314)
Browse files Browse the repository at this point in the history
And apply `provider@extra_args` to the request body in the same way in all providers.

Fixes #313
  • Loading branch information
hadley authored Feb 6, 2025
1 parent 766ec20 commit e18bd8e
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 57 deletions.
6 changes: 2 additions & 4 deletions R/httr2.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ chat_perform <- function(provider,
mode = c("value", "stream", "async-stream", "async-value"),
turns,
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {

mode <- arg_match(mode)
stream <- mode %in% c("stream", "async-stream")
Expand All @@ -16,8 +15,7 @@ chat_perform <- function(provider,
turns = turns,
tools = tools,
stream = stream,
type = type,
extra_args = extra_args
type = type
)

switch(mode,
Expand Down
12 changes: 5 additions & 7 deletions R/provider-azure.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ method(chat_request, ProviderAzure) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {

req <- request(provider@base_url)
req <- req_url_path_append(req, "/chat/completions")
Expand Down Expand Up @@ -179,7 +178,6 @@ method(chat_request, ProviderAzure) <- function(provider,

messages <- compact(unlist(as_json(provider, turns), recursive = FALSE))
tools <- as_json(provider, unname(tools))
extra_args <- utils::modifyList(provider@extra_args, extra_args)

if (!is.null(type)) {
response_format <- list(
Expand All @@ -194,17 +192,17 @@ method(chat_request, ProviderAzure) <- function(provider,
response_format <- NULL
}

data <- compact(list2(
body <- compact(list2(
messages = messages,
model = provider@model,
seed = provider@seed,
stream = stream,
stream_options = if (stream) list(include_usage = TRUE),
tools = tools,
response_format = response_format,
!!!extra_args
response_format = response_format
))
req <- req_body_json(req, data)
body <- modify_list(body, provider@extra_args)
req <- req_body_json(req, body)

req
}
Expand Down
6 changes: 2 additions & 4 deletions R/provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ method(chat_request, ProviderBedrock) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {

req <- request(paste0(
"https://bedrock-runtime.", provider@region, ".amazonaws.com"
Expand Down Expand Up @@ -151,8 +150,7 @@ method(chat_request, ProviderBedrock) <- function(provider,
system = system,
toolConfig = toolConfig
)
extra_args <- utils::modifyList(provider@extra_args, extra_args)
body <- modify_list(body, extra_args)
body <- modify_list(body, provider@extra_args)
req <- req_body_json(req, body)

req
Expand Down
6 changes: 2 additions & 4 deletions R/provider-claude.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ method(chat_request, ProviderClaude) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {

req <- request(provider@base_url)
# https://docs.anthropic.com/en/api/messages
Expand Down Expand Up @@ -125,7 +124,6 @@ method(chat_request, ProviderClaude) <- function(provider,
}
tools <- as_json(provider, unname(tools))

extra_args <- utils::modifyList(provider@extra_args, extra_args)
body <- compact(list2(
model = provider@model,
system = system,
Expand All @@ -134,8 +132,8 @@ method(chat_request, ProviderClaude) <- function(provider,
max_tokens = provider@max_tokens,
tools = tools,
tool_choice = tool_choice,
!!!extra_args
))
body <- modify_list(body, provider@extra_args)
req <- req_body_json(req, body)

req
Expand Down
9 changes: 4 additions & 5 deletions R/provider-cortex.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ method(chat_request, ProviderCortex) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {
if (length(tools) != 0) {
cli::cli_abort("Tools are not supported by Cortex.")
}
Expand Down Expand Up @@ -180,10 +179,10 @@ method(chat_request, ProviderCortex) <- function(provider,
# Cortex does not yet support multi-turn chats.
turns <- tail(turns, n = 1)
messages <- as_json(provider, turns)
extra_args <- utils::modifyList(provider@extra_args, extra_args)

data <- compact(list2(messages = messages, stream = stream, !!!extra_args))
req <- req_body_json(req, data)
body <- list(messages = messages, stream = stream)
body <- modify_list(body, provider@extra_args)
req <- req_body_json(req, body)

req
}
Expand Down
12 changes: 5 additions & 7 deletions R/provider-databricks.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ method(chat_request, ProviderDatabricks) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {
req <- request(provider@base_url)
# Note: this API endpoint is undocumented and seems to exist primarily for
# compatibility with the OpenAI Python SDK. The documented endpoint is
Expand All @@ -108,7 +107,6 @@ method(chat_request, ProviderDatabricks) <- function(provider,

messages <- compact(unlist(as_json(provider, turns), recursive = FALSE))
tools <- as_json(provider, unname(tools))
extra_args <- utils::modifyList(provider@extra_args, extra_args)

if (!is.null(type)) {
response_format <- list(
Expand All @@ -123,15 +121,15 @@ method(chat_request, ProviderDatabricks) <- function(provider,
response_format <- NULL
}

data <- compact(list2(
body <- compact(list(
messages = messages,
model = provider@model,
stream = stream,
tools = tools,
response_format = response_format,
!!!extra_args
response_format = response_format
))
req <- req_body_json(req, data)
body <- modify_list(body, provider@extra_args)
req <- req_body_json(req, body)

req
}
Expand Down
11 changes: 5 additions & 6 deletions R/provider-gemini.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ method(chat_request, ProviderGemini) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {


req <- request(provider@base_url)
Expand Down Expand Up @@ -107,15 +106,15 @@ method(chat_request, ProviderGemini) <- function(provider,
} else {
tools <- NULL
}
extra_args <- utils::modifyList(provider@extra_args, extra_args)

body <- compact(list2(
body <- compact(list(
contents = contents,
tools = tools,
systemInstruction = system,
generationConfig = generation_config,
!!!extra_args
generationConfig = generation_config
))
body <- modify_list(body, provider@extra_args)

req <- req_body_json(req, body)

req
Expand Down
12 changes: 5 additions & 7 deletions R/provider-openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ method(chat_request, ProviderOpenAI) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {

req <- request(provider@base_url)
req <- req_url_path_append(req, "/chat/completions")
Expand All @@ -123,7 +122,6 @@ method(chat_request, ProviderOpenAI) <- function(provider,

messages <- compact(unlist(as_json(provider, turns), recursive = FALSE))
tools <- as_json(provider, unname(tools))
extra_args <- utils::modifyList(provider@extra_args, extra_args)

if (!is.null(type)) {
response_format <- list(
Expand All @@ -138,17 +136,17 @@ method(chat_request, ProviderOpenAI) <- function(provider,
response_format <- NULL
}

data <- compact(list2(
body <- compact(list(
messages = messages,
model = provider@model,
seed = provider@seed,
stream = stream,
stream_options = if (stream) list(include_usage = TRUE),
tools = tools,
response_format = response_format,
!!!extra_args
response_format = response_format
))
req <- req_body_json(req, data)
body <- utils::modifyList(body, provider@extra_args)
req <- req_body_json(req, body)

req
}
Expand Down
6 changes: 2 additions & 4 deletions R/provider-openrouter.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,14 @@ method(chat_request, ProviderOpenRouter) <- function(
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()
type = NULL
) {
req <- chat_request(
super(provider, ProviderOpenAI),
stream = stream,
turns = turns,
tools = tools,
type = type,
extra_args = extra_args
type = type
)

# https://openrouter.ai/docs/api-keys
Expand Down
14 changes: 6 additions & 8 deletions R/provider-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ method(chat_request, ProviderSnowflake) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL,
extra_args = list()) {
type = NULL) {
if (length(tools) != 0) {
cli::cli_abort(
"Tool calling is not supported.",
Expand Down Expand Up @@ -114,15 +113,14 @@ method(chat_request, ProviderSnowflake) <- function(provider,
req <- req_error(req, body = function(resp) resp_body_json(resp)$message)

messages <- as_json(provider, turns)
extra_args <- utils::modifyList(provider@extra_args, extra_args)

data <- compact(list2(
body <- list(
messages = messages,
model = provider@model,
stream = stream,
!!!extra_args
))
req <- req_body_json(req, data)
stream = stream
)
body <- modify_list(body, provider@extra_args)
req <- req_body_json(req, body)

req
}
Expand Down
2 changes: 1 addition & 1 deletion R/provider.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Provider <- new_class(
# Create a request------------------------------------

chat_request <- new_generic("chat_request", "provider",
function(provider, stream = TRUE, turns = list(), tools = list(), type = NULL, extra_args = list()) {
function(provider, stream = TRUE, turns = list(), tools = list(), type = NULL) {
S7_dispatch()
}
)
Expand Down

0 comments on commit e18bd8e

Please sign in to comment.