diff --git a/client_truncate.go b/client_truncate.go new file mode 100644 index 0000000..50a9a38 --- /dev/null +++ b/client_truncate.go @@ -0,0 +1,37 @@ +package dns + +import "github.com/miekg/dns" + +func TruncateDNSMessage(request *dns.Msg, response *dns.Msg) (*dns.Msg, int) { + maxLen := 512 + if edns0Option := request.IsEdns0(); edns0Option != nil { + if udpSize := int(edns0Option.UDPSize()); udpSize > 0 { + maxLen = udpSize + } + } + return truncateDNSMessage(response, maxLen) +} + +func truncateDNSMessage(response *dns.Msg, maxLen int) (*dns.Msg, int) { + responseLen := response.Len() + if responseLen <= maxLen { + return response, responseLen + } + newResponse := *response + response = &newResponse + response.Compress = true + responseLen = response.Len() + if responseLen <= maxLen { + return response, responseLen + } + for len(response.Answer) > 0 && responseLen > maxLen { + response.Answer = response.Answer[:len(response.Answer)-1] + response.Truncated = true + responseLen = response.Len() + } + if responseLen > maxLen { + response.Ns = nil + response.Extra = nil + } + return response, response.Len() +} diff --git a/transport_https.go b/transport_https.go index 885a053..bde93b3 100644 --- a/transport_https.go +++ b/transport_https.go @@ -57,6 +57,7 @@ func (t *HTTPSTransport) Start() error { func (t *HTTPSTransport) Reset() { t.transport.CloseIdleConnections() + t.transport = t.transport.Clone() } func (t *HTTPSTransport) Close() error { diff --git a/transport_udp.go b/transport_udp.go index cb9e5a1..7bd349b 100644 --- a/transport_udp.go +++ b/transport_udp.go @@ -8,6 +8,7 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" "github.com/miekg/dns" @@ -25,6 +26,8 @@ func init() { type UDPTransport struct { myTransportAdapter + tcpTransport *TCPTransport + logger logger.ContextLogger } func NewUDPTransport(options TransportOptions) (*UDPTransport, error) { @@ -40,13 +43,31 @@ func NewUDPTransport(options TransportOptions) (*UDPTransport, error) { if serverAddr.Port == 0 { serverAddr.Port = 53 } + tcpTransport, err := NewTCPTransport(options) + if err != nil { + return nil, err + } transport := &UDPTransport{ newAdapter(options, serverAddr), + tcpTransport, + options.Logger, } transport.handler = transport return transport, nil } +func (t *UDPTransport) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) { + response, err := t.myTransportAdapter.Exchange(ctx, message) + if err != nil { + return nil, err + } + if response.Truncated { + t.logger.InfoContext(ctx, "response truncated, retrying with TCP") + return t.tcpTransport.Exchange(ctx, message) + } + return response, nil +} + func (t *UDPTransport) DialContext(ctx context.Context) (net.Conn, error) { return t.dialer.DialContext(ctx, "udp", t.serverAddr) }