diff --git a/Extension/src/LanguageServer/client.ts b/Extension/src/LanguageServer/client.ts index d937cff044..c5c30fe740 100644 --- a/Extension/src/LanguageServer/client.ts +++ b/Extension/src/LanguageServer/client.ts @@ -53,9 +53,10 @@ import { } from './codeAnalysis'; import { Location, TextEdit, WorkspaceEdit } from './commonTypes'; import * as configs from './configurations'; +import { CopilotCompletionContextProvider } from './copilotCompletionContextProvider'; import { DataBinding } from './dataBinding'; import { cachedEditorConfigSettings, getEditorConfigSettings } from './editorConfig'; -import { CppSourceStr, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension'; +import { CppSourceStr, SnippetEntry, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension'; import { LocalizeStringParams, getLocaleId, getLocalizedString } from './localization'; import { PersistentFolderState, PersistentWorkspaceState } from './persistentState'; import { RequestCancelled, ServerCancelled, createProtocolFilter } from './protocolFilter'; @@ -554,6 +555,15 @@ export interface ProjectContextResult { fileContext: FileContextResult; } +export interface CompletionContextsResult { + context: SnippetEntry[]; +} + +export interface CompletionContextParams { + file: string; + caretOffset: number; +} + // Requests const PreInitializationRequest: RequestType = new RequestType('cpptools/preinitialize'); const InitializationRequest: RequestType = new RequestType('cpptools/initialize'); @@ -575,6 +585,7 @@ const ChangeCppPropertiesRequest: RequestType = const IncludesRequest: RequestType = new RequestType('cpptools/getIncludes'); const CppContextRequest: RequestType = new RequestType('cpptools/getChatContext'); const ProjectContextRequest: RequestType = new RequestType('cpptools/getProjectContext'); +const CompletionContextRequest: RequestType = new RequestType('cpptools/getCompletionContext'); // Notifications to the server const DidOpenNotification: NotificationType = new NotificationType('textDocument/didOpen'); @@ -807,6 +818,7 @@ export interface Client { getIncludes(maxDepth: number): Promise; getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise; getProjectContext(uri: vscode.Uri): Promise; + getCompletionContext(fileName: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise; } export function createClient(workspaceFolder?: vscode.WorkspaceFolder): Client { @@ -839,7 +851,7 @@ export class DefaultClient implements Client { private settingsTracker: SettingsTracker; private loggingLevel: number = 1; private configurationProvider?: string; - + private copilotCompletionProvider?: CopilotCompletionContextProvider; public lastCustomBrowseConfiguration: PersistentFolderState | undefined; public lastCustomBrowseConfigurationProviderId: PersistentFolderState | undefined; public lastCustomBrowseConfigurationProviderVersion: PersistentFolderState | undefined; @@ -1298,6 +1310,8 @@ export class DefaultClient implements Client { this.semanticTokensProviderDisposable = vscode.languages.registerDocumentSemanticTokensProvider(util.documentSelector, this.semanticTokensProvider, semanticTokensLegend); } + this.copilotCompletionProvider = await CopilotCompletionContextProvider.Create(); + // Listen for messages from the language server. this.registerNotifications(); @@ -1807,6 +1821,7 @@ export class DefaultClient implements Client { if (diagnosticsCollectionIntelliSense) { diagnosticsCollectionIntelliSense.delete(document.uri); } + this.copilotCompletionProvider?.removeFile(uri); openFileVersions.delete(uri); } @@ -2255,6 +2270,12 @@ export class DefaultClient implements Client { () => this.languageClient.sendRequest(CppContextRequest, params, token), token); } + public async getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise { + await withCancellation(this.ready, token); + return DefaultClient.withLspCancellationHandling( + () => this.languageClient.sendRequest(CompletionContextRequest, { file: file.toString(), caretOffset }, token), token); + } + /** * a Promise that can be awaited to know when it's ok to proceed. * @@ -4159,4 +4180,5 @@ class NullClient implements Client { getIncludes(maxDepth: number): Promise { return Promise.resolve({} as GetIncludesResult); } getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise { return Promise.resolve({} as ChatContextResult); } getProjectContext(uri: vscode.Uri): Promise { return Promise.resolve({} as ProjectContextResult); } + getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise { return Promise.resolve({} as CompletionContextsResult); } } diff --git a/Extension/src/LanguageServer/copilotCompletionContextProvider.ts b/Extension/src/LanguageServer/copilotCompletionContextProvider.ts new file mode 100644 index 0000000000..5c3e67b27f --- /dev/null +++ b/Extension/src/LanguageServer/copilotCompletionContextProvider.ts @@ -0,0 +1,185 @@ +/* -------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All Rights Reserved. + * See 'LICENSE' in the project root for license information. + * ------------------------------------------------------------------------------------------ */ +import * as vscode from 'vscode'; +import { DocumentSelector } from 'vscode-languageserver-protocol'; +import { getOutputChannelLogger, Logger } from '../logger'; +import * as telemetry from '../telemetry'; +import { CopilotContextTelemetry } from './copilotContextTelemetry'; +import { getCopilotApi } from './copilotProviders'; +import { clients } from './extension'; +import { CodeSnippet, CompletionContext, ContextProviderApiV1, ContextResolver } from './tmp/contextProviderV1'; + +class DefaultValueFallback extends Error { + static readonly DefaultValue = "DefaultValue"; + constructor() { super(DefaultValueFallback.DefaultValue); } +} + +class CancellationError extends Error { + static readonly Cancelled = "Cancelled"; + constructor() { super(CancellationError.Cancelled); } +} + +// Mutually exclusive values for the kind of snippets. They either are: +// - computed. +// - obtained from the cache. +// - missing and the computation is taking too long and no cache is present (cache miss). The value +// is asynchronously computed and stored in cache. +// - the token is signaled as cancelled, in which case all the operations are aborted. +// - an unknown state. +enum SnippetsKind { + Computed = 'computed', + GotFromCache = 'gotFromCacheHit', + MissingCacheMiss = 'missingCacheMiss', + Cancelled = 'cancelled', + Unknown = 'unknown' +} + +export class CopilotCompletionContextProvider implements ContextResolver { + private static readonly providerId = 'cppTools'; + private readonly completionContextCache: Map = new Map(); + private static readonly defaultCppDocumentSelector: DocumentSelector = [{ language: 'cpp' }, { language: 'c' }, { language: 'cuda-cpp' }]; + private static readonly defaultTimeBudgetFactor: number = 0.5; + private completionContextCancellation = new vscode.CancellationTokenSource(); + + // Get the default value if the timeout expires, but throws an exception if the token is cancelled. + private async waitForCompletionWithTimeoutAndCancellation(promise: Promise, defaultValue: T | undefined, + timeout: number, token: vscode.CancellationToken): Promise<[T | undefined, SnippetsKind]> { + const defaultValuePromise = new Promise((resolve, reject) => setTimeout(() => { + if (token.isCancellationRequested) { + reject(new CancellationError()); + } else { + reject(new DefaultValueFallback()); + } + }, timeout)); + const cancellationPromise = new Promise((_, reject) => { + token.onCancellationRequested(() => { + reject(new CancellationError()); + }); + }); + let snippetsOrNothing: T | undefined; + try { + snippetsOrNothing = await Promise.race([promise, cancellationPromise, defaultValuePromise]); + } catch (e) { + if (e instanceof DefaultValueFallback) { + return [defaultValue, defaultValue !== undefined ? SnippetsKind.GotFromCache : SnippetsKind.MissingCacheMiss]; + } else if (e instanceof CancellationError) { + return [undefined, SnippetsKind.Cancelled]; + } else { + throw e; + } + } + + return [snippetsOrNothing, SnippetsKind.Computed]; + } + + // Get the completion context with a timeout and a cancellation token. + // The cancellationToken indicates that the value should not be returned nor cached. + private async getCompletionContextWithCancellation(documentUri: string, caretOffset: number, + startTime: number, out: Logger, telemetry: CopilotContextTelemetry, token: vscode.CancellationToken): Promise { + try { + const docUri = vscode.Uri.parse(documentUri); + const snippets = await clients.getClientFor(docUri).getCompletionContext(docUri, caretOffset, token); + + const codeSnippets = snippets.context.map((item) => { + if (token.isCancellationRequested) { + telemetry.addCancelledLate(); + throw new CancellationError(); + } + return { + importance: item.importance, uri: item.uri, value: item.text + }; + }); + + this.completionContextCache.set(documentUri, codeSnippets); + const duration: number = performance.now() - startTime; + out.appendLine(`Copilot: getCompletionContextWithCancellation(): Cached in [ms]: ${duration}`); + telemetry.addSnippetCount(codeSnippets?.length); + telemetry.addCacheComputedElapsed(duration); + + return codeSnippets; + } catch (e) { + const err = e as Error; + out.appendLine(`Copilot: getCompletionContextWithCancellation(): Error: '${err?.message}', stack '${err?.stack}`); + telemetry.addError(); + return []; + } + } + + private async fetchTimeBudgetFactor(context: CompletionContext): Promise { + const budgetFactor = context.activeExperiments.get("CppToolsCopilotTimeBudget"); + return (budgetFactor as number) !== undefined ? budgetFactor as number : CopilotCompletionContextProvider.defaultTimeBudgetFactor; + } + + public static async Create() { + const copilotCompletionProvider = new CopilotCompletionContextProvider(); + await copilotCompletionProvider.registerCopilotContextProvider(); + return copilotCompletionProvider; + } + + public removeFile(fileUri: string): void { + this.completionContextCache.delete(fileUri); + } + + public async resolve(context: CompletionContext, copilotAborts: vscode.CancellationToken): Promise { + const startTime = performance.now(); + const out: Logger = getOutputChannelLogger(); + const timeBudgetFactor = await this.fetchTimeBudgetFactor(context); + const telemetry = new CopilotContextTelemetry(); + let codeSnippets: CodeSnippet[] | undefined; + let codeSnippetsKind: SnippetsKind = SnippetsKind.Unknown; + try { + this.completionContextCancellation.cancel(); + this.completionContextCancellation = new vscode.CancellationTokenSource(); + const docUri = context.documentContext.uri; + const cachedValue: CodeSnippet[] | undefined = this.completionContextCache.get(docUri.toString()); + const snippetsPromise = this.getCompletionContextWithCancellation(docUri, + context.documentContext.offset, startTime, out, telemetry.fork(), this.completionContextCancellation.token); + [codeSnippets, codeSnippetsKind] = await this.waitForCompletionWithTimeoutAndCancellation( + snippetsPromise, cachedValue, context.timeBudget * timeBudgetFactor, copilotAborts); + if (codeSnippetsKind === SnippetsKind.Cancelled) { + const duration: number = performance.now() - startTime; + out.appendLine(`Copilot: getCompletionContext(): cancelled, elapsed time (ms) : ${duration}`); + telemetry.addCancelled(); + telemetry.addCancellationElapsed(duration); + throw new CancellationError(); + } + telemetry.addSnippetCount(codeSnippets?.length); + return codeSnippets ?? []; + } catch (e: any) { + telemetry.addError(); + throw e; + } finally { + telemetry.addKind(codeSnippetsKind.toString()); + const duration: number = performance.now() - startTime; + if (codeSnippets === undefined) { + out.appendLine(`Copilot: getCompletionContext(): no snkppets provided (${codeSnippetsKind.toString()}), elapsed time (ms): ${duration}`); + } else { + out.appendLine(`Copilot: getCompletionContext(): provided ${codeSnippets?.length} snippets (${codeSnippetsKind.toString()}), elapsed time (ms): ${duration}`); + } + telemetry.addResolvedElapsed(duration); + telemetry.addCacheSize(this.completionContextCache.size); + // //?? TODO telemetry.file(); + } + + return []; + } + + public async registerCopilotContextProvider(): Promise { + try { + const isCustomSnippetProviderApiEnabled = await telemetry.isExperimentEnabled("CppToolsCustomSnippetsApi"); + if (isCustomSnippetProviderApiEnabled) { + const contextAPI = (await getCopilotApi() as any).getContextProviderAPI('v1') as ContextProviderApiV1; + contextAPI.registerContextProvider({ + id: CopilotCompletionContextProvider.providerId, + selector: CopilotCompletionContextProvider.defaultCppDocumentSelector, + resolver: this + }); + } + } catch { + console.warn("Failed to register the Copilot Context Provider."); + telemetry.logCopilotEvent("registerCopilotContextProviderError", { "message": "Failed to register the Copilot Context Provider." }); + } + } +} diff --git a/Extension/src/LanguageServer/copilotContextTelemetry.ts b/Extension/src/LanguageServer/copilotContextTelemetry.ts new file mode 100644 index 0000000000..5160ca5fae --- /dev/null +++ b/Extension/src/LanguageServer/copilotContextTelemetry.ts @@ -0,0 +1,72 @@ +/* -------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All Rights Reserved. + * See 'LICENSE' in the project root for license information. + * ------------------------------------------------------------------------------------------ */ +import { randomUUID } from 'crypto'; +import * as telemetry from '../telemetry'; + +export class CopilotContextTelemetry { + private static readonly correlationIdKey = 'correlationId'; + private static readonly copilotEventName = 'copilotContextProvider'; + private readonly metrics: Record = {}; + private readonly properties: Record = {}; + private readonly id: string; + constructor(correlationId?: string) { + this.id = correlationId ?? randomUUID().toString(); + } + + private addMetric(key: string, value: number): void { + this.metrics[key] = value; + } + + private addProperty(key: string, value: string): void { + this.properties[key] = value; + } + + public addCancelled(): void { + this.addProperty('cancelled', 'true'); + } + + public addCancellationElapsed(duration: number): void { + this.addMetric('cancellationElapsedMs', duration); + } + + public addCancelledLate(): void { + this.addProperty('cancelledLate', 'true'); + } + + public addError(): void { + this.addProperty('error', 'true'); + } + + public addKind(snippetsKind: string): void { + this.addProperty('kind', snippetsKind.toString()); + } + + public addResolvedElapsed(duration: number): void { + this.addMetric('overallResolveElapsedMs', duration); + } + + public addCacheSize(size: number): void { + this.addMetric('cacheSize', size); + } + + public addCacheComputedElapsed(duration: number): void { + this.addMetric('cacheComputedElapsedMs', duration); + } + + // count can be undefined, in which case the count is set to -1 to indicate + // snippets are not available (different than having 0 snippets). + public addSnippetCount(count?: number) { + this.addMetric('snippetsCount', count ?? -1); + } + + public file(): void { + this.properties[CopilotContextTelemetry.correlationIdKey] = this.id; + telemetry.logCopilotEvent(CopilotContextTelemetry.copilotEventName, this.properties, this.metrics); + } + + public fork(): CopilotContextTelemetry { + return new CopilotContextTelemetry(this.id); + } +} diff --git a/Extension/src/LanguageServer/extension.ts b/Extension/src/LanguageServer/extension.ts index 8bc64f82f8..37ec03fbce 100644 --- a/Extension/src/LanguageServer/extension.ts +++ b/Extension/src/LanguageServer/extension.ts @@ -34,6 +34,14 @@ import { CppSettings } from './settings'; import { LanguageStatusUI, getUI } from './ui'; import { makeLspRange, rangeEquals, showInstallCompilerWalkthrough } from './utils'; +export interface SnippetEntry { + uri: string; + text: string; + startLine: number; + endLine: number; + importance: number; +} + nls.config({ messageFormat: nls.MessageFormat.bundle, bundleFormat: nls.BundleFormat.standalone })(); const localize: nls.LocalizeFunc = nls.loadMessageBundle(); export const CppSourceStr: string = "C/C++";