Skip to content

Commit

Permalink
change model dropdown
Browse files Browse the repository at this point in the history
  • Loading branch information
WHALEEYE committed Oct 29, 2024
1 parent 293804e commit 98926f4
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,24 @@
import customtkinter as ctk

from crab import Experiment
from crab.agents.backend_models import OpenAIModel, ClaudeModel, GeminiModel
from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel
from crab.agents.policies import SingleAgentPolicy
from gui.utils import get_benchmark

warnings.filterwarnings("ignore")

AVAILABLE_MODELS = {
"gpt-4o": ("OpenAIModel", "gpt-4o"),
"gpt-4turbo": ("OpenAIModel", "gpt-4-turbo"),
"gemini": ("GeminiModel", "gemini-1.5-pro-latest"),
"claude": ("ClaudeModel", "claude-3-opus-20240229"),
"GPT-4o": ("OpenAIModel", "gpt-4o"),
"GPT-4 Turbo": ("OpenAIModel", "gpt-4-turbo"),
"Gemini": ("GeminiModel", "gemini-1.5-pro-latest"),
"Claude": ("ClaudeModel", "claude-3-opus-20240229"),
}


def get_model_instance(model_key: str):
if model_key not in AVAILABLE_MODELS:
raise ValueError(f"Model {model_key} not supported")

model_config = AVAILABLE_MODELS[model_key]
model_class_name = model_config[0]
model_name = model_config[1]
Expand All @@ -46,12 +47,13 @@ def get_model_instance(model_key: str):
elif model_class_name == "ClaudeModel":
return ClaudeModel(model=model_name, history_messages_len=2)


def assign_task():
task_description = input_entry.get()
input_entry.delete(0, "end")
display_message(task_description)

model = get_model_instance(selected_model.get())
model = get_model_instance(model_dropdown.get())
agent_policy = SingleAgentPolicy(model_backend=model)

task_id = str(uuid4())
Expand Down Expand Up @@ -80,7 +82,8 @@ def display_message(message, sender="user"):


if __name__ == "__main__":
# TODO: Handle JSON decode error from environment action endpoint and display model response in GUI
# TODO: Handle JSON decode error from environment action endpoint and
# display model response in GUI
log_dir = (Path(__file__).parent / "logs").resolve()

ctk.set_appearance_mode("System")
Expand All @@ -93,15 +96,14 @@ def display_message(message, sender="user"):
model_frame = ctk.CTkFrame(app)
model_frame.pack(pady=10, padx=10, fill="x")

model_label = ctk.CTkLabel(model_frame, text="Select Model:")
model_label = ctk.CTkLabel(model_frame, text="Model")
model_label.pack(side="left", padx=(0, 10))

selected_model = ctk.StringVar(value="gpt-4o")
model_dropdown = ctk.CTkOptionMenu(
model_frame,
values=list(AVAILABLE_MODELS.keys()),
variable=selected_model,
)
model_dropdown.set(next(iter(AVAILABLE_MODELS)))
model_dropdown.pack(side="left", fill="x", expand=True)

chat_display_frame = ctk.CTkFrame(app, width=380, height=380)
Expand Down

0 comments on commit 98926f4

Please sign in to comment.