From f20bdf99b45810ca7f2afff62aa03130bc3548d4 Mon Sep 17 00:00:00 2001 From: pelikhan Date: Sat, 18 Jan 2025 14:00:35 -0800 Subject: [PATCH] handle files --- packages/cli/src/runtime.ts | 29 +++++++++--------- packages/sample/genaisrc/gradio.genai.mjs | 36 +++++++++++++++++------ 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/packages/cli/src/runtime.ts b/packages/cli/src/runtime.ts index dbd638b843..e2e39c1228 100644 --- a/packages/cli/src/runtime.ts +++ b/packages/cli/src/runtime.ts @@ -206,13 +206,16 @@ export async function gradioConnect( }) const handleFile = (v: unknown) => { + if (!v) return v if (v instanceof File || v instanceof Blob || v instanceof Buffer) return handle_file(v) const f = v as WorkspaceFile if (typeof f === "object" && f?.filename) { const { filename, content, encoding } = f - if (!content) return handle_file((v as WorkspaceFile).filename) - else { + if (!content) { + const f = handle_file((v as WorkspaceFile).filename) + return f + } else { const bytes = Buffer.from(content, encoding || "utf8") return handle_file(bytes) } @@ -221,6 +224,7 @@ export async function gradioConnect( } const handleFiles = (payload: unknown[] | Record) => { + if (!payload) return payload if (Array.isArray(payload)) { const result = [] for (let i = 0; i < payload.length; ++i) @@ -245,23 +249,18 @@ export async function gradioConnect( const config = app.config const api: GradioApiInfo = await app.view_api() - const run = async (options?: GradioSubmitOptions): Promise => { - for await (const status of submit(options)) { - if (status.type === "data") { - const data = status.data - return data - } else { - console.debug(`gradio ${space}: ${status.type}`) - } - } - return undefined + const predict = async (options?: GradioSubmitOptions): Promise => { + const { endpoint = "/predict", payload = undefined } = options || {} + const payloadWithFiles = handleFiles(payload) + const res = await app.predict(endpoint, payloadWithFiles) + return res.data } return { config, submit, handleFile, - run, + predict, api, } } @@ -276,7 +275,7 @@ export function defGradioTool( args: Record, endpointInfo?: GradioEndpointInfo ) => Awaitable>, - renderer: (data: unknown[]) => Awaitable, + renderer: (data: unknown) => Awaitable, options?: GradioClientOptions & Pick ) { const { endpoint = "/predict", ...restOptions } = options || {} @@ -296,7 +295,7 @@ export function defGradioTool( endpoint, payload, } - const data = await app.run(req) + const data = await app.predict(req) return await renderer(data) }) } diff --git a/packages/sample/genaisrc/gradio.genai.mjs b/packages/sample/genaisrc/gradio.genai.mjs index 7a99b8ffd2..ca6b07fc7b 100644 --- a/packages/sample/genaisrc/gradio.genai.mjs +++ b/packages/sample/genaisrc/gradio.genai.mjs @@ -3,14 +3,30 @@ script({ }) import { defGradioTool, gradioConnect } from "genaiscript/runtime" +const put = env.vars.prompt || "A rabbit is wearing a space suit" -const captioner = await gradioConnect("hysts/ViTPose-transformers") -console.log(await captioner.api.named_endpoints) -const caption = await captioner.run({ - endpoint: "/process_image", - payload: { image: env.files[0] }, -}) -console.log(caption) +try { + const prompter = await gradioConnect("microsoft/Promptist") + const predictions = await prompter.predict({ + payload: [put], + }) + const newPrompt = predictions[0] + console.log({ newPrompt }) +} catch (e) { + console.error(e) +} + +try { + const captioner = await gradioConnect("hysts/ViTPose-transformers") + const predictions = await captioner.predict({ + endpoint: "/process_image", + payload: { image: env.files[0] }, + }) + const caption = predictions?.[0] + console.log({ caption }) +} catch (e) { + console.error(e) +} // see https://github.com/freddyaboulton/gradio-tools defGradioTool( @@ -24,9 +40,11 @@ defGradioTool( //console.debug(info) return [query] }, - (data) => data?.[0] + (data) => { + console.log(data) + return data?.[0] + } ) -const put = env.vars.prompt || "A rabbit is wearing a space suit" def("PROMPT", put) $`Improve the for a Stable Diffusion model.`