Skip to content

Commit 9c13876

Browse files
authored
chore(generative-ai): add query and aggregation accuracy tests COMPASS-7850 (#5692)
1 parent f00ff74 commit 9c13876

File tree

1 file changed

+104
-2
lines changed

1 file changed

+104
-2
lines changed

packages/compass-generative-ai/scripts/ai-accuracy-tests.ts

+104-2
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ type TestOptions = {
245245
collectionName: string;
246246
includeSampleDocuments?: boolean;
247247
userInput: string;
248+
// When supplied, this overrides the general test accuracy requirement. (0-1)
249+
minAccuracyForTest?: number;
248250
assertResponse?: (responseContent: unknown) => Promise<void>;
249251
assertResult?: (responseContent: Document[]) => Promise<void> | void;
250252
acceptAggregationResponse?: boolean;
@@ -509,7 +511,7 @@ async function pushResultsToDB({
509511
}
510512
}
511513

512-
const tests = [
514+
const tests: TestOptions[] = [
513515
{
514516
type: 'query',
515517
databaseName: 'netflix',
@@ -750,6 +752,106 @@ const tests = [
750752
},
751753
]),
752754
},
755+
{
756+
type: 'aggregation',
757+
databaseName: 'sample_airbnb',
758+
collectionName: 'listingsAndReviews',
759+
// TODO(COMPASS-7763): GPT-4 generates better results for this input.
760+
// When we've swapped over we can increase the accuracy for this test.
761+
// For now it will be giving low accuracy. gpt-3.5-turbo usually tries to
762+
// use $expr in a $project stage which is not valid syntax.
763+
minAccuracyForTest: 0,
764+
userInput:
765+
'what percentage of listings have a "Washer" in their amenities? Only consider listings with more than 2 beds. Return is as a string named "washerPercentage" like "75%", rounded to the nearest whole number.',
766+
assertResult: anyOf([
767+
isDeepStrictEqualTo([
768+
{
769+
_id: null,
770+
tvPercentage: '67%',
771+
},
772+
]),
773+
isDeepStrictEqualTo([
774+
{
775+
tvPercentage: '67%',
776+
},
777+
]),
778+
]),
779+
},
780+
781+
{
782+
type: 'query',
783+
databaseName: 'NYC',
784+
collectionName: 'parking_2015',
785+
// TODO(COMPASS-7763): GPT-4 generates better results for this input.
786+
// When we've swapped over we can increase the accuracy for this test.
787+
// For now it will be giving low accuracy.
788+
minAccuracyForTest: 0.5,
789+
userInput:
790+
'Write a query that does the following: "find all of the parking incidents that occurred on an ave (match all ways to write ave). Give me an array of all of the plate ids involved, in an object with their summons number and vehicle make and body type. Put the vehicle make and body type into lower case. No _id, sorted by the summons number lowest first.',
791+
assertResult: anyOf([
792+
isDeepStrictEqualTo([
793+
{
794+
'Summons Number': {
795+
$numberLong: '7093881087',
796+
},
797+
'Plate ID': 'FPG1269',
798+
'Vehicle Make': 'gmc',
799+
'Vehicle Body Type': 'subn',
800+
},
801+
{
802+
'Summons Number': {
803+
$numberLong: '7623830399',
804+
},
805+
'Plate ID': 'T645263C',
806+
'Vehicle Make': 'chevr',
807+
'Vehicle Body Type': 'subn',
808+
},
809+
{
810+
'Summons Number': {
811+
$numberLong: '7721537642',
812+
},
813+
'Plate ID': 'GMX1207',
814+
'Vehicle Make': 'honda',
815+
'Vehicle Body Type': '4dsd',
816+
},
817+
{
818+
'Summons Number': {
819+
$numberLong: '7784786281',
820+
},
821+
'Plate ID': 'DRW5164',
822+
'Vehicle Make': 'acura',
823+
'Vehicle Body Type': '4dsd',
824+
},
825+
]),
826+
827+
isDeepStrictEqualTo([
828+
{
829+
'Summons Number': 7093881087,
830+
'Plate ID': 'FPG1269',
831+
'Vehicle Make': 'gmc',
832+
'Vehicle Body Type': 'subn',
833+
},
834+
{
835+
'Summons Number': 7623830399,
836+
'Plate ID': 'T645263C',
837+
'Vehicle Make': 'chevr',
838+
'Vehicle Body Type': 'subn',
839+
},
840+
{
841+
'Summons Number': 7721537642,
842+
'Plate ID': 'GMX1207',
843+
'Vehicle Make': 'honda',
844+
'Vehicle Body Type': '4dsd',
845+
},
846+
{
847+
'Summons Number': 7784786281,
848+
'Plate ID': 'DRW5164',
849+
'Vehicle Make': 'acura',
850+
'Vehicle Body Type': '4dsd',
851+
},
852+
]),
853+
]),
854+
},
753855
];
754856
async function main() {
755857
try {
@@ -771,7 +873,7 @@ async function main() {
771873
// usageStats
772874
} = await runTest(test);
773875
const minAccuracy = DEFAULT_MIN_ACCURACY;
774-
const failed = accuracy < minAccuracy;
876+
const failed = accuracy < (test.minAccuracyForTest ?? minAccuracy);
775877

776878
results.push({
777879
Type: test.type.slice(0, 1).toUpperCase(),

0 commit comments

Comments
 (0)