Skip to content

Commit

Permalink
Editable sequence breakers (#1092)
Browse files Browse the repository at this point in the history
- gemini 2.0 flash
- fix image model receive for reel
- add dry sequence breakers to ui
- openai chat format using prompt templates
  • Loading branch information
sceuick authored Dec 14, 2024
1 parent 4a6ec0d commit 62ad0b4
Show file tree
Hide file tree
Showing 21 changed files with 246 additions and 90 deletions.
28 changes: 14 additions & 14 deletions .vscode/tasks.json
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "typescript",
"tsconfig": "tsconfig.json",
"option": "watch",
"problemMatcher": ["$tsc-watch"],
"group": {
"kind": "build",
"isDefault": true
},
"isBackground": true,
"label": "tsc: watch - tsconfig.json"
},
// {
// "type": "typescript",
// "tsconfig": "tsconfig.json",
// "option": "watch",
// "problemMatcher": ["$tsc-watch"],
// "group": {
// "kind": "build",
// "isDefault": true
// },
// "isBackground": true,
// "label": "tsc: watch - tsconfig.json"
// },
{
"type": "typescript",
"tsconfig": "srv.tsconfig.json",
Expand All @@ -27,13 +27,13 @@
{
"type": "shell",
"command": "pnpm",
"args": ["tsc", "-p", "tsconfig.json", "--noEmit"],
"args": ["tsc", "-p", "tsconfig.json", "--noEmit", "-w"],
"problemMatcher": "$tsc-watch",
"label": "tsc: watch no-emit",
"isBackground": true,
"group": {
"kind": "build",
"isDefault": false
"isDefault": true
}
}
]
Expand Down
4 changes: 4 additions & 0 deletions common/adapters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ export const JSON_SCHEMA_SUPPORTED: { [key in AIAdapter | ThirdPartyFormat]?: bo
export const THIRDPARTY_HANDLERS: { [svc in ThirdPartyFormat]: AIAdapter } = {
openai: 'openai',
'openai-chat': 'openai',
'openai-chatv2': 'openai',
claude: 'claude',
aphrodite: 'kobold',
exllamav2: 'kobold',
Expand All @@ -109,6 +110,7 @@ export const THIRDPARTY_FORMATS = [
'kobold',
'openai',
'openai-chat',
'openai-chatv2',
'claude',
'ooba',
'llamacpp',
Expand Down Expand Up @@ -277,13 +279,15 @@ export const GOOGLE_MODELS = {
GEMINI_15_FLASH_002: { id: 'gemini-1.5-flash-002', label: 'Gemini 1.5 Flash 002' },
GEMINI_15_FLASH_8B: { id: 'gemini-1.5-flash-8b', label: 'Gemini 1.5 Flash 8B' },
GEMINI_EXP_1114: { id: 'gemini-exp-1114', label: 'Gemini Exp 1114' },
GEMINI_20_FLASH: { id: 'gemini-2.0-flash-exp', label: 'Gemini 2.0 Flash' },
}

export const GOOGLE_LIMITS: Record<string, number> = {
'gemini-1.5-pro': 2097152,
'gemini-1.0-pro-latest': 32768,
'gemini-1.5-flash': 1048576,
'gemini-1.5-flash-8b': 1048576,
'gemini-2.0-flash-exp': 1048576,
}

/** Note: claude-v1 and claude-instant-v1 not included as they may point
Expand Down
4 changes: 2 additions & 2 deletions common/prompt-order.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ export function promptOrderToSections(opts: OrderOptions) {
? formatHolders[opts.format] || formatHolders.Universal
: formatHolders.Universal

const system = holders.system
const defs = order.map((o) => holders[o.placeholder]).join('\n')
const system = holders.system_prompt || holders.system
const defs = order.map((o) => getOrderHolder(opts.format!, o.placeholder)).join('\n')
const history = holders.history
const post = holders.post

Expand Down
6 changes: 4 additions & 2 deletions common/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ export async function createPromptParts(opts: PromptOpts, encoder: TokenCounter)
return { lines: lines.reverse(), parts, template: prompt }
}

export type AssembledPrompt = Awaited<ReturnType<typeof assemblePrompt>>

/**
* This is only ever invoked server-side
*
Expand All @@ -281,7 +283,7 @@ export async function assemblePrompt(
const template = getTemplate(opts)

const history = { lines, order: 'asc' } as const
let { parsed, inserts, length } = await injectPlaceholders(template, {
let { parsed, inserts, length, sections } = await injectPlaceholders(template, {
opts,
parts,
history,
Expand All @@ -291,7 +293,7 @@ export async function assemblePrompt(
jsonValues: opts.jsonValues,
})

return { lines: history.lines, prompt: parsed, inserts, parts, post, length }
return { lines: history.lines, prompt: parsed, inserts, parts, post, length, sections }
}

export function getTemplate(opts: Pick<GenerateRequestV2, 'settings' | 'chat'>) {
Expand Down
1 change: 1 addition & 0 deletions common/requests/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function startRequest(request: GenerateRequestV2, prompt: string) {
switch (request.settings!.thirdPartyFormat) {
case 'openai':
case 'openai-chat':
case 'openai-chatv2':
return handleOAI(opts, payload)

case 'aphrodite':
Expand Down
5 changes: 3 additions & 2 deletions common/requests/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ export async function* handleOAI(opts: PayloadOpts, payload: any) {
}

const suffix = gen.thirdPartyUrl?.endsWith('/') ? '' : '/'
const urlPath =
gen.thirdPartyFormat === 'openai-chat' ? `${suffix}chat/completions` : `${suffix}completions`
const urlPath = gen.thirdPartyFormat?.startsWith('openai-chat')
? `${suffix}chat/completions`
: `${suffix}completions`
const fullUrl = `${gen.thirdPartyUrl}${urlPath}`

if (!gen.streamResponse) {
Expand Down
37 changes: 26 additions & 11 deletions common/template-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export type TemplateOpts = {
sections?: {
flags: { [key in Section]?: boolean }
sections: { [key in Section]: string[] }
done: boolean
}

/**
Expand Down Expand Up @@ -184,6 +185,7 @@ export async function parseTemplate(
const sections: TemplateOpts['sections'] = {
flags: {},
sections: { system: [], history: [], post: [] },
done: false,
}

opts.sections = sections
Expand All @@ -205,6 +207,7 @@ export async function parseTemplate(
const ast = parser.parse(template, {}) as PNode[]
readInserts(opts, ast)
let output = render(template, opts, ast)
opts.sections.done = true
let unusedTokens = 0
let linesAddedCount = 0

Expand Down Expand Up @@ -333,16 +336,23 @@ function render(template: string, opts: TemplateOpts, existingAst?: PNode[]) {
}

const output: string[] = []
let prevMarker: Section = 'system'

for (let i = 0; i < ast.length; i++) {
const parent = ast[i]

const result = renderNode(parent, opts)

const marker = getMarker(parent)
fillSection(opts, marker, result)
const marker = getMarker(opts, parent, prevMarker)
prevMarker = marker

if (result) output.push(result)
if (!opts.sections?.done) {
fillSection(opts, marker, result)
}

if (result) {
output.push(result)
}
}
return output.join('').replace(/\n\n+/g, '\n\n')
} catch (err) {
Expand Down Expand Up @@ -623,6 +633,9 @@ function renderIterator(holder: IterableHolder, children: CNode[], opts: Templat
if (isHistory && opts.limit?.output) {
const id = HISTORY_MARKER
opts.limit.output[id] = { src: holder, lines: output }
if (opts.sections) {
opts.sections.flags.history = true
}
return id
}

Expand Down Expand Up @@ -788,38 +801,40 @@ function fillSection(opts: TemplateOpts, marker: Section | undefined, result: st
const flags = opts.sections.flags
const sections = opts.sections.sections

if (!flags.system) {
if (!flags.system && marker === 'system') {
sections.system.push(result)
return
}

if (marker === 'history') {
flags.system = true
flags.history = true
return
}

sections.post.push(result)
}

function getMarker(node: PNode): Section | undefined {
if (typeof node === 'string') return
function getMarker(opts: TemplateOpts, node: PNode, previous: Section): Section {
if (!opts.sections) return previous
if (opts.sections.flags.history) return 'post'

if (typeof node === 'string') return previous

switch (node.kind) {
case 'placeholder': {
if (node.value === 'history') return 'history'
if (node.value === 'system_prompt') return 'system'
return
return previous
}

case 'each':
if (node.value === 'history') return 'history'
return
return previous

case 'if':
if (node.value === 'system_prompt') return 'system'
return
return previous
}

return
return previous
}
1 change: 1 addition & 0 deletions srv/adapter/agnaistic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ export function getHandlers(settings: Partial<AppSchema.GenSettings>) {
return handlers[settings.thirdPartyFormat!]

case 'openai-chat':
case 'openai-chatv2':
return handlers.openai

case 'featherless':
Expand Down
4 changes: 4 additions & 0 deletions srv/adapter/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
import { getCachedSubscriptionModels } from '../db/subscriptions'
import { sendOne } from '../api/ws'
import { ResponseSchema } from '/common/types/library'
import { toChatMessages } from './template-chat-payload'

let version = ''

Expand Down Expand Up @@ -232,6 +233,7 @@ export async function createInferenceStream(opts: InferenceRequest) {
guidance: opts.guidance,
previous: opts.previous,
placeholders: opts.placeholders,
characters: {},
lists: opts.lists,
jsonSchema: opts.jsonSchema,
imageData: opts.imageData,
Expand Down Expand Up @@ -392,6 +394,7 @@ export async function createChatStream(
*/

const prompt = await assemblePrompt(opts, opts.parts, opts.lines, encoder)
const messages = await toChatMessages(opts, prompt, encoder)

const size = encoder(
[
Expand Down Expand Up @@ -421,6 +424,7 @@ export async function createChatStream(
log,
members: opts.members.concat(opts.sender),
prompt: prompt.prompt,
messages,
parts: prompt.parts,
sender: opts.sender,
mappedSettings,
Expand Down
24 changes: 13 additions & 11 deletions srv/adapter/openai.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import { sanitiseAndTrim } from '/common/requests/util'
import { ChatRole, CompletionItem, ModelAdapter } from './type'
import { ChatRole, ModelAdapter } from './type'
import { defaultPresets } from '../../common/presets'
import { OPENAI_CHAT_MODELS, OPENAI_MODELS } from '../../common/adapters'
import { AppSchema } from '../../common/types/schema'
import { config } from '../config'
import { AppLog } from '../middleware'
import { requestFullCompletion, toChatCompletionPayload } from './chat-completion'
import { decryptText } from '../db/util'
Expand Down Expand Up @@ -42,23 +41,26 @@ export const handleOAI: ModelAdapter = async function* (opts) {
stream: (gen.streamResponse && kind !== 'summary') ?? defaultPresets.openai.streamResponse,
temperature: gen.temp ?? defaultPresets.openai.temp,
max_tokens: maxResponseLength,
max_completion_tokens: maxResponseLength,
top_p: gen.topP ?? 1,
stop: [`\n${handle}:`].concat(gen.stopSequences!),
}

body.presence_penalty = gen.presencePenalty ?? defaultPresets.openai.presencePenalty
body.frequency_penalty = gen.frequencyPenalty ?? defaultPresets.openai.frequencyPenalty

const useChat =
(isThirdParty && gen.thirdPartyFormat === 'openai-chat') || !!OPENAI_CHAT_MODELS[oaiModel]
const isChatFormat =
gen.thirdPartyFormat === 'openai-chat' || gen.thirdPartyFormat == 'openai-chatv2'
const useChat = (isThirdParty && isChatFormat) || !!OPENAI_CHAT_MODELS[oaiModel]
if (useChat) {
const messages: CompletionItem[] = config.inference.flatChatCompletion
? [{ role: 'system', content: opts.prompt }]
: await toChatCompletionPayload(
opts,
getTokenCounter('openai', OPENAI_MODELS.Turbo),
body.max_tokens
)
const messages =
gen.thirdPartyFormat === 'openai-chatv2' && opts.messages
? opts.messages
: await toChatCompletionPayload(
opts,
getTokenCounter('openai', OPENAI_MODELS.Turbo),
body.max_tokens
)

body.messages = messages
yield { prompt: messages }
Expand Down
16 changes: 9 additions & 7 deletions srv/adapter/payloads.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,18 @@ function getBasePayload(opts: AdapterProps, stops: string[] = []) {

const json_schema = opts.jsonSchema ? toJsonSchema(opts.jsonSchema) : undefined

const characterNames = Object.values(opts.characters || {}).map((c) => c.name)
const characterNames = Object.values(opts.characters || {})
.map((c) => c.name.split(' '))
.concat(opts.members.map((m) => m.handle.split(' ')))
.flat()

const sequenceBreakers = Array.from(
new Set([
opts.char.name,
opts.replyAs.name,
...characterNames,
...opts.members.map((m) => m.handle),
]).values()
new Set(
[opts.char.name.split(' '), opts.replyAs.name.split(' '), ...characterNames].flat()
).values()
)
.concat(gen.drySequenceBreakers || [])
.flat()
.filter((t) => !!t)

if (!gen.temp) {
Expand Down
Loading

0 comments on commit 62ad0b4

Please sign in to comment.