Skip to content

Commit

Permalink
better escape functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mishig25 committed Oct 22, 2024
1 parent 5bc694b commit 77e3ce2
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 88 deletions.
63 changes: 63 additions & 0 deletions packages/tasks/src/snippets/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks";

export interface StringifyMessagesOptions {
sep: string;
start: string;
end: string;
attributeKeyQuotes?: boolean;
customContentEscaper?: (str: string) => string;
}

export function stringifyMessages(messages: ChatCompletionInputMessage[], opts: StringifyMessagesOptions): string {
const keyRole = opts.attributeKeyQuotes ? `"role"` : "role";
const keyContent = opts.attributeKeyQuotes ? `"role"` : "role";

const messagesStringified = messages.map(({ role, content }) => {
if (typeof content === "string") {
content = JSON.stringify(content).slice(1, -1);
if (opts.customContentEscaper) {
content = opts.customContentEscaper(content);
}
return `{ ${keyRole}: "${role}", ${keyContent}: "${content}" }`;
} else {
2;
content = content.map(({ image_url, text, type }) => ({
type,
image_url,
...(text ? { text: JSON.stringify(text).slice(1, -1) } : undefined),
}));
content = JSON.stringify(content).slice(1, -1);
if (opts.customContentEscaper) {
content = opts.customContentEscaper(content);
}
return `{ ${keyRole}: "${role}", ${keyContent}: ${content} }`;
}
});

return opts.start + messagesStringified.join(opts.sep) + opts.end;
}

type PartialGenerationParameters = Partial<Pick<GenerationParameters, "temperature" | "max_tokens" | "top_p">>;

export interface StringifyGenerationConfigOptions {
sep: string;
start: string;
end: string;
attributeValueConnector: string;
attributeKeyQuotes?: boolean;
}

export function stringifyGenerationConfig(
config: PartialGenerationParameters,
opts: StringifyGenerationConfigOptions
): string {
const quote = opts.attributeKeyQuotes ? `"` : "";

return (
opts.start +
Object.entries(config)
.map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${quote}${val}${quote}`)
.join(opts.sep) +
opts.end
);
}
43 changes: 16 additions & 27 deletions packages/tasks/src/snippets/curl.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import type { PipelineType } from "../pipelines.js";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type {
GenerationConfigFormatter,
GenerationMessagesFormatter,
InferenceSnippet,
ModelDataMinimal,
} from "./types.js";
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
content: `curl https://api-inference.huggingface.co/models/${model.id} \\
Expand All @@ -16,25 +12,6 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): Infe
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"`,
});

const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
start +
messages
.map(({ role }) => {
// escape single quotes since single quotes is used to define http post body inside curl requests
// TODO: handle the case below
// content = content?.replace(/'/g, "'\\''");
return `{ "role": "${role}", "content": "test msg" }`;
})
.join(sep) +
end;

const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end }) =>
start +
Object.entries(config)
.map(([key, val]) => `"${key}": ${val}`)
.join(sep) +
end;

export const snippetTextGeneration = (
model: ModelDataMinimal,
accessToken: string,
Expand Down Expand Up @@ -64,8 +41,20 @@ export const snippetTextGeneration = (
-H 'Content-Type: application/json' \\
--data '{
"model": "${model.id}",
"messages": ${formatGenerationMessages({ messages, sep: ",\n ", start: `[\n `, end: `\n]` })},
${formatGenerationConfig({ config, sep: ",\n ", start: "", end: "" })},
"messages": ${stringifyMessages(messages, {
sep: ",\n ",
start: `[\n `,
end: `\n]`,
attributeKeyQuotes: true,
customContentEscaper: (str) => str.replace(/'/g, "'\\''"),
})},
${stringifyGenerationConfig(config, {
sep: ",\n ",
start: "",
end: "",
attributeKeyQuotes: true,
attributeValueConnector: ": ",
})},
"stream": ${!!streaming}
}'`,
};
Expand Down
27 changes: 9 additions & 18 deletions packages/tasks/src/snippets/js.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import type { PipelineType } from "../pipelines.js";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type {
GenerationConfigFormatter,
GenerationMessagesFormatter,
InferenceSnippet,
ModelDataMinimal,
} from "./types.js";
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

export const snippetBasic = (model: ModelDataMinimal, accessToken: string): InferenceSnippet => ({
content: `async function query(data) {
Expand All @@ -30,16 +26,6 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
});`,
});

const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
start + messages.map(({ role, content }) => `{ role: "${role}", content: "${content}" }`).join(sep) + end;

const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end }) =>
start +
Object.entries(config)
.map(([key, val]) => `${key}: ${val}`)
.join(sep) +
end;

export const snippetTextGeneration = (
model: ModelDataMinimal,
accessToken: string,
Expand All @@ -57,14 +43,19 @@ export const snippetTextGeneration = (
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
{ role: "user", content: "What is the capital of France?" },
];
const messagesStr = formatGenerationMessages({ messages, sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });
const messagesStr = stringifyMessages(messages, { sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" });

const config = {
...(opts?.temperature ? { temperature: opts.temperature } : undefined),
max_tokens: opts?.max_tokens ?? 500,
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
};
const configStr = formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "" });
const configStr = stringifyGenerationConfig(config, {
sep: ",\n\t",
start: "",
end: "",
attributeValueConnector: ": ",
});

if (streaming) {
return [
Expand Down
33 changes: 15 additions & 18 deletions packages/tasks/src/snippets/python.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,8 @@
import type { PipelineType } from "../pipelines.js";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type {
GenerationConfigFormatter,
GenerationMessagesFormatter,
InferenceSnippet,
ModelDataMinimal,
} from "./types.js";

const formatGenerationMessages: GenerationMessagesFormatter = ({ messages, sep, start, end }) =>
start + messages.map(({ role, content }) => `{ "role": "${role}", "content": "${content}" }`).join(sep) + end;

const formatGenerationConfig: GenerationConfigFormatter = ({ config, sep, start, end, connector }) =>
start +
Object.entries(config)
.map(([key, val]) => `${key}${connector}${val}`)
.join(sep) +
end;
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

export const snippetConversational = (
model: ModelDataMinimal,
Expand All @@ -33,14 +19,25 @@ export const snippetConversational = (
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
{ role: "user", content: "What is the capital of France?" },
];
const messagesStr = formatGenerationMessages({ messages, sep: ",\n\t", start: `[\n\t`, end: `\n]` });
const messagesStr = stringifyMessages(messages, {
sep: ",\n\t",
start: `[\n\t`,
end: `\n]`,
attributeKeyQuotes: true,
});

const config = {
...(opts?.temperature ? { temperature: opts.temperature } : undefined),
max_tokens: opts?.max_tokens ?? 500,
...(opts?.top_p ? { top_p: opts.top_p } : undefined),
};
const configStr = formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "", connector: "=" });
const configStr = stringifyGenerationConfig(config, {
sep: ",\n\t",
start: "",
end: "",
attributeValueConnector: "=",
attributeKeyQuotes: true,
});

if (streaming) {
return [
Expand Down
25 changes: 0 additions & 25 deletions packages/tasks/src/snippets/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { ModelData } from "../model-data";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks";

/**
* Minimal model data required for snippets.
Expand All @@ -15,27 +14,3 @@ export interface InferenceSnippet {
content: string;
client?: string; // for instance: `client` could be `huggingface_hub` or `openai` client for Python snippets
}

interface GenerationSnippetDelimiter {
sep: string;
start: string;
end: string;
connector?: string;
}

type PartialGenerationParameters = Partial<Pick<GenerationParameters, "temperature" | "max_tokens" | "top_p">>;

export type GenerationMessagesFormatter = ({
messages,
sep,
start,
end,
}: GenerationSnippetDelimiter & { messages: ChatCompletionInputMessage[] }) => string;

export type GenerationConfigFormatter = ({
config,
sep,
start,
end,
connector,
}: GenerationSnippetDelimiter & { config: PartialGenerationParameters }) => string;

0 comments on commit 77e3ce2

Please sign in to comment.