Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add api args to Bedrock provider #295

Merged
merged 11 commits into from
Feb 6, 2025
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ellmer (development version)

* `chat_bedrock()` gains `api_args` argument (@billsanto, #295).

* New `content_pdf_file()` and `content_pdf_url()` allow you to upload PDFs to supported models. Models that currently support PDFs are Google Gemini and Claude Anthropic. With help from @walkerke and @andrie (#265).

* `Chat$get_model()` returns the model name (#299).
Expand Down
46 changes: 37 additions & 9 deletions R/provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,53 @@ NULL
#' Chat with an AWS bedrock model
#'
#' @description
#' [AWS Bedrock](https://aws.amazon.com/bedrock/) provides a number of chat
#' based models, including those Anthropic's
#' [Claude](https://aws.amazon.com/bedrock/claude/).
#' [AWS Bedrock](https://aws.amazon.com/bedrock/) provides a number of
#' language models, including those from Anthropic's
#' [Claude](https://aws.amazon.com/bedrock/claude/), using the Bedrock
#' [Converse API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html).
#'
#' ## Authentication
#'
#' Authenthication is handled through \{paws.common\}, so if authenthication
#' Authentication is handled through \{paws.common\}, so if authentication
#' does not work for you automatically, you'll need to follow the advice
#' at <https://www.paws-r-sdk.com/#credentials>. In particular, if your
#' org uses AWS SSO, you'll need to run `aws sso login` at the terminal.
#'
#' @param profile AWS profile to use.
#' @param model ellmer provides a default model, but you'll typically need to
#' you'll specify a model that you actually have access to.
#'
#' If you're using [cross-region inference](https://aws.amazon.com/blogs/machine-learning/getting-started-with-cross-region-inference-in-amazon-bedrock/),
#' you'll need to use the inference profile ID, e.g.
#' `model="us.anthropic.claude-3-5-sonnet-20240620-v1:0"`.
#' @param api_args Named list of arbitrary extra arguments appended to the body
#' of every chat API call. Some useful arguments include:
#'
#' ```R
#' api_args = list(
#' inferenceConfig = list(
#' maxTokens = 100,
#' temperature = 0.7,
#' topP = 0.9,
#' topK = 20
#' )
#' )
#' ```
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @family chatbots
#' @export
#' @examples
#' \dontrun{
#' # Basic usage
#' chat <- chat_bedrock()
#' chat$chat("Tell me three jokes about statisticians")
#' }
chat_bedrock <- function(system_prompt = NULL,
turns = NULL,
model = NULL,
profile = NULL,
api_args = list(),
echo = NULL) {

check_installed("paws.common", "AWS authentication")
Expand All @@ -47,7 +69,8 @@ chat_bedrock <- function(system_prompt = NULL,
model = model,
profile = profile,
region = credentials$region,
cache = cache
cache = cache,
extra_args = api_args
)

Chat$new(provider = provider, turns = turns, echo = echo)
Expand Down Expand Up @@ -123,11 +146,14 @@ method(chat_request, ProviderBedrock) <- function(provider,
}

# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
req <- req_body_json(req, list(
body <- list(
messages = messages,
system = system,
toolConfig = toolConfig
))
)
extra_args <- utils::modifyList(provider@extra_args, extra_args)
body <- modify_list(body, extra_args)
req <- req_body_json(req, body)

req
}
Expand All @@ -145,15 +171,17 @@ method(stream_parse, ProviderBedrock) <- function(provider, event) {

body <- event$body
body$event_type <- event$headers$`:event-type`
body$p <- NULL # padding?
body$p <- NULL # padding? Looks like: "p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJ",

body
}

method(stream_text, ProviderBedrock) <- function(provider, event) {
if (event$event_type == "contentBlockDelta") {
event$delta$text
}
}

method(stream_merge_chunks, ProviderBedrock) <- function(provider, result, chunk) {
i <- chunk$contentBlockIndex + 1

Expand Down Expand Up @@ -324,7 +352,7 @@ paws_credentials <- function(profile, cache = aws_creds_cache(profile),
creds <- locate_aws_credentials(profile),
error = function(cnd) {
if (is_testing()) {
testthat::skip("Failed to locate AWS credentails")
testthat::skip("Failed to locate AWS credentials")
}
cli::cli_abort("No IAM credentials found.", parent = cnd)
}
Expand Down
7 changes: 7 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,10 @@ has_connect_viewer_token <- function(...) {
}
connectcreds::has_viewer_token(...)
}

modify_list <- function(x, y) {
if (is.null(x)) return(y)
if (is.null(y)) return(x)

utils::modifyList(x, y)
}
33 changes: 26 additions & 7 deletions man/chat_bedrock.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 16 additions & 1 deletion tests/testthat/_snaps/provider-bedrock.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# handles errors

Code
chat$chat("What is 1 + 1?", echo = FALSE)
Condition
Error in `req_perform()`:
! HTTP 400 Bad Request.
* STRING_VALUE cannot be converted to Float
Code
chat$chat("What is 1 + 1?", echo = TRUE)
Condition
Error in `req_perform_connection()`:
! HTTP 400 Bad Request.
* STRING_VALUE cannot be converted to Float

# defaults are reported

Code
Expand All @@ -19,6 +34,6 @@
Code
. <- chat$chat("What's in this image?", image_remote)
Condition
Error:
Error in `method(as_json, list(ellmer::ProviderBedrock, ellmer::ContentImageRemote))`:
! Bedrock doesn't support remote images

32 changes: 26 additions & 6 deletions tests/testthat/test-provider-bedrock.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ test_that("can make simple streaming request", {
expect_match(paste0(unlist(resp), collapse = ""), "2")
})

test_that("can set api args", {
chat <- chat_bedrock(
api_args = list(inferenceConfig = list(maxTokens = 1)),
echo = FALSE
)
result <- chat$chat("Who are the reindeer?")
expect_true(nchar(result) < 10)
})

test_that("handles errors", {
chat <- chat_bedrock(
api_args = list(inferenceConfig = list(temperature = "hot")),
echo = FALSE
)
expect_snapshot(error = TRUE, {
chat$chat("What is 1 + 1?", echo = FALSE)
chat$chat("What is 1 + 1?", echo = TRUE)
})
})

# Common provider interface -----------------------------------------------

test_that("defaults are reported", {
Expand Down Expand Up @@ -61,27 +81,27 @@ test_that("AWS credential caching works as expected", {
locate_aws_credentials = function(profile) {
if (!is.null(profile) && profile == "test") {
list(
access_key = "key1",
access_key_id = "key1",
secret_key = "secret1",
expiration = Sys.time() + 3600
)
} else {
list(
access_key = "key2",
access_key_id = "key2",
secret_key = "secret2",
expiration = Sys.time() + 3600
)
}
}
)

creds1 <- paws_credentials(profile = "test")
creds2 <- paws_credentials(profile = NULL)
creds1 <- paws_credentials(profile = "test", reauth = TRUE)
creds2 <- paws_credentials(profile = NULL, reauth = TRUE)

# Verify different credentials were returned.
expect_false(identical(creds1, creds2))
expect_equal(creds1$access_key, "key1")
expect_equal(creds2$access_key, "key2")
expect_equal(creds1$access_key_id, "key1")
expect_equal(creds2$access_key_id, "key2")

# Verify cached credentials match original ones.
expect_identical(creds1, paws_credentials(profile = "test"))
Expand Down
Loading