Skip to content

Commit

Permalink
make other optional, test image classify
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Jan 17, 2025
1 parent 52deaaf commit e4a0848
Showing 4 changed files with 91 additions and 38 deletions.
13 changes: 10 additions & 3 deletions docs/src/content/docs/reference/scripts/classify.md
Original file line number Diff line number Diff line change
@@ -49,9 +49,16 @@ Each label id should be a single word that encodes into a single token. This all

### `other` label

A `other` label is automatically added to the list
of label to give an escape route for the LLM when
it is not able to classify the text.
A `other` label can be automatically added to the list
of label to give an escape route for the LLM when it is not able to classify the text.

```js
const res = await classify(
"...",
{ ... },
{ other: true }
)
```

## Model and other options

68 changes: 43 additions & 25 deletions packages/cli/src/runtime.ts
Original file line number Diff line number Diff line change
@@ -11,6 +11,14 @@ import { pipeline } from "@huggingface/transformers"
// symbols exported as is
export { delay, uniq, uniqBy, z, pipeline, chunk, groupBy }

export type ClassifyOptions = {
/**
* Inject a 'other' label
*/
other?: boolean
ctx?: ChatGenerationContext
} & Omit<PromptGeneratorOptions, "choices">

/**
* Classify prompt
*
@@ -23,10 +31,7 @@ export { delay, uniq, uniqBy, z, pipeline, chunk, groupBy }
export async function classify<L extends Record<string, string>>(
text: string | PromptGenerator,
labels: L,
options?: {
instructions?: string
ctx?: ChatGenerationContext
} & Omit<PromptGeneratorOptions, "choices">
options?: ClassifyOptions
): Promise<{
label: keyof typeof labels | "other"
entropy?: number
@@ -35,18 +40,24 @@ export async function classify<L extends Record<string, string>>(
answer: string
logprobs?: Record<keyof typeof labels | "other", Logprob>
}> {
const { instructions, ...rest } = options || {}
const entries = Object.entries(labels).map(([k, v]) => [
k.trim().toLowerCase(),
v,
])
if (!entries.length) throw Error("classify must have at least one label")
const { other, ...rest } = options || {}

const entries = Object.entries({
...labels,
...(other
? {
other: "This label is used when the text does not fit any of the available labels.",
}
: {}),
}).map(([k, v]) => [k.trim().toLowerCase(), v])

if (entries.length < 2)
throw Error("classify must have at least two label (including other)")

const choices = entries.map(([k]) => k)
const allChoices = uniq<keyof typeof labels | "other">([
...choices,
"other",
])
const allChoices = uniq<keyof typeof labels | "other">(choices)
const ctx = options?.ctx || env.generator

const res = await ctx.runPrompt(
async (_) => {
_.$`## Expert Classifier
@@ -58,26 +69,33 @@ For each label, you will find a short description. Use these descriptions to gui
_.$`## Labels
You must classify the data as one of the following labels.
${entries.map(([id, descr]) => `- Label '${id}': ${descr}`).join("\n")}
- Label 'other': This label is used when the text does not fit any of the available labels.
## Output
Provide a short justification for your choice
and output the label as your last word.
Provide a single sentence justification for your choice.
and output the label as a single word on the last line. Do not emit "Label".
`
_.fence(
`- Label 'yes': funny
- Label 'no': not funny
DATA:
Why did the chicken cross the road? Because moo.
Output:
It's a classic joke but the ending does not relate to the start of the joke.
no
`,
{ language: "example" }
)
if (typeof text === "string") _.def("DATA", text)
else await text(_)
if (options?.instructions) {
_.$`## Additional instructions
${instructions}
`
}
},
{
model: "classify",
choices: [...choices, "other"],
choices: choices,
label: `classify ${choices.join(", ")}`,
cache: "classify",
logprobs: true,
topLogprobs: Math.min(3, choices.length),
system: [
12 changes: 8 additions & 4 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
@@ -838,22 +838,26 @@ async function choicesToLogitBias(
disableFallback: true,
})) || {}
if (!encode) {
logVerbose(
`unabled to compute logit bias, no token encoder found for ${model}`
logWarn(
`unable to compute logit bias, no token encoder found for ${model}`
)
trace.warn(
`unabled to compute logit bias, no token encoder found for ${model}`
`unable to compute logit bias, no token encoder found for ${model}`
)
return undefined
}
const logit_bias: Record<number, number> = Object.fromEntries(
choices.map((c) => {
const { token, weight } = typeof c === "string" ? { token: c } : c
const encoded = typeof token === "number" ? [token] : encode(token)
if (encoded.length !== 1)
if (encoded.length !== 1) {
logWarn(
`choice ${c} tokenizes to ${encoded.join(", ")} (expected one token)`
)
trace.warn(
`choice ${c} tokenizes to ${encoded.join(", ")} (expected one token)`
)
}
return [encoded[0], isNaN(weight) ? CHOICE_LOGIT_BIAS : weight] as [
number,
number,
36 changes: 30 additions & 6 deletions packages/sample/genaisrc/classify.genai.mjs
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
import { classify } from "genaiscript/runtime"

const res = await classify("The app crashes when I try to upload a file.", {
bug: "a software defect",
feat: "a feature request",
qa: "an inquiry about how to use the software",
})
const qa = await classify(
"The app crashes when I try to upload a file.",
{
bug: "a software defect",
feat: "a feature request",
qa: "an inquiry about how to use the software",
},
{ other: true }
)

console.log(res)
console.log(qa)

const joke = await classify(
"Why did the chicken cross the roard? To fry in the sun.",
{
yes: "funny",
no: "not funny",
}
)

console.log(joke)

const robots = await classify(
(_) => _.defImages("src/robots.jpg"),
{
object: "Depicts objects, machines, robots, toys, ...",
animal: "Animals, pets, monsters",
},
{ other: true }
)
console.log(robots)

0 comments on commit e4a0848

Please sign in to comment.