Skip to content

Commit

Permalink
fix #49
Browse files Browse the repository at this point in the history
  • Loading branch information
firefart committed Mar 9, 2024
1 parent ae46f5d commit 0d246a4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
4 changes: 2 additions & 2 deletions internal/cmd/tcpscanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func httpScan(ctx context.Context, opts TCPScannerOpts, ip netip.Addr, port uint
if err := helper.ConnectionWrite(ctx, tlsConn, []byte(httpRequest), opts.Timeout); err != nil {
return fmt.Errorf("error on sending TLS data: %w", err)
}
data, err := helper.ConnectionRead(ctx, tlsConn, opts.Timeout)
data, err := helper.ConnectionReadAll(ctx, tlsConn, opts.Timeout)
if err != nil {
return fmt.Errorf("error on reading after sending TLS data: %w", err)
}
Expand All @@ -120,7 +120,7 @@ func httpScan(ctx context.Context, opts TCPScannerOpts, ip netip.Addr, port uint
if err := helper.ConnectionWrite(ctx, dataConnection, []byte(httpRequest), opts.Timeout); err != nil {
return fmt.Errorf("error on sending data: %w", err)
}
data, err := helper.ConnectionRead(ctx, dataConnection, opts.Timeout)
data, err := helper.ConnectionReadAll(ctx, dataConnection, opts.Timeout)
if err != nil {
return fmt.Errorf("error on reading after sending data: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/udpscanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func snmpScan(ctx context.Context, opts UDPScannerOpts, ip netip.Addr, port uint
return fmt.Errorf("error on sending SNMP request: %w", err)
}

resp, err := helper.ConnectionRead(ctx, remote, opts.Timeout)
resp, err := helper.ConnectionReadAll(ctx, remote, opts.Timeout)
if err != nil {
// ignore timeouts
if errors.Is(err, helper.ErrTimeout) {
Expand Down Expand Up @@ -245,7 +245,7 @@ func dnsScan(ctx context.Context, opts UDPScannerOpts, ip netip.Addr, port uint1
return fmt.Errorf("error on sending DNS request: %w", err)
}

resp, err := helper.ConnectionRead(ctx, remote, opts.Timeout)
resp, err := helper.ConnectionReadAll(ctx, remote, opts.Timeout)
if err != nil {
// ignore timeouts
if errors.Is(err, helper.ErrTimeout) {
Expand Down
24 changes: 22 additions & 2 deletions internal/connection.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package internal

import (
"bufio"
"context"
"crypto/tls"
"fmt"
"net"
"slices"
"time"

"github.com/firefart/stunner/internal/helper"
Expand Down Expand Up @@ -73,11 +75,29 @@ func (s *Stun) SendAndReceive(ctx context.Context, logger DebugLogger, conn net.
if err != nil {
return nil, fmt.Errorf("Send: %w", err)
}
buffer, err := helper.ConnectionRead(ctx, conn, timeout)

// need this otherwise the read call is blocking forever
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("could not set read deadline: %v", err)
}

r := bufio.NewReader(conn)
// read the header first to get the message length
header, err := helper.ConnectionRead(ctx, r, headerSize, timeout)
if err != nil {
return nil, fmt.Errorf("ConnectionRead Header: %w", err)
}
headerParsed := parseHeader(header)
expectedPacketSize := int(headerParsed.MessageLength) // + headerSize
logger.Debugf("expectedPacketSize %d", expectedPacketSize)

// only read the message length and leave potential additional data on the connection
// for later read operations
buffer, err := helper.ConnectionRead(ctx, r, expectedPacketSize, timeout)
if err != nil {
return nil, fmt.Errorf("ConnectionRead: %w", err)
}
resp, err := fromBytes(buffer)
resp, err := fromBytes(slices.Concat(header, buffer))
if err != nil {
return nil, fmt.Errorf("fromBytes: %w", err)
}
Expand Down
31 changes: 23 additions & 8 deletions internal/helper/connection.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package helper

import (
"bufio"
"context"
"errors"
"fmt"
Expand All @@ -11,26 +12,39 @@ import (

var ErrTimeout = errors.New("timeout occurred. you can try to increase the timeout if the server responds too slowly")

// ConnectionRead reads all data from a connection
func ConnectionRead(ctx context.Context, conn net.Conn, timeout time.Duration) ([]byte, error) {
// ConnectionReadAll reads all data from a connection
func ConnectionReadAll(ctx context.Context, conn net.Conn, timeout time.Duration) ([]byte, error) {
// need this otherwise the read call is blocking forever
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("could not set read deadline: %v", err)
}
return connectionRead(ctx, bufio.NewReader(conn), nil, timeout)
}

// ConnectionRead reads the data from the connection up to maxSizeToRead
func ConnectionRead(ctx context.Context, r *bufio.Reader, maxSizeToRead int, timeout time.Duration) ([]byte, error) {
return connectionRead(ctx, r, &maxSizeToRead, timeout)
}

func connectionRead(ctx context.Context, r *bufio.Reader, maxSizeToRead *int, timeout time.Duration) ([]byte, error) {
var ret []byte

ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// need this otherwise the read call is blocking forever
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("could not set read deadline: %v", err)
bufLen := 1024
if maxSizeToRead != nil && *maxSizeToRead < bufLen {
bufLen = *maxSizeToRead
}

bufLen := 1024
buf := make([]byte, bufLen)
alreadyRead := 0
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
i, err := conn.Read(buf)
i, err := r.Read(buf)
if err != nil {
if err != io.EOF {
// also return read data on timeout so caller can use it
Expand All @@ -41,9 +55,10 @@ func ConnectionRead(ctx context.Context, conn net.Conn, timeout time.Duration) (
}
return ret, nil
}
alreadyRead += i
ret = append(ret, buf[:i]...)
// we've read all data, bail out
if i < bufLen {
if i < bufLen || (maxSizeToRead != nil && (alreadyRead >= *maxSizeToRead)) {
return ret, nil
}
}
Expand Down

0 comments on commit 0d246a4

Please sign in to comment.