Skip to content
8 changes: 4 additions & 4 deletions internal/lsp/lsproto/jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ func (m *Message) UnmarshalJSON(data []byte) error {
var err error
if len(raw.Params) > 0 {
params, err = unmarshalParams(raw.Method, raw.Params)
if err != nil {
return fmt.Errorf("%w: %w", ErrorCodeInvalidRequest, err)
}
}

if raw.ID == nil {
Expand All @@ -74,6 +71,9 @@ func (m *Message) UnmarshalJSON(data []byte) error {
Params: params,
}

if err != nil {
return fmt.Errorf("%w: %w", ErrorCodeInvalidParams, err)
}
return nil
}

Expand Down Expand Up @@ -124,7 +124,7 @@ func (r *RequestMessage) UnmarshalJSON(data []byte) error {

type ResponseMessage struct {
JSONRPC jsonrpc.JSONRPCVersion `json:"jsonrpc"`
ID *jsonrpc.ID `json:"id,omitzero"`
ID *jsonrpc.ID `json:"id"`
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the spec, if we fail to decode a request or notification (i.e. the whole thing, not just params), we are allowed to return an error response with null id.

Result any `json:"result,omitzero"`
Error *jsonrpc.ResponseError `json:"error,omitzero"`
}
Expand Down
24 changes: 18 additions & 6 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ func (r *lspReader) Read() (*lsproto.Message, error) {

req := &lsproto.Message{}
if err := json.Unmarshal(data, req); err != nil {
if errors.Is(err, lsproto.ErrorCodeInvalidParams) {
return req, fmt.Errorf("%w: %w", lsproto.ErrorCodeInvalidParams, err)
}
return nil, fmt.Errorf("%w: %w", lsproto.ErrorCodeInvalidRequest, err)
}

Expand Down Expand Up @@ -334,8 +337,14 @@ func (s *Server) readLoop(ctx context.Context) error {
}
msg, err := s.read()
if err != nil {
if errors.Is(err, lsproto.ErrorCodeInvalidRequest) {
if err := s.sendError(nil, err); err != nil {
if errors.Is(err, lsproto.ErrorCodeInvalidRequest) || errors.Is(err, lsproto.ErrorCodeInvalidParams) {
var id *jsonrpc.ID
if errors.Is(err, lsproto.ErrorCodeInvalidParams) {
if msg != nil && msg.Kind == jsonrpc.MessageKindRequest {
id = msg.AsRequest().ID
}
}
if err := s.sendError(id, err); err != nil {
return err
}
continue
Expand Down Expand Up @@ -500,6 +509,12 @@ func (s *Server) sendResult(id *jsonrpc.ID, result any) error {
}

func (s *Server) sendError(id *jsonrpc.ID, err error) error {
// Do not send error response for notifications,
// except for parse errors which may occur before determining if the message is a request or notification.
if id == nil && !errors.Is(err, lsproto.ErrorCodeInvalidRequest) {
s.logger.Errorf("error handling notification: %s", err)
return nil
}
code := lsproto.ErrorCodeInternalError
if errCode, ok := errors.AsType[lsproto.ErrorCode](err); ok {
code = errCode
Expand Down Expand Up @@ -546,10 +561,7 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R
return err
}
s.logger.Warn("unknown method '", req.Method, "'")
if req.ID != nil {
return s.sendError(req.ID, lsproto.ErrorCodeInvalidRequest)
}
return nil
return s.sendError(req.ID, lsproto.ErrorCodeInvalidRequest)
}

type handlerMap map[lsproto.Method]func(*Server, context.Context, *lsproto.RequestMessage) error
Expand Down