Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ai): add message renderer #5873

Merged
merged 7 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .changeset/heavy-dots-applaud.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
"@aws-amplify/ui-react-ai": minor
---

feat(ai): add message renderer

```tsx
<AIConversation
messages={messages}
handleSendMessage={sendMessage}
isLoading={isLoading}
messageRenderer={{
text: (message) => <ReactMarkdown>{message}</ReactMarkdown>,
}}
/>
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import amplifyOutputs from '@environments/ai/gen2/amplify_outputs';
export default amplifyOutputs;
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { Amplify } from 'aws-amplify';
import { createAIHooks, AIConversation } from '@aws-amplify/ui-react-ai';
import { generateClient } from 'aws-amplify/api';
import '@aws-amplify/ui-react/styles.css';
import '@aws-amplify/ui-react-ai/ai-conversation-styles.css';

import outputs from './amplify_outputs';
import type { Schema } from '@environments/ai/gen2/amplify/data/resource';
import { Authenticator, Card, Text } from '@aws-amplify/ui-react';
import Image from 'next/image';

const client = generateClient<Schema>({ authMode: 'userPool' });
const { useAIConversation } = createAIHooks(client);

Amplify.configure(outputs);

function arrayBufferToBase64(buffer: ArrayBuffer) {
let binary = '';
const bytes = new Uint8Array(buffer);
const len = bytes.byteLength;
for (let i = 0; i < len; i++) {
binary += String.fromCharCode(bytes[i]);
}
return window.btoa(binary);
}

function convertBufferToBase64(buffer: ArrayBuffer, format: string): string {
let base64string = '';
// Use node-based buffer if available
// fall back on browser if not
if (typeof Buffer !== 'undefined') {
base64string = Buffer.from(new Uint8Array(buffer)).toString('base64');
} else {
base64string = arrayBufferToBase64(buffer);
}
return `data:image/${format};base64,${base64string}`;
}

function Chat() {
const [
{
data: { messages },
isLoading,
},
sendMessage,
] = useAIConversation('pirateChat');

return (
<Card variation="outlined" width="50%" height="300px" margin="0 auto">
<AIConversation
messages={messages}
handleSendMessage={sendMessage}
isLoading={isLoading}
allowAttachments
messageRenderer={{
text: (message) => <Text className="testing">{message}</Text>,
image: (image) => (
<Image
className="testing"
width={200}
height={200}
src={convertBufferToBase64(image.source.bytes, image.format)}
alt=""
/>
),
}}
suggestedPrompts={[
{
inputText: 'hello',
header: 'hello',
},
{
inputText: 'how are you?',
header: 'how are you?',
},
]}
variant="bubble"
/>
</Card>
);
}

export default function Example() {
return (
<Authenticator>
<Chat />
</Authenticator>
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function AIConversationBase({
isLoading,
displayText,
allowAttachments,
messageRenderer,
}: AIConversationBaseProps): JSX.Element {
useSetUserAgent({
componentName: 'AIConversation',
Expand Down Expand Up @@ -78,6 +79,7 @@ function AIConversationBase({
},
displayText,
allowAttachments,
messageRenderer,
};

return (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { createContextUtilities } from '@aws-amplify/ui-react-core';
import { MessageRenderer } from '../types';

export const {
MessageRendererContext,
MessageRendererProvider,
useMessageRenderer,
} = createContextUtilities<MessageRenderer>({
contextName: 'MessageRenderer',
defaultValue: undefined,
errorMessage:
'`useMessageRenderer` must be used with an AIConversation component',
});
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,11 @@ export {
RESPONSE_COMPONENT_PREFIX,
} from './ResponseComponentsContext';
export { SendMessageContextProvider } from './SendMessageContext';
export {
MessageRendererProvider,
MessageRendererContext,
useMessageRenderer,
} from './MessageRenderContext';
export { AttachmentProvider, AttachmentContext } from './AttachmentContext';

export * from './elements';
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export function createAIConversation(input: AIConversationInput = {}): {
controls,
displayText,
allowAttachments,
messageRenderer,
} = input;

function AIConversation(props: AIConversationProps): JSX.Element {
Expand All @@ -48,6 +49,7 @@ export function createAIConversation(input: AIConversationInput = {}): {
avatars,
handleSendMessage,
isLoading,
messageRenderer,
};
return (
<AIConversationProvider {...providerProps}>
Expand Down
8 changes: 7 additions & 1 deletion packages/react-ai/src/components/AIConversation/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
} from './views';
import { DisplayTextTemplate } from '@aws-amplify/ui';
import { AIConversationDisplayText } from './displayText';
import { ConversationMessage, SendMessage } from '../../types';
import { ConversationMessage, ImageContent, SendMessage } from '../../types';
import { ControlsContextProps } from './context/ControlsContext';

export interface Controls {
Expand All @@ -32,6 +32,7 @@ export interface AIConversationInput {
variant?: MessageVariant;
controls?: ControlsContextProps;
allowAttachments?: boolean;
messageRenderer?: MessageRenderer;
}

export interface AIConversationProps {
Expand All @@ -54,6 +55,11 @@ export interface AIConversation {

export type MessageVariant = 'bubble' | 'default';

export interface MessageRenderer {
text?: (message: string) => React.JSX.Element;
calebpollman marked this conversation as resolved.
Show resolved Hide resolved
image?: (image: ImageContent) => React.JSX.Element;
}

export interface Avatar {
username?: string;
avatar?: React.ReactNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import React from 'react';
import { withBaseElementProps } from '@aws-amplify/ui-react-core/elements';

import {
MessageRendererContext,
MessagesContext,
MessageVariantContext,
RoleContext,
Expand Down Expand Up @@ -63,25 +64,30 @@ const ContentContainer: typeof View = React.forwardRef(

export const MessageControl: MessageControl = ({ message }) => {
const responseComponents = React.useContext(ResponseComponentsContext);
const messageRenderer = React.useContext(MessageRendererContext);
return (
<ContentContainer>
{message.content.map((content, index) => {
if (content.text) {
return (
return messageRenderer?.text ? (
messageRenderer?.text(content.text)
) : (
<TextContent data-testid={'text-content'} key={index}>
{content.text}
</TextContent>
);
} else if (content.image) {
return (
return messageRenderer?.image ? (
messageRenderer?.image(content.image)
) : (
<MediaContent
data-testid={'image-content'}
key={index}
src={convertBufferToBase64(
content.image?.source.bytes,
content.image?.format
)}
></MediaContent>
/>
);
} else if (content.toolUse) {
// For now tool use is limited to custom response components
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { MessagesControl, MessageControl } from '../MessagesControl';
import { convertBufferToBase64 } from '../../../utils';
import { ConversationMessage } from '../../../../../types';
import { ResponseComponentsProvider } from '../../../context/ResponseComponentsContext';
import { MessageRendererProvider } from '../../../context';

const AITextMessage: ConversationMessage = {
conversationId: 'foobar',
Expand Down Expand Up @@ -387,4 +388,28 @@ describe('MessageControl', () => {
const { container } = render(<MessageControl message={ToolUseMessage} />);
expect(container.firstChild).toBeEmptyDOMElement();
});

it('uses text message renderer if passed', () => {
render(
<MessageRendererProvider
text={(message) => <div data-testid="custom-message">{message}</div>}
>
<MessageControl message={AITextMessage} />
</MessageRendererProvider>
);
const message = screen.getByTestId('custom-message');
expect(message).toBeInTheDocument();
});

it('uses image message renderer if passed', () => {
render(
<MessageRendererProvider
image={() => <img data-testid="custom-message" />}
>
<MessageControl message={AIImageMessage} />
</MessageRendererProvider>
);
const message = screen.getByTestId('custom-message');
expect(message).toBeInTheDocument();
});
});
Loading