diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 1d829849c3..df82f21d79 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -424,32 +424,40 @@ func (s *Server) dispatchLoop(ctx context.Context) error { s.pendingClientRequestsMu.Unlock() } - handle := func() { - if err := s.handleRequestOrNotification(requestCtx, req); err != nil { - if errors.Is(err, context.Canceled) { - if err := s.sendError(req.ID, lsproto.ErrorCodeRequestCancelled); err != nil { - lspExit(err) - } - } else if errors.Is(err, io.EOF) { - lspExit(nil) - } else { - if err := s.sendError(req.ID, err); err != nil { - lspExit(err) - } + handleError := func(err error) { + if errors.Is(err, context.Canceled) { + if err := s.sendError(req.ID, lsproto.ErrorCodeRequestCancelled); err != nil { + lspExit(err) + } + } else if errors.Is(err, io.EOF) { + lspExit(nil) + } else { + if err := s.sendError(req.ID, err); err != nil { + lspExit(err) } } + } + removeRequest := func() { if req.ID != nil { s.pendingClientRequestsMu.Lock() + defer s.pendingClientRequestsMu.Unlock() delete(s.pendingClientRequests, *req.ID) - s.pendingClientRequestsMu.Unlock() } } - if isBlockingMethod(req.Method) { - handle() + if doAsyncWork, err := s.handleRequestOrNotification(requestCtx, req); err != nil { + handleError(err) + removeRequest() + } else if doAsyncWork != nil { + go func() { + if lsError := doAsyncWork(); lsError != nil { + handleError(lsError) + } + removeRequest() + }() } else { - go handle() + removeRequest() } } } @@ -547,24 +555,46 @@ func (s *Server) send(msg *lsproto.Message) error { } } -func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.RequestMessage) error { +// handleRequestOrNotification looks up the handler for the given request or notification, executes its synchronous work +// and returns any asynchronous work as a function to be executed by the caller. +func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.RequestMessage) (func() error, error) { ctx = lsproto.WithClientCapabilities(ctx, &s.clientCapabilities) if handler := handlers()[req.Method]; handler != nil { start := time.Now() - err := handler(s, ctx, req) + doAsyncWork, err := handler(s, ctx, req) idStr := "" if req.ID != nil { idStr = " (" + req.ID.String() + ")" } + if err != nil { + s.logger.Error("error handling method '", req.Method, "'", idStr, ": ", err) + return nil, err + } + if doAsyncWork != nil { + return func() error { + if ctx.Err() != nil { + return ctx.Err() + } + asyncWorkErr := doAsyncWork() + s.logger.Info(core.IfElse(asyncWorkErr != nil, "error handling method '", "handled method '"), req.Method, "'", idStr, " in ", time.Since(start)) + return asyncWorkErr + }, nil + } s.logger.Info("handled method '", req.Method, "'", idStr, " in ", time.Since(start)) - return err + return nil, nil } s.logger.Warn("unknown method '", req.Method, "'") - return s.sendError(req.ID, lsproto.ErrorCodeInvalidRequest) + if req.ID != nil { + return nil, s.sendError(req.ID, lsproto.ErrorCodeInvalidRequest) + } + return nil, nil } -type handlerMap map[lsproto.Method]func(*Server, context.Context, *lsproto.RequestMessage) error +// handlerMap maps LSP method to a handler function. The handler function executes any work that must be done synchronously +// before other requests/notifications can be processed, and returns any additional work as a function to be executed +// asynchronously after the synchronous work is complete. +type handlerMap map[lsproto.Method]func(*Server, context.Context, *lsproto.RequestMessage) (func() error, error) var handlers = sync.OnceValue(func() handlerMap { handlers := make(handlerMap) @@ -627,9 +657,9 @@ var handlers = sync.OnceValue(func() handlerMap { }) func registerNotificationHandler[Req any](handlers handlerMap, info lsproto.NotificationInfo[Req], fn func(*Server, context.Context, Req) error) { - handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error { + handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) { if s.session == nil && req.Method != lsproto.MethodInitialized { - return lsproto.ErrorCodeServerNotInitialized + return nil, lsproto.ErrorCodeServerNotInitialized } var params Req @@ -638,9 +668,9 @@ func registerNotificationHandler[Req any](handlers handlerMap, info lsproto.Noti params = req.Params.(Req) } if err := fn(s, ctx, params); err != nil { - return err + return nil, err } - return ctx.Err() + return nil, ctx.Err() } } @@ -649,9 +679,9 @@ func registerRequestHandler[Req, Resp any]( info lsproto.RequestInfo[Req, Resp], fn func(*Server, context.Context, Req, *lsproto.RequestMessage) (Resp, error), ) { - handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error { + handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) { if s.session == nil && req.Method != lsproto.MethodInitialize { - return lsproto.ErrorCodeServerNotInitialized + return nil, lsproto.ErrorCodeServerNotInitialized } var params Req @@ -661,17 +691,17 @@ func registerRequestHandler[Req, Resp any]( } resp, err := fn(s, ctx, params, req) if err != nil { - return err + return nil, err } if ctx.Err() != nil { - return ctx.Err() + return nil, ctx.Err() } - return s.sendResult(req.ID, resp) + return nil, s.sendResult(req.ID, resp) } } func registerLanguageServiceDocumentRequestHandler[Req lsproto.HasTextDocumentURI, Resp any](handlers handlerMap, info lsproto.RequestInfo[Req, Resp], fn func(*Server, context.Context, *ls.LanguageService, Req) (Resp, error)) { - handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error { + handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) { var params Req // Ignore empty params. if req.Params != nil { @@ -679,22 +709,24 @@ func registerLanguageServiceDocumentRequestHandler[Req lsproto.HasTextDocumentUR } ls, err := s.session.GetLanguageService(ctx, params.TextDocumentURI()) if err != nil { - return err - } - defer s.recover(ctx, req) - resp, err := fn(s, ctx, ls, params) - if err != nil { - return err + return nil, err } - if ctx.Err() != nil { - return ctx.Err() - } - return s.sendResult(req.ID, resp) + return func() error { + defer s.recover(ctx, req) + resp, lsErr := fn(s, ctx, ls, params) + if lsErr != nil { + return lsErr + } + if ctx.Err() != nil { + return ctx.Err() + } + return s.sendResult(req.ID, resp) + }, nil } } func registerLanguageServiceWithAutoImportsRequestHandler[Req lsproto.HasTextDocumentURI, Resp any](handlers handlerMap, info lsproto.RequestInfo[Req, Resp], fn func(*Server, context.Context, *ls.LanguageService, Req) (Resp, error)) { - handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error { + handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) { var params Req // Ignore empty params. if req.Params != nil { @@ -702,30 +734,32 @@ func registerLanguageServiceWithAutoImportsRequestHandler[Req lsproto.HasTextDoc } languageService, err := s.session.GetLanguageService(ctx, params.TextDocumentURI()) if err != nil { - return err + return nil, err } - defer s.recover(ctx, req) - resp, err := fn(s, ctx, languageService, params) - if errors.Is(err, ls.ErrNeedsAutoImports) { - languageService, err = s.session.GetLanguageServiceWithAutoImports(ctx, params.TextDocumentURI()) - if err != nil { - return err + return func() error { + defer s.recover(ctx, req) + resp, lsErr := fn(s, ctx, languageService, params) + if errors.Is(lsErr, ls.ErrNeedsAutoImports) { + languageService, lsErr = s.session.GetLanguageServiceWithAutoImports(ctx, params.TextDocumentURI()) + if lsErr != nil { + return lsErr + } + if ctx.Err() != nil { + return ctx.Err() + } + resp, lsErr = fn(s, ctx, languageService, params) + if errors.Is(lsErr, ls.ErrNeedsAutoImports) { + panic(info.Method + " returned ErrNeedsAutoImports even after enabling auto imports") + } + } + if lsErr != nil { + return lsErr } if ctx.Err() != nil { return ctx.Err() } - resp, err = fn(s, ctx, languageService, params) - if errors.Is(err, ls.ErrNeedsAutoImports) { - panic(info.Method + " returned ErrNeedsAutoImports even after enabling auto imports") - } - } - if err != nil { - return err - } - if ctx.Err() != nil { - return ctx.Err() - } - return s.sendResult(req.ID, resp) + return s.sendResult(req.ID, resp) + }, nil } } @@ -734,7 +768,7 @@ func registerMultiProjectReferenceRequestHandler[Req lsproto.HasTextDocumentPosi info lsproto.RequestInfo[Req, Resp], fn func(*ls.LanguageService, context.Context, Req, ls.CrossProjectOrchestrator) (Resp, error), ) { - handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) error { + handlers[info.Method] = func(s *Server, ctx context.Context, req *lsproto.RequestMessage) (func() error, error) { var params Req // Ignore empty params. if req.Params != nil { @@ -743,14 +777,16 @@ func registerMultiProjectReferenceRequestHandler[Req lsproto.HasTextDocumentPosi // !!! sheetal: multiple projects that contain the file through symlinks defaultLs, orchestrator, err := s.getLanguageServiceAndCrossProjectOrchestrator(ctx, params.TextDocumentURI(), req) if err != nil { - return err + return nil, err } - defer s.recover(ctx, req) - resp, err := fn(defaultLs, ctx, params, orchestrator) - if err != nil { - return err - } - return s.sendResult(req.ID, resp) + return func() error { + defer s.recover(ctx, req) + resp, lsErr := fn(defaultLs, ctx, params, orchestrator) + if lsErr != nil { + return lsErr + } + return s.sendResult(req.ID, resp) + }, nil } } @@ -1355,22 +1391,6 @@ func (s *Server) NpmInstall(cwd string, args []string) ([]byte, error) { return s.npmInstall(cwd, args) } -func isBlockingMethod(method lsproto.Method) bool { - switch method { - case lsproto.MethodInitialize, - lsproto.MethodInitialized, - lsproto.MethodTextDocumentDidOpen, - lsproto.MethodTextDocumentDidChange, - lsproto.MethodTextDocumentDidSave, - lsproto.MethodTextDocumentDidClose, - lsproto.MethodWorkspaceDidChangeWatchedFiles, - lsproto.MethodWorkspaceDidChangeConfiguration, - lsproto.MethodWorkspaceConfiguration: - return true - } - return false -} - // Developer/debugging command handlers func (s *Server) handleRunGC(_ context.Context, _ any, _ *lsproto.RequestMessage) (lsproto.RunGCResponse, error) {