Skip to content

Commit

Permalink
handle files
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Jan 18, 2025
1 parent 0a7bd74 commit f20bdf9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
29 changes: 14 additions & 15 deletions packages/cli/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,16 @@ export async function gradioConnect(
})

const handleFile = (v: unknown) => {
if (!v) return v
if (v instanceof File || v instanceof Blob || v instanceof Buffer)
return handle_file(v)
const f = v as WorkspaceFile
if (typeof f === "object" && f?.filename) {
const { filename, content, encoding } = f
if (!content) return handle_file((v as WorkspaceFile).filename)
else {
if (!content) {
const f = handle_file((v as WorkspaceFile).filename)
return f
} else {
const bytes = Buffer.from(content, encoding || "utf8")
return handle_file(bytes)
}
Expand All @@ -221,6 +224,7 @@ export async function gradioConnect(
}

const handleFiles = (payload: unknown[] | Record<string, unknown>) => {
if (!payload) return payload
if (Array.isArray(payload)) {
const result = []
for (let i = 0; i < payload.length; ++i)
Expand All @@ -245,23 +249,18 @@ export async function gradioConnect(
const config = app.config
const api: GradioApiInfo = await app.view_api()

const run = async (options?: GradioSubmitOptions): Promise<unknown[]> => {
for await (const status of submit(options)) {
if (status.type === "data") {
const data = status.data
return data
} else {
console.debug(`gradio ${space}: ${status.type}`)
}
}
return undefined
const predict = async (options?: GradioSubmitOptions): Promise<unknown> => {
const { endpoint = "/predict", payload = undefined } = options || {}
const payloadWithFiles = handleFiles(payload)
const res = await app.predict(endpoint, payloadWithFiles)
return res.data
}

return {
config,
submit,
handleFile,
run,
predict,
api,
}
}
Expand All @@ -276,7 +275,7 @@ export function defGradioTool(
args: Record<string, any>,
endpointInfo?: GradioEndpointInfo
) => Awaitable<unknown[] | Record<string, unknown>>,
renderer: (data: unknown[]) => Awaitable<ToolCallOutput>,
renderer: (data: unknown) => Awaitable<ToolCallOutput>,
options?: GradioClientOptions & Pick<GradioSubmitOptions, "endpoint">
) {
const { endpoint = "/predict", ...restOptions } = options || {}
Expand All @@ -296,7 +295,7 @@ export function defGradioTool(
endpoint,
payload,
}
const data = await app.run(req)
const data = await app.predict(req)
return await renderer(data)
})
}
36 changes: 27 additions & 9 deletions packages/sample/genaisrc/gradio.genai.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,30 @@ script({
})

import { defGradioTool, gradioConnect } from "genaiscript/runtime"
const put = env.vars.prompt || "A rabbit is wearing a space suit"

const captioner = await gradioConnect("hysts/ViTPose-transformers")
console.log(await captioner.api.named_endpoints)
const caption = await captioner.run({
endpoint: "/process_image",
payload: { image: env.files[0] },
})
console.log(caption)
try {
const prompter = await gradioConnect("microsoft/Promptist")
const predictions = await prompter.predict({
payload: [put],
})
const newPrompt = predictions[0]
console.log({ newPrompt })
} catch (e) {
console.error(e)
}

try {
const captioner = await gradioConnect("hysts/ViTPose-transformers")
const predictions = await captioner.predict({
endpoint: "/process_image",
payload: { image: env.files[0] },
})
const caption = predictions?.[0]
console.log({ caption })
} catch (e) {
console.error(e)
}

// see https://github.com/freddyaboulton/gradio-tools
defGradioTool(
Expand All @@ -24,9 +40,11 @@ defGradioTool(
//console.debug(info)
return [query]
},
(data) => data?.[0]
(data) => {
console.log(data)
return data?.[0]
}
)

const put = env.vars.prompt || "A rabbit is wearing a space suit"
def("PROMPT", put)
$`Improve the <PROMPT> for a Stable Diffusion model.`

0 comments on commit f20bdf9

Please sign in to comment.