Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements support for OpenRouter #263

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,4 @@ trajectories/*
.vscode/**

# PyCharm
.idea/
.idea/
74 changes: 74 additions & 0 deletions sweagent/agent/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,78 @@ def query(self, history: list[dict[str, str]]) -> str:
self.update_stats(input_tokens, output_tokens)
return response

class OpenRouterModel(BaseModel):
MODELS = {
"openrouter/gpt-3.5-turbo-0125": {
"max_context": 16_385,
"cost_per_input_token": 5e-07,
"cost_per_output_token": 1.5e-06,
},
"openrouter/gpt-3.5-turbo-1106": {
"max_context": 16_385,
"cost_per_input_token": 1.5e-06,
"cost_per_output_token": 2e-06,
},
"openrouter/meta-llama/llama-3-70b-instruct": {
"max_context": 8_192,
"cost_per_input_token": 0.59e-06,
"cost_per_output_token": 0.79e-06,
},
# Add more models as needed
}

def __init__(self, args: ModelArguments, commands: list[Command]):
super().__init__(args, commands)

# Set the OpenRouter API Key
cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
self.api_base_url = 'https://openrouter.ai/api/v1'
self.api_key = cfg["OPENROUTER_API_KEY"]
self.client = OpenAI(api_key=self.api_key, base_url=self.api_base_url)

def history_to_messages(
self, history: list[dict[str, str]], is_demonstration: bool = False
) -> Union[str, list[dict[str, str]]]:
"""
Create `messages` by filtering out all keys except for role/content per `history` turn
"""
# Remove system messages if it is a demonstration
if is_demonstration:
history = [entry for entry in history if entry["role"] != "system"]
return '\n'.join([entry["content"] for entry in history])
# Return history components with just role, content fields
return [
{k: v for k, v in entry.items() if k in ["role", "content"]}
for entry in history
]

@retry(
wait=wait_random_exponential(min=1, max=15),
reraise=True,
stop=stop_after_attempt(3),
retry=retry_if_not_exception_type((CostLimitExceededError, RuntimeError)),
)
def query(self, history: list[dict[str, str]]) -> str:
"""
As OpenRouter shares compatibility with thhe OpenAI client, we can query the OpenAI API the same way...
Query the API with the given `history` and return the response.
"""
try:
# Perform OpenAI API call
response = self.client.chat.completions.create(
messages=self.history_to_messages(history),
model=self.api_model.removeprefix('openrouter/'),
temperature=self.args.temperature,
top_p=self.args.top_p,
)
except BadRequestError:
raise CostLimitExceededError(f"Context window ({self.model_metadata['max_context']} tokens) exceeded")
# Calculate + update costs, return response
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
self.update_stats(input_tokens, output_tokens)
return response.choices[0].message.content


class HumanModel(BaseModel):
MODELS = {"human": {}}
Expand Down Expand Up @@ -849,5 +921,7 @@ def get_model(args: ModelArguments, commands: Optional[list[Command]] = None):
return OllamaModel(args, commands)
elif args.model_name in TogetherModel.SHORTCUTS:
return TogetherModel(args, commands)
elif args.model_name.startswith("openrouter"):
return OpenRouterModel(args, commands)
else:
raise ValueError(f"Invalid model name: {args.model_name}")