diff --git a/R/provider-bedrock.R b/R/provider-bedrock.R index a40d893a..bee4660a 100644 --- a/R/provider-bedrock.R +++ b/R/provider-bedrock.R @@ -7,37 +7,117 @@ NULL #' Chat with an AWS bedrock model #' #' @description -#' [AWS Bedrock](https://aws.amazon.com/bedrock/) provides a number of chat -#' based models, including those Anthropic's -#' [Claude](https://aws.amazon.com/bedrock/claude/). +#' [AWS Bedrock](https://aws.amazon.com/bedrock/) provides a number of +#' language models, including those from Anthropic's +#' [Claude](https://aws.amazon.com/bedrock/claude/), using the Bedrock +#' [Converse API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html). +#' Although Ellmer provides a default model, you'll need to +#' specify a model that you actually have access to using the `model` argument. +#' If using [cross-region inference](https://aws.amazon.com/blogs/machine-learning/getting-started-with-cross-region-inference-in-amazon-bedrock/), +#' you'll need to use the inference profile ID for +#' any model argument, e.g., `model="us.anthropic.claude-3-5-sonnet-20240620-v1:0"`. +#' For examples of tool usage, asynchronous input, and other advanced features, +#' visit the [vignettes](https://posit-dev.github.io/ellmer/vignettes/) section +#' of the repo. #' #' ## Authentication -#' -#' Authenthication is handled through \{paws.common\}, so if authenthication +#' +#' Authentication is handled through \{paws.common\}, so if authentication #' does not work for you automatically, you'll need to follow the advice #' at . In particular, if your #' org uses AWS SSO, you'll need to run `aws sso login` at the terminal. #' #' @param profile AWS profile to use. +#' @param api_args Optional list of arguments passed to the Bedrock API. Use +#' this to customize model behavior. Valid arguments are: `temperature`, +#' `top_p`, `top_k`, `stop_sequences`, and `max_tokens`, though certain +#' models may not support every parameter. Check the AWS Bedrock model +#' documentation for specifics. Note that different model families +#' (Claude, Nova, Llama, etc.) may natively use different parameter +#' names for the same concept, e.g., max_tokens, max_new_tokens, or +#' max_gen_len. However, Ellmer uses the parameter names above +#' for consistency across all models, and the Converse API conveniently +#' handles the mapping from these to the model-specific native +#' parameter names. +#' @param verbose Logical. When TRUE, prints AWS credentials, +#' request and response headers/bodies for debugging. #' @inheritParams chat_openai #' @inherit chat_openai return #' @family chatbots #' @export #' @examples #' \dontrun{ +#' # Basic usage #' chat <- chat_bedrock() #' chat$chat("Tell me three jokes about statisticians") +#' +#' # Using custom API parameters +#' chat <- chat_bedrock( +#' model = "us.meta.llama3-2-3b-instruct-v1:0", +#' api_args = list( +#' temperature = 0.7, +#' max_tokens = 2000 +#' ) +#' ) +#' +#' # Enable verbose output for debugging requests and responses +#' chat <- chat_bedrock(verbose = TRUE) +#' +#' # Custom system prompt with API parameters +#' chat <- chat_bedrock( +#' system_prompt = "You are a helpful data science assistant", +#' api_args = list(temperature = 0.5) +#' ) +#' +#' # Use a non-default AWS profile in ~/.aws/credentials +#' chat <- chat_bedrock(profile = "my_profile_name") +#' +#' # Image interpretation when using a vision capable model +#' chat <- chat_bedrock( +#' model = "us.meta.llama3-2-11b-instruct-v1:0" +#' ) +#' chat$chat( +#' "What's in this image?", +#' content_image_file("path/to/image.jpg") +#' ) +#' +#' # The echo argument, "none", "text", and "all" determines whether +#' # input and/or output is echoed to the console. Also of note, "none" uses a +#' # non-streaming endpoint, whereas "text", "all", or TRUE uses a streaming endpoint. +#' # You can use verbose=TRUE to verify which endpoint is used. +#' chat <- chat_bedrock(verbose = TRUE) +#' chat$chat("What is 1 + 1?") # Streaming response +#' resp <- chat$chat("What is 1 + 1?", echo = "none") # Non-streaming response +#' resp # View response +#' +#' # Use echo = "none" in the client constructor to suppress streaming response +#' chat <- chat_bedrock(echo = "none") +#' resp <- chat$chat("What is 1 + 1?") # Non-streaming response +#' resp # View response +#' chat$chat("What is 1 + 1?", echo=TRUE) # Overrides client echo arg, uses streaming +#' +#' # $stream returns a generator, requiring concatentation of the streamed responses. +#' resp <- chat$stream("What is the capital of France?") # resp is a generator object +#' chunks <- coro::collect(resp) # returns list of partial text responses +#' complete_response <- paste(chunks, collapse="") # Full text response, no echo #' } chat_bedrock <- function(system_prompt = NULL, turns = NULL, model = NULL, profile = NULL, - echo = NULL) { + echo = NULL, + api_args = NULL, + verbose = FALSE) { check_installed("paws.common", "AWS authentication") cache <- aws_creds_cache(profile) credentials <- paws_credentials(profile, cache = cache) + # Validate api_args if present + if (!is.null(api_args)) { + validate_parameters(api_args, model) + } + turns <- normalize_turns(turns, system_prompt) model <- set_default(model, "anthropic.claude-3-5-sonnet-20240620-v1:0") echo <- check_echo(echo) @@ -47,7 +127,9 @@ chat_bedrock <- function(system_prompt = NULL, model = model, profile = profile, region = credentials$region, - cache = cache + cache = cache, + api_args = if (is.null(api_args)) list() else api_args, + verbose = verbose ) Chat$new(provider = provider, turns = turns, echo = echo) @@ -60,10 +142,41 @@ ProviderBedrock <- new_class( model = prop_string(), profile = prop_string(allow_null = TRUE), region = prop_string(), - cache = class_list + cache = class_list, + api_args = class_list, + verbose = class_logical ) ) +validate_parameters <- function(api_args, model) { + # Check for unsupported parameters in Llama models + if (grepl("llama", model, ignore.case = TRUE)) { + if (!is.null(api_args$top_k)) { + cli::cli_abort("top_k parameter is not supported for Llama models") + } + if (!is.null(api_args$stop_sequences)) { + cli::cli_abort("stop_sequences parameter is not supported for Llama models") + } + } + + # Validate temperature + if (!is.null(api_args$temperature) && (api_args$temperature < 0 || api_args$temperature > 1)) { + cli::cli_abort("temperature must be a numeric value between 0 and 1, inclusive") + } + + # Validate top_p + if (!is.null(api_args$top_p) && (api_args$top_p < 0 || api_args$top_p > 1)) { + cli::cli_abort("top_p must be a numeric value between 0 and 1, inclusive") + } + + # Validate top_k + if (!is.null(api_args$top_k)) { + if (!is.numeric(api_args$top_k) || api_args$top_k <= 0 || api_args$top_k %% 1 != 0) { + cli::cli_abort("top_k must be a positive integer") + } + } +} + method(chat_request, ProviderBedrock) <- function(provider, stream = TRUE, turns = list(), @@ -71,6 +184,11 @@ method(chat_request, ProviderBedrock) <- function(provider, type = NULL, extra_args = list()) { + # Validate parameters if api_args are present + if (length(provider@api_args) > 0) { + validate_parameters(provider@api_args, provider@model) + } + req <- request(paste0( "https://bedrock-runtime.", provider@region, ".amazonaws.com" )) @@ -88,6 +206,16 @@ method(chat_request, ProviderBedrock) <- function(provider, aws_session_token = creds$session_token ) + if (provider@verbose) { + cli::cli_h3("AWS Credentials") + cli::cli_alert_info(paste0("Profile: ", provider@profile, + "; Key: ", paste0(creds$access_key_id), + "; Secret: ", paste0(substr(creds$secret_access_key, 1, 2), + paste(rep("*", 4), collapse = "")), + "; Session: ", creds$session_token, + "; Region: ", provider@region)) + } + req <- req_error(req, body = function(resp) { body <- resp_body_json(resp) body$Message %||% body$message @@ -123,16 +251,55 @@ method(chat_request, ProviderBedrock) <- function(provider, } # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html - req <- req_body_json(req, list( + # Build request body + body <- list( messages = messages, system = system, toolConfig = toolConfig - )) + ) + + # Add inference configuration from api_args if present + if (length(provider@api_args) > 0) { + inference_config <- list() + + # Convert snake_case parameters to camelCase for Converse API + if (!is.null(provider@api_args$max_tokens)) { + inference_config$maxTokens <- provider@api_args$max_tokens + } + if (!is.null(provider@api_args$temperature)) { + inference_config$temperature <- provider@api_args$temperature + } + if (!is.null(provider@api_args$top_p)) { + inference_config$topP <- provider@api_args$top_p + } + if (!is.null(provider@api_args$top_k)) { + inference_config$topK <- provider@api_args$top_k + } + if (!is.null(provider@api_args$stop_sequences)) { + inference_config$stopSequences <- provider@api_args$stop_sequences + } + + # Only add inferenceConfig if we have parameters + if (length(inference_config) > 0) { + body$inferenceConfig <- inference_config + } + } + + req <- req_body_json(req, body) + + if (provider@verbose) { + cli::cli_h3("Request Body") + cat(jsonlite::toJSON(body, auto_unbox = TRUE, pretty = TRUE), "\n") + req <- httr2::req_verbose(req) + } - req + return(req) } method(chat_resp_stream, ProviderBedrock) <- function(provider, resp) { + if (provider@verbose) { + cli::cli_h3("Response Stream") + } resp_stream_aws(resp) } @@ -145,15 +312,22 @@ method(stream_parse, ProviderBedrock) <- function(provider, event) { body <- event$body body$event_type <- event$headers$`:event-type` - body$p <- NULL # padding? + body$p <- NULL # padding? Looks like: "p": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJ", + + if (provider@verbose) { + cli::cli_h3("Response Chunk") + cat(jsonlite::toJSON(body, auto_unbox = TRUE, pretty = TRUE), "\n") + } body } + method(stream_text, ProviderBedrock) <- function(provider, event) { if (event$event_type == "contentBlockDelta") { event$delta$text } } + method(stream_merge_chunks, ProviderBedrock) <- function(provider, result, chunk) { i <- chunk$contentBlockIndex + 1 @@ -200,6 +374,12 @@ method(stream_merge_chunks, ProviderBedrock) <- function(provider, result, chunk } method(value_turn, ProviderBedrock) <- function(provider, result, has_type = FALSE) { + # Print response if verbose mode is enabled + if (provider@verbose) { + cli::cli_h3("Response Body") + cat(jsonlite::toJSON(result, auto_unbox = TRUE, pretty = TRUE), "\n") + } + contents <- lapply(result$output$message$content, function(content) { if (has_name(content, "text")) { ContentText(content$text) @@ -310,7 +490,7 @@ paws_credentials <- function(profile, cache = aws_creds_cache(profile), creds <- locate_aws_credentials(profile), error = function(cnd) { if (is_testing()) { - testthat::skip("Failed to locate AWS credentails") + testthat::skip("Failed to locate AWS credentials") } cli::cli_abort("No IAM credentials found.", parent = cnd) } diff --git a/man/chat_bedrock.Rd b/man/chat_bedrock.Rd index 0ea1f062..45496451 100644 --- a/man/chat_bedrock.Rd +++ b/man/chat_bedrock.Rd @@ -9,7 +9,9 @@ chat_bedrock( turns = NULL, model = NULL, profile = NULL, - echo = NULL + echo = NULL, + api_args = NULL, + verbose = FALSE ) } \arguments{ @@ -34,17 +36,41 @@ the console). } Note this only affects the \code{chat()} method.} + +\item{api_args}{Optional list of arguments passed to the Bedrock API. Use +this to customize model behavior. Valid arguments are: \code{temperature}, +\code{top_p}, \code{top_k}, \code{stop_sequences}, and \code{max_tokens}, though certain +models may not support every parameter. Check the AWS Bedrock model +documentation for specifics. Note that different model families +(Claude, Nova, Llama, etc.) may natively use different parameter +names for the same concept, e.g., max_tokens, max_new_tokens, or +max_gen_len. However, Ellmer uses the parameter names above +for consistency across all models, and the Converse API conveniently +handles the mapping from these to the model-specific native +parameter names.} + +\item{verbose}{Logical. When TRUE, prints AWS credentials, +request and response headers/bodies for debugging.} } \value{ A \link{Chat} object. } \description{ -\href{https://aws.amazon.com/bedrock/}{AWS Bedrock} provides a number of chat -based models, including those Anthropic's -\href{https://aws.amazon.com/bedrock/claude/}{Claude}. +\href{https://aws.amazon.com/bedrock/}{AWS Bedrock} provides a number of +language models, including those from Anthropic's +\href{https://aws.amazon.com/bedrock/claude/}{Claude}, using the Bedrock +\href{https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html}{Converse API}. +Although Ellmer provides a default model, you'll need to +specify a model that you actually have access to using the \code{model} argument. +If using \href{https://aws.amazon.com/blogs/machine-learning/getting-started-with-cross-region-inference-in-amazon-bedrock/}{cross-region inference}, +you'll need to use the inference profile ID for +any model argument, e.g., \code{model="us.anthropic.claude-3-5-sonnet-20240620-v1:0"}. +For examples of tool usage, asynchronous input, and other advanced features, +visit the \href{https://posit-dev.github.io/ellmer/vignettes/}{vignettes} section +of the repo. \subsection{Authentication}{ -Authenthication is handled through \{paws.common\}, so if authenthication +Authentication is handled through \{paws.common\}, so if authentication does not work for you automatically, you'll need to follow the advice at \url{https://www.paws-r-sdk.com/#credentials}. In particular, if your org uses AWS SSO, you'll need to run \verb{aws sso login} at the terminal. @@ -52,8 +78,59 @@ org uses AWS SSO, you'll need to run \verb{aws sso login} at the terminal. } \examples{ \dontrun{ +# Basic usage chat <- chat_bedrock() chat$chat("Tell me three jokes about statisticians") + +# Using custom API parameters +chat <- chat_bedrock( + model = "us.meta.llama3-2-3b-instruct-v1:0", + api_args = list( + temperature = 0.7, + max_tokens = 2000 + ) +) + +# Enable verbose output for debugging requests and responses +chat <- chat_bedrock(verbose = TRUE) + +# Custom system prompt with API parameters +chat <- chat_bedrock( + system_prompt = "You are a helpful data science assistant", + api_args = list(temperature = 0.5) +) + +# Use a non-default AWS profile in ~/.aws/credentials +chat <- chat_bedrock(profile = "my_profile_name") + +# Image interpretation when using a vision capable model +chat <- chat_bedrock( + model = "us.meta.llama3-2-11b-instruct-v1:0" +) +chat$chat( + "What's in this image?", + content_image_file("path/to/image.jpg") +) + +# The echo argument, "none", "text", and "all" determines whether +# input and/or output is echoed to the console. Also of note, "none" uses a +# non-streaming endpoint, whereas "text", "all", or TRUE uses a streaming endpoint. +# You can use verbose=TRUE to verify which endpoint is used. +chat <- chat_bedrock(verbose = TRUE) +chat$chat("What is 1 + 1?") # Streaming response +resp <- chat$chat("What is 1 + 1?", echo = "none") # Non-streaming response +resp # View response + +# Use echo = "none" in the client constructor to suppress streaming response +chat <- chat_bedrock(echo = "none") +resp <- chat$chat("What is 1 + 1?") # Non-streaming response +resp # View response +chat$chat("What is 1 + 1?", echo=TRUE) # Overrides client echo arg, uses streaming + +# $stream returns a generator, requiring concatentation of the streamed responses. +resp <- chat$stream("What is the capital of France?") # resp is a generator object +chunks <- coro::collect(resp) # returns list of partial text responses +complete_response <- paste(chunks, collapse="") # Full text response, no echo } } \seealso{