Skip to content

Commit

Permalink
fix #47
Browse files Browse the repository at this point in the history
  • Loading branch information
firefart committed Feb 8, 2024
1 parent e709bda commit c0b7fc9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 159 deletions.
13 changes: 12 additions & 1 deletion internal/helper/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package helper
import (
"context"
"errors"
"fmt"
"io"
"net"
"time"
Expand All @@ -17,13 +18,18 @@ func ConnectionRead(ctx context.Context, conn net.Conn, timeout time.Duration) (
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// need this otherwise the read call is blocking forever
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("could not set read deadline: %v", err)
}

bufLen := 1024
buf := make([]byte, bufLen)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
buf := make([]byte, bufLen)
i, err := conn.Read(buf)
if err != nil {
if err != io.EOF {
Expand Down Expand Up @@ -53,6 +59,11 @@ func ConnectionWrite(ctx context.Context, conn net.Conn, data []byte, timeout ti
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// need this otherwise the read call is blocking forever
if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
return fmt.Errorf("could not set write deadline: %v", err)
}

for {
select {
case <-ctx.Done():
Expand Down
45 changes: 43 additions & 2 deletions internal/socksimplementations/socksturntcphandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,42 @@ func (s *SocksTurnTCPHandler) Refresh(ctx context.Context) {

const bufferLength = 1024 * 100

type readDeadline interface {
SetReadDeadline(time.Time) error
}
type writeDeadline interface {
SetWriteDeadline(time.Time) error
}

// ReadFromClient is used to copy data
func (s *SocksTurnTCPHandler) ReadFromClient(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error {
for {
// anonymous func for defer
// this might not be the fastest, but it does the trick
// in this case the timeout is per buffer read/write to support long
// running downloads.
err := func() error {
ctx, cancel := context.WithTimeout(ctx, s.Timeout)
timeOut := time.Now().Add(s.Timeout)

ctx, cancel := context.WithDeadline(ctx, timeOut)
defer cancel()

select {
case <-ctx.Done():
return ctx.Err()
default:
if c, ok := remote.(writeDeadline); ok {
if err := c.SetWriteDeadline(timeOut); err != nil {
return fmt.Errorf("could not set write deadline on remote: %v", err)
}
}

if c, ok := client.(readDeadline); ok {
if err := c.SetReadDeadline(timeOut); err != nil {
return fmt.Errorf("could not set read deadline on client: %v", err)
}
}

i, err := io.CopyN(remote, client, bufferLength)
if errors.Is(err, io.EOF) {
return nil
Expand All @@ -150,13 +174,30 @@ func (s *SocksTurnTCPHandler) ReadFromRemote(ctx context.Context, remote io.Read
for {
// anonymous func for defer
// this might not be the fastest, but it does the trick
// in this case the timeout is per buffer read/write to support long
// running downloads.
err := func() error {
ctx, cancel := context.WithTimeout(ctx, s.Timeout)
timeOut := time.Now().Add(s.Timeout)

ctx, cancel := context.WithDeadline(ctx, timeOut)
defer cancel()

select {
case <-ctx.Done():
return ctx.Err()
default:
if c, ok := client.(writeDeadline); ok {
if err := c.SetWriteDeadline(timeOut); err != nil {
return fmt.Errorf("could not set write deadline on client: %v", err)
}
}

if c, ok := remote.(readDeadline); ok {
if err := c.SetReadDeadline(timeOut); err != nil {
return fmt.Errorf("could not set read deadline on remote: %v", err)
}
}

i, err := io.CopyN(client, remote, bufferLength)
if errors.Is(err, io.EOF) {
return nil
Expand Down
152 changes: 0 additions & 152 deletions internal/socksimplementations/socksturnudphandler.go

This file was deleted.

8 changes: 4 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func main() {
&cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"},
&cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"},
&cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"},
&cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"},
&cli.DurationFlag{Name: "timeout", Value: 2 * time.Second, Usage: "connect timeout to turn server"},
},
Before: func(ctx *cli.Context) error {
if ctx.Bool("debug") {
Expand Down Expand Up @@ -86,7 +86,7 @@ func main() {
&cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"},
&cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"},
&cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"},
&cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"},
&cli.DurationFlag{Name: "timeout", Value: 2 * time.Second, Usage: "connect timeout to turn server"},
&cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"},
&cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"},
},
Expand Down Expand Up @@ -168,7 +168,7 @@ func main() {
&cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"},
&cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"},
&cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"},
&cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"},
&cli.DurationFlag{Name: "timeout", Value: 2 * time.Second, Usage: "connect timeout to turn server"},
&cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"},
&cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"},
&cli.StringFlag{Name: "target", Aliases: []string{"t"}, Required: true, Usage: "Target to leak memory to in the form host:port. Should be a public server under your control"},
Expand Down Expand Up @@ -231,7 +231,7 @@ func main() {
&cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"},
&cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"},
&cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"},
&cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"},
&cli.DurationFlag{Name: "timeout", Value: 2 * time.Second, Usage: "connect timeout to turn server"},
&cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"},
&cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"},
},
Expand Down

0 comments on commit c0b7fc9

Please sign in to comment.