From 015dea75ed02067b2167b2802ddb1db3078c1c65 Mon Sep 17 00:00:00 2001 From: Rob Gordon Date: Sun, 21 Jul 2024 22:43:36 -0400 Subject: [PATCH] Create new AI Toolbar (#690) * Add new AI Toolbar and endpoints * Pass model into processResult; update ai packages; repairText on result * Add translations * Track session activity on AI creation * Speed up writer text * Fix chart input E2E --- api/package.json | 4 +- api/prompt/_edit.ts | 42 +++++ api/prompt/_shared.ts | 103 ++++++++++ api/prompt/convert.ts | 102 ++-------- api/prompt/edit.ts | 78 ++++---- api/prompt/prompt.ts | 29 +++ app/e2e/pro.spec.ts | 2 +- app/package.json | 3 +- app/src/components/AiToolbar.tsx | 2 +- app/src/components/AiToolbar2.tsx | 147 +++++++++++++++ app/src/components/ConvertToFlowchart.tsx | 9 +- app/src/components/LoadFileButton.tsx | 1 - app/src/components/TextEditor.tsx | 42 ++++- app/src/lib/constants.ts | 1 + ...tizeOnPaste.test.ts => repairText.test.ts} | 20 +- .../lib/{sanitizeOnPaste.ts => repairText.ts} | 10 +- .../lib/{convertToFlowchart.ts => runAi.ts} | 103 +++++----- app/src/lib/usePromptStore.ts | 129 ++++++++++++- app/src/lib/writeEditorText.ts | 2 +- app/src/locales/de/messages.js | 2 +- app/src/locales/de/messages.po | 78 +++++--- app/src/locales/en/messages.js | 2 +- app/src/locales/en/messages.po | 74 +++++--- app/src/locales/es/messages.js | 2 +- app/src/locales/es/messages.po | 76 +++++--- app/src/locales/fr/messages.js | 2 +- app/src/locales/fr/messages.po | 78 +++++--- app/src/locales/hi/messages.js | 2 +- app/src/locales/hi/messages.po | 76 +++++--- app/src/locales/ko/messages.js | 2 +- app/src/locales/ko/messages.po | 76 +++++--- app/src/locales/pt-br/messages.js | 2 +- app/src/locales/pt-br/messages.po | 76 +++++--- app/src/locales/zh/messages.js | 2 +- app/src/locales/zh/messages.po | 76 +++++--- app/src/pages/EditHosted.tsx | 9 +- app/src/pages/New.tsx | 12 +- app/src/pages/Sandbox.tsx | 10 +- app/src/pages/createExamples.tsx | 12 ++ app/src/ui/Shared.tsx | 15 +- pnpm-lock.yaml | 178 ++++++++---------- 41 files changed, 1135 insertions(+), 556 deletions(-) create mode 100644 api/prompt/_edit.ts create mode 100644 api/prompt/_shared.ts create mode 100644 api/prompt/prompt.ts create mode 100644 app/src/components/AiToolbar2.tsx rename app/src/lib/{sanitizeOnPaste.test.ts => repairText.test.ts} (53%) rename app/src/lib/{sanitizeOnPaste.ts => repairText.ts} (75%) rename app/src/lib/{convertToFlowchart.ts => runAi.ts} (52%) create mode 100644 app/src/pages/createExamples.tsx diff --git a/api/package.json b/api/package.json index 97d6cc2de..05c1875d0 100644 --- a/api/package.json +++ b/api/package.json @@ -11,14 +11,14 @@ "author": "", "license": "ISC", "dependencies": { - "@ai-sdk/openai": "^0.0.9", + "@ai-sdk/openai": "^0.0.37", "@notionhq/client": "^0.4.13", "@octokit/core": "^4.2.0", "@sendgrid/mail": "^7.4.6", "@supabase/supabase-js": "^2.31.0", "@upstash/ratelimit": "^1.1.3", "@vercel/kv": "^1.0.1", - "ai": "^3.2.19", + "ai": "^3.2.32", "ajv": "^8.12.0", "axios": "^0.27.2", "csv-parse": "^5.3.6", diff --git a/api/prompt/_edit.ts b/api/prompt/_edit.ts new file mode 100644 index 000000000..201263606 --- /dev/null +++ b/api/prompt/_edit.ts @@ -0,0 +1,42 @@ +import { VercelApiHandler } from "@vercel/node"; +import { llmMany } from "../_lib/_llm"; +import { z } from "zod"; + +const nodeSchema = z.object({ + // id: z.string(), + // classes: z.string(), + label: z.string(), +}); + +const edgeSchema = z.object({ + from: z.string(), + to: z.string(), + label: z.string().optional().default(""), +}); + +const graphSchema = z.object({ + nodes: z.array(nodeSchema), + edges: z.array(edgeSchema), +}); + +const handler: VercelApiHandler = async (req, res) => { + const { graph, prompt } = req.body; + if (!graph || !prompt) { + throw new Error("Missing graph or prompt"); + } + + const result = await llmMany( + `${prompt} + +Here is the current state of the flowchart: +${JSON.stringify(graph, null, 2)} +`, + { + updateGraph: graphSchema, + } + ); + + res.json(result); +}; + +export default handler; diff --git a/api/prompt/_shared.ts b/api/prompt/_shared.ts new file mode 100644 index 000000000..069c0d6a3 --- /dev/null +++ b/api/prompt/_shared.ts @@ -0,0 +1,103 @@ +import { z } from "zod"; +import { streamText } from "ai"; +import { stripe } from "../_lib/_stripe"; +import { kv } from "@vercel/kv"; +import { Ratelimit } from "@upstash/ratelimit"; +import { openai } from "@ai-sdk/openai"; + +export const reqSchema = z.object({ + prompt: z.string().min(1), + document: z.string(), +}); + +export async function handleRateLimit(req: Request) { + const ip = getIp(req); + let isPro = false, + customerId: null | string = null; + + const token = req.headers.get("Authorization"); + + if (token) { + const sid = token.split(" ")[1]; + const sub = await stripe.subscriptions.retrieve(sid); + if (sub.status === "active" || sub.status === "trialing") { + isPro = true; + customerId = sub.customer as string; + } + } + + const ratelimit = new Ratelimit({ + redis: kv, + limiter: isPro + ? Ratelimit.slidingWindow(3, "1m") + : Ratelimit.fixedWindow(3, "30d"), + }); + + const rateLimitKey = isPro ? `pro_${customerId}` : `unauth_${ip}`; + const { success, limit, reset, remaining } = await ratelimit.limit( + rateLimitKey + ); + + if (!success) { + return new Response("You have reached your request limit.", { + status: 429, + headers: { + "X-RateLimit-Limit": limit.toString(), + "X-RateLimit-Remaining": remaining.toString(), + "X-RateLimit-Reset": reset.toString(), + }, + }); + } + + return null; +} + +export async function processRequest( + req: Request, + systemMessage: string, + content: string, + model: Parameters[0] = "gpt-4-turbo" +) { + const rateLimitResponse = await handleRateLimit(req); + if (rateLimitResponse) return rateLimitResponse; + + const result = await streamText({ + model: openai.chat(model), + system: systemMessage, + temperature: 1, + messages: [ + { + role: "user", + content, + }, + ], + }); + + return result.toTextStreamResponse(); +} + +function getIp(req: Request) { + return ( + req.headers.get("x-real-ip") || + req.headers.get("cf-connecting-ip") || + req.headers.get("x-forwarded-for") || + req.headers.get("x-client-ip") || + req.headers.get("x-cluster-client-ip") || + req.headers.get("forwarded-for") || + req.headers.get("forwarded") || + req.headers.get("via") || + req.headers.get("x-forwarded") || + req.headers.get + ); +} + +export const systemMessageStyle = `You can style nodes using classes at the end of a node. Available styles include: +- Colors: .color_blue, .color_red, .color_green, .color_yellow, .color_orange +- Shapes: .shape_circle, .shape_diamond, .shape_ellipse, .shape_right-rhomboid`; + +export const systemMessageExample = `Node A + Node B .shape_circle + \\(Secret Node) + Node C + label from c to d: Node D .color_green.shape_diamond + label from d to a: (Node A)`; diff --git a/api/prompt/convert.ts b/api/prompt/convert.ts index 0c0e06387..b1feb0e2b 100644 --- a/api/prompt/convert.ts +++ b/api/prompt/convert.ts @@ -1,97 +1,32 @@ -import { z } from "zod"; -import { streamText } from "ai"; -import { stripe } from "../_lib/_stripe"; -import { kv } from "@vercel/kv"; -import { Ratelimit } from "@upstash/ratelimit"; -import { openai } from "@ai-sdk/openai"; +import { processRequest, reqSchema } from "./_shared"; export const config = { runtime: "edge", }; -const reqSchema = z.object({ - prompt: z.string().min(1), -}); +const systemMessage = `You are the Flowchart Fun creation assistant. When I give you a document respond with a diagram in Flowchart Fun syntax. The Flowchart Fun syntax you use indentation to express a tree shaped graph. You use text before a colon to labels to edges. You link back to earlier nodes by referring to their label in parentheses. The following characters must be escaped when used in a node or edge label: (,:,#, and .\n\nYou can style nodes using classes at the end of a node. Available styles include: +- Colors: .color_blue, .color_red, .color_green, .color_yellow +- Shapes: .shape_circle, .shape_diamond, .shape_hexagon -const systemMessage = `You are the Flowchart Fun creation assistant. When I give you a document respond with a diagram in Flowchart Fun syntax. The Flowchart Fun syntax you use indentation to express a tree shaped graph. You use text before a colon to labels to edges. You link back to earlier nodes by referring to their label in parentheses. The following characters must be escaped when used in a node or edge label: (,:,#, and .\n\nHere is a very simple graph illustrating the syntax: +Here is a very simple graph illustrating the syntax: - Node A - Node B + Node A .color_blue + Node B .shape_circle \\(Secret Node) - Node C - label from c to d: Node D + Node C .color_green + label from c to d: Node D .shape_diamond label from d to a: (Node A) Note: Don't provide any explanation. Don't wrap your response in a code block.`; -export default async function handler(req: Request) { - const ip = getIp(req); - - let isPro = false, - customerId: null | string = null; - - // Check for auth token - const token = req.headers.get("Authorization"); - - if (token) { - // get sid from token - const sid = token.split(" ")[1]; - - // check if subscription is active or trialing - const sub = await stripe.subscriptions.retrieve(sid); - if (sub.status === "active" || sub.status === "trialing") { - isPro = true; - customerId = sub.customer as string; - } - } - - // Implement rate-limiting based on IP for unauthorized users and customerId for authorized users - // Initialize Upstash Ratelimit - const ratelimit = new Ratelimit({ - redis: kv, - limiter: isPro - ? Ratelimit.slidingWindow(3, "1m") // Pro users: 3 requests per minute - : Ratelimit.fixedWindow(2, "30d"), // Unauthenticated users: 2 requests per month - }); - - // Determine the key for rate limiting - const rateLimitKey = isPro ? `pro_${customerId}` : `unauth_${ip}`; - - // Check the rate limit - const { success, limit, reset, remaining } = await ratelimit.limit( - rateLimitKey - ); - - if (!success) { - return new Response("You have reached your request limit.", { - status: 429, - headers: { - "X-RateLimit-Limit": limit.toString(), - "X-RateLimit-Remaining": remaining.toString(), - "X-RateLimit-Reset": reset.toString(), - }, - }); - } +export default async function handler(req: Request) { const body = await req.json(); const parsed = reqSchema.safeParse(body); - if (!parsed.success) { return new Response(JSON.stringify(parsed.error), { status: 400 }); } - const result = await streamText({ - model: openai.chat("gpt-4-turbo"), - system: systemMessage, - temperature: 0.15, - messages: [ - { - role: "user", - content: getContent(parsed.data.prompt), - }, - ], - }); - - return result.toTextStreamResponse(); + return processRequest(req, systemMessage, getContent(parsed.data.prompt)); } function getContent(prompt: string): string { @@ -100,18 +35,3 @@ function getContent(prompt: string): string { prompt ); } - -function getIp(req: Request) { - return ( - req.headers.get("x-real-ip") || - req.headers.get("cf-connecting-ip") || - req.headers.get("x-forwarded-for") || - req.headers.get("x-client-ip") || - req.headers.get("x-cluster-client-ip") || - req.headers.get("forwarded-for") || - req.headers.get("forwarded") || - req.headers.get("via") || - req.headers.get("x-forwarded") || - req.headers.get - ); -} diff --git a/api/prompt/edit.ts b/api/prompt/edit.ts index 201263606..1efdf7164 100644 --- a/api/prompt/edit.ts +++ b/api/prompt/edit.ts @@ -1,42 +1,44 @@ -import { VercelApiHandler } from "@vercel/node"; -import { llmMany } from "../_lib/_llm"; -import { z } from "zod"; - -const nodeSchema = z.object({ - // id: z.string(), - // classes: z.string(), - label: z.string(), -}); - -const edgeSchema = z.object({ - from: z.string(), - to: z.string(), - label: z.string().optional().default(""), -}); - -const graphSchema = z.object({ - nodes: z.array(nodeSchema), - edges: z.array(edgeSchema), -}); - -const handler: VercelApiHandler = async (req, res) => { - const { graph, prompt } = req.body; - if (!graph || !prompt) { - throw new Error("Missing graph or prompt"); +import { processRequest, reqSchema } from "./_shared"; + +export const config = { + runtime: "edge", +}; + +const systemMessage = `You are an AI document editor specializing in Flowchart Fun syntax. When given a document and editing instructions, return the same document with only the requested changes. Do not make any additional changes, including whitespace changes, beyond what is explicitly requested. Preserve all original formatting and content except where modifications are necessary. + +Flowchart Fun Syntax: +- Use indentation to express a tree-shaped graph. +- Text before a colon represents labels for edges. +- Link back to earlier nodes by referring to their label in parentheses. +- Escape the following characters when used in a node or edge label: (,:,#, and \\. +- Use classes at the end of a node to apply styles. (e.g., .color_blue,.shape_circle) + +Example: + Node A + Node B .color_blue + \\(Secret Node) + Node C + label from c to d: Node D + label from d to a: (Node A) + +When editing, ensure that the Flowchart Fun syntax remains valid and consistent.`; + +export default async function handler(req: Request) { + const body = await req.json(); + const parsed = reqSchema.safeParse(body); + + if (!parsed.success) { + return new Response(JSON.stringify(parsed.error), { status: 400 }); } - const result = await llmMany( - `${prompt} - -Here is the current state of the flowchart: -${JSON.stringify(graph, null, 2)} -`, - { - updateGraph: graphSchema, - } + return processRequest( + req, + systemMessage, + getContent(parsed.data.prompt, parsed.data.document), + "gpt-4-turbo-2024-04-09" ); +} - res.json(result); -}; - -export default handler; +function getContent(prompt: string, document: string): string { + return `Edit the following document according to these instructions:\n\nInstructions: ${prompt}\n\nDocument:\n${document}`; +} diff --git a/api/prompt/prompt.ts b/api/prompt/prompt.ts new file mode 100644 index 000000000..4d0a8f47f --- /dev/null +++ b/api/prompt/prompt.ts @@ -0,0 +1,29 @@ +import { + processRequest, + reqSchema, + systemMessageExample, + systemMessageStyle, +} from "./_shared"; + +export const config = { + runtime: "edge", +}; + +const systemMessage = `You are the Flowchart Fun creation assistant. When I give you a prompt, respond with a diagram in Flowchart Fun syntax. The Flowchart Fun syntax uses indentation to express a tree-shaped graph. Use text before a colon to label edges. Link back to earlier nodes by referring to their label in parentheses. The following characters must be escaped when used in a node or edge label: (,:,#, and .\n\n${systemMessageStyle} + +Here is a very simple graph illustrating the syntax: + +${systemMessageExample} + +Note: Don't provide any explanation. Don't wrap your response in a code block.`; + +export default async function handler(req: Request) { + const body = await req.json(); + const parsed = reqSchema.safeParse(body); + + if (!parsed.success) { + return new Response(JSON.stringify(parsed.error), { status: 400 }); + } + + return processRequest(req, systemMessage, parsed.data.prompt); +} diff --git a/app/e2e/pro.spec.ts b/app/e2e/pro.spec.ts index dc9847883..924bacfe4 100644 --- a/app/e2e/pro.spec.ts +++ b/app/e2e/pro.spec.ts @@ -132,7 +132,7 @@ test("Create chart from AI", async () => { test("Create chart from imported data", async () => { try { await page.getByRole("link", { name: "New" }).click(); - await page.getByRole("button", { name: "Create" }).click(); + await page.getByTestId("Create Chart").click(); await page.waitForURL(new RegExp(`${BASE_URL}/u/\\d+`)); await page.getByRole("button", { name: "Import Data" }).click(); diff --git a/app/package.json b/app/package.json index a255775cd..7ab353132 100644 --- a/app/package.json +++ b/app/package.json @@ -28,6 +28,7 @@ "theme:schema:generate": "pnpx ts-json-schema-generator --path 'src/lib/FFTheme.ts' --type 'FFTheme' -f tsconfig.json --strict-tuples -o src/lib/FFTheme.schema.json --minify --no-top-ref" }, "dependencies": { + "@ai-sdk/openai": "^0.0.37", "@formkit/auto-animate": "1.0.0-beta.6", "@lingui/core": "^3.8.9", "@lingui/react": "^3.8.9", @@ -59,7 +60,7 @@ "@svgr/webpack": "^6.3.1", "@tone-row/slang": "^1.2.35", "@tone-row/strip-comments": "^2.0.1", - "ai": "^3.2.19", + "ai": "^3.2.32", "buffer": "^6.0.3", "classnames": "^2.3.2", "construct-style-sheets-polyfill": "^3.1.0", diff --git a/app/src/components/AiToolbar.tsx b/app/src/components/AiToolbar.tsx index f73157095..61a313345 100644 --- a/app/src/components/AiToolbar.tsx +++ b/app/src/components/AiToolbar.tsx @@ -34,7 +34,7 @@ export function AiToolbar() { } }, [userPasted]); - const convertIsRunning = usePromptStore((s) => s.convertIsRunning); + const convertIsRunning = usePromptStore((s) => s.isRunning); // Qualities for displaying Convert to Flowchart button: // OR diff --git a/app/src/components/AiToolbar2.tsx b/app/src/components/AiToolbar2.tsx new file mode 100644 index 000000000..3884e5a37 --- /dev/null +++ b/app/src/components/AiToolbar2.tsx @@ -0,0 +1,147 @@ +import { Button2, IconButton2, Textarea } from "../ui/Shared"; +import { CaretDown, CaretUp, MagicWand } from "phosphor-react"; +import cx from "classnames"; +import { t } from "@lingui/macro"; +import { createExamples } from "../pages/createExamples"; +import { + Mode, + acceptDiff, + rejectDiff, + setCurrentText, + setIsOpen, + setMode, + usePromptStore, + useRunAiWithStore, +} from "../lib/usePromptStore"; + +function getModeDescription(mode: Mode): string { + const prompts = createExamples(); + switch (mode) { + case "prompt": + return `E.G., "${prompts[0]}"`; + case "convert": + return t`Paste your document or outline here to convert it into an organized flowchart.`; + case "edit": + return t`Use this mode to modify and enhance your current chart.`; + } +} + +function getModeTitle(mode: Mode): string { + switch (mode) { + case "prompt": + return t`Prompt`; + case "convert": + return t`Convert`; + case "edit": + return t`Edit`; + } +} + +export function AiToolbar2() { + const isOpen = usePromptStore((state) => state.isOpen); + const currentMode = usePromptStore((state) => state.mode); + const isRunning = usePromptStore((state) => state.isRunning); + const runAiWithStore = useRunAiWithStore(); + const diff = usePromptStore((state) => state.diff); + + const toggleOpen = () => setIsOpen(!isOpen); + + const handleModeChange = (mode: Mode) => { + setMode(mode); + if (!isOpen) setIsOpen(true); + }; + + const currentText = usePromptStore((state) => state.currentText); + + const showAcceptDiffButton = diff && !isRunning; + + return ( +
+
+
+ + {!showAcceptDiffButton ? ( + (["prompt", "convert", "edit"] as Mode[]).map((mode) => ( + handleModeChange(mode)} + className={cx({ + "hover:bg-white dark:hover:bg-neutral-700": + mode !== currentMode, + "dark:bg-purple-700 dark:text-purple-100": + mode === currentMode, + })} + > + {getModeTitle(mode)} + + )) + ) : ( + + Keep changes? + + )} +
+ {!showAcceptDiffButton ? ( + + {!isOpen ? : } + + ) : ( +
+ + Accept + + + Reject + +
+ )} +
+ {isOpen && ( +
+

+ {getModeDescription(currentMode)} +

+