Skip to content

Commit d1ddc8d

Browse files
committed
llm: validate user-defined model data and better typing/bugfix
1 parent 1534e6e commit d1ddc8d

File tree

10 files changed

+404
-18
lines changed

10 files changed

+404
-18
lines changed

src/packages/frontend/account/user-defined-llm.tsx

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
Flex,
55
Form,
66
Input,
7+
InputNumber,
78
List,
89
Modal,
910
Popconfirm,
@@ -33,6 +34,7 @@ import { LanguageModelVendorAvatar } from "@cocalc/frontend/components/language-
3334
import { webapp_client } from "@cocalc/frontend/webapp-client";
3435
import { OTHER_SETTINGS_USER_DEFINED_LLM as KEY } from "@cocalc/util/db-schema/defaults";
3536
import {
37+
FALLBACK_MAX_TOKENS,
3638
LLM_PROVIDER,
3739
SERVICES,
3840
UserDefinedLLM,
@@ -350,6 +352,26 @@ export function UserDefinedLLMComponent({ style, on_change }: Props) {
350352
>
351353
<Input />
352354
</Form.Item>
355+
<Form.Item
356+
label="Max Tokens"
357+
name="max_tokens"
358+
help={`Context window size in tokens. Leave empty to use default (${FALLBACK_MAX_TOKENS}). Valid range: 1000-2000000.`}
359+
rules={[
360+
{
361+
type: "number",
362+
min: 1000,
363+
max: 2000000,
364+
message: "Must be between 1000 and 2000000",
365+
},
366+
]}
367+
>
368+
<InputNumber
369+
min={1000}
370+
max={2000000}
371+
placeholder={`${FALLBACK_MAX_TOKENS} (default)`}
372+
style={{ width: "100%" }}
373+
/>
374+
</Form.Item>
353375
</Form>
354376
</Modal>
355377
);

src/packages/frontend/misc/llm.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ import { estimateTokenCount, sliceByTokens } from "tokenx";
99

1010
import type { History } from "@cocalc/frontend/client/types";
1111
import type { LanguageModel } from "@cocalc/util/db-schema/llm-utils";
12-
import { getMaxTokens } from "@cocalc/util/db-schema/llm-utils";
12+
import {
13+
getMaxTokens,
14+
isUserDefinedModel,
15+
} from "@cocalc/util/db-schema/llm-utils";
16+
import { getUserDefinedLLMByModel } from "@cocalc/frontend/frame-editors/llm/use-userdefined-llm";
1317

1418
import { timed } from "./timing";
1519

@@ -95,7 +99,12 @@ const truncateHistoryImpl = (
9599
if (maxTokens <= 0) {
96100
return [];
97101
}
98-
const modelMaxTokens = getMaxTokens(model);
102+
// Try to get user-defined config if this is a user model
103+
const userConfig = isUserDefinedModel(model)
104+
? getUserDefinedLLMByModel(model)
105+
: null;
106+
107+
const modelMaxTokens = getMaxTokens(model, userConfig ?? undefined);
99108
const maxLength = modelMaxTokens * APPROX_CHARACTERS_PER_TOKEN;
100109
for (let i = 0; i < history.length; i++) {
101110
// Performance: ensure all entries in history are reasonably short, so they don't
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,11 @@
11
require("@testing-library/jest-dom");
22
process.env.COCALC_TEST_MODE = true;
3+
4+
// Polyfill TextEncoder and TextDecoder for Jest/jsdom environment
5+
// These are needed by @msgpack/msgpack and other libraries
6+
const { TextEncoder, TextDecoder } = require("util");
7+
global.TextEncoder = TextEncoder;
8+
global.TextDecoder = TextDecoder;
9+
10+
// Define DEBUG global (normally provided by rspack in production)
11+
global.DEBUG = false;

src/packages/server/llm/evaluate-lc.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ import { ServerSettings } from "@cocalc/database/settings/server-settings";
1111
import {
1212
ANTHROPIC_VERSION,
1313
AnthropicModel,
14+
FALLBACK_MAX_TOKENS,
1415
fromCustomOpenAIModel,
1516
GOOGLE_MODEL_TO_ID,
1617
GoogleModel,
1718
isAnthropicModel,
1819
isCustomOpenAI,
1920
isGoogleModel,
21+
isGoogleThinkingModel,
2022
isMistralModel,
2123
isOpenAIModel,
2224
isXaiModel,
@@ -172,10 +174,9 @@ export const PROVIDER_CONFIGS = {
172174
return new ChatGoogleGenerativeAI({
173175
model: modelName,
174176
apiKey,
175-
maxOutputTokens: options.maxTokens,
176-
// Only enable thinking tokens for Gemini 2.5 models
177-
...(modelName === "gemini-2.5-flash" || modelName === "gemini-2.5-pro"
178-
? { maxReasoningTokens: 1024 }
177+
// Enable thinking tokens for Gemini 2.5+ models
178+
...(isGoogleThinkingModel(modelName)
179+
? { maxReasoningTokens: FALLBACK_MAX_TOKENS }
179180
: {}),
180181
streaming: options.stream != null,
181182
});

src/packages/server/llm/ollama.ts

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { Ollama } from "@langchain/ollama";
1+
import { Ollama } from "@langchain/ollama";
22
import {
33
ChatPromptTemplate,
44
MessagesPlaceholder,
@@ -22,27 +22,43 @@ interface OllamaOpts {
2222
model: string; // this must be ollama-[model]
2323
stream?: Stream;
2424
maxTokens?: number;
25+
endpoint?: string; // optional endpoint for user-defined models
2526
}
2627

2728
export async function evaluateOllama(
2829
opts: Readonly<OllamaOpts>,
2930
client?: Ollama,
3031
): Promise<ChatOutput> {
31-
if (client == null && !isOllamaLLM(opts.model)) {
32+
if (client == null && !isOllamaLLM(opts.model) && !opts.endpoint) {
3233
throw new Error(`model ${opts.model} not supported`);
3334
}
3435
const model = fromOllamaModel(opts.model);
35-
const { system, history, input, maxTokens, stream } = opts;
36+
const { system, history, input, maxTokens, stream, endpoint } = opts;
3637
log.debug("evaluateOllama", {
3738
input,
3839
history,
3940
system,
4041
model,
4142
stream: stream != null,
4243
maxTokens,
44+
endpoint,
4345
});
4446

45-
const ollama = client ?? (await getOllama(model));
47+
// Create Ollama client: use provided client, or create from endpoint, or get from server settings
48+
let ollama: Ollama;
49+
if (client != null) {
50+
ollama = client;
51+
} else if (endpoint != null) {
52+
// User-defined Ollama model with custom endpoint
53+
ollama = new Ollama({
54+
baseUrl: endpoint,
55+
model,
56+
keepAlive: "24h",
57+
});
58+
} else {
59+
// Platform Ollama model from server settings
60+
ollama = await getOllama(model);
61+
}
4662

4763
const historyMessagesKey = "history";
4864

src/packages/server/llm/test/mock2.test.ts

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,75 @@ describe("evaluateWithLangChain (LangChain mocked)", () => {
313313
});
314314
});
315315

316+
test("user-defined Ollama with custom max_tokens", async () => {
317+
const ollamaConfig = [
318+
{
319+
id: 1,
320+
service: "ollama",
321+
model: "llama3",
322+
display: "User Llama3",
323+
endpoint: "http://localhost:11434",
324+
apiKey: "",
325+
max_tokens: 32000,
326+
},
327+
];
328+
329+
mockCallback2.mockResolvedValueOnce({
330+
other_settings: {
331+
[OTHER_SETTINGS_USER_DEFINED_LLM]: JSON.stringify(ollamaConfig),
332+
},
333+
});
334+
335+
await evaluateUserDefinedLLM(
336+
{
337+
input: "hello",
338+
model: "user-ollama-llama3",
339+
},
340+
userAccountId,
341+
);
342+
343+
expect(mockOllama).toHaveBeenCalledWith({
344+
baseUrl: "http://localhost:11434",
345+
model: "llama3",
346+
keepAlive: "24h",
347+
});
348+
});
349+
350+
test("user-defined Google with custom max_tokens", async () => {
351+
const googleConfig = [
352+
{
353+
id: 1,
354+
service: "google",
355+
model: "gemini-2.5-flash",
356+
display: "User Gemini Flash",
357+
endpoint: "",
358+
apiKey: "user-google-key",
359+
max_tokens: 128000,
360+
},
361+
];
362+
363+
mockCallback2.mockResolvedValueOnce({
364+
other_settings: {
365+
[OTHER_SETTINGS_USER_DEFINED_LLM]: JSON.stringify(googleConfig),
366+
},
367+
});
368+
369+
await evaluateUserDefinedLLM(
370+
{
371+
input: "hi",
372+
model: "user-google-gemini-2.5-flash",
373+
},
374+
userAccountId,
375+
);
376+
377+
expect(mockChatGoogle).toHaveBeenCalledWith(
378+
expect.objectContaining({
379+
apiKey: "user-google-key",
380+
model: "gemini-2.5-flash",
381+
}),
382+
);
383+
});
384+
316385
test("ollama streams with configured model", async () => {
317386
streamChunks = ["hi", " there"];
318387
const stream = jest.fn();

src/packages/server/llm/test/models.test.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,4 +260,20 @@ test_llm("user")("User-defined LLMs", () => {
260260
},
261261
LLM_TIMEOUT,
262262
);
263+
264+
// Test user-defined model with custom max_tokens
265+
test_llm_case("google")(
266+
"user-defined model with custom max_tokens (requires COCALC_TEST_GOOGLE_GENAI_KEY)",
267+
async () => {
268+
await testUserDefinedLLM({
269+
service: "google",
270+
display: "Test Gemini Flash with custom max_tokens",
271+
endpoint: "",
272+
model: "gemini-2.5-flash",
273+
apiKey: process.env.COCALC_TEST_GOOGLE_GENAI_KEY!,
274+
max_tokens: 128000, // Custom large context window
275+
});
276+
},
277+
LLM_TIMEOUT,
278+
);
263279
});

src/packages/server/llm/user-defined.ts

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ import { evaluateWithLangChain } from "./evaluate-lc";
1717

1818
const log = getLogger("llm:userdefined");
1919

20+
const REDACTED_VALUE = "[redacted]";
21+
const SENSITIVE_KEYS = new Set([
22+
"apiKey",
23+
"openAIApiKey",
24+
"azureOpenAIApiKey",
25+
"api_key",
26+
]);
27+
2028
interface UserDefinedOpts {
2129
input: string; // new input that user types
2230
system?: string; // extra setup that we add for relevance and context
@@ -30,7 +38,7 @@ export async function evaluateUserDefinedLLM(
3038
opts: Readonly<UserDefinedOpts>,
3139
account_id?: string,
3240
) {
33-
log.debug(`evaluateUserDefinedLLM[${account_id}]`, opts);
41+
log.debug(`evaluateUserDefinedLLM[${account_id}]`, redactSensitive(opts));
3442

3543
const { user_defined_llm } = await getServerSettings();
3644
if (!user_defined_llm) {
@@ -48,7 +56,7 @@ export async function evaluateUserDefinedLLM(
4856
}
4957

5058
const conf = await getConfig(account_id, um.service, um.model);
51-
log.debug("conf", conf);
59+
log.debug("conf", redactSensitive(conf));
5260
if (conf == null) {
5361
throw new Error(`Unable to retrieve user defined model ${model}`);
5462
}
@@ -61,6 +69,8 @@ export async function evaluateUserDefinedLLM(
6169
return await evaluateOllama({
6270
...opts,
6371
model: toOllamaModel(conf.model),
72+
endpoint,
73+
maxTokens: conf.max_tokens,
6474
});
6575
}
6676
case "openai":
@@ -76,6 +86,7 @@ export async function evaluateUserDefinedLLM(
7686
apiKey,
7787
endpoint: endpoint || undefined, // don't pass along empty strings!
7888
service,
89+
maxTokens: conf.max_tokens, // Use max_tokens from config
7990
},
8091
"user",
8192
);
@@ -106,8 +117,58 @@ async function getConfig(
106117
}
107118
}
108119
} catch (err) {
109-
log.error("Failed to parse user defined llm", user_llm_json, err);
120+
log.error(
121+
"Failed to parse user defined llm",
122+
redactUserLLMJson(user_llm_json),
123+
err,
124+
);
110125
throw err;
111126
}
112127
return null;
113128
}
129+
130+
function redactSensitive(value: any): any {
131+
if (value == null) {
132+
return value;
133+
}
134+
if (typeof value === "function") {
135+
return value;
136+
}
137+
if (typeof value !== "object") {
138+
return value;
139+
}
140+
if (Array.isArray(value)) {
141+
return value.map((item) => redactSensitive(item));
142+
}
143+
if (value instanceof Date) {
144+
return value;
145+
}
146+
const output: Record<string, any> = {};
147+
for (const [key, val] of Object.entries(value)) {
148+
output[key] = SENSITIVE_KEYS.has(key)
149+
? REDACTED_VALUE
150+
: redactSensitive(val);
151+
}
152+
return output;
153+
}
154+
155+
function redactUserLLMJson(value: unknown): unknown {
156+
if (typeof value !== "string") {
157+
return value;
158+
}
159+
try {
160+
const parsed = JSON.parse(value);
161+
return JSON.stringify(redactSensitive(parsed));
162+
} catch (_err) {
163+
return redactSensitiveString(value);
164+
}
165+
}
166+
167+
function redactSensitiveString(value: string): string {
168+
let redacted = value;
169+
for (const key of SENSITIVE_KEYS) {
170+
const regex = new RegExp(`("${key}"\\s*:\\s*")([^"]*)(")`, "g");
171+
redacted = redacted.replace(regex, `$1${REDACTED_VALUE}$3`);
172+
}
173+
return redacted;
174+
}

0 commit comments

Comments
 (0)