From 89fe3f42e297837c96c061aa66a3fad70ca0e438 Mon Sep 17 00:00:00 2001 From: andreasvogt89 <30302212+andreasvogt89@users.noreply.github.com> Date: Wed, 21 Sep 2022 20:27:37 +0200 Subject: [PATCH 1/3] option for allowing retransmissions at the packet server --- server-packet.go | 30 ++++++----- server-packet_test.go | 114 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 13 deletions(-) diff --git a/server-packet.go b/server-packet.go index 1347d43..cccd7d8 100644 --- a/server-packet.go +++ b/server-packet.go @@ -46,6 +46,9 @@ type PacketServer struct { // This should only be set to true for debugging purposes. InsecureSkipVerify bool + // Don't block dupplicated requests (retransmissions). + AllowRetransmission bool + // ErrorLog specifies an optional logger for errors // around packet accepting, processing, and validation. // If nil, logging is done via the log package's standard logger. @@ -171,26 +174,27 @@ func (s *PacketServer) Serve(conn net.PacketConn) error { IP: remoteAddr.String(), Identifier: packet.Identifier, } - - requestsLock.Lock() - if _, ok := requests[key]; ok { + if !s.AllowRetransmission { + requestsLock.Lock() + if _, ok := requests[key]; ok { + requestsLock.Unlock() + return + } + requests[key] = struct{}{} requestsLock.Unlock() - return - } - requests[key] = struct{}{} - requestsLock.Unlock() + //clean up afterwards + defer func() { + requestsLock.Lock() + delete(requests, key) + requestsLock.Unlock() + }() + } response := packetResponseWriter{ conn: conn, addr: remoteAddr, } - defer func() { - requestsLock.Lock() - delete(requests, key) - requestsLock.Unlock() - }() - request := Request{ LocalAddr: conn.LocalAddr(), RemoteAddr: remoteAddr, diff --git a/server-packet_test.go b/server-packet_test.go index 7497aaa..2b0812b 100644 --- a/server-packet_test.go +++ b/server-packet_test.go @@ -186,3 +186,117 @@ func TestPacketServer_singleUse(t *testing.T) { t.Fatalf("got err %v; expecting ErrServerShutdown", err) } } + +func TestPacketServer_AllowRetransmission(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + pc, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatal(err) + } + + secret := []byte("123456790") + var receivedRequests = 0 + var identifiers = make(map[byte]struct{}) + server := PacketServer{ + SecretSource: StaticSecretSource(secret), + AllowRetransmission: true, + Handler: HandlerFunc(func(w ResponseWriter, r *Request) { + receivedRequests++ + if _, ok := identifiers[r.Identifier]; ok { + return + } + identifiers[r.Identifier] = struct{}{} + time.Sleep(time.Millisecond * 200) + w.Write(r.Response(CodeAccessReject)) + }), + } + + var clientErr error + go func(rr *int) { + defer server.Shutdown(context.Background()) + + packet := New(CodeAccessRequest, secret) + client := Client{ + Retry: time.Millisecond * 10, + } + response, err := client.Exchange(context.Background(), packet, pc.LocalAddr().String()) + if err != nil { + clientErr = err + return + } + if response.Code != CodeAccessReject { + clientErr = fmt.Errorf("got response code %v; expecting CodeAccessReject", response.Code) + } + if receivedRequests < 2 { + clientErr = fmt.Errorf("got %d requests; expecting at least 2", receivedRequests) + } + }(&receivedRequests) + + if err := server.Serve(pc); err != ErrServerShutdown { + t.Fatal(err) + } + + server.Shutdown(context.Background()) + if clientErr != nil { + t.Fatal(clientErr) + } +} + +func TestPacketServer_BlockRetransmission(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + t.Fatal(err) + } + pc, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatal(err) + } + var receivedRequests = 0 + var identifiers = make(map[byte]struct{}) + secret := []byte("123456790") + server := PacketServer{ + SecretSource: StaticSecretSource(secret), + //AllowRetransmission: true, + Handler: HandlerFunc(func(w ResponseWriter, r *Request) { + receivedRequests++ + if _, ok := identifiers[r.Identifier]; ok { + return + } + time.Sleep(time.Millisecond * 200) + w.Write(r.Response(CodeAccessReject)) + }), + } + + var clientErr error + go func(rr *int) { + defer server.Shutdown(context.Background()) + + packet := New(CodeAccessRequest, secret) + client := Client{ + Retry: time.Millisecond * 10, + } + response, err := client.Exchange(context.Background(), packet, pc.LocalAddr().String()) + if err != nil { + clientErr = err + return + } + if response.Code != CodeAccessReject { + clientErr = fmt.Errorf("got response code %v; expecting CodeAccessReject", response.Code) + } + if receivedRequests != 2 { + clientErr = fmt.Errorf("got %d requests; expecting only 1", receivedRequests) + } + }(&receivedRequests) + + if err := server.Serve(pc); err != ErrServerShutdown { + t.Fatal(err) + } + + server.Shutdown(context.Background()) + if clientErr != nil { + t.Fatal(clientErr) + } +} From 65f650c577602ee19eb36c3408abb160090be1fd Mon Sep 17 00:00:00 2001 From: andreasvogt89 <30302212+andreasvogt89@users.noreply.github.com> Date: Wed, 21 Sep 2022 20:53:09 +0200 Subject: [PATCH 2/3] there is no need to create a key while allowing retransmissions --- server-packet.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server-packet.go b/server-packet.go index cccd7d8..6d4f65b 100644 --- a/server-packet.go +++ b/server-packet.go @@ -170,11 +170,11 @@ func (s *PacketServer) Serve(conn net.PacketConn) error { return } - key := requestKey{ - IP: remoteAddr.String(), - Identifier: packet.Identifier, - } if !s.AllowRetransmission { + key := requestKey{ + IP: remoteAddr.String(), + Identifier: packet.Identifier, + } requestsLock.Lock() if _, ok := requests[key]; ok { requestsLock.Unlock() From 95a0bb06b0181310a3e0365c85850b4b1eae95cd Mon Sep 17 00:00:00 2001 From: Vogt Andreas Date: Thu, 22 Sep 2022 12:45:05 +0200 Subject: [PATCH 3/3] Unittest correction --- server-packet_test.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/server-packet_test.go b/server-packet_test.go index 2b0812b..440234c 100644 --- a/server-packet_test.go +++ b/server-packet_test.go @@ -215,7 +215,7 @@ func TestPacketServer_AllowRetransmission(t *testing.T) { } var clientErr error - go func(rr *int) { + go func(rr int) { defer server.Shutdown(context.Background()) packet := New(CodeAccessRequest, secret) @@ -233,7 +233,7 @@ func TestPacketServer_AllowRetransmission(t *testing.T) { if receivedRequests < 2 { clientErr = fmt.Errorf("got %d requests; expecting at least 2", receivedRequests) } - }(&receivedRequests) + }(receivedRequests) if err := server.Serve(pc); err != ErrServerShutdown { t.Fatal(err) @@ -259,19 +259,18 @@ func TestPacketServer_BlockRetransmission(t *testing.T) { secret := []byte("123456790") server := PacketServer{ SecretSource: StaticSecretSource(secret), - //AllowRetransmission: true, Handler: HandlerFunc(func(w ResponseWriter, r *Request) { receivedRequests++ if _, ok := identifiers[r.Identifier]; ok { return } - time.Sleep(time.Millisecond * 200) + time.Sleep(time.Millisecond * 500) w.Write(r.Response(CodeAccessReject)) }), } var clientErr error - go func(rr *int) { + go func(rr int) { defer server.Shutdown(context.Background()) packet := New(CodeAccessRequest, secret) @@ -286,10 +285,10 @@ func TestPacketServer_BlockRetransmission(t *testing.T) { if response.Code != CodeAccessReject { clientErr = fmt.Errorf("got response code %v; expecting CodeAccessReject", response.Code) } - if receivedRequests != 2 { + if receivedRequests != 1 { clientErr = fmt.Errorf("got %d requests; expecting only 1", receivedRequests) } - }(&receivedRequests) + }(receivedRequests) if err := server.Serve(pc); err != ErrServerShutdown { t.Fatal(err)