From e5e4bddef3285a98e8e8da2e332e6f0f3cc5e74b Mon Sep 17 00:00:00 2001 From: sheldonhull Date: Tue, 14 Jan 2025 08:02:50 -0600 Subject: [PATCH] feat: add support for o1 models in openai and azure (#368) * feat: add support for o1 models in openai and azure Add support for OpenAI o1 models by using `max_completion_tokens` instead of `max_tokens`. * **mods.go** - Add a check in the `startCompletionCmd` function to determine if the model is an o1 model and set the `max_completion_tokens` parameter accordingly. * **config.go** - Add a new field `MaxCompletionTokens` to the `Config` struct to store the value for the `max_completion_tokens` parameter. * **config_template.yml** - Add entries for `o1-preview` and `o1-mini` models under the `openai` section with `max-input-chars` set to 128000. - Add aliases for `o1-preview` and `o1-mini` models. - Add entries for `o1-preview` and `o1-mini` models under the `azure` section with `max-input-chars` set to 128000. - Add aliases for `o1-preview` and `o1-mini` models under the `azure` section. * feat: add support for o1 models in openai and azure Add support for OpenAI o1 models by using `max_completion_tokens` instead of `max_tokens`. * **mods.go** - Add a check in the `startCompletionCmd` function to determine if the model is an o1 model and set the `max_completion_tokens` parameter accordingly. * **config.go** - Add a new field `MaxCompletionTokens` to the `Config` struct to store the value for the `max_completion_tokens` parameter. * **config_template.yml** - Add entries for `o1-preview` and `o1-mini` models under the `openai` section with `max-input-chars` set to 128000. - Add aliases for `o1-preview` and `o1-mini` models. - Add entries for `o1-preview` and `o1-mini` models under the `azure` section with `max-input-chars` set to 128000. - Add aliases for `o1-preview` and `o1-mini` models under the `azure` section. * fix: adjust o1 model prefix check and remap system messages to user messages * fix: max tokens --------- Co-authored-by: Carlos Alexandro Becker --- config.go | 95 +++++++++++++++++++++++---------------------- config_template.yml | 22 ++++++++++- mods.go | 8 ++++ stream.go | 15 ++++++- 4 files changed, 91 insertions(+), 49 deletions(-) diff --git a/config.go b/config.go index b306aa63..64e612e4 100644 --- a/config.go +++ b/config.go @@ -131,53 +131,54 @@ func (ft *FormatText) UnmarshalYAML(unmarshal func(interface{}) error) error { // Config holds the main configuration and is mapped to the YAML settings file. type Config struct { - Model string `yaml:"default-model" env:"MODEL"` - Format bool `yaml:"format" env:"FORMAT"` - FormatText FormatText `yaml:"format-text"` - FormatAs string `yaml:"format-as" env:"FORMAT_AS"` - Raw bool `yaml:"raw" env:"RAW"` - Quiet bool `yaml:"quiet" env:"QUIET"` - MaxTokens int `yaml:"max-tokens" env:"MAX_TOKENS"` - MaxInputChars int `yaml:"max-input-chars" env:"MAX_INPUT_CHARS"` - Temperature float32 `yaml:"temp" env:"TEMP"` - Stop []string `yaml:"stop" env:"STOP"` - TopP float32 `yaml:"topp" env:"TOPP"` - TopK int `yaml:"topk" env:"TOPK"` - NoLimit bool `yaml:"no-limit" env:"NO_LIMIT"` - CachePath string `yaml:"cache-path" env:"CACHE_PATH"` - NoCache bool `yaml:"no-cache" env:"NO_CACHE"` - IncludePromptArgs bool `yaml:"include-prompt-args" env:"INCLUDE_PROMPT_ARGS"` - IncludePrompt int `yaml:"include-prompt" env:"INCLUDE_PROMPT"` - MaxRetries int `yaml:"max-retries" env:"MAX_RETRIES"` - WordWrap int `yaml:"word-wrap" env:"WORD_WRAP"` - Fanciness uint `yaml:"fanciness" env:"FANCINESS"` - StatusText string `yaml:"status-text" env:"STATUS_TEXT"` - HTTPProxy string `yaml:"http-proxy" env:"HTTP_PROXY"` - APIs APIs `yaml:"apis"` - System string `yaml:"system"` - Role string `yaml:"role" env:"ROLE"` - AskModel bool - API string - Models map[string]Model - Roles map[string][]string - ShowHelp bool - ResetSettings bool - Prefix string - Version bool - Settings bool - Dirs bool - Theme string - SettingsPath string - ContinueLast bool - Continue string - Title string - ShowLast bool - Show string - List bool - ListRoles bool - Delete string - DeleteOlderThan time.Duration - User string + Model string `yaml:"default-model" env:"MODEL"` + Format bool `yaml:"format" env:"FORMAT"` + FormatText FormatText `yaml:"format-text"` + FormatAs string `yaml:"format-as" env:"FORMAT_AS"` + Raw bool `yaml:"raw" env:"RAW"` + Quiet bool `yaml:"quiet" env:"QUIET"` + MaxTokens int `yaml:"max-tokens" env:"MAX_TOKENS"` + MaxCompletionTokens int `yaml:"max-completion-tokens" env:"MAX_COMPLETION_TOKENS"` + MaxInputChars int `yaml:"max-input-chars" env:"MAX_INPUT_CHARS"` + Temperature float32 `yaml:"temp" env:"TEMP"` + Stop []string `yaml:"stop" env:"STOP"` + TopP float32 `yaml:"topp" env:"TOPP"` + TopK int `yaml:"topk" env:"TOPK"` + NoLimit bool `yaml:"no-limit" env:"NO_LIMIT"` + CachePath string `yaml:"cache-path" env:"CACHE_PATH"` + NoCache bool `yaml:"no-cache" env:"NO_CACHE"` + IncludePromptArgs bool `yaml:"include-prompt-args" env:"INCLUDE_PROMPT_ARGS"` + IncludePrompt int `yaml:"include-prompt" env:"INCLUDE_PROMPT"` + MaxRetries int `yaml:"max-retries" env:"MAX_RETRIES"` + WordWrap int `yaml:"word-wrap" env:"WORD_WRAP"` + Fanciness uint `yaml:"fanciness" env:"FANCINESS"` + StatusText string `yaml:"status-text" env:"STATUS_TEXT"` + HTTPProxy string `yaml:"http-proxy" env:"HTTP_PROXY"` + APIs APIs `yaml:"apis"` + System string `yaml:"system"` + Role string `yaml:"role" env:"ROLE"` + AskModel bool + API string + Models map[string]Model + Roles map[string][]string + ShowHelp bool + ResetSettings bool + Prefix string + Version bool + Settings bool + Dirs bool + Theme string + SettingsPath string + ContinueLast bool + Continue string + Title string + ShowLast bool + Show string + List bool + ListRoles bool + Delete string + DeleteOlderThan time.Duration + User string cacheReadFromID, cacheWriteToID, cacheWriteToTitle string } diff --git a/config_template.yml b/config_template.yml index 5738dae8..59b9f111 100644 --- a/config_template.yml +++ b/config_template.yml @@ -47,6 +47,8 @@ theme: charm max-input-chars: 12250 # {{ index .Help "max-tokens" }} # max-tokens: 100 +# {{ index .Help "max-completion-tokens" }} +max-completion-tokens: 100 # {{ index .Help "apis" }} apis: openai: @@ -91,6 +93,12 @@ apis: aliases: ["35"] max-input-chars: 12250 fallback: + o1-preview: + aliases: ["o1-preview"] + max-input-chars: 128000 + o1-mini: + aliases: ["o1-mini"] + max-input-chars: 128000 copilot: base-url: https://api.githubcopilot.com models: @@ -112,6 +120,12 @@ apis: claude-3.5-sonnet: aliases: ["claude3.5-sonnet", "sonnet-3.5", "claude-3-5-sonnet"] max-input-chars: 680000 + o1-preview: + aliases: ["o1-preview"] + max-input-chars: 128000 + o1-mini: + aliases: ["o1-mini"] + max-input-chars: 128000 anthropic: base-url: https://api.anthropic.com/v1 api-key: @@ -158,7 +172,7 @@ apis: base-url: https://api.perplexity.ai api-key: api-key-env: PERPLEXITY_API_KEY - models: # https://docs.perplexity.ai/guides/model-cards + models: # https://docs.perplexity.ai/guides/model-cards llama-3.1-sonar-small-128k-online: aliases: ["llam31-small"] max-input-chars: 127072 @@ -265,6 +279,12 @@ apis: aliases: ["az35"] max-input-chars: 12250 fallback: + o1-preview: + aliases: ["o1-preview"] + max-input-chars: 128000 + o1-mini: + aliases: ["o1-mini"] + max-input-chars: 128000 runpod: # https://docs.runpod.io/serverless/workers/vllm/openai-compatibility base-url: https://api.runpod.ai/v2/${YOUR_ENDPOINT}/openai/v1 diff --git a/mods.go b/mods.go index 950277d9..f25eda8f 100644 --- a/mods.go +++ b/mods.go @@ -395,6 +395,14 @@ func (m *Mods) startCompletionCmd(content string) tea.Cmd { mod.MaxChars = cfg.MaxInputChars } + // Check if the model is an o1 model and unset the max_tokens parameter + // accordingly, as it's unsupported by o1. + // We do set max_completion_tokens instead, which is supported. + // Release won't have a prefix with a dash, so just putting o1 for match. + if strings.HasPrefix(mod.Name, "o1") { + cfg.MaxTokens = 0 + } + switch mod.API { case "anthropic": return m.createAnthropicStream(content, accfg, mod) diff --git a/stream.go b/stream.go index de713e97..8143462d 100644 --- a/stream.go +++ b/stream.go @@ -21,9 +21,22 @@ func (m *Mods) createOpenAIStream(content string, ccfg openai.ClientConfig, mod return err } + // Remap system messages to user messages due to beta limitations + messages := []openai.ChatCompletionMessage{} + for _, message := range m.messages { + if message.Role == openai.ChatMessageRoleSystem { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: message.Content, + }) + } else { + messages = append(messages, message) + } + } + req := openai.ChatCompletionRequest{ Model: mod.Name, - Messages: m.messages, + Messages: messages, Stream: true, User: cfg.User, }