Skip to content

Commit

Permalink
828 streamline change chat model request (#883)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarsh-scottlogic authored Apr 2, 2024
1 parent c00b4ee commit 4f31e75
Show file tree
Hide file tree
Showing 16 changed files with 175 additions and 84 deletions.
41 changes: 32 additions & 9 deletions backend/src/controller/modelController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { Response } from 'express';

import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest';
import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest';
import { MODEL_CONFIG_ID, modelConfigIds } from '@src/models/chat';
import {
MODEL_CONFIG_ID,
chatModelIds,
modelConfigIds,
} from '@src/models/chat';
import { ChatInfoMessage } from '@src/models/chatMessage';
import { LEVEL_NAMES } from '@src/models/level';
import { pushMessageToHistory } from '@src/utils/chat';
Expand All @@ -13,14 +17,32 @@ function handleSetModel(req: OpenAiSetModelRequest, res: Response) {
const { model } = req.body;

if (model === undefined) {
res.status(400).send();
} else {
const configuration =
req.body.configuration ?? req.session.chatModel.configuration;
req.session.chatModel = { id: model, configuration };
console.debug('GPT model set:', JSON.stringify(req.session.chatModel));
res.status(200).send();
sendErrorResponse(res, 400, 'Missing model');
return;
}

if (!chatModelIds.includes(model)) {
sendErrorResponse(res, 400, 'Invalid model');
return;
}

const configuration =
req.body.configuration ?? req.session.chatModel.configuration;
req.session.chatModel = { id: model, configuration };
console.debug('GPT model set:', JSON.stringify(req.session.chatModel));

const chatInfoMessage = {
infoMessage: `changed model to ${model}`,
chatMessageType: 'GENERIC_INFO',
} as ChatInfoMessage;
// for now, the chat model only changes in the sandbox level
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory =
pushMessageToHistory(
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory,
chatInfoMessage
);

res.send({ chatInfoMessage });
}

function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) {
Expand Down Expand Up @@ -59,13 +81,14 @@ function handleConfigureModel(req: OpenAiConfigureModelRequest, res: Response) {
infoMessage: `changed ${configId} to ${value}`,
chatMessageType: 'GENERIC_INFO',
} as ChatInfoMessage;
// for now, the chat model only changes in the sandbox level
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory =
pushMessageToHistory(
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory,
chatInfoMessage
);

res.status(200).send({ chatInfoMessage });
res.send({ chatInfoMessage });
}

export { handleSetModel, handleConfigureModel };
8 changes: 3 additions & 5 deletions backend/src/langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { OpenAI } from 'langchain/llms/openai';
import { PromptTemplate } from 'langchain/prompts';

import { getDocumentVectors } from './document';
import { CHAT_MODELS } from './models/chat';
import { CHAT_MODEL_ID } from './models/chat';
import { PromptEvaluationChainReply, QaChainReply } from './models/langchain';
import { LEVEL_NAMES } from './models/level';
import { getOpenAIKey, getValidOpenAIModels } from './openai';
Expand All @@ -30,10 +30,8 @@ function makePromptTemplate(
return PromptTemplate.fromTemplate(fullPrompt);
}

function getChatModel() {
return getValidOpenAIModels().includes(CHAT_MODELS.GPT_4)
? CHAT_MODELS.GPT_4
: CHAT_MODELS.GPT_3_5_TURBO;
function getChatModel(): CHAT_MODEL_ID {
return getValidOpenAIModels().includes('gpt-4') ? 'gpt-4' : 'gpt-3.5-turbo';
}

function initQAModel(level: LEVEL_NAMES, Prompt: string) {
Expand Down
9 changes: 6 additions & 3 deletions backend/src/models/api/OpenAiSetModelRequest.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import { Request } from 'express';

import { CHAT_MODELS, ChatModelConfigurations } from '@src/models/chat';
import { CHAT_MODEL_ID, ChatModelConfigurations } from '@src/models/chat';
import { ChatMessage } from '@src/models/chatMessage';

export type OpenAiSetModelRequest = Request<
never,
never,
{
model?: CHAT_MODELS;
chatInfoMessage: ChatMessage;
},
{
model?: CHAT_MODEL_ID;
configuration?: ChatModelConfigurations;
},
never
Expand Down
24 changes: 12 additions & 12 deletions backend/src/models/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@ import { ChatInfoMessage, ChatMessage } from './chatMessage';
import { DEFENCE_ID } from './defence';
import { EmailInfo } from './email';

enum CHAT_MODELS {
GPT_4_TURBO = 'gpt-4-1106-preview',
GPT_4 = 'gpt-4',
GPT_4_0613 = 'gpt-4-0613',
GPT_3_5_TURBO = 'gpt-3.5-turbo',
GPT_3_5_TURBO_0613 = 'gpt-3.5-turbo-0613',
GPT_3_5_TURBO_16K = 'gpt-3.5-turbo-16k',
GPT_3_5_TURBO_16K_0613 = 'gpt-3.5-turbo-16k-0613',
}
const chatModelIds = [
'gpt-4-1106-preview',
'gpt-4',
'gpt-4-0613',
'gpt-3.5-turbo',
] as const;

type CHAT_MODEL_ID = (typeof chatModelIds)[number];

type ChatModel = {
id: CHAT_MODELS;
id: CHAT_MODEL_ID;
configuration: ChatModelConfigurations;
};

Expand Down Expand Up @@ -106,7 +105,7 @@ interface LevelHandlerResponse {
}

const defaultChatModel: ChatModel = {
id: CHAT_MODELS.GPT_3_5_TURBO,
id: 'gpt-3.5-turbo',
configuration: {
temperature: 1,
topP: 1,
Expand All @@ -131,4 +130,5 @@ export type {
SingleDefenceReport,
MODEL_CONFIG_ID,
};
export { CHAT_MODELS, defaultChatModel, modelConfigIds };
export { defaultChatModel, modelConfigIds, chatModelIds };
export type { CHAT_MODEL_ID };
13 changes: 7 additions & 6 deletions backend/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import { getQAPromptFromConfig } from './defence';
import { sendEmail } from './email';
import { queryDocuments } from './langchain';
import {
CHAT_MODELS,
CHAT_MODEL_ID,
ChatGptReply,
ChatModel,
ChatResponse,
FunctionCallResponse,
ToolCallResponse,
chatModelIds,
} from './models/chat';
import { ChatMessage } from './models/chatMessage';
import { QaLlmDefence } from './models/defence';
Expand Down Expand Up @@ -84,10 +85,10 @@ const chatGptTools: ChatCompletionTool[] = [

// list of valid chat models for the api key
const validOpenAiModels = (() => {
let validModels: string[] = [];
let validModels: CHAT_MODEL_ID[] = [];
return {
get: () => validModels,
set: (models: string[]) => {
set: (models: CHAT_MODEL_ID[]) => {
validModels = models;
},
};
Expand Down Expand Up @@ -117,8 +118,8 @@ async function getValidModelsFromOpenAI() {

// get the model ids that are supported by our app. Non-chat models like Dall-e and whisper are not supported.
const validModels = models.data
.map((model) => model.id)
.filter((id) => Object.values(CHAT_MODELS).includes(id as CHAT_MODELS))
.map((model) => model.id as CHAT_MODEL_ID)
.filter((id) => chatModelIds.includes(id))
.sort();

validOpenAiModels.set(validModels);
Expand Down Expand Up @@ -283,7 +284,7 @@ async function chatGptChatCompletion(

function getChatCompletionsInContextWindow(
chatHistory: ChatMessage[],
gptModel: CHAT_MODELS
gptModel: CHAT_MODEL_ID
): ChatCompletionMessageParam[] {
const completions = chatHistory
.map((chatMessage) =>
Expand Down
17 changes: 8 additions & 9 deletions backend/src/utils/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ import {
} from 'openai/resources/chat/completions';
import { promptTokensEstimate, stringTokens } from 'openai-chat-tokens';

import { CHAT_MODELS } from '@src/models/chat';
import { CHAT_MODEL_ID } from '@src/models/chat';
import { chatGptTools } from '@src/openai';

// The size of each model's context window in number of tokens. https://platform.openai.com/docs/models
const chatModelMaxTokens = {
[CHAT_MODELS.GPT_4_TURBO]: 128000,
[CHAT_MODELS.GPT_4]: 8192,
[CHAT_MODELS.GPT_4_0613]: 8192,
[CHAT_MODELS.GPT_3_5_TURBO]: 4097,
[CHAT_MODELS.GPT_3_5_TURBO_0613]: 4097,
[CHAT_MODELS.GPT_3_5_TURBO_16K]: 16385,
[CHAT_MODELS.GPT_3_5_TURBO_16K_0613]: 16385,
const chatModelMaxTokens: {
[key in CHAT_MODEL_ID]: number;
} = {
'gpt-4': 8192,
'gpt-4-1106-preview': 128000,
'gpt-4-0613': 8192,
'gpt-3.5-turbo': 16385,
};

const TOKENS_PER_TOOL_CALL = 4;
Expand Down
4 changes: 2 additions & 2 deletions backend/test/integration/openai.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { expect, jest, test, describe } from '@jest/globals';

import { CHAT_MODELS, ChatModel } from '@src/models/chat';
import { ChatModel } from '@src/models/chat';
import { ChatMessage } from '@src/models/chatMessage';
import { chatGptSendMessage } from '@src/openai';

Expand Down Expand Up @@ -54,7 +54,7 @@ describe('OpenAI Integration Tests', () => {
},
];
const chatModel: ChatModel = {
id: CHAT_MODELS.GPT_4,
id: 'gpt-4',
configuration: {
temperature: 1,
topP: 1,
Expand Down
65 changes: 63 additions & 2 deletions backend/test/unit/controller/modelController.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { expect, jest, test, describe } from '@jest/globals';
import { Response } from 'express';

import { handleConfigureModel } from '@src/controller/modelController';
import {
handleConfigureModel,
handleSetModel,
} from '@src/controller/modelController';
import { OpenAiConfigureModelRequest } from '@src/models/api/OpenAiConfigureModelRequest';
import { OpenAiSetModelRequest } from '@src/models/api/OpenAiSetModelRequest';
import { modelConfigIds } from '@src/models/chat';
import { ChatMessage } from '@src/models/chatMessage';
import { LEVEL_NAMES, LevelState } from '@src/models/level';
Expand All @@ -14,6 +18,64 @@ function responseMock() {
} as unknown as Response;
}

describe('handleSetModel', () => {
test('WHEN passed sensible parameters THEN sets model AND adds info message to chat history AND responds with info message', () => {
const req = {
body: {
model: 'gpt-4',
},
session: {
chatModel: {
id: 'gpt-3.5-turbo',
},
levelState: [{}, {}, {}, { chatHistory: [] } as unknown as LevelState],
},
} as OpenAiSetModelRequest;
const res = responseMock();

handleSetModel(req, res);

expect(req.session.chatModel.id).toBe('gpt-4');

const expectedInfoMessage = {
infoMessage: 'changed model to gpt-4',
chatMessageType: 'GENERIC_INFO',
} as ChatMessage;
expect(
req.session.levelState[LEVEL_NAMES.SANDBOX].chatHistory.at(-1)
).toEqual(expectedInfoMessage);
expect(res.send).toHaveBeenCalledWith({
chatInfoMessage: expectedInfoMessage,
});
});

test('WHEN missing model THEN does not set model', () => {
const req = {
body: {},
} as OpenAiSetModelRequest;
const res = responseMock();

handleSetModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith('Missing model');
});

test('WHEN model is invalid THEN does not set model', () => {
const req = {
body: {
model: 'invalid model',
},
} as unknown as OpenAiSetModelRequest;
const res = responseMock();

handleSetModel(req, res);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith('Invalid model');
});
});

describe('handleConfigureModel', () => {
test('WHEN passed sensible parameters THEN configures model AND adds info message to chat history AND responds with info message', () => {
const req = {
Expand All @@ -37,7 +99,6 @@ describe('handleConfigureModel', () => {

handleConfigureModel(req, res);

expect(res.status).toHaveBeenCalledWith(200);
expect(req.session.chatModel.configuration.topP).toBe(0.5);

const expectedInfoMessage = {
Expand Down
7 changes: 1 addition & 6 deletions backend/test/unit/openai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@ describe('getValidModelsFromOpenAI', () => {
{ id: 'da-vinci-1' },
{ id: 'da-vinci-2' },
];
const expectedValidModels = [
'gpt-3.5-turbo',
'gpt-3.5-turbo-0613',
'gpt-4',
'gpt-4-0613',
];
const expectedValidModels = ['gpt-3.5-turbo', 'gpt-4', 'gpt-4-0613'];

mockListFn.mockResolvedValueOnce({
data: mockModelList,
Expand Down
7 changes: 2 additions & 5 deletions frontend/src/components/ControlPanel/ControlPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import DefenceBox from '@src/components/DefenceBox/DefenceBox';
import ModelBox from '@src/components/ModelBox/ModelBox';
import DetailElement from '@src/components/ThemedButtons/DetailElement';
import ThemedButton from '@src/components/ThemedButtons/ThemedButton';
import { ChatMessage, ChatModel } from '@src/models/chat';
import { CHAT_MODEL_ID, ChatMessage, ChatModel } from '@src/models/chat';
import {
DEFENCE_ID,
DefenceConfigItem,
Expand All @@ -24,13 +24,12 @@ function ControlPanel({
resetDefenceConfiguration,
setDefenceConfiguration,
openDocumentViewer,
addInfoMessage,
addChatMessage,
}: {
currentLevel: LEVEL_NAMES;
defences: Defence[];
chatModel?: ChatModel;
setChatModelId: (modelId: string) => void;
setChatModelId: (modelId: CHAT_MODEL_ID) => void;
chatModelOptions: string[];
toggleDefence: (defence: Defence) => void;
resetDefenceConfiguration: (
Expand All @@ -42,7 +41,6 @@ function ControlPanel({
config: DefenceConfigItem[]
) => Promise<boolean>;
openDocumentViewer: () => void;
addInfoMessage: (message: string) => void;
addChatMessage: (chatMessage: ChatMessage) => void;
}) {
const configurableDefences =
Expand Down Expand Up @@ -101,7 +99,6 @@ function ControlPanel({
chatModel={chatModel}
setChatModelId={setChatModelId}
chatModelOptions={chatModelOptions}
addInfoMessage={addInfoMessage}
addChatMessage={addChatMessage}
/>
)}
Expand Down
Loading

0 comments on commit 4f31e75

Please sign in to comment.