Skip to content

Commit 06bd3aa

Browse files
committed
Implement resolve(server)
1 parent 5944fc3 commit 06bd3aa

File tree

3 files changed

+62
-41
lines changed

3 files changed

+62
-41
lines changed

adapter/inbound.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ type InboundContext struct {
5050
// Deprecated
5151
InboundOptions option.InboundOptions
5252
UDPDisableDomainUnmapping bool
53-
DestinationAddresses []netip.Addr
54-
SourceGeoIPCode string
55-
GeoIPCode string
56-
ProcessInfo *process.Info
57-
QueryType uint16
58-
FakeIP bool
53+
DNSServer string
54+
55+
DestinationAddresses []netip.Addr
56+
SourceGeoIPCode string
57+
GeoIPCode string
58+
ProcessInfo *process.Info
59+
QueryType uint16
60+
FakeIP bool
5961

6062
// rule cache
6163

route/route.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ func (r *Router) actionSniff(
584584

585585
func (r *Router) actionResolve(ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionResolve) error {
586586
if metadata.Destination.IsFqdn() {
587-
// TODO: check if WithContext is necessary
587+
metadata.DNSServer = action.Server
588588
addresses, err := r.Lookup(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn, action.Strategy)
589589
if err != nil {
590590
return err

route/route_dns.go

+53-34
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,20 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
185185
cached bool
186186
err error
187187
)
188+
printResult := func() {
189+
if err != nil {
190+
if errors.Is(err, dns.ErrResponseRejectedCached) {
191+
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
192+
} else if errors.Is(err, dns.ErrResponseRejected) {
193+
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain)
194+
} else {
195+
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
196+
}
197+
} else if len(responseAddrs) == 0 {
198+
r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
199+
err = dns.RCodeNameError
200+
}
201+
}
188202
responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy)
189203
if cached {
190204
if len(responseAddrs) == 0 {
@@ -196,46 +210,51 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
196210
ctx, metadata := adapter.ExtendContext(ctx)
197211
metadata.Destination = M.Socksaddr{}
198212
metadata.Domain = domain
199-
var (
200-
transport dns.Transport
201-
options dns.QueryOptions
202-
rule adapter.DNSRule
203-
ruleIndex int
204-
)
205-
ruleIndex = -1
206-
for {
207-
dnsCtx := adapter.OverrideContext(ctx)
208-
var addressLimit bool
209-
transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true)
210-
if strategy != dns.DomainStrategyAsIS {
211-
options.Strategy = strategy
212-
}
213-
if rule != nil && rule.WithAddressLimit() {
214-
addressLimit = true
215-
responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {
216-
metadata.DestinationAddresses = responseAddrs
217-
return rule.MatchAddressLimit(metadata)
218-
})
219-
} else {
220-
addressLimit = false
221-
responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options)
213+
if metadata.DNSServer != "" {
214+
transport, loaded := r.transportMap[metadata.DNSServer]
215+
if !loaded {
216+
return nil, E.New("transport not found: ", metadata.DNSServer)
222217
}
223-
if err != nil {
224-
if errors.Is(err, dns.ErrResponseRejectedCached) {
225-
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
226-
} else if errors.Is(err, dns.ErrResponseRejected) {
227-
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain)
218+
if strategy == dns.DomainStrategyAsIS {
219+
if transportDomainStrategy, loaded := r.transportDomainStrategy[transport]; loaded {
220+
strategy = transportDomainStrategy
228221
} else {
229-
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
222+
strategy = r.defaultDomainStrategy
230223
}
231-
} else if len(responseAddrs) == 0 {
232-
r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
233-
err = dns.RCodeNameError
234224
}
235-
if !addressLimit || err == nil {
236-
break
225+
responseAddrs, err = r.dnsClient.Lookup(ctx, transport, domain, dns.QueryOptions{Strategy: strategy})
226+
} else {
227+
var (
228+
transport dns.Transport
229+
options dns.QueryOptions
230+
rule adapter.DNSRule
231+
ruleIndex int
232+
)
233+
ruleIndex = -1
234+
for {
235+
dnsCtx := adapter.OverrideContext(ctx)
236+
var addressLimit bool
237+
transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true)
238+
if strategy != dns.DomainStrategyAsIS {
239+
options.Strategy = strategy
240+
}
241+
if rule != nil && rule.WithAddressLimit() {
242+
addressLimit = true
243+
responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {
244+
metadata.DestinationAddresses = responseAddrs
245+
return rule.MatchAddressLimit(metadata)
246+
})
247+
} else {
248+
addressLimit = false
249+
responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options)
250+
}
251+
if !addressLimit || err == nil {
252+
break
253+
}
254+
printResult()
237255
}
238256
}
257+
printResult()
239258
if len(responseAddrs) > 0 {
240259
r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
241260
}

0 commit comments

Comments
 (0)