diff --git a/pkg/grpc/gateway/request.go b/pkg/grpc/gateway/request.go index 3a2c540..d73eea3 100644 --- a/pkg/grpc/gateway/request.go +++ b/pkg/grpc/gateway/request.go @@ -62,7 +62,7 @@ func DoRequest[T any](ctx context.Context, req *resty.Request) (*T, error) { func DoStreamingRequest[T any](ctx context.Context, c Client, req *resty.Request) (<-chan *T, <-chan error, error) { var resBody T if _, ok := any(&resBody).(*httpbody.HttpBody); ok { - resCh, errCh, err := doHTTPStreamingRequest(ctx, req) + resCh, errCh, err := doHTTPStreamingRequest(ctx, c, req) if err != nil { return nil, nil, err } @@ -79,29 +79,7 @@ func DoStreamingRequest[T any](ctx context.Context, c Client, req *resty.Request return nil, nil, fmt.Errorf("send request: %w", err) } if rawRes.IsError() { - body := rawRes.RawBody() - defer func() { _ = body.Close() }() - data, err := io.ReadAll(body) - if err != nil { - return nil, nil, fmt.Errorf("read error response body: %w", err) - } - - var res streamingResponse - if err := json.Unmarshal(data, &res); err != nil { - return nil, nil, fmt.Errorf("unmarshal raw response: %w", err) - } - rawErrRes, ok := res[streamingResponseErrorKey] - if !ok { - return nil, nil, errors.New(string(data)) - } - var errResp rpcstatus.Status - if err := c.Unmarshal(rawErrRes, &errResp); err != nil { - return nil, nil, fmt.Errorf("unmarshal error response: %w", err) - } - if err := status.ErrorProto(&errResp); err != nil { - return nil, nil, err - } - return nil, nil, status.Error(HTTPStatusToCode(rawRes.StatusCode()), rawRes.String()) + return nil, nil, wrapStreamingResponseError(c, rawRes) } resCh := make(chan *T) @@ -166,22 +144,18 @@ func doHTTPRequest(ctx context.Context, req *resty.Request) (any, error) { }, nil } -func doHTTPStreamingRequest(ctx context.Context, req *resty.Request) (any, <-chan error, error) { +func doHTTPStreamingRequest(ctx context.Context, c Client, req *resty.Request) (any, <-chan error, error) { res, err := req.SetContext(ctx). + SetHeader("Accept", "text/event-stream"). + SetHeader("Cache-Control", "no-cache"). + SetHeader("Connection", "keep-alive"). SetDoNotParseResponse(true). Send() if err != nil { return nil, nil, fmt.Errorf("send request: %w", err) } if res.IsError() { - errResp, ok := res.Error().(*rpcstatus.Status) - if !ok { - return nil, nil, fmt.Errorf("cast error response: %s", res.String()) - } - if err := status.ErrorProto(errResp); err != nil { - return nil, nil, err - } - return nil, nil, status.Error(HTTPStatusToCode(res.StatusCode()), errResp.String()) + return nil, nil, wrapStreamingResponseError(c, res) } resCh := make(chan *httpbody.HttpBody) @@ -219,3 +193,29 @@ func doHTTPStreamingRequest(ctx context.Context, req *resty.Request) (any, <-cha }() return resCh, errCh, nil } + +func wrapStreamingResponseError(c Client, resp *resty.Response) error { + body := resp.RawBody() + defer func() { _ = body.Close() }() + data, err := io.ReadAll(body) + if err != nil { + return fmt.Errorf("read error response body: %w", err) + } + + var streamingResp streamingResponse + if err := json.Unmarshal(data, &streamingResp); err != nil { + return fmt.Errorf("unmarshal raw response: %w", err) + } + rawErrRes, ok := streamingResp[streamingResponseErrorKey] + if !ok { + return errors.New(string(data)) + } + var errResp rpcstatus.Status + if err := c.Unmarshal(rawErrRes, &errResp); err != nil { + return fmt.Errorf("unmarshal error response: %w", err) + } + if err := status.ErrorProto(&errResp); err != nil { + return err + } + return status.Error(HTTPStatusToCode(resp.StatusCode()), resp.String()) +}