Skip to content

Commit

Permalink
feat: add support for o1 models in openai and azure (#368)
Browse files Browse the repository at this point in the history
* feat: add support for o1 models in openai and azure

Add support for OpenAI o1 models by using `max_completion_tokens` instead of `max_tokens`.

* **mods.go**
  - Add a check in the `startCompletionCmd` function to determine if the model is an o1 model and set the `max_completion_tokens` parameter accordingly.

* **config.go**
  - Add a new field `MaxCompletionTokens` to the `Config` struct to store the value for the `max_completion_tokens` parameter.

* **config_template.yml**
  - Add entries for `o1-preview` and `o1-mini` models under the `openai` section with `max-input-chars` set to 128000.
  - Add aliases for `o1-preview` and `o1-mini` models.
  - Add entries for `o1-preview` and `o1-mini` models under the `azure` section with `max-input-chars` set to 128000.
  - Add aliases for `o1-preview` and `o1-mini` models under the `azure` section.

* feat: add support for o1 models in openai and azure

Add support for OpenAI o1 models by using `max_completion_tokens` instead of `max_tokens`.

* **mods.go**
  - Add a check in the `startCompletionCmd` function to determine if the model is an o1 model and set the `max_completion_tokens` parameter accordingly.

* **config.go**
  - Add a new field `MaxCompletionTokens` to the `Config` struct to store the value for the `max_completion_tokens` parameter.

* **config_template.yml**
  - Add entries for `o1-preview` and `o1-mini` models under the `openai` section with `max-input-chars` set to 128000.
  - Add aliases for `o1-preview` and `o1-mini` models.
  - Add entries for `o1-preview` and `o1-mini` models under the `azure` section with `max-input-chars` set to 128000.
  - Add aliases for `o1-preview` and `o1-mini` models under the `azure` section.

* fix: adjust o1 model prefix check and remap system messages to user messages

* fix: max tokens

---------

Co-authored-by: Carlos Alexandro Becker <[email protected]>
  • Loading branch information
sheldonhull and caarlos0 authored Jan 14, 2025
1 parent b52d41e commit e5e4bdd
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 49 deletions.
95 changes: 48 additions & 47 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,53 +131,54 @@ func (ft *FormatText) UnmarshalYAML(unmarshal func(interface{}) error) error {

// Config holds the main configuration and is mapped to the YAML settings file.
type Config struct {
Model string `yaml:"default-model" env:"MODEL"`
Format bool `yaml:"format" env:"FORMAT"`
FormatText FormatText `yaml:"format-text"`
FormatAs string `yaml:"format-as" env:"FORMAT_AS"`
Raw bool `yaml:"raw" env:"RAW"`
Quiet bool `yaml:"quiet" env:"QUIET"`
MaxTokens int `yaml:"max-tokens" env:"MAX_TOKENS"`
MaxInputChars int `yaml:"max-input-chars" env:"MAX_INPUT_CHARS"`
Temperature float32 `yaml:"temp" env:"TEMP"`
Stop []string `yaml:"stop" env:"STOP"`
TopP float32 `yaml:"topp" env:"TOPP"`
TopK int `yaml:"topk" env:"TOPK"`
NoLimit bool `yaml:"no-limit" env:"NO_LIMIT"`
CachePath string `yaml:"cache-path" env:"CACHE_PATH"`
NoCache bool `yaml:"no-cache" env:"NO_CACHE"`
IncludePromptArgs bool `yaml:"include-prompt-args" env:"INCLUDE_PROMPT_ARGS"`
IncludePrompt int `yaml:"include-prompt" env:"INCLUDE_PROMPT"`
MaxRetries int `yaml:"max-retries" env:"MAX_RETRIES"`
WordWrap int `yaml:"word-wrap" env:"WORD_WRAP"`
Fanciness uint `yaml:"fanciness" env:"FANCINESS"`
StatusText string `yaml:"status-text" env:"STATUS_TEXT"`
HTTPProxy string `yaml:"http-proxy" env:"HTTP_PROXY"`
APIs APIs `yaml:"apis"`
System string `yaml:"system"`
Role string `yaml:"role" env:"ROLE"`
AskModel bool
API string
Models map[string]Model
Roles map[string][]string
ShowHelp bool
ResetSettings bool
Prefix string
Version bool
Settings bool
Dirs bool
Theme string
SettingsPath string
ContinueLast bool
Continue string
Title string
ShowLast bool
Show string
List bool
ListRoles bool
Delete string
DeleteOlderThan time.Duration
User string
Model string `yaml:"default-model" env:"MODEL"`
Format bool `yaml:"format" env:"FORMAT"`
FormatText FormatText `yaml:"format-text"`
FormatAs string `yaml:"format-as" env:"FORMAT_AS"`
Raw bool `yaml:"raw" env:"RAW"`
Quiet bool `yaml:"quiet" env:"QUIET"`
MaxTokens int `yaml:"max-tokens" env:"MAX_TOKENS"`
MaxCompletionTokens int `yaml:"max-completion-tokens" env:"MAX_COMPLETION_TOKENS"`
MaxInputChars int `yaml:"max-input-chars" env:"MAX_INPUT_CHARS"`
Temperature float32 `yaml:"temp" env:"TEMP"`
Stop []string `yaml:"stop" env:"STOP"`
TopP float32 `yaml:"topp" env:"TOPP"`
TopK int `yaml:"topk" env:"TOPK"`
NoLimit bool `yaml:"no-limit" env:"NO_LIMIT"`
CachePath string `yaml:"cache-path" env:"CACHE_PATH"`
NoCache bool `yaml:"no-cache" env:"NO_CACHE"`
IncludePromptArgs bool `yaml:"include-prompt-args" env:"INCLUDE_PROMPT_ARGS"`
IncludePrompt int `yaml:"include-prompt" env:"INCLUDE_PROMPT"`
MaxRetries int `yaml:"max-retries" env:"MAX_RETRIES"`
WordWrap int `yaml:"word-wrap" env:"WORD_WRAP"`
Fanciness uint `yaml:"fanciness" env:"FANCINESS"`
StatusText string `yaml:"status-text" env:"STATUS_TEXT"`
HTTPProxy string `yaml:"http-proxy" env:"HTTP_PROXY"`
APIs APIs `yaml:"apis"`
System string `yaml:"system"`
Role string `yaml:"role" env:"ROLE"`
AskModel bool
API string
Models map[string]Model
Roles map[string][]string
ShowHelp bool
ResetSettings bool
Prefix string
Version bool
Settings bool
Dirs bool
Theme string
SettingsPath string
ContinueLast bool
Continue string
Title string
ShowLast bool
Show string
List bool
ListRoles bool
Delete string
DeleteOlderThan time.Duration
User string

cacheReadFromID, cacheWriteToID, cacheWriteToTitle string
}
Expand Down
22 changes: 21 additions & 1 deletion config_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ theme: charm
max-input-chars: 12250
# {{ index .Help "max-tokens" }}
# max-tokens: 100
# {{ index .Help "max-completion-tokens" }}
max-completion-tokens: 100
# {{ index .Help "apis" }}
apis:
openai:
Expand Down Expand Up @@ -91,6 +93,12 @@ apis:
aliases: ["35"]
max-input-chars: 12250
fallback:
o1-preview:
aliases: ["o1-preview"]
max-input-chars: 128000
o1-mini:
aliases: ["o1-mini"]
max-input-chars: 128000
copilot:
base-url: https://api.githubcopilot.com
models:
Expand All @@ -112,6 +120,12 @@ apis:
claude-3.5-sonnet:
aliases: ["claude3.5-sonnet", "sonnet-3.5", "claude-3-5-sonnet"]
max-input-chars: 680000
o1-preview:
aliases: ["o1-preview"]
max-input-chars: 128000
o1-mini:
aliases: ["o1-mini"]
max-input-chars: 128000
anthropic:
base-url: https://api.anthropic.com/v1
api-key:
Expand Down Expand Up @@ -158,7 +172,7 @@ apis:
base-url: https://api.perplexity.ai
api-key:
api-key-env: PERPLEXITY_API_KEY
models: # https://docs.perplexity.ai/guides/model-cards
models: # https://docs.perplexity.ai/guides/model-cards
llama-3.1-sonar-small-128k-online:
aliases: ["llam31-small"]
max-input-chars: 127072
Expand Down Expand Up @@ -265,6 +279,12 @@ apis:
aliases: ["az35"]
max-input-chars: 12250
fallback:
o1-preview:
aliases: ["o1-preview"]
max-input-chars: 128000
o1-mini:
aliases: ["o1-mini"]
max-input-chars: 128000
runpod:
# https://docs.runpod.io/serverless/workers/vllm/openai-compatibility
base-url: https://api.runpod.ai/v2/${YOUR_ENDPOINT}/openai/v1
Expand Down
8 changes: 8 additions & 0 deletions mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd {
mod.MaxChars = cfg.MaxInputChars
}

// Check if the model is an o1 model and unset the max_tokens parameter
// accordingly, as it's unsupported by o1.
// We do set max_completion_tokens instead, which is supported.
// Release won't have a prefix with a dash, so just putting o1 for match.
if strings.HasPrefix(mod.Name, "o1") {
cfg.MaxTokens = 0
}

switch mod.API {
case "anthropic":
return m.createAnthropicStream(content, accfg, mod)
Expand Down
15 changes: 14 additions & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,22 @@ func (m *Mods) createOpenAIStream(content string, ccfg openai.ClientConfig, mod
return err
}

// Remap system messages to user messages due to beta limitations
messages := []openai.ChatCompletionMessage{}
for _, message := range m.messages {
if message.Role == openai.ChatMessageRoleSystem {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: message.Content,
})
} else {
messages = append(messages, message)
}
}

req := openai.ChatCompletionRequest{
Model: mod.Name,
Messages: m.messages,
Messages: messages,
Stream: true,
User: cfg.User,
}
Expand Down

0 comments on commit e5e4bdd

Please sign in to comment.