Skip to content

Commit

Permalink
support for anthropic models on vertex (#1122)
Browse files Browse the repository at this point in the history
* support for anthropic models on vertex

* Update CHANGELOG.md
  • Loading branch information
jjallaire authored Jan 26, 2025
1 parent 788787b commit d3f6cb4
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Add `@wraps` to functions wrapped by Inspect decorators to preserve type information.
- Hugging Face: Add support for stop sequences for HF models.
- Docker: More robust parsing of version strings (handle development versions).
- Vertex: Support for Anthropic models hosted on Vertex.
- Bugfix: Fix issue w/ approvals for samples with id==0.
- Bugfix: Use "plain" display when running eval_async() outside of eval().

Expand Down
11 changes: 11 additions & 0 deletions docs/models.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,17 @@ $ inspect eval eval.py -M "vertex_init_args={'project': 'my-project', location:

Vertex AI provides the same `safety_settings` outlined in the [Google] provider.

#### Anthropic on Vertex

To use Anthropic models on Vertex, you can use the standard `anthropic` model provider with a `vertex` suffix (e.g. `anthropic/vertex/claude-3-5-sonnet-v2@20241022`). You should also set two environment variables indicating your project ID and region. Here is a complete example:

```bash
export ANTHROPIC_VERTEX_PROJECT_ID=project-12345
export ANTHROPIC_VERTEX_REGION=us-east5
inspect eval ctf.py --model anthropic/vertex/claude-3-5-sonnet-v2@20241022
```


### Hugging Face {#sec-hugging-face-transformers}

The Hugging Face provider implements support for local models using the [transformers](https://pypi.org/project/transformers/) package. You can use any Hugging Face model by specifying it with the `hf/` prefix. For example:
Expand Down
39 changes: 35 additions & 4 deletions src/inspect_ai/model/_providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
APIConnectionError,
AsyncAnthropic,
AsyncAnthropicBedrock,
AsyncAnthropicVertex,
BadRequestError,
InternalServerError,
NotGiven,
RateLimitError,
)
from anthropic.types import (
Expand Down Expand Up @@ -65,14 +67,20 @@ def __init__(
api_key: str | None = None,
config: GenerateConfig = GenerateConfig(),
bedrock: bool = False,
vertex: bool = False,
**model_args: Any,
):
# extract any service prefix from model name
parts = model_name.split("/")
if len(parts) > 1:
service = parts[0]
bedrock = service == "bedrock"
self.service: str | None = parts[0]
model_name = "/".join(parts[1:])
elif bedrock:
self.service = "bedrock"
elif vertex:
self.service = "vertex"
else:
self.service = None

# call super
super().__init__(
Expand All @@ -84,7 +92,7 @@ def __init__(
)

# create client
if bedrock:
if self.is_bedrock():
base_url = model_base_url(
base_url, ["ANTHROPIC_BEDROCK_BASE_URL", "BEDROCK_ANTHROPIC_BASE_URL"]
)
Expand All @@ -95,14 +103,31 @@ def __init__(
if base_region is None:
aws_region = os.environ.get("AWS_DEFAULT_REGION", None)

self.client: AsyncAnthropic | AsyncAnthropicBedrock = AsyncAnthropicBedrock(
self.client: (
AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
) = AsyncAnthropicBedrock(
base_url=base_url,
max_retries=(
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
),
aws_region=aws_region,
**model_args,
)
elif self.is_vertex():
base_url = model_base_url(
base_url, ["ANTHROPIC_VERTEX_BASE_URL", "VERTEX_ANTHROPIC_BASE_URL"]
)
region = os.environ.get("ANTHROPIC_VERTEX_REGION", NotGiven())
project_id = os.environ.get("ANTHROPIC_VERTEX_PROJECT_ID", NotGiven())
self.client = AsyncAnthropicVertex(
region=region,
project_id=project_id,
base_url=base_url,
max_retries=(
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
),
**model_args,
)
else:
# resolve api_key
if not self.api_key:
Expand All @@ -119,6 +144,12 @@ def __init__(
**model_args,
)

def is_bedrock(self) -> bool:
return self.service == "bedrock"

def is_vertex(self) -> bool:
return self.service == "vertex"

async def generate(
self,
input: list[ChatMessage],
Expand Down

0 comments on commit d3f6cb4

Please sign in to comment.