Skip to content

Commit

Permalink
create a new context for AI Root Nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
netroy committed Dec 17, 2024
1 parent ad39243 commit c8006fb
Show file tree
Hide file tree
Showing 59 changed files with 788 additions and 369 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import type { BaseOutputParser } from '@langchain/core/output_parsers';
import { PromptTemplate } from '@langchain/core/prompts';
import { initializeAgentExecutorWithOptions } from 'langchain/agents';
import { CombiningOutputParser } from 'langchain/output_parsers';
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
import { NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import { NodeOperationError } from 'n8n-workflow';

import { isChatInstance, getPromptInputByType, getConnectedTools } from '@utils/helpers';
import { getOptionalOutputParsers } from '@utils/output_parsers/N8nOutputParser';
Expand All @@ -18,15 +17,13 @@ export async function conversationalAgentExecute(
nodeVersion: number,
): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing Conversational Agent');
const model = await this.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0);
const model = await this.aiRootNodeContext.getModel();

if (!isChatInstance(model)) {
throw new NodeOperationError(this.getNode(), 'Conversational Agent requires Chat Model');
}

const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
const memory = await this.aiRootNodeContext.getMemory();

const tools = await getConnectedTools(this, nodeVersion >= 1.5, true, true);
const outputParsers = await getOptionalOutputParsers(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@ import { PromptTemplate } from '@langchain/core/prompts';
import { ChatOpenAI } from '@langchain/openai';
import type { AgentExecutorInput } from 'langchain/agents';
import { AgentExecutor, OpenAIAgent } from 'langchain/agents';
import { BufferMemory, type BaseChatMemory } from 'langchain/memory';
import { BufferMemory } from 'langchain/memory';
import { CombiningOutputParser } from 'langchain/output_parsers';
import {
type IExecuteFunctions,
type INodeExecutionData,
NodeConnectionType,
NodeOperationError,
} from 'n8n-workflow';
import { type IExecuteFunctions, type INodeExecutionData, NodeOperationError } from 'n8n-workflow';

import { getConnectedTools, getPromptInputByType } from '@utils/helpers';
import { getOptionalOutputParsers } from '@utils/output_parsers/N8nOutputParser';
Expand All @@ -23,20 +18,15 @@ export async function openAiFunctionsAgentExecute(
nodeVersion: number,
): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing OpenAi Functions Agent');
const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as ChatOpenAI;
const model = await this.aiRootNodeContext.getModel();

if (!(model instanceof ChatOpenAI)) {
throw new NodeOperationError(
this.getNode(),
'OpenAI Functions Agent requires OpenAI Chat Model',
);
}
const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
const memory = await this.aiRootNodeContext.getMemory();
const tools = await getConnectedTools(this, nodeVersion >= 1.5, false);
const outputParsers = await getOptionalOutputParsers(this);
const options = this.getNodeParameter('options', 0, {}) as {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { BaseOutputParser } from '@langchain/core/output_parsers';
import { PromptTemplate } from '@langchain/core/prompts';
import { PlanAndExecuteAgentExecutor } from 'langchain/experimental/plan_and_execute';
import { CombiningOutputParser } from 'langchain/output_parsers';
import {
type IExecuteFunctions,
type INodeExecutionData,
NodeConnectionType,
NodeOperationError,
} from 'n8n-workflow';
import { type IExecuteFunctions, type INodeExecutionData, NodeOperationError } from 'n8n-workflow';

import { getConnectedTools, getPromptInputByType } from '@utils/helpers';
import { getOptionalOutputParsers } from '@utils/output_parsers/N8nOutputParser';
Expand All @@ -22,11 +16,7 @@ export async function planAndExecuteAgentExecute(
nodeVersion: number,
): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing PlanAndExecute Agent');
const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseChatModel;

const model = await this.aiRootNodeContext.getModel();
const tools = await getConnectedTools(this, nodeVersion >= 1.5, true, true);

await checkForStructuredTools(tools, this.getNode(), 'Plan & Execute Agent');
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { BaseOutputParser } from '@langchain/core/output_parsers';
import { PromptTemplate } from '@langchain/core/prompts';
import { AgentExecutor, ChatAgent, ZeroShotAgent } from 'langchain/agents';
import { CombiningOutputParser } from 'langchain/output_parsers';
import {
type IExecuteFunctions,
type INodeExecutionData,
NodeConnectionType,
NodeOperationError,
} from 'n8n-workflow';
import { type IExecuteFunctions, type INodeExecutionData, NodeOperationError } from 'n8n-workflow';

import { getConnectedTools, getPromptInputByType, isChatInstance } from '@utils/helpers';
import { getOptionalOutputParsers } from '@utils/output_parsers/N8nOutputParser';
Expand All @@ -24,9 +17,7 @@ export async function reActAgentAgentExecute(
): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing ReAct Agent');

const model = (await this.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0)) as
| BaseLanguageModel
| BaseChatModel;
const model = await this.aiRootNodeContext.getModel();

const tools = await getConnectedTools(this, nodeVersion >= 1.5, true, true);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import type { DataSource } from '@n8n/typeorm';
import type { SqlCreatePromptArgs } from 'langchain/agents/toolkits/sql';
import { SqlToolkit, createSqlAgent } from 'langchain/agents/toolkits/sql';
import { SqlDatabase } from 'langchain/sql_db';
import {
type IExecuteFunctions,
type INodeExecutionData,
NodeConnectionType,
NodeOperationError,
type IDataObject,
} from 'n8n-workflow';
Expand All @@ -31,10 +28,8 @@ export async function sqlAgentAgentExecute(
): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing SQL Agent');

const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;
const model = await this.aiRootNodeContext.getModel();

const items = this.getInputData();

const returnData: INodeExecutionData[] = [];
Expand Down Expand Up @@ -113,9 +108,7 @@ export async function sqlAgentAgentExecute(
const toolkit = new SqlToolkit(dbInstance, model);
const agentExecutor = createSqlAgent(model, toolkit, agentOptions);

const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
const memory = await this.aiRootNodeContext.getMemory();

agentExecutor.memory = memory;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import { HumanMessage } from '@langchain/core/messages';
import type { BaseMessage } from '@langchain/core/messages';
import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts';
Expand All @@ -9,9 +8,8 @@ import { DynamicStructuredTool } from '@langchain/core/tools';
import type { AgentAction, AgentFinish } from 'langchain/agents';
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
import { omit } from 'lodash';
import { BINARY_ENCODING, jsonParse, NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
import type { ZodObject } from 'zod';
import { BINARY_ENCODING, jsonParse, NodeOperationError } from 'n8n-workflow';
import type { IExecuteFunctions, INodeExecutionData, ZodObjectAny } from 'n8n-workflow';
import { z } from 'zod';

import { isChatInstance, getPromptInputByType, getConnectedTools } from '@utils/helpers';
Expand All @@ -22,9 +20,8 @@ import {

import { SYSTEM_MESSAGE } from './prompt';

function getOutputParserSchema(outputParser: N8nOutputParser): ZodObject<any, any, any, any> {
const schema =
(outputParser.getSchema() as ZodObject<any, any, any, any>) ?? z.object({ text: z.string() });
function getOutputParserSchema(outputParser: N8nOutputParser): ZodObjectAny {
const schema = (outputParser.getSchema() as ZodObjectAny) ?? z.object({ text: z.string() });

return schema;
}
Expand Down Expand Up @@ -98,7 +95,7 @@ function fixEmptyContentMessage(steps: AgentFinish | AgentAction[]) {

export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing Tools Agent');
const model = await this.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0);
const model = await this.aiRootNodeContext.getModel();

if (!isChatInstance(model) || !model.bindTools) {
throw new NodeOperationError(
Expand All @@ -107,9 +104,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
);
}

const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BaseChatMemory
| undefined;
const memory = await this.aiRootNodeContext.getMemory();

const tools = (await getConnectedTools(this, true, false)) as Array<DynamicStructuredTool | Tool>;
const outputParser = (await getOptionalOutputParsers(this))?.[0];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import type { BaseOutputParser } from '@langchain/core/output_parsers';
import type { DynamicStructuredTool, Tool } from 'langchain/tools';
import { NodeOperationError, type IExecuteFunctions, type INode } from 'n8n-workflow';
import type { z } from 'zod';

type ZodObjectAny = z.ZodObject<any, any, any, any>;
import type { IExecuteFunctions, INode, ZodObjectAny } from 'n8n-workflow';
import { NodeOperationError } from 'n8n-workflow';

export async function extractParsedOutput(
ctx: IExecuteFunctions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ async function getImageMessage(
}

const bufferData = await context.helpers.getBinaryDataBuffer(itemIndex, binaryDataKey);
const model = (await context.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;

const model = await context.aiRootNodeContext.getModel();
const dataURI = `data:image/jpeg;base64,${bufferData.toString('base64')}`;

const directUriModels = [ChatGoogleGenerativeAI, ChatOllama];
Expand All @@ -108,7 +106,7 @@ async function getImageMessage(
async function getChainPromptTemplate(
context: IExecuteFunctions,
itemIndex: number,
llm: BaseLanguageModel | BaseChatModel,
model: BaseLanguageModel | BaseChatModel,
messages?: MessagesTemplate[],
formatInstructions?: string,
query?: string,
Expand All @@ -119,7 +117,7 @@ async function getChainPromptTemplate(
partialVariables: formatInstructions ? { formatInstructions } : undefined,
});

if (isChatInstance(llm)) {
if (isChatInstance(model)) {
const parsedMessages = await Promise.all(
(messages ?? []).map(async (message) => {
const messageClass = [
Expand Down Expand Up @@ -166,12 +164,12 @@ async function getChainPromptTemplate(

async function createSimpleLLMChain(
context: IExecuteFunctions,
llm: BaseLanguageModel,
model: BaseLanguageModel,
query: string,
prompt: ChatPromptTemplate | PromptTemplate,
): Promise<string[]> {
const chain = new LLMChain({
llm,
llm: model,
prompt,
}).withConfig(getTracingConfig(context));

Expand All @@ -187,22 +185,22 @@ async function getChain(
context: IExecuteFunctions,
itemIndex: number,
query: string,
llm: BaseLanguageModel,
model: BaseLanguageModel,
outputParsers: N8nOutputParser[],
messages?: MessagesTemplate[],
): Promise<unknown[]> {
const chatTemplate: ChatPromptTemplate | PromptTemplate = await getChainPromptTemplate(
context,
itemIndex,
llm,
model,
messages,
undefined,
query,
);

// If there are no output parsers, create a simple LLM chain and execute the query
if (!outputParsers.length) {
return await createSimpleLLMChain(context, llm, query, chatTemplate);
return await createSimpleLLMChain(context, model, query, chatTemplate);
}

// If there's only one output parser, use it; otherwise, create a combined output parser
Expand All @@ -215,13 +213,13 @@ async function getChain(
const prompt = await getChainPromptTemplate(
context,
itemIndex,
llm,
model,
messages,
formatInstructions,
query,
);

const chain = prompt.pipe(llm).pipe(combinedOutputParser);
const chain = prompt.pipe(model).pipe(combinedOutputParser);
const response = (await chain.withConfig(getTracingConfig(context)).invoke({ query })) as
| string
| string[];
Expand Down Expand Up @@ -515,10 +513,8 @@ export class ChainLlm implements INodeType {
const items = this.getInputData();

const returnData: INodeExecutionData[] = [];
const llm = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;

const model = await this.aiRootNodeContext.getModel();

const outputParsers = await getOptionalOutputParsers(this);

Expand All @@ -545,7 +541,7 @@ export class ChainLlm implements INodeType {
throw new NodeOperationError(this.getNode(), "The 'prompt' parameter is empty.");
}

const responses = await getChain(this, itemIndex, prompt, llm, outputParsers, messages);
const responses = await getChain(this, itemIndex, prompt, model, outputParsers, messages);

responses.forEach((response) => {
let data: IDataObject;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
} from '@langchain/core/prompts';
import type { BaseRetriever } from '@langchain/core/retrievers';
import { RetrievalQAChain } from 'langchain/chains';
import {
NodeConnectionType,
Expand Down Expand Up @@ -161,15 +159,8 @@ export class ChainRetrievalQa implements INodeType {
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
this.logger.debug('Executing Retrieval QA Chain');

const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;

const retriever = (await this.getInputConnectionData(
NodeConnectionType.AiRetriever,
0,
)) as BaseRetriever;
const model = await this.aiRootNodeContext.getModel();
const retriever = await this.aiRootNodeContext.getRetriever();

const items = this.getInputData();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { Document } from '@langchain/core/documents';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { PromptTemplate } from '@langchain/core/prompts';
import type { SummarizationChainParams } from 'langchain/chains';
import { loadSummarizationChain } from 'langchain/chains';
Expand Down Expand Up @@ -166,14 +165,11 @@ export class ChainSummarizationV1 implements INodeType {
this.logger.debug('Executing Vector Store QA Chain');
const type = this.getNodeParameter('type', 0) as 'map_reduce' | 'stuff' | 'refine';

const model = (await this.getInputConnectionData(
NodeConnectionType.AiLanguageModel,
0,
)) as BaseLanguageModel;
const model = await this.aiRootNodeContext.getModel();

const documentInput = (await this.getInputConnectionData(NodeConnectionType.AiDocument, 0)) as
| N8nJsonLoader
| Array<Document<Record<string, unknown>>>;
const documentInput = await this.aiRootNodeContext.getDocument<
N8nJsonLoader | Array<Document<Record<string, unknown>>>
>();

const options = this.getNodeParameter('options', 0, {}) as {
prompt?: string;
Expand Down
Loading

0 comments on commit c8006fb

Please sign in to comment.