diff --git a/.gitignore b/.gitignore index f7f8ac3..f1298ea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /.idea/ /vendor/ +.DS_Store diff --git a/client.go b/client.go index e369906..022fd77 100644 --- a/client.go +++ b/client.go @@ -8,11 +8,12 @@ import ( "time" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/task" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" "github.com/miekg/dns" ) @@ -37,8 +38,8 @@ type Client struct { rdrc RDRCStore initRDRCFunc func() RDRCStore logger logger.ContextLogger - cache *cache.LruCache[dns.Question, *dns.Msg] - transportCache *cache.LruCache[transportCacheKey, *dns.Msg] + cache freelru.Cache[dns.Question, *dns.Msg] + transportCache freelru.Cache[transportCacheKey, *dns.Msg] } type RDRCStore interface { @@ -57,6 +58,7 @@ type ClientOptions struct { DisableCache bool DisableExpire bool IndependentCache bool + CacheCapacity uint32 RDRC func() RDRCStore Logger logger.ContextLogger } @@ -73,11 +75,15 @@ func NewClient(options ClientOptions) *Client { if client.timeout == 0 { client.timeout = DefaultTimeout } + cacheCapacity := options.CacheCapacity + if cacheCapacity < 1024 { + cacheCapacity = 1024 + } if !client.disableCache { if !client.independentCache { - client.cache = cache.New[dns.Question, *dns.Msg]() + client.cache = common.Must1(freelru.NewSharded[dns.Question, *dns.Msg](cacheCapacity, maphash.NewHasher[dns.Question]().Hash32)) } else { - client.transportCache = cache.New[transportCacheKey, *dns.Msg]() + client.transportCache = common.Must1(freelru.NewSharded[transportCacheKey, *dns.Msg](cacheCapacity, maphash.NewHasher[transportCacheKey]().Hash32)) } } return client @@ -89,11 +95,11 @@ func (c *Client) Start() { } } -func (c *Client) Exchange(ctx context.Context, transport Transport, message *dns.Msg, strategy DomainStrategy) (*dns.Msg, error) { - return c.ExchangeWithResponseCheck(ctx, transport, message, strategy, nil) +func (c *Client) Exchange(ctx context.Context, transport Transport, message *dns.Msg, options QueryOptions) (*dns.Msg, error) { + return c.ExchangeWithResponseCheck(ctx, transport, message, options, nil) } -func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transport, message *dns.Msg, strategy DomainStrategy, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) { +func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transport, message *dns.Msg, options QueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) { if len(message.Question) == 0 { if c.logger != nil { c.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) @@ -109,15 +115,14 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return &responseMessage, nil } question := message.Question[0] - clientSubnet, clientSubnetLoaded := ClientSubnetFromContext(ctx) - if clientSubnetLoaded { - message = SetClientSubnet(message, clientSubnet, true) + if options.ClientSubnet.IsValid() { + message = SetClientSubnet(message, options.ClientSubnet, true) } isSimpleRequest := len(message.Question) == 1 && len(message.Ns) == 0 && len(message.Extra) == 0 && - !clientSubnetLoaded - disableCache := !isSimpleRequest || c.disableCache || DisableCacheFromContext(ctx) + !options.ClientSubnet.IsValid() + disableCache := !isSimpleRequest || c.disableCache || options.DisableCache if !disableCache { response, ttl := c.loadResponse(question, transport) if response != nil { @@ -126,7 +131,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return response, nil } } - if question.Qtype == dns.TypeA && strategy == DomainStrategyUseIPv6 || question.Qtype == dns.TypeAAAA && strategy == DomainStrategyUseIPv4 { + if question.Qtype == dns.TypeA && options.Strategy == DomainStrategyUseIPv6 || question.Qtype == dns.TypeAAAA && options.Strategy == DomainStrategyUseIPv4 { responseMessage := dns.Msg{ MsgHdr: dns.MsgHdr{ Id: message.Id, @@ -142,7 +147,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp } if !transport.Raw() { if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { - return c.exchangeToLookup(ctx, transport, message, question) + return c.exchangeToLookup(ctx, transport, message, question, options) } return nil, ErrNoRawSupport } @@ -171,7 +176,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return response, ErrResponseRejected } if question.Qtype == dns.TypeHTTPS { - if strategy == DomainStrategyUseIPv4 || strategy == DomainStrategyUseIPv6 { + if options.Strategy == DomainStrategyUseIPv4 || options.Strategy == DomainStrategyUseIPv6 { for _, rr := range response.Answer { https, isHTTPS := rr.(*dns.HTTPS) if !isHTTPS { @@ -179,7 +184,7 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp } content := https.SVCB content.Value = common.Filter(content.Value, func(it dns.SVCBKeyValue) bool { - if strategy == DomainStrategyUseIPv4 { + if options.Strategy == DomainStrategyUseIPv4 { return it.Key() != dns.SVCB_IPV6HINT } else { return it.Key() != dns.SVCB_IPV4HINT @@ -197,8 +202,8 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp } } } - if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded { - timeToLive = int(rewriteTTL) + if options.RewriteTTL != nil { + timeToLive = int(*options.RewriteTTL) } for _, recordList := range [][]dns.RR{response.Answer, response.Ns, response.Extra} { for _, record := range recordList { @@ -213,26 +218,26 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp return response, err } -func (c *Client) Lookup(ctx context.Context, transport Transport, domain string, strategy DomainStrategy) ([]netip.Addr, error) { - return c.LookupWithResponseCheck(ctx, transport, domain, strategy, nil) +func (c *Client) Lookup(ctx context.Context, transport Transport, domain string, options QueryOptions) ([]netip.Addr, error) { + return c.LookupWithResponseCheck(ctx, transport, domain, options, nil) } -func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transport, domain string, strategy DomainStrategy, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { +func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transport, domain string, options QueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { if dns.IsFqdn(domain) { domain = domain[:len(domain)-1] } dnsName := dns.Fqdn(domain) if transport.Raw() { - if strategy == DomainStrategyUseIPv4 { - return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, strategy, responseChecker) - } else if strategy == DomainStrategyUseIPv6 { - return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, strategy, responseChecker) + if options.Strategy == DomainStrategyUseIPv4 { + return c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker) + } else if options.Strategy == DomainStrategyUseIPv6 { + return c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker) } var response4 []netip.Addr var response6 []netip.Addr var group task.Group group.Append("exchange4", func(ctx context.Context) error { - response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, strategy, responseChecker) + response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeA, options, responseChecker) if err != nil { return err } @@ -240,7 +245,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor return nil }) group.Append("exchange6", func(ctx context.Context) error { - response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, strategy, responseChecker) + response, err := c.lookupToExchange(ctx, transport, dnsName, dns.TypeAAAA, options, responseChecker) if err != nil { return err } @@ -251,11 +256,11 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor if len(response4) == 0 && len(response6) == 0 { return nil, err } - return sortAddresses(response4, response6, strategy), nil + return sortAddresses(response4, response6, options.Strategy), nil } - disableCache := c.disableCache || DisableCacheFromContext(ctx) + disableCache := c.disableCache || options.DisableCache if !disableCache { - if strategy == DomainStrategyUseIPv4 { + if options.Strategy == DomainStrategyUseIPv4 { response, err := c.questionCache(dns.Question{ Name: dnsName, Qtype: dns.TypeA, @@ -264,7 +269,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor if err != ErrNotCached { return response, err } - } else if strategy == DomainStrategyUseIPv6 { + } else if options.Strategy == DomainStrategyUseIPv6 { response, err := c.questionCache(dns.Question{ Name: dnsName, Qtype: dns.TypeAAAA, @@ -285,16 +290,16 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor Qclass: dns.ClassINET, }, transport) if len(response4) > 0 || len(response6) > 0 { - return sortAddresses(response4, response6, strategy), nil + return sortAddresses(response4, response6, options.Strategy), nil } } } if responseChecker != nil && c.rdrc != nil { var rejected bool - if strategy != DomainStrategyUseIPv6 { + if options.Strategy != DomainStrategyUseIPv6 { rejected = c.rdrc.LoadRDRC(transport.Name(), dnsName, dns.TypeA) } - if !rejected && strategy != DomainStrategyUseIPv4 { + if !rejected && options.Strategy != DomainStrategyUseIPv4 { rejected = c.rdrc.LoadRDRC(transport.Name(), dnsName, dns.TypeAAAA) } if rejected { @@ -303,7 +308,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } ctx, cancel := context.WithTimeout(ctx, c.timeout) var rCode int - response, err := transport.Lookup(ctx, domain, strategy) + response, err := transport.Lookup(ctx, domain, options.Strategy) cancel() if err != nil { return nil, wrapError(err) @@ -329,12 +334,12 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } if !disableCache { var timeToLive uint32 - if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded { - timeToLive = rewriteTTL + if options.RewriteTTL != nil { + timeToLive = *options.RewriteTTL } else { timeToLive = DefaultTTL } - if strategy != DomainStrategyUseIPv6 { + if options.Strategy != DomainStrategyUseIPv6 { question4 := dns.Question{ Name: dnsName, Qtype: dns.TypeA, @@ -362,7 +367,7 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor } c.storeCache(transport, question4, message4, int(timeToLive)) } - if strategy != DomainStrategyUseIPv4 { + if options.Strategy != DomainStrategyUseIPv4 { question6 := dns.Question{ Name: dnsName, Qtype: dns.TypeAAAA, @@ -396,19 +401,15 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor func (c *Client) ClearCache() { if c.cache != nil { - c.cache.Clear() + c.cache.Purge() } if c.transportCache != nil { - c.transportCache.Clear() + c.transportCache.Purge() } } func (c *Client) LookupCache(ctx context.Context, domain string, strategy DomainStrategy) ([]netip.Addr, bool) { - if c.independentCache { - return nil, false - } - disableCache := c.disableCache || DisableCacheFromContext(ctx) - if disableCache { + if c.disableCache || c.independentCache { return nil, false } if dns.IsFqdn(domain) { @@ -452,19 +453,10 @@ func (c *Client) LookupCache(ctx context.Context, domain string, strategy Domain } func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) { - if c.independentCache || len(message.Question) != 1 { + if c.disableCache || c.independentCache || len(message.Question) != 1 { return nil, false } question := message.Question[0] - _, clientSubnetLoaded := transportNameFromContext(ctx) - isSimpleRequest := len(message.Question) == 1 && - len(message.Ns) == 0 && - len(message.Extra) == 0 && - !clientSubnetLoaded - disableCache := !isSimpleRequest || c.disableCache || DisableCacheFromContext(ctx) - if disableCache { - return nil, false - } response, ttl := c.loadResponse(question, nil) if response == nil { return nil, false @@ -488,54 +480,52 @@ func (c *Client) storeCache(transport Transport, question dns.Question, message } if c.disableExpire { if !c.independentCache { - c.cache.Store(question, message) + c.cache.Add(question, message) } else { - c.transportCache.Store(transportCacheKey{ + c.transportCache.Add(transportCacheKey{ Question: question, transportName: transport.Name(), }, message) } return } - expireAt := time.Now().Add(time.Second * time.Duration(timeToLive)) if !c.independentCache { - c.cache.StoreWithExpire(question, message, expireAt) + c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive)) } else { - c.transportCache.StoreWithExpire(transportCacheKey{ + c.transportCache.AddWithLifetime(transportCacheKey{ Question: question, transportName: transport.Name(), - }, message, expireAt) + }, message, time.Second*time.Duration(timeToLive)) } } -func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, message *dns.Msg, question dns.Question) (*dns.Msg, error) { +func (c *Client) exchangeToLookup(ctx context.Context, transport Transport, message *dns.Msg, question dns.Question, options QueryOptions) (*dns.Msg, error) { domain := question.Name - var strategy DomainStrategy if question.Qtype == dns.TypeA { - strategy = DomainStrategyUseIPv4 + options.Strategy = DomainStrategyUseIPv4 } else { - strategy = DomainStrategyUseIPv6 + options.Strategy = DomainStrategyUseIPv6 } - result, err := c.Lookup(ctx, transport, domain, strategy) + result, err := c.Lookup(ctx, transport, domain, options) if err != nil { return nil, wrapError(err) } var timeToLive uint32 - if rewriteTTL, loaded := RewriteTTLFromContext(ctx); loaded { - timeToLive = rewriteTTL + if options.RewriteTTL != nil { + timeToLive = *options.RewriteTTL } else { timeToLive = DefaultTTL } return FixedResponse(message.Id, question, result, timeToLive), nil } -func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name string, qType uint16, strategy DomainStrategy, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { +func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name string, qType uint16, options QueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { question := dns.Question{ Name: name, Qtype: qType, Qclass: dns.ClassINET, } - disableCache := c.disableCache || DisableCacheFromContext(ctx) + disableCache := c.disableCache || options.DisableCache if !disableCache { cachedAddresses, err := c.questionCache(question, transport) if err != ErrNotCached { @@ -553,7 +543,7 @@ func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name err error ) if responseChecker != nil { - response, err = c.ExchangeWithResponseCheck(ctx, transport, &message, strategy, func(response *dns.Msg) bool { + response, err = c.ExchangeWithResponseCheck(ctx, transport, &message, options, func(response *dns.Msg) bool { addresses, addrErr := MessageToAddresses(response) if addrErr != nil { return false @@ -561,7 +551,7 @@ func (c *Client) lookupToExchange(ctx context.Context, transport Transport, name return responseChecker(addresses) }) } else { - response, err = c.Exchange(ctx, transport, &message, strategy) + response, err = c.Exchange(ctx, transport, &message, options) } if err != nil { return nil, err @@ -584,9 +574,9 @@ func (c *Client) loadResponse(question dns.Question, transport Transport) (*dns. ) if c.disableExpire { if !c.independentCache { - response, loaded = c.cache.Load(question) + response, loaded = c.cache.Get(question) } else { - response, loaded = c.transportCache.Load(transportCacheKey{ + response, loaded = c.transportCache.Get(transportCacheKey{ Question: question, transportName: transport.Name(), }) @@ -598,9 +588,9 @@ func (c *Client) loadResponse(question dns.Question, transport Transport) (*dns. } else { var expireAt time.Time if !c.independentCache { - response, expireAt, loaded = c.cache.LoadWithExpire(question) + response, expireAt, loaded = c.cache.GetWithLifetime(question) } else { - response, expireAt, loaded = c.transportCache.LoadWithExpire(transportCacheKey{ + response, expireAt, loaded = c.transportCache.GetWithLifetime(transportCacheKey{ Question: question, transportName: transport.Name(), }) @@ -611,9 +601,9 @@ func (c *Client) loadResponse(question dns.Question, transport Transport) (*dns. timeNow := time.Now() if timeNow.After(expireAt) { if !c.independentCache { - c.cache.Delete(question) + c.cache.Remove(question) } else { - c.transportCache.Delete(transportCacheKey{ + c.transportCache.Remove(transportCacheKey{ Question: question, transportName: transport.Name(), }) @@ -684,3 +674,14 @@ func wrapError(err error) error { } return err } + +type transportKey struct{} + +func contextWithTransportName(ctx context.Context, transportName string) context.Context { + return context.WithValue(ctx, transportKey{}, transportName) +} + +func transportNameFromContext(ctx context.Context) (string, bool) { + value, loaded := ctx.Value(transportKey{}).(string) + return value, loaded +} diff --git a/client_options.go b/client_options.go new file mode 100644 index 0000000..1eccb26 --- /dev/null +++ b/client_options.go @@ -0,0 +1,10 @@ +package dns + +import "net/netip" + +type QueryOptions struct { + Strategy DomainStrategy + DisableCache bool + RewriteTTL *uint32 + ClientSubnet netip.Prefix +} diff --git a/dialer.go b/dialer.go index 23a89ab..65ac5fe 100644 --- a/dialer.go +++ b/dialer.go @@ -25,7 +25,9 @@ func (d *DialerWrapper) DialContext(ctx context.Context, network string, destina if destination.IsIP() { return d.dialer.DialContext(ctx, network, destination) } - addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, d.strategy) + addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, QueryOptions{ + Strategy: d.strategy, + }) if err != nil { return nil, err } @@ -36,7 +38,9 @@ func (d *DialerWrapper) ListenPacket(ctx context.Context, destination M.Socksadd if destination.IsIP() { return d.dialer.ListenPacket(ctx, destination) } - addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, d.strategy) + addresses, err := d.client.Lookup(ctx, d.transport, destination.Fqdn, QueryOptions{ + Strategy: d.strategy, + }) if err != nil { return nil, err } diff --git a/extensions.go b/extensions.go deleted file mode 100644 index 86bbbd5..0000000 --- a/extensions.go +++ /dev/null @@ -1,56 +0,0 @@ -package dns - -import ( - "context" - "net/netip" -) - -type disableCacheKey struct{} - -func ContextWithDisableCache(ctx context.Context, val bool) context.Context { - return context.WithValue(ctx, (*disableCacheKey)(nil), val) -} - -func DisableCacheFromContext(ctx context.Context) bool { - val := ctx.Value((*disableCacheKey)(nil)) - if val == nil { - return false - } - return val.(bool) -} - -type rewriteTTLKey struct{} - -func ContextWithRewriteTTL(ctx context.Context, val uint32) context.Context { - return context.WithValue(ctx, (*rewriteTTLKey)(nil), val) -} - -func RewriteTTLFromContext(ctx context.Context) (uint32, bool) { - val := ctx.Value((*rewriteTTLKey)(nil)) - if val == nil { - return 0, false - } - return val.(uint32), true -} - -type transportKey struct{} - -func contextWithTransportName(ctx context.Context, transportName string) context.Context { - return context.WithValue(ctx, transportKey{}, transportName) -} - -func transportNameFromContext(ctx context.Context) (string, bool) { - value, loaded := ctx.Value(transportKey{}).(string) - return value, loaded -} - -type clientSubnetKey struct{} - -func ContextWithClientSubnet(ctx context.Context, clientSubnet netip.Prefix) context.Context { - return context.WithValue(ctx, clientSubnetKey{}, clientSubnet) -} - -func ClientSubnetFromContext(ctx context.Context) (netip.Prefix, bool) { - clientSubnet, ok := ctx.Value(clientSubnetKey{}).(netip.Prefix) - return clientSubnet, ok -} diff --git a/go.mod b/go.mod index 2daf4ea..4b4537e 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( github.com/miekg/dns v1.1.62 github.com/sagernet/quic-go v0.47.0-beta.2 - github.com/sagernet/sing v0.5.0 + github.com/sagernet/sing v0.6.0-alpha.10 github.com/stretchr/testify v1.9.0 ) diff --git a/go.sum b/go.sum index 41ca92a..2664cc3 100644 --- a/go.sum +++ b/go.sum @@ -21,8 +21,8 @@ github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5 github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/sagernet/quic-go v0.47.0-beta.2 h1:1tCGWFOSaXIeuQaHrwOMJIYvlupjTcaVInGQw5ArULU= github.com/sagernet/quic-go v0.47.0-beta.2/go.mod h1:bLVKvElSEMNv7pu7SZHscW02TYigzQ5lQu3Nh4wNh8Q= -github.com/sagernet/sing v0.5.0 h1:soo2wVwLcieKWWKIksFNK6CCAojUgAppqQVwyRYGkEM= -github.com/sagernet/sing v0.5.0/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.0-alpha.10 h1:gIUiFof6SDDcAg3m3pUOshOPZBLC7z9VCgDt4Tzs24g= +github.com/sagernet/sing v0.6.0-alpha.10/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/quic/transport_quic.go b/quic/transport_quic.go index b5b765f..8d86afb 100644 --- a/quic/transport_quic.go +++ b/quic/transport_quic.go @@ -159,6 +159,7 @@ func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn quic.C if err != nil { return nil, err } + _ = stream.Close() buffer.Reset() _, err = buffer.ReadFullFrom(stream, 2) if err != nil { diff --git a/transport_test.go b/transport_test.go index bdaf0a3..0ac29ed 100644 --- a/transport_test.go +++ b/transport_test.go @@ -37,7 +37,9 @@ func TestTransports(t *testing.T) { client := dns.NewClient(dns.ClientOptions{ Logger: logger.NOP(), }) - addresses, err := client.Lookup(context.Background(), transport, "cloudflare.com", dns.DomainStrategyAsIS) + addresses, err := client.Lookup(context.Background(), transport, "cloudflare.com", dns.QueryOptions{ + Strategy: dns.DomainStrategyUseIPv4, + }) require.NoError(t, err) require.NotEmpty(t, addresses) })