Skip to content

Commit

Permalink
Add Cpp Context Traits to Completions Prompt (#12821)
Browse files Browse the repository at this point in the history
- Move related files code to its own module
- Refactor related files provider code to enable unit testing
- Add Cpp context traits to completions prompt
  • Loading branch information
kuchungmsft authored Oct 14, 2024
1 parent f09715f commit 05c9176
Show file tree
Hide file tree
Showing 4 changed files with 512 additions and 82 deletions.
134 changes: 134 additions & 0 deletions Extension/src/LanguageServer/copilotProviders.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/* --------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All Rights Reserved.
* See 'LICENSE' in the project root for license information.
* ------------------------------------------------------------------------------------------ */
'use strict';

import * as vscode from 'vscode';
import * as util from '../common';
import * as telemetry from '../telemetry';
import { ChatContextResult, GetIncludesResult } from './client';
import { getActiveClient } from './extension';

let isRelatedFilesApiEnabled: boolean | undefined;

export interface CopilotTrait {
name: string;
value: string;
includeInPrompt?: boolean;
promptTextOverride?: string;
}

export interface CopilotApi {
registerRelatedFilesProvider(
providerId: { extensionId: string; languageId: string },
callback: (
uri: vscode.Uri,
context: { flags: Record<string, unknown> },
cancellationToken: vscode.CancellationToken
) => Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] }>
): Disposable;
}

export async function registerRelatedFilesProvider(): Promise<void> {
if (!await getIsRelatedFilesApiEnabled()) {
return;
}

const api = await getCopilotApi();
if (util.extensionContext && api) {
try {
for (const languageId of ['c', 'cpp', 'cuda-cpp']) {
api.registerRelatedFilesProvider(
{ extensionId: util.extensionContext.extension.id, languageId },
async (_uri: vscode.Uri, context: { flags: Record<string, unknown> }, token: vscode.CancellationToken) => {

const getIncludesHandler = async () => (await getIncludesWithCancellation(1, token))?.includedFiles.map(file => vscode.Uri.file(file)) ?? [];
const getTraitsHandler = async () => {
const chatContext: ChatContextResult | undefined = await (getActiveClient().getChatContext(token) ?? undefined);

if (!chatContext) {
return undefined;
}

let traits: CopilotTrait[] = [
{ name: "language", value: chatContext.language, includeInPrompt: true, promptTextOverride: `The language is ${chatContext.language}.` },
{ name: "compiler", value: chatContext.compiler, includeInPrompt: true, promptTextOverride: `This project compiles using ${chatContext.compiler}.` },
{ name: "standardVersion", value: chatContext.standardVersion, includeInPrompt: true, promptTextOverride: `This project uses the ${chatContext.standardVersion} language standard.` },
{ name: "targetPlatform", value: chatContext.targetPlatform, includeInPrompt: true, promptTextOverride: `This build targets ${chatContext.targetPlatform}.` },
{ name: "targetArchitecture", value: chatContext.targetArchitecture, includeInPrompt: true, promptTextOverride: `This build targets ${chatContext.targetArchitecture}.` }
];

const excludeTraits = context.flags.copilotcppExcludeTraits as string[] ?? [];
traits = traits.filter(trait => !excludeTraits.includes(trait.name));

return traits.length > 0 ? traits : undefined;
};

// Call both handlers in parallel
const traitsPromise = ((context.flags.copilotcppTraits as boolean) ?? false) ? getTraitsHandler() : Promise.resolve(undefined);
const includesPromise = getIncludesHandler();

return { entries: await includesPromise, traits: await traitsPromise };
}
);
}
} catch {
console.log("Failed to register Copilot related files provider.");
}
}
}

export async function registerRelatedFilesCommands(commandDisposables: vscode.Disposable[], enabled: boolean): Promise<void> {
if (await getIsRelatedFilesApiEnabled()) {
commandDisposables.push(vscode.commands.registerCommand('C_Cpp.getIncludes', enabled ? (maxDepth: number) => getIncludes(maxDepth) : () => Promise.resolve()));
}
}

async function getIncludesWithCancellation(maxDepth: number, token: vscode.CancellationToken): Promise<GetIncludesResult> {
const activeClient = getActiveClient();
const includes = await activeClient.getIncludes(maxDepth, token);
const wksFolder = activeClient.RootUri?.toString();

if (!wksFolder) {
return includes;
}

includes.includedFiles = includes.includedFiles.filter(header => vscode.Uri.file(header).toString().startsWith(wksFolder));
return includes;
}

async function getIncludes(maxDepth: number): Promise<GetIncludesResult> {
const tokenSource = new vscode.CancellationTokenSource();
try {
const includes = await getIncludesWithCancellation(maxDepth, tokenSource.token);
return includes;
} finally {
tokenSource.dispose();
}
}

async function getIsRelatedFilesApiEnabled(): Promise<boolean> {
if (isRelatedFilesApiEnabled === undefined) {
isRelatedFilesApiEnabled = await telemetry.isExperimentEnabled("CppToolsRelatedFilesApi");
}

return isRelatedFilesApiEnabled;
}

export async function getCopilotApi(): Promise<CopilotApi | undefined> {
const copilotExtension = vscode.extensions.getExtension<CopilotApi>('github.copilot');
if (!copilotExtension) {
return undefined;
}

if (!copilotExtension.isActive) {
try {
return await copilotExtension.activate();
} catch {
return undefined;
}
} else {
return copilotExtension.exports;
}
}
86 changes: 6 additions & 80 deletions Extension/src/LanguageServer/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ import * as util from '../common';
import { getCrashCallStacksChannel } from '../logger';
import { PlatformInformation } from '../platform';
import * as telemetry from '../telemetry';
import { Client, DefaultClient, DoxygenCodeActionCommandArguments, GetIncludesResult, openFileVersions } from './client';
import { Client, DefaultClient, DoxygenCodeActionCommandArguments, openFileVersions } from './client';
import { ClientCollection } from './clientCollection';
import { CodeActionDiagnosticInfo, CodeAnalysisDiagnosticIdentifiersAndUri, codeAnalysisAllFixes, codeAnalysisCodeToFixes, codeAnalysisFileToCodeActions } from './codeAnalysis';
import { registerRelatedFilesCommands, registerRelatedFilesProvider } from './copilotProviders';
import { CppBuildTaskProvider } from './cppBuildTaskProvider';
import { getCustomConfigProviders } from './customProviders';
import { getLanguageConfig } from './languageConfig';
Expand All @@ -33,24 +34,6 @@ import { CppSettings } from './settings';
import { LanguageStatusUI, getUI } from './ui';
import { makeLspRange, rangeEquals, showInstallCompilerWalkthrough } from './utils';

interface CopilotTrait {
name: string;
value: string;
includeInPrompt?: boolean;
promptTextOverride?: string;
}

interface CopilotApi {
registerRelatedFilesProvider(
providerId: { extensionId: string; languageId: string },
callback: (
uri: vscode.Uri,
context: { flags: Record<string, unknown> },
cancellationToken: vscode.CancellationToken
) => Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] }>
): Disposable;
}

nls.config({ messageFormat: nls.MessageFormat.bundle, bundleFormat: nls.BundleFormat.standalone })();
const localize: nls.LocalizeFunc = nls.loadMessageBundle();
export const CppSourceStr: string = "C/C++";
Expand Down Expand Up @@ -201,8 +184,7 @@ export async function activate(): Promise<void> {

void clients.ActiveClient.ready.then(() => intervalTimer = global.setInterval(onInterval, 2500));

const isRelatedFilesApiEnabled = await telemetry.isExperimentEnabled("CppToolsRelatedFilesApi");
registerCommands(true, isRelatedFilesApiEnabled);
await registerCommands(true);

vscode.tasks.onDidStartTask(() => getActiveClient().PauseCodeAnalysis());

Expand Down Expand Up @@ -274,22 +256,7 @@ export async function activate(): Promise<void> {
disposables.push(tool);
}

if (isRelatedFilesApiEnabled) {
const api = await getCopilotApi();
if (util.extensionContext && api) {
try {
for (const languageId of ['c', 'cpp', 'cuda-cpp']) {
api.registerRelatedFilesProvider(
{ extensionId: util.extensionContext.extension.id, languageId },
async (_uri: vscode.Uri, _context: { flags: Record<string, unknown> }, token: vscode.CancellationToken) =>
({ entries: (await getIncludesWithCancellation(1, token))?.includedFiles.map(file => vscode.Uri.file(file)) ?? [] })
);
}
} catch {
console.log("Failed to register Copilot related files provider.");
}
}
}
await registerRelatedFilesProvider();
}

export function updateLanguageConfigurations(): void {
Expand Down Expand Up @@ -386,7 +353,7 @@ function onInterval(): void {
/**
* registered commands
*/
export function registerCommands(enabled: boolean, isRelatedFilesApiEnabled: boolean): void {
export async function registerCommands(enabled: boolean): Promise<void> {
commandDisposables.forEach(d => d.dispose());
commandDisposables.length = 0;
commandDisposables.push(vscode.commands.registerCommand('C_Cpp.SwitchHeaderSource', enabled ? onSwitchHeaderSource : onDisabledCommand));
Expand Down Expand Up @@ -445,9 +412,7 @@ export function registerCommands(enabled: boolean, isRelatedFilesApiEnabled: boo
commandDisposables.push(vscode.commands.registerCommand('C_Cpp.ExtractToMemberFunction', enabled ? () => onExtractToFunction(false, true) : onDisabledCommand));
commandDisposables.push(vscode.commands.registerCommand('C_Cpp.ExpandSelection', enabled ? (r: Range) => onExpandSelection(r) : onDisabledCommand));

if (!isRelatedFilesApiEnabled) {
commandDisposables.push(vscode.commands.registerCommand('C_Cpp.getIncludes', enabled ? (maxDepth: number) => getIncludes(maxDepth) : () => Promise.resolve()));
}
await registerRelatedFilesCommands(commandDisposables, enabled);
}

function onDisabledCommand() {
Expand Down Expand Up @@ -1412,42 +1377,3 @@ export async function preReleaseCheck(): Promise<void> {
}
}
}

export async function getIncludesWithCancellation(maxDepth: number, token: vscode.CancellationToken): Promise<GetIncludesResult> {
const includes = await clients.ActiveClient.getIncludes(maxDepth, token);
const wksFolder = clients.ActiveClient.RootUri?.toString();

if (!wksFolder) {
return includes;
}

includes.includedFiles = includes.includedFiles.filter(header => vscode.Uri.file(header).toString().startsWith(wksFolder));
return includes;
}

async function getIncludes(maxDepth: number): Promise<GetIncludesResult> {
const tokenSource = new vscode.CancellationTokenSource();
try {
const includes = await getIncludesWithCancellation(maxDepth, tokenSource.token);
return includes;
} finally {
tokenSource.dispose();
}
}

async function getCopilotApi(): Promise<CopilotApi | undefined> {
const copilotExtension = vscode.extensions.getExtension<CopilotApi>('github.copilot');
if (!copilotExtension) {
return undefined;
}

if (!copilotExtension.isActive) {
try {
return await copilotExtension.activate();
} catch {
return undefined;
}
} else {
return copilotExtension.exports;
}
}
3 changes: 1 addition & 2 deletions Extension/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ export async function activate(context: vscode.ExtensionContext): Promise<CppToo
if (shouldActivateLanguageServer) {
await LanguageServer.activate();
} else if (isIntelliSenseEngineDisabled) {
const isRelatedFilesApiEnabled = await Telemetry.isExperimentEnabled("CppToolsRelatedFilesApi");
LanguageServer.registerCommands(false, isRelatedFilesApiEnabled);
await LanguageServer.registerCommands(false);
// The check here for isIntelliSenseEngineDisabled avoids logging
// the message on old Macs that we've already displayed a warning for.
log(localize("intellisense.disabled", "intelliSenseEngine is disabled"));
Expand Down
Loading

0 comments on commit 05c9176

Please sign in to comment.