Skip to content

Commit

Permalink
feat: ThreadMetadataRuntimeCore (#1135)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Nov 5, 2024
1 parent d6dd541 commit f410c06
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 69 deletions.
1 change: 0 additions & 1 deletion packages/react/src/api/ThreadListItemRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ export class ThreadListItemRuntimeImpl implements ThreadListItemRuntime {

public switchTo(): Promise<void> {
const state = this._core.getState();

return this._threadListBinding.switchToThread(state.threadId);
}

Expand Down
4 changes: 3 additions & 1 deletion packages/react/src/api/ThreadListRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ const getThreadListItemState = (
const threadData = threadList.getThreadMetadataById(threadId);
if (!threadData) return SKIP_UPDATE;
return {
...threadData,
threadId: threadData.threadId,
title: threadData.title,
state: threadData.state,
isMain: threadList.mainThread.metadata.threadId === threadId,
};
};
Expand Down
16 changes: 5 additions & 11 deletions packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import {
RuntimeCapabilities,
SubmittedFeedback,
ThreadRuntimeEventType,
ThreadMetadata,
ThreadMetadataRuntimeCore,
} from "../core/ThreadRuntimeCore";
import { DefaultEditComposerRuntimeCore } from "../composer/DefaultEditComposerRuntimeCore";
import { SpeechSynthesisAdapter } from "../speech/SpeechAdapterTypes";
Expand Down Expand Up @@ -57,18 +57,12 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore {
public readonly composer = new DefaultThreadComposerRuntimeCore(this);

constructor(
private configProvider: ModelConfigProvider,
private _metadata: ThreadMetadata,
private readonly _configProvider: ModelConfigProvider,
private readonly _metadata: ThreadMetadataRuntimeCore,
) {}

public getModelConfig() {
return this.configProvider.getModelConfig();
}

public updateMetadata(metadata: Partial<ThreadMetadata>) {
this._metadata = { ...this._metadata, ...metadata };
this._notifyEventSubscribers("metadata-update");
this._notifySubscribers();
return this._configProvider.getModelConfig();
}

private _editComposers = new Map<string, DefaultEditComposerRuntimeCore>();
Expand Down Expand Up @@ -188,7 +182,7 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore {

public unstable_on(event: ThreadRuntimeEventType, callback: () => void) {
if (event === "model-config-update") {
return this.configProvider.subscribe?.(callback) ?? (() => {});
return this._configProvider.subscribe?.(callback) ?? (() => {});
}

const subscribers = this._eventSubscribers.get(event);
Expand Down
19 changes: 13 additions & 6 deletions packages/react/src/runtimes/core/ThreadRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,27 @@ export type SubmittedFeedback = Readonly<{
export type ThreadMetadata = Readonly<{
threadId: string;
state: "archived" | "regular" | "new" | "deleted";
title?: string;
title?: string | undefined;
}>;

export type ThreadRuntimeEventType =
| "switched-to"
| "switched-away"
| "run-start"
| "model-config-update"
| "metadata-update";
| "model-config-update";

export type ThreadMetadataRuntimeCore = ThreadMetadata &
Readonly<{
create(title?: string): Promise<void>;
rename(newTitle: string): Promise<void>;
archive(): Promise<void>;
unarchive(): Promise<void>;
delete(): Promise<void>;
subscribe(callback: () => void): Unsubscribe;
}>;

export type ThreadRuntimeCore = Readonly<{
metadata: ThreadMetadata;

updateMetadata: (metadata: Partial<ThreadMetadata>) => void;
metadata: ThreadMetadataRuntimeCore;

getMessageById: (messageId: string) =>
| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ const DEFAULT_THREAD_ID = "DEFAULT_THREAD_ID";
export class ExternalStoreThreadListRuntimeCore
implements ThreadListRuntimeCore
{
private _threads: readonly string[] = [];
private _archivedThreads: readonly string[] = [];

public get newThread() {
return undefined;
}

public get threads() {
return this.adapter.threads?.map((t) => t.threadId) ?? EMPTY_ARRAY;
return this._threads;
}

public get archivedThreads() {
return this.adapter.archivedThreads?.map((t) => t.threadId) ?? EMPTY_ARRAY;
return this._archivedThreads;
}

private _mainThread: ExternalStoreThreadRuntimeCore;
Expand Down Expand Up @@ -69,17 +72,42 @@ export class ExternalStoreThreadListRuntimeCore
return;
}

if (previousAdapter.threads !== newThreads) {
this._threads =
this.adapter.threads?.map((t) => t.threadId) ?? EMPTY_ARRAY;
}

if (previousAdapter.archivedThreads !== newArchivedThreads) {
this._archivedThreads =
this.adapter.archivedThreads?.map((t) => t.threadId) ?? EMPTY_ARRAY;
}

if (previousAdapter.threadId !== newThreadId) {
this._mainThread._notifyEventSubscribers("switched-away");
this._mainThread = this.threadFactory(newThreadId);
this._mainThread._notifyEventSubscribers("switched-to");
}

const previousMainState = this._mainThread.metadata.state;
const mainState = this.archivedThreads.includes(
this._mainThread.metadata.threadId,
)
? "archived"
: "regular";

if (previousMainState !== mainState) {
if (mainState === "archived") {
this._mainThread.metadata.archive();
} else {
this._mainThread.metadata.unarchive();
}
}

this._notifySubscribers();
}

public async switchToThread(threadId: string): Promise<void> {
if (this._mainThread?.threadId === threadId) return;
if (this._mainThread?.metadata.threadId === threadId) return;
const onSwitchToThread = this.adapter.onSwitchToThread;
if (!onSwitchToThread)
throw new Error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
ThreadRuntimeCore,
} from "../core/ThreadRuntimeCore";
import { BaseThreadRuntimeCore } from "../core/BaseThreadRuntimeCore";
import { LocalThreadMetadataRuntimeCore } from "../local/LocalThreadMetadataRuntimeCore";

const EMPTY_ARRAY = Object.freeze([]);

Expand Down Expand Up @@ -49,7 +50,6 @@ export class ExternalStoreThreadRuntimeCore
return this._capabilities;
}

public threadId!: string;
private _messages!: ThreadMessage[];
public isDisabled!: boolean;

Expand Down Expand Up @@ -80,8 +80,9 @@ export class ExternalStoreThreadRuntimeCore
threadId: string,
store: ExternalStoreAdapter<any>,
) {
super(configProvider, { threadId, state: "new" });
this.threadId = threadId;
const metadata = new LocalThreadMetadataRuntimeCore(threadId);
metadata.create();
super(configProvider, metadata);
this.setStore(store);
}

Expand Down
115 changes: 73 additions & 42 deletions packages/react/src/runtimes/local/LocalThreadListRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import { generateId } from "../../utils/idUtils";
import { LocalThreadRuntimeCore } from "./LocalThreadRuntimeCore";
import { ThreadMetadata } from "../core/ThreadRuntimeCore";

export type ThreadListAdapter = {
subscribe(callback: () => void): Unsubscribe;
};

export type LocalThreadData = {
runtime: LocalThreadRuntimeCore;
lastState: ThreadMetadata["state"];
Expand Down Expand Up @@ -67,14 +71,15 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {

public switchToNewThread(): Promise<void> {
if (this._newThread === undefined) {
let threadId;
let threadId: string;
do {
threadId = generateId();
} while (this._threadData.has(threadId));

const runtime = this._threadFactory(threadId, { messages: [] });
const dispose = runtime.unstable_on("metadata-update", () => {
this._syncState(threadId, runtime.metadata.state);
const dispose = runtime.metadata.subscribe(() => {
this._syncState(threadId, runtime.metadata);
threadId = runtime.metadata.threadId;
});
this._threadData.set(threadId, { runtime, lastState: "new", dispose });
this._newThread = threadId;
Expand All @@ -85,48 +90,69 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
}

private async _syncState(
threadId: string,
state: "archived" | "regular" | "new" | "deleted",
lastThreadId: string,
{ state, threadId }: ThreadMetadata,
) {
const data = this._threadData.get(threadId);
const data = this._threadData.get(lastThreadId);
if (!data) throw new Error("Thread not found");
if (data.lastState === state) return;
const lastState = data.lastState;
if (lastState === state && lastThreadId === threadId) return;

if (state === "archived") {
this._archivedThreads = [
...this._archivedThreads,
data.runtime.metadata.threadId,
];
}
if (state === "regular") {
this._threads = [...this._threads, data.runtime.metadata.threadId];
}
if (state === "deleted") {
data.dispose();
this._threadData.delete(threadId);
}
if (state === "new") {
if (this._newThread) {
this.delete(this._newThread);
if (lastThreadId !== threadId) {
this._threadData.delete(lastThreadId);
if (lastState === "new") {
this._newThread = threadId;
}
if (lastState === "regular") {
this._threads = this._threads.map((t) =>
t === lastThreadId ? threadId : t,
);
}
if (lastState === "archived") {
this._archivedThreads = this._archivedThreads.map((t) =>
t === lastThreadId ? threadId : t,
);
}
this._newThread = threadId;
}

if (data.lastState === "regular") {
this._threads = this._threads.filter((t) => t !== threadId);
}
if (lastState !== state) {
if (lastState === "new") {
this._newThread = undefined;
}

if (data.lastState === "archived") {
this._archivedThreads = this._archivedThreads.filter(
(t) => t !== threadId,
);
}
if (lastState === "regular") {
this._threads = this._threads.filter((t) => t !== threadId);
}

if (data.lastState === "new") {
this._newThread = undefined;
if (lastState === "archived") {
this._archivedThreads = this._archivedThreads.filter(
(t) => t !== threadId,
);
}

if (state === "new") {
if (this._newThread) {
this.delete(this._newThread);
}
this._newThread = threadId;
}
if (state === "archived") {
this._archivedThreads = [
...this._archivedThreads,
data.runtime.metadata.threadId,
];
}
if (state === "regular") {
this._threads = [...this._threads, data.runtime.metadata.threadId];
}
if (state === "deleted") {
data.dispose();
this._threadData.delete(threadId);
}

data.lastState = state;
}

data.lastState = state;
this._notifySubscribers();

if (
Expand All @@ -142,37 +168,42 @@ export class LocalThreadListRuntimeCore implements ThreadListRuntimeCore {
}
}

public async rename(threadId: string, newTitle: string): Promise<void> {
public rename(threadId: string, newTitle: string): Promise<void> {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");
data.runtime.metadata.rename(newTitle);

data.runtime.updateMetadata({ title: newTitle });
return Promise.resolve();
}

public async archive(threadId: string): Promise<void> {
public archive(threadId: string): Promise<void> {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");
if (data.lastState !== "regular")
throw new Error("Thread is not yet created or archived");
data.runtime.updateMetadata({ state: "archived" });
data.runtime.metadata.archive();

return Promise.resolve();
}

public unarchive(threadId: string): Promise<void> {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");
if (data.lastState !== "archived")
throw new Error("Thread is not archived");
data.runtime.updateMetadata({ state: "regular" });
data.runtime.metadata.unarchive();

return Promise.resolve();
}

public async delete(threadId: string): Promise<void> {
public delete(threadId: string): Promise<void> {
const data = this._threadData.get(threadId);
if (!data) throw new Error("Thread not found");
if (data.lastState !== "regular" && data.lastState !== "archived")
throw new Error("Thread is not yet created or already deleted");
data.runtime.updateMetadata({ state: "deleted" });
data.runtime.metadata.delete();

return Promise.resolve();
}

private _subscriptions = new Set<() => void>();
Expand Down
Loading

0 comments on commit f410c06

Please sign in to comment.