Skip to content

Commit

Permalink
feat(agent): added example of agent
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianmusial committed Dec 28, 2023
1 parent 86be28f commit b61dfe9
Show file tree
Hide file tree
Showing 23 changed files with 260 additions and 49 deletions.
1 change: 1 addition & 0 deletions .env.dist
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
OPENAI_API_KEY=
ASSISTANT_ID=
POKEMON_API_URL=
4 changes: 4 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
"test:e2e": "jest --config ./test/jest-e2e.json"
},
"dependencies": {
"@nestjs/axios": "^3.0.1",
"@nestjs/common": "^10.0.0",
"@nestjs/config": "^3.1.1",
"@nestjs/core": "^10.0.0",
"@nestjs/platform-express": "^10.0.0",
"axios": "^1.6.3",
"class-transformer": "^0.5.1",
"class-validator": "^0.14.0",
"envfile": "^7.0.0",
"openai": "^4.20.0",
"reflect-metadata": "^0.1.13",
Expand Down
2 changes: 0 additions & 2 deletions src/assistant/agent.model.ts

This file was deleted.

15 changes: 0 additions & 15 deletions src/assistant/agent.service.ts

This file was deleted.

19 changes: 19 additions & 0 deletions src/assistant/agent/agent.base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { OnModuleInit } from '@nestjs/common';
import { AssistantCreateParams } from 'openai/resources/beta';
import { AgentService } from './agent.service';
import { AgentData } from './agent.model';

export class AgentBase implements OnModuleInit {
definition: AssistantCreateParams.AssistantToolsFunction;

onModuleInit(): void {
this.agentService.add(this.definition, this.output.bind(this));
}

constructor(protected readonly agentService: AgentService) {}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
async output(data: AgentData): Promise<string> {
return '';
}
}
7 changes: 7 additions & 0 deletions src/assistant/agent/agent.model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export type Agent = (data: AgentData) => Promise<string>;
export type Agents = Record<string, Agent>;

export interface AgentData {
threadId: string;
params: string;
}
8 changes: 8 additions & 0 deletions src/assistant/agent/agent.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { Module } from '@nestjs/common';
import { AgentService } from './agent.service';

@Module({
providers: [AgentService],
exports: [AgentService],
})
export class AgentModule {}
21 changes: 21 additions & 0 deletions src/assistant/agent/agent.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { Injectable } from '@nestjs/common';
import { Agent, Agents } from './agent.model';
import { AssistantCreateParams } from 'openai/resources/beta';

@Injectable()
export class AgentService {
public agents: Agents = {};
public tools: AssistantCreateParams.AssistantToolsFunction[] = [];

add(
definition: AssistantCreateParams.AssistantToolsFunction,
fn: Agent,
): void {
this.tools.push(definition);
this.agents[definition.function.name] = fn;
}

get(name: string): Agent {
return this.agents[name];
}
}
File renamed without changes.
2 changes: 1 addition & 1 deletion src/assistant/assistant-files.service.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Inject, Injectable } from '@nestjs/common';
import { FileObject } from 'openai/resources';
import { createReadStream } from 'fs';
import { AiService } from './ai.service';
import { AiService } from './ai/ai.service';
import { AssistantConfig } from './assistant.model';

@Injectable()
Expand Down
3 changes: 3 additions & 0 deletions src/assistant/assistant-memory.service.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Injectable, Logger } from '@nestjs/common';
import { writeFile, readFile } from 'fs/promises';
import * as envfile from 'envfile';
import * as process from 'process';

@Injectable()
export class AssistantMemoryService {
Expand All @@ -16,6 +17,8 @@ export class AssistantMemoryService {
ASSISTANT_ID: id,
};

process.env.ASSISTANT_ID = id;

await writeFile(sourcePath, envfile.stringify(newVariables));
} catch (error) {
this.logger.error(`Can't save variable: ${error}`);
Expand Down
10 changes: 5 additions & 5 deletions src/assistant/assistant.module.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { DynamicModule, Module, OnModuleInit } from '@nestjs/common';
import { AssistantService } from './assistant.service';
import { ChatbotService } from './chatbot.service';
import { AiService } from './ai.service';
import { RunService } from './run.service';
import { AgentService } from './agent.service';
import { ChatbotService } from './chatbot/chatbot.service';
import { AiService } from './ai/ai.service';
import { RunService } from './run/run.service';
import { AssistantConfig } from './assistant.model';
import { AssistantFilesService } from './assistant-files.service';
import { AssistantMemoryService } from './assistant-memory.service';
import { AgentModule } from './agent/agent.module';

const sharedServices = [
AiService,
Expand All @@ -15,10 +15,10 @@ const sharedServices = [
AssistantMemoryService,
ChatbotService,
RunService,
AgentService,
];

@Module({
imports: [AgentModule],
providers: [...sharedServices],
exports: [...sharedServices],
})
Expand Down
28 changes: 20 additions & 8 deletions src/assistant/assistant.service.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Inject, Injectable, Logger } from '@nestjs/common';
import { Assistant, AssistantCreateParams } from 'openai/resources/beta';
import { AiService } from './ai.service';
import { AiService } from './ai/ai.service';
import { AssistantConfig } from './assistant.model';
import { AssistantFilesService } from './assistant-files.service';
import { AssistantMemoryService } from './assistant-memory.service';
import { AgentService } from './agent/agent.service';

@Injectable()
export class AssistantService {
Expand All @@ -16,30 +17,41 @@ export class AssistantService {
private readonly aiService: AiService,
private readonly assistantFilesService: AssistantFilesService,
private readonly assistantMemoryService: AssistantMemoryService,
private readonly agentService: AgentService,
) {}

getParams(): AssistantCreateParams {
return {
...this.config.params,
tools: [...(this.config.params.tools || []), ...this.agentService.tools],
};
}

async init(): Promise<void> {
const { id, params, options } = this.config;
const { id, options } = this.config;

if (!id) {
await this.create();
return await this.create();
}

try {
this.assistant = await this.assistants.update(id, params, options);
this.assistant = await this.assistants.update(
id,
this.getParams(),
options,
);
} catch (e) {
await this.create();
}
}

async update(params: Partial<AssistantCreateParams>): Promise<void> {
this.assistant = await this.assistants.update(this.assistant.id, {
...params,
});
this.assistant = await this.assistants.update(this.assistant.id, params);
}

async create(): Promise<void> {
const { params, options } = this.config;
const { options } = this.config;
const params = this.getParams();
this.assistant = await this.assistants.create(params, options);

if (this.config.files?.length) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { Injectable } from '@nestjs/common';
import { AiService } from './ai.service';
import { AssistantService } from './assistant.service';
import { AiService } from '../ai/ai.service';
import { AssistantService } from '../assistant.service';
import {
MessageContentText,
MessageCreateParams,
Run,
ThreadMessage,
} from 'openai/resources/beta/threads';
import { RunService } from './run.service';
import { RunService } from '../run/run.service';

@Injectable()
export class ChatbotService {
Expand Down
31 changes: 18 additions & 13 deletions src/assistant/run.service.ts → src/assistant/run/run.service.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Injectable } from '@nestjs/common';
import { Run, RunSubmitToolOutputsParams } from 'openai/resources/beta/threads';
import { AiService } from './ai.service';
import { AgentService } from './agent.service';
import { AiService } from '../ai/ai.service';
import { AgentService } from '../agent/agent.service';

@Injectable()
export class RunService {
Expand All @@ -13,6 +13,11 @@ export class RunService {
private readonly agentsService: AgentService,
) {}

async continueRun(run: Run): Promise<Run> {
await new Promise(resolve => setTimeout(resolve, this.timeout));
return this.threads.runs.retrieve(run.thread_id, run.id);
}

async resolve(run: Run): Promise<void> {
while (true)
switch (run.status) {
Expand All @@ -24,10 +29,10 @@ export class RunService {
return;
case 'requires_action':
await this.submitAction(run);
run = await this.continueRun(run);
continue;
default:
await new Promise(resolve => setTimeout(resolve, this.timeout));
run = await this.threads.runs.retrieve(run.thread_id, run.id);
run = await this.continueRun(run);
}
}

Expand All @@ -37,15 +42,15 @@ export class RunService {
}

const toolCalls = run.required_action.submit_tool_outputs.tool_calls || [];
const outputs: RunSubmitToolOutputsParams.ToolOutput[] = [];

for (const toolCall of toolCalls) {
const { name, arguments: arg } = toolCall.function;
const agent = this.agentsService.get(name);
const output = await agent(arg);

outputs.push({ tool_call_id: toolCall.id, output });
}
const outputs: RunSubmitToolOutputsParams.ToolOutput[] = await Promise.all(
toolCalls.map(async toolCall => {
const { name, arguments: params } = toolCall.function;
const agent = this.agentsService.get(name);
const output = await agent({ params, threadId: run.thread_id });

return { tool_call_id: toolCall.id, output };
}),
);

await this.threads.runs.submitToolOutputs(run.thread_id, run.id, {
tool_outputs: outputs,
Expand Down
7 changes: 7 additions & 0 deletions src/chat/agents/agents.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { Module } from '@nestjs/common';
import { PokemonModule } from './pokemon/pokemon.module';

@Module({
imports: [PokemonModule],
})
export class AgentsModule {}
46 changes: 46 additions & 0 deletions src/chat/agents/pokemon/get-pokemon.agent.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { Injectable } from '@nestjs/common';
import { AssistantCreateParams } from 'openai/resources/beta';
import { AgentData } from '../../../assistant/agent/agent.model';
import { GetPokemonParamsDto } from './get-pokemon.model';
import { AgentService } from '../../../assistant/agent/agent.service';
import { PokemonService } from './pokemon.service';
import { AgentBase } from '../../../assistant/agent/agent.base';

@Injectable()
export class GetPokemonAgent extends AgentBase {
definition: AssistantCreateParams.AssistantToolsFunction = {
type: 'function',
function: {
name: 'getPokemon',
description: 'Get pokemon stats and types',
parameters: {
type: 'object',
properties: {
name: {
type: 'string',
description: 'The name of the pokemon provided by user',
},
},
required: ['name'],
},
},
};

constructor(
protected readonly agentService: AgentService,
private readonly pokemonService: PokemonService,
) {
super(agentService);
}

async output(data: AgentData): Promise<string> {
try {
const parsedData = JSON.parse(data?.params) as GetPokemonParamsDto;
const pokemon = await this.pokemonService.getPokemon(parsedData?.name);

return JSON.stringify(pokemon);
} catch (errors) {
return `Invalid data: ${JSON.stringify(errors)}`;
}
}
}
6 changes: 6 additions & 0 deletions src/chat/agents/pokemon/get-pokemon.model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import { IsNotEmpty } from 'class-validator';

export class GetPokemonParamsDto {
@IsNotEmpty()
name: string;
}
12 changes: 12 additions & 0 deletions src/chat/agents/pokemon/pokemon.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { Module } from '@nestjs/common';
import { HttpModule } from '@nestjs/axios';
import { ConfigModule } from '@nestjs/config';
import { PokemonService } from './pokemon.service';
import { GetPokemonAgent } from './get-pokemon.agent';
import { AgentModule } from '../../../assistant/agent/agent.module';

@Module({
imports: [ConfigModule, HttpModule, AgentModule],
providers: [PokemonService, GetPokemonAgent],
})
export class PokemonModule {}
22 changes: 22 additions & 0 deletions src/chat/agents/pokemon/pokemon.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { Injectable } from '@nestjs/common';
import { HttpService } from '@nestjs/axios';
import { firstValueFrom, map } from 'rxjs';
import { ConfigService } from '@nestjs/config';

@Injectable()
export class PokemonService {
apiUrl = this.configService.get('POKEMON_API_URL');

constructor(
private readonly httpService: HttpService,
private readonly configService: ConfigService,
) {}

async getPokemon(name: string): Promise<string> {
return firstValueFrom(
this.httpService
.get(`${this.apiUrl}/pokemon/${name.toLowerCase()}`)
.pipe(map(res => res.data)),
);
}
}
Loading

0 comments on commit b61dfe9

Please sign in to comment.