|
| 1 | +/* -------------------------------------------------------------------------------------------- |
| 2 | + * Copyright (c) Microsoft Corporation. All Rights Reserved. |
| 3 | + * See 'LICENSE' in the project root for license information. |
| 4 | + * ------------------------------------------------------------------------------------------ */ |
| 5 | +import { CodeSnippet, ContextResolver, ResolveRequest } from '@github/copilot-language-server'; |
| 6 | +import * as vscode from 'vscode'; |
| 7 | +import { DocumentSelector } from 'vscode-languageserver-protocol'; |
| 8 | +import { getOutputChannelLogger, Logger } from '../logger'; |
| 9 | +import * as telemetry from '../telemetry'; |
| 10 | +import { CopilotCompletionContextTelemetry } from './copilotCompletionContextTelemetry'; |
| 11 | +import { getCopilotApi } from './copilotProviders'; |
| 12 | +import { clients } from './extension'; |
| 13 | + |
| 14 | +class DefaultValueFallback extends Error { |
| 15 | + static readonly DefaultValue = "DefaultValue"; |
| 16 | + constructor() { super(DefaultValueFallback.DefaultValue); } |
| 17 | +} |
| 18 | + |
| 19 | +class CancellationError extends Error { |
| 20 | + static readonly Canceled = "Canceled"; |
| 21 | + constructor() { super(CancellationError.Canceled); } |
| 22 | +} |
| 23 | + |
| 24 | +class CopilotContextProviderException extends Error { |
| 25 | +} |
| 26 | + |
| 27 | +class WellKnownErrors extends Error { |
| 28 | + static readonly ClientNotFound = "ClientNotFound"; |
| 29 | + private constructor(message: string) { super(message); } |
| 30 | + public static clientNotFound(): Error { |
| 31 | + return new WellKnownErrors(WellKnownErrors.ClientNotFound); |
| 32 | + } |
| 33 | +} |
| 34 | + |
| 35 | +// Mutually exclusive values for the kind of snippets. They either are: |
| 36 | +// - computed. |
| 37 | +// - obtained from the cache. |
| 38 | +// - missing and the computation is taking too long and no cache is present (cache miss). The value |
| 39 | +// is asynchronously computed and stored in cache. |
| 40 | +// - the token is signaled as cancelled, in which case all the operations are aborted. |
| 41 | +// - an unknown state. |
| 42 | +enum SnippetsKind { |
| 43 | + Computed = 'computed', |
| 44 | + GotFromCache = 'gotFromCacheHit', |
| 45 | + MissingCacheMiss = 'missingCacheMiss', |
| 46 | + Canceled = 'canceled', |
| 47 | + Unknown = 'unknown' |
| 48 | +} |
| 49 | + |
| 50 | +export class CopilotCompletionContextProvider implements ContextResolver<CodeSnippet> { |
| 51 | + private static readonly providerId = 'cppTools'; |
| 52 | + private readonly completionContextCache: Map<string, CodeSnippet[]> = new Map<string, CodeSnippet[]>(); |
| 53 | + private static readonly defaultCppDocumentSelector: DocumentSelector = [{ language: 'cpp' }, { language: 'c' }, { language: 'cuda-cpp' }]; |
| 54 | + private static readonly defaultTimeBudgetFactor: number = 0.5; |
| 55 | + private completionContextCancellation = new vscode.CancellationTokenSource(); |
| 56 | + private contextProviderDisposable: vscode.Disposable | undefined; |
| 57 | + |
| 58 | + private async waitForCompletionWithTimeoutAndCancellation<T>(promise: Promise<T>, defaultValue: T | undefined, |
| 59 | + timeout: number, token: vscode.CancellationToken): Promise<[T | undefined, SnippetsKind]> { |
| 60 | + const defaultValuePromise = new Promise<T>((_resolve, reject) => setTimeout(() => { |
| 61 | + if (token.isCancellationRequested) { |
| 62 | + reject(new CancellationError()); |
| 63 | + } else { |
| 64 | + reject(new DefaultValueFallback()); |
| 65 | + } |
| 66 | + }, timeout)); |
| 67 | + const cancellationPromise = new Promise<T>((_, reject) => { |
| 68 | + token.onCancellationRequested(() => { |
| 69 | + reject(new CancellationError()); |
| 70 | + }); |
| 71 | + }); |
| 72 | + let snippetsOrNothing: T | undefined; |
| 73 | + try { |
| 74 | + snippetsOrNothing = await Promise.race([promise, cancellationPromise, defaultValuePromise]); |
| 75 | + } catch (e) { |
| 76 | + if (e instanceof DefaultValueFallback) { |
| 77 | + return [defaultValue, defaultValue !== undefined ? SnippetsKind.GotFromCache : SnippetsKind.MissingCacheMiss]; |
| 78 | + } else if (e instanceof CancellationError) { |
| 79 | + return [undefined, SnippetsKind.Canceled]; |
| 80 | + } else { |
| 81 | + throw e; |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + return [snippetsOrNothing, SnippetsKind.Computed]; |
| 86 | + } |
| 87 | + |
| 88 | + // Get the completion context with a timeout and a cancellation token. |
| 89 | + // The cancellationToken indicates that the value should not be returned nor cached. |
| 90 | + private async getCompletionContextWithCancellation(documentUri: string, caretOffset: number, |
| 91 | + startTime: number, out: Logger, telemetry: CopilotCompletionContextTelemetry, token: vscode.CancellationToken): Promise<CodeSnippet[]> { |
| 92 | + try { |
| 93 | + const docUri = vscode.Uri.parse(documentUri); |
| 94 | + const client = clients.getClientFor(docUri); |
| 95 | + if (!client) { throw WellKnownErrors.clientNotFound(); } |
| 96 | + const getContextStartTime = performance.now(); |
| 97 | + const snippets = await client.getCompletionContext(docUri, caretOffset, token); |
| 98 | + |
| 99 | + const codeSnippets = snippets.context.map((item) => { |
| 100 | + if (token.isCancellationRequested) { |
| 101 | + telemetry.addInternalCanceled(); |
| 102 | + throw new CancellationError(); |
| 103 | + } |
| 104 | + return { |
| 105 | + importance: item.importance, uri: item.uri, value: item.text |
| 106 | + }; |
| 107 | + }); |
| 108 | + |
| 109 | + this.completionContextCache.set(documentUri, codeSnippets); |
| 110 | + const duration = CopilotCompletionContextProvider.getRoundedDuration(startTime); |
| 111 | + out.appendLine(`Copilot: getCompletionContextWithCancellation(): ${codeSnippets.length} snippets cached in [ms]: ${duration}`); |
| 112 | + telemetry.addSnippetCount(codeSnippets.length); |
| 113 | + telemetry.addCacheComputedElapsed(duration); |
| 114 | + telemetry.addComputeContextElapsed(CopilotCompletionContextProvider.getRoundedDuration(getContextStartTime)); |
| 115 | + return codeSnippets; |
| 116 | + } catch (e) { |
| 117 | + if (e instanceof CancellationError) { |
| 118 | + telemetry.addInternalCanceled(CopilotCompletionContextProvider.getRoundedDuration(startTime)); |
| 119 | + throw e; |
| 120 | + } else if (e instanceof vscode.CancellationError || (e as Error)?.message === CancellationError.Canceled) { |
| 121 | + telemetry.addCopilotCanceled(CopilotCompletionContextProvider.getRoundedDuration(startTime)); |
| 122 | + throw e; |
| 123 | + } |
| 124 | + |
| 125 | + if (e instanceof WellKnownErrors) { |
| 126 | + telemetry.addWellKnownError(e.message); |
| 127 | + } |
| 128 | + |
| 129 | + const err = e as Error; |
| 130 | + out.appendLine(`Copilot: getCompletionContextWithCancellation(): Error: '${err?.message}', stack '${err?.stack}`); |
| 131 | + telemetry.addError(); |
| 132 | + return []; |
| 133 | + } finally { |
| 134 | + telemetry.file(); |
| 135 | + } |
| 136 | + } |
| 137 | + |
| 138 | + private async fetchTimeBudgetFactor(context: ResolveRequest): Promise<number> { |
| 139 | + const budgetFactor = context.activeExperiments.get("CppToolsCopilotTimeBudget"); |
| 140 | + return (budgetFactor as number) !== undefined ? budgetFactor as number : CopilotCompletionContextProvider.defaultTimeBudgetFactor; |
| 141 | + } |
| 142 | + |
| 143 | + private static getRoundedDuration(startTime: number): number { |
| 144 | + return Math.round(performance.now() - startTime); |
| 145 | + } |
| 146 | + |
| 147 | + public static async Create() { |
| 148 | + const copilotCompletionProvider = new CopilotCompletionContextProvider(); |
| 149 | + await copilotCompletionProvider.registerCopilotContextProvider(); |
| 150 | + return copilotCompletionProvider; |
| 151 | + } |
| 152 | + |
| 153 | + public dispose(): void { |
| 154 | + this.completionContextCancellation.cancel(); |
| 155 | + this.contextProviderDisposable?.dispose(); |
| 156 | + } |
| 157 | + |
| 158 | + public removeFile(fileUri: string): void { |
| 159 | + this.completionContextCache.delete(fileUri); |
| 160 | + } |
| 161 | + |
| 162 | + public async resolve(context: ResolveRequest, copilotCancel: vscode.CancellationToken): Promise<CodeSnippet[]> { |
| 163 | + const resolveStartTime = performance.now(); |
| 164 | + const out: Logger = getOutputChannelLogger(); |
| 165 | + const timeBudgetFactor = await this.fetchTimeBudgetFactor(context); |
| 166 | + const telemetry = new CopilotCompletionContextTelemetry(); |
| 167 | + let codeSnippets: CodeSnippet[] | undefined; |
| 168 | + let codeSnippetsKind: SnippetsKind = SnippetsKind.Unknown; |
| 169 | + try { |
| 170 | + this.completionContextCancellation.cancel(); |
| 171 | + this.completionContextCancellation = new vscode.CancellationTokenSource(); |
| 172 | + const docUri = context.documentContext.uri; |
| 173 | + const cachedValue: CodeSnippet[] | undefined = this.completionContextCache.get(docUri.toString()); |
| 174 | + const computeSnippetsPromise = this.getCompletionContextWithCancellation(docUri, |
| 175 | + context.documentContext.offset, resolveStartTime, out, telemetry.fork(), this.completionContextCancellation.token); |
| 176 | + [codeSnippets, codeSnippetsKind] = await this.waitForCompletionWithTimeoutAndCancellation( |
| 177 | + computeSnippetsPromise, cachedValue, context.timeBudget * timeBudgetFactor, copilotCancel); |
| 178 | + if (codeSnippetsKind === SnippetsKind.Canceled) { |
| 179 | + const duration: number = CopilotCompletionContextProvider.getRoundedDuration(resolveStartTime); |
| 180 | + out.appendLine(`Copilot: getCompletionContext(): cancelled, elapsed time (ms) : ${duration}`); |
| 181 | + telemetry.addInternalCanceled(duration); |
| 182 | + throw new CancellationError(); |
| 183 | + } |
| 184 | + telemetry.addSnippetCount(codeSnippets?.length); |
| 185 | + return codeSnippets ?? []; |
| 186 | + } catch (e: any) { |
| 187 | + if (e instanceof CancellationError) { |
| 188 | + throw e; |
| 189 | + } |
| 190 | + |
| 191 | + // For any other exception's type, it is an error. |
| 192 | + telemetry.addError(); |
| 193 | + throw e; |
| 194 | + } finally { |
| 195 | + telemetry.addKind(codeSnippetsKind.toString()); |
| 196 | + const duration: number = CopilotCompletionContextProvider.getRoundedDuration(resolveStartTime); |
| 197 | + if (codeSnippets === undefined) { |
| 198 | + out.appendLine(`Copilot: getCompletionContext(): no snippets provided (${codeSnippetsKind.toString()}), elapsed time (ms): ${duration}`); |
| 199 | + } else { |
| 200 | + out.appendLine(`Copilot: getCompletionContext(): provided ${codeSnippets?.length} snippets (${codeSnippetsKind.toString()}), elapsed time (ms): ${duration}`); |
| 201 | + } |
| 202 | + telemetry.addResolvedElapsed(duration); |
| 203 | + telemetry.addCacheSize(this.completionContextCache.size); |
| 204 | + telemetry.file(); |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + public async registerCopilotContextProvider(): Promise<void> { |
| 209 | + try { |
| 210 | + const isCustomSnippetProviderApiEnabled = await telemetry.isExperimentEnabled("CppToolsCustomSnippetsApi"); |
| 211 | + if (isCustomSnippetProviderApiEnabled) { |
| 212 | + const copilotApi = await getCopilotApi(); |
| 213 | + if (!copilotApi) { throw new CopilotContextProviderException("getCopilotApi() returned null."); } |
| 214 | + const contextAPI = await copilotApi.getContextProviderAPI("v1"); |
| 215 | + if (!contextAPI) { throw new CopilotContextProviderException("getContextProviderAPI(v1) returned null."); } |
| 216 | + this.contextProviderDisposable = contextAPI.registerContextProvider({ |
| 217 | + id: CopilotCompletionContextProvider.providerId, |
| 218 | + selector: CopilotCompletionContextProvider.defaultCppDocumentSelector, |
| 219 | + resolver: this |
| 220 | + }); |
| 221 | + } |
| 222 | + } catch (e) { |
| 223 | + console.warn("Failed to register the Copilot Context Provider."); |
| 224 | + let msg = "Failed to register the Copilot Context Provider"; |
| 225 | + if (e instanceof CopilotContextProviderException) { |
| 226 | + msg = msg + ": " + e.message; |
| 227 | + } |
| 228 | + telemetry.logCopilotEvent("registerCopilotContextProviderError", { "message": msg }); |
| 229 | + } |
| 230 | + } |
| 231 | +} |
0 commit comments