Skip to content

Commit fa2c2a7

Browse files
committed
Stop treating RST_STREAM as EOF
In hq, it's important to know when the stream is reset as opposed to ending naturally. You need to know this so that you can acknowledge any potentially unread header blocks. This changes the code so that resets cause subsequent reads to immediately fail with a recognizable error code. Receiving RST_STREAM stops further reading. It already caused existing data to be discarded. Finally, when a read fails on a reset stream, the state moves to the previously unused ResetRead state.
1 parent b47c785 commit fa2c2a7

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

connection_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ func (h *testReceiveHandler) StreamReadable(s RecvStream) {
405405
break
406406
case ErrorWouldBlock:
407407
return
408-
case ErrorStreamIsClosed, ErrorConnIsClosed, io.EOF:
408+
case ErrorStreamReset, ErrorConnIsClosed, io.EOF:
409409
h.done = true
410410
return
411411
default:
@@ -565,7 +565,7 @@ func TestSendReceiveStreamRst(t *testing.T) {
565565
ss := pair.server.GetStream(4)
566566
b := make([]byte, 1024)
567567
n, err = ss.Read(b)
568-
assertEquals(t, err, io.EOF)
568+
assertEquals(t, err, ErrorStreamReset)
569569
assertEquals(t, 0, n)
570570
}
571571

@@ -783,10 +783,11 @@ func TestUnidirectionalStreamRst(t *testing.T) {
783783
err = inputAll(client)
784784
assertNotError(t, err, "packets should be OK")
785785

786+
assertEquals(t, cstream.RecvState(), RecvStreamStateResetRecvd)
786787
n, err = cstream.Read(d)
787-
assertEquals(t, err, io.EOF)
788+
assertEquals(t, err, ErrorStreamReset)
788789
assertEquals(t, n, 0)
789-
assertEquals(t, cstream.RecvState(), RecvStreamStateResetRecvd)
790+
assertEquals(t, cstream.RecvState(), RecvStreamStateResetRead)
790791
}
791792

792793
func TestUnidirectionalStreamRstImmediate(t *testing.T) {
@@ -801,11 +802,12 @@ func TestUnidirectionalStreamRstImmediate(t *testing.T) {
801802
assertNotError(t, err, "packets should be OK")
802803

803804
cstream := pair.client.GetRecvStream(sstream.Id())
805+
assertEquals(t, cstream.RecvState(), RecvStreamStateResetRecvd)
804806
var d [3]byte
805807
n, err := cstream.Read(d[:])
806-
assertEquals(t, err, io.EOF)
808+
assertEquals(t, err, ErrorStreamReset)
807809
assertEquals(t, n, 0)
808-
assertEquals(t, cstream.RecvState(), RecvStreamStateResetRecvd)
810+
assertEquals(t, cstream.RecvState(), RecvStreamStateResetRead)
809811
}
810812

811813
func TestUnidirectionalStopSending(t *testing.T) {
@@ -839,10 +841,11 @@ func TestUnidirectionalStopSending(t *testing.T) {
839841
err = inputAll(pair.server)
840842
assertNotError(t, err, "packets should be OK")
841843

844+
assertEquals(t, sstream.RecvState(), RecvStreamStateResetRecvd)
842845
n, err = sstream.Read(d)
843-
assertEquals(t, err, io.EOF)
846+
assertEquals(t, err, ErrorStreamReset)
844847
assertEquals(t, n, 0)
845-
assertEquals(t, sstream.RecvState(), RecvStreamStateResetRecvd)
848+
assertEquals(t, sstream.RecvState(), RecvStreamStateResetRead)
846849
}
847850

848851
func TestBidirectionalStopSending(t *testing.T) {

errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ var ErrorDestroyConnection = fatalError("Terminate connection")
8686
var ErrorReceivedVersionNegotiation = fatalError("Received a version negotiation packet advertising a different version than ours")
8787
var ErrorConnIsClosed = fatalError("Connection is closed")
8888
var ErrorConnIsClosing = nonFatalError("Connection is closing")
89+
var ErrorStreamReset = fatalError("Stream was reset")
8990
var ErrorStreamIsClosed = fatalError("Stream is closed")
9091
var ErrorInvalidPacket = nonFatalError("Invalid packet")
9192
var ErrorConnectionTimedOut = fatalError("Connection timed out")

stream.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const (
5252
RecvStreamStateDataRecvd = RecvStreamState(2) // Not tracked
5353
RecvStreamStateResetRecvd = RecvStreamState(3)
5454
RecvStreamStateDataRead = RecvStreamState(4)
55-
RecvStreamStateResetRead = RecvStreamState(5) // Not tracked
55+
RecvStreamStateResetRead = RecvStreamState(5)
5656
)
5757

5858
// String produces a nice string from a RecvStreamState.
@@ -392,6 +392,12 @@ func (s *recvStreamBase) read(b []byte) (int, error) {
392392
s.log(logTypeStream, "Reading len=%v read offset=%v available chunks=%v",
393393
len(b), s.readOffset, len(s.chunks))
394394

395+
if s.state == RecvStreamStateResetRecvd {
396+
s.log(logTypeStream, "Reading stopped for RST_STREAM")
397+
s.setRecvState(RecvStreamStateResetRead)
398+
return 0, ErrorStreamReset
399+
}
400+
395401
read := 0
396402

397403
for len(b) > 0 {

0 commit comments

Comments
 (0)