Skip to content

Commit

Permalink
Add rejected DNS response cache support
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Feb 14, 2024
1 parent 8dc88e9 commit 9a328d8
Showing 1 changed file with 87 additions and 12 deletions.
99 changes: 87 additions & 12 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,29 @@ import (
const DefaultTTL = 600

var (
ErrNoRawSupport = E.New("no raw query support by current transport")
ErrNotCached = E.New("not cached")
ErrResponseRejected = E.New("response rejected")
ErrNoRawSupport = E.New("no raw query support by current transport")
ErrNotCached = E.New("not cached")
ErrResponseRejected = E.New("response rejected")
ErrResponseRejectedCached = E.Extend(ErrResponseRejected, "cached")
)

type Client struct {
disableCache bool
disableExpire bool
independentCache bool
rdrc RDRCStore
initRDRCFunc func() RDRCStore
logger logger.ContextLogger
cache *cache.LruCache[dns.Question, *dns.Msg]
transportCache *cache.LruCache[transportCacheKey, *dns.Msg]
}

type RDRCStore interface {
LoadRDRC(transportName string, qName string) (rejected bool)
SaveRDRC(transportName string, qName string) error
SaveRDRCAsync(transportName string, qName string, logger logger.Logger)
}

type transportCacheKey struct {
dns.Question
transportName string
Expand All @@ -42,6 +51,7 @@ type ClientOptions struct {
DisableCache bool
DisableExpire bool
IndependentCache bool
RDRC func() RDRCStore
Logger logger.ContextLogger
}

Expand All @@ -50,6 +60,7 @@ func NewClient(options ClientOptions) *Client {
disableCache: options.DisableCache,
disableExpire: options.DisableExpire,
independentCache: options.IndependentCache,
initRDRCFunc: options.RDRC,
logger: options.Logger,
}
if !client.disableCache {
Expand All @@ -62,6 +73,12 @@ func NewClient(options ClientOptions) *Client {
return client
}

func (c *Client) Start() {
if c.initRDRCFunc != nil {
c.rdrc = c.initRDRCFunc()
}
}

func (c *Client) Exchange(ctx context.Context, transport Transport, message *dns.Msg, strategy DomainStrategy) (*dns.Msg, error) {
return c.ExchangeWithResponseCheck(ctx, transport, message, strategy, nil)
}
Expand Down Expand Up @@ -121,11 +138,20 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
if loaded {
SetClientSubnet(message, clientSubnet, true)
}
if responseChecker != nil && c.rdrc != nil {
rejected := c.rdrc.LoadRDRC(transport.Name(), question.Name)
if rejected {
return nil, ErrResponseRejectedCached
}
}
response, err := transport.Exchange(ctx, message)
if err != nil {
return nil, err
}
if responseChecker != nil && !responseChecker(response) {
if c.rdrc != nil {
c.rdrc.SaveRDRCAsync(transport.Name(), question.Name, c.logger)
}
return response, ErrResponseRejected
}
var timeToLive int
Expand Down Expand Up @@ -154,14 +180,6 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
return response, err
}

func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) {
response, cached := c.exchangeCache(ctx, message)
if !cached {
return nil, false
}
return response, true
}

func (c *Client) Lookup(ctx context.Context, transport Transport, domain string, strategy DomainStrategy) ([]netip.Addr, error) {
return c.LookupWithResponseCheck(ctx, transport, domain, strategy, nil)
}
Expand Down Expand Up @@ -238,12 +256,21 @@ func (c *Client) LookupWithResponseCheck(ctx context.Context, transport Transpor
}
}
}
if responseChecker != nil && c.rdrc != nil {
rejected := c.rdrc.LoadRDRC(transport.Name(), dnsName)
if rejected {
return nil, ErrResponseRejectedCached
}
}
var rCode int
response, err := transport.Lookup(ctx, domain, strategy)
if err != nil {
return nil, wrapError(err)
}
if responseChecker != nil && !responseChecker(response) {
if c.rdrc != nil {
c.rdrc.SaveRDRCAsync(transport.Name(), dnsName, c.logger)
}
return response, ErrResponseRejected
}
header := dns.MsgHdr{
Expand Down Expand Up @@ -326,7 +353,55 @@ func (c *Client) ClearCache() {
}
}

func (c *Client) exchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) {
func (c *Client) LookupCache(ctx context.Context, domain string, strategy DomainStrategy) ([]netip.Addr, bool) {
if c.independentCache {
return nil, false
}
disableCache := c.disableCache || DisableCacheFromContext(ctx)
if disableCache {
return nil, false
}
if dns.IsFqdn(domain) {
domain = domain[:len(domain)-1]
}
dnsName := dns.Fqdn(domain)
if strategy == DomainStrategyUseIPv4 {
response, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}, nil)
if err != ErrNotCached {
return response, true
}
} else if strategy == DomainStrategyUseIPv6 {
response, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}, nil)
if err != ErrNotCached {
return response, true
}
} else {
response4, _ := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}, nil)
response6, _ := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}, nil)
if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), true
}
}
return nil, false
}

func (c *Client) ExchangeCache(ctx context.Context, message *dns.Msg) (*dns.Msg, bool) {
if c.independentCache || len(message.Question) != 1 {
return nil, false
}
Expand Down

0 comments on commit 9a328d8

Please sign in to comment.