Skip to content

Commit

Permalink
fix: fix the issue which the server-side did not report failure after…
Browse files Browse the repository at this point in the history
… a panic occurred during request processing
  • Loading branch information
YangruiEmma committed Jan 4, 2025
1 parent 11073fe commit fe8890e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
27 changes: 16 additions & 11 deletions pkg/remote/trans/default_server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remo
func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, recvMsg remote.Message) (nctx context.Context, err error) {
var bufReader remote.ByteBuffer
defer func() {
if r := recover(); r != nil {
stack := string(debug.Stack())
panicErr := kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[happened in Read] %s", r), stack)
rpcStats := rpcinfo.AsMutableRPCStats(recvMsg.RPCInfo().Stats())
rpcStats.SetPanicked(panicErr)
err = remote.NewTransError(remote.ProtocolError, panicErr)
nctx = ctx
}
t.ext.ReleaseBuffer(bufReader, err)
rpcinfo.Record(ctx, recvMsg.RPCInfo(), stats.ReadFinish, err)
}()
Expand Down Expand Up @@ -133,9 +141,8 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
var sendMsg remote.Message
closeConnOutsideIfErr := true
defer func() {
panicErr := recover()
var wrapErr error
if panicErr != nil {
var panicErr error
if r := recover(); r != nil {
stack := string(debug.Stack())
if conn != nil {
ri := rpcinfo.GetRPCInfo(ctx)
Expand All @@ -144,10 +151,9 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
} else {
klog.CtxErrorf(ctx, "KITEX: panic happened, error=%v\nstack=%s", panicErr, stack)
}
if err != nil {
wrapErr = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[happened in OnRead] %s, last error=%s", panicErr, err.Error()), stack)
} else {
wrapErr = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[happened in OnRead] %s", panicErr), stack)
panicErr = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[happened in OnRead] %s", panicErr), stack)
if err == nil {
err = panicErr
}
}
t.finishTracer(ctx, ri, err, panicErr)
Expand All @@ -158,10 +164,9 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
if rpcinfo.PoolEnabled() {
t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr())
}
if wrapErr != nil {
err = wrapErr
}
if err != nil && !closeConnOutsideIfErr {
// when error is not nil, outside will close conn,
// set err to nil to indicate that this kind of error does not require closing the connection
err = nil
}
}()
Expand All @@ -186,7 +191,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
// reply processing
var methodInfo serviceinfo.MethodInfo
if methodInfo, err = GetMethodInfo(ri, svcInfo); err != nil {
// it won't be err, because the method has been checked in decode, err check here just do defensive inspection
// it won't be error, because the method has been checked in decode, err check here just do defensive inspection
t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true)
// for proxy case, need read actual remoteAddr, error print must exec after writeErrorReplyIfNeeded,
// t.OnError(ctx, err, conn) will be executed at outer function when transServer close the conn
Expand Down
54 changes: 54 additions & 0 deletions pkg/remote/trans/default_server_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"errors"
"net"
"strings"
"testing"

"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -211,6 +212,59 @@ func TestSvrTransHandlerReadErr(t *testing.T) {
test.Assert(t, errors.Is(err, mockErr))
}

func TestSvrTransHandlerReadPanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockTracer := stats.NewMockTracer(ctrl)
mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes()
mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) {
err := rpcinfo.GetRPCInfo(ctx).Stats().Error()
test.Assert(t, err != nil)
}).AnyTimes()

buf := remote.NewReaderWriterBuffer(1024)
ext := &MockExtension{
NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
return buf
},
NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
return buf
},
}

tracerCtl := &rpcinfo.TraceController{}
tracerCtl.Append(mockTracer)
opt := &remote.ServerOption{
Codec: &MockCodec{
EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
return nil
},
DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
panic("mock")
},
},
SvcSearcher: svcSearcher,
TargetSvcInfo: svcInfo,
TracerCtl: tracerCtl,
InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr)
return ri
},
}
ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}),
rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)

svrHandler, err := NewDefaultSvrTransHandler(opt, ext)
test.Assert(t, err == nil)
pl := remote.NewTransPipeline(svrHandler)
svrHandler.SetPipeline(pl)
err = svrHandler.OnRead(ctx, &mocks.Conn{})
test.Assert(t, err != nil)
test.Assert(t, strings.Contains(err.Error(), "panic"))
}

func TestSvrTransHandlerOnReadHeartbeat(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down
4 changes: 2 additions & 2 deletions pkg/remote/trans/netpoll/server_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ func TestInvokeErr(t *testing.T) {
test.Assert(t, isInvoked)
}

// TestPanicAfterRead test server_handler not panic after read
func TestPanicAfterRead(t *testing.T) {
// TestPipelineNilPanic test server_handler that TransPipeline is nil
func TestPipelineNilPanic(t *testing.T) {
// 1. prepare mock data
var isWriteBufFlushed bool
var isReaderBufReleased bool
Expand Down

0 comments on commit fe8890e

Please sign in to comment.