diff --git a/quic/transport_http3.go b/quic/transport_http3.go index c7dd9f8..b90693e 100644 --- a/quic/transport_http3.go +++ b/quic/transport_http3.go @@ -80,8 +80,9 @@ func (t *HTTP3Transport) Raw() bool { } func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - message.Id = 0 - rawMessage, err := message.Pack() + exMessage := *message + exMessage.Id = 0 + rawMessage, err := exMessage.Pack() if err != nil { return nil, err } diff --git a/quic/transport_quic.go b/quic/transport_quic.go index fb35058..eb701fd 100644 --- a/quic/transport_quic.go +++ b/quic/transport_quic.go @@ -121,7 +121,7 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, ) for i := 0; i < 2; i++ { conn, err = t.openConnection() - if conn == nil { + if err != nil { return nil, err } response, err = t.exchange(ctx, message, conn) @@ -138,8 +138,9 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, } func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn quic.Connection) (*mDNS.Msg, error) { - message.Id = 0 - rawMessage, err := message.Pack() + exMessage := *message + exMessage.Id = 0 + rawMessage, err := exMessage.Pack() if err != nil { return nil, err } diff --git a/transport_base.go b/transport_base.go index ae51842..8586cb6 100644 --- a/transport_base.go +++ b/transport_base.go @@ -114,35 +114,39 @@ func (t *myTransportAdapter) recvLoop(conn *dnsConnection) { func (t *myTransportAdapter) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) { messageId := message.Id - var response *dns.Msg - var err error - for attempts := 0; attempts < 3; attempts++ { - response, err = t.exchange(ctx, message) - if err != nil && !common.Done(ctx) { - continue + var ( + conn *dnsConnection + err error + ) + for attempts := 0; attempts < 2; attempts++ { + conn, err = t.open(t.ctx) + if err == nil { + break } - break } - if err == nil { - response.Id = messageId + if err != nil { + return nil, err } - return response, err -} - -func (t *myTransportAdapter) exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) { - conn, err := t.open(t.ctx) + response, err := t.exchange(ctx, conn, message) if err != nil { return nil, err } + response.Id = messageId + return response, nil +} + +func (t *myTransportAdapter) exchange(ctx context.Context, conn *dnsConnection, message *dns.Msg) (*dns.Msg, error) { + messageId := message.Id callback := make(chan *dns.Msg) + exMessage := *message conn.access.Lock() conn.queryId++ - message.Id = conn.queryId - conn.callbacks[message.Id] = callback + exMessage.Id = conn.queryId + conn.callbacks[exMessage.Id] = callback conn.access.Unlock() - defer t.cleanup(conn, message.Id, callback) + defer t.cleanup(conn, exMessage.Id, callback) conn.writeAccess.Lock() - err = t.handler.WriteMessage(conn, message) + err := t.handler.WriteMessage(conn, &exMessage) conn.writeAccess.Unlock() if err != nil { conn.cancel() @@ -150,6 +154,7 @@ func (t *myTransportAdapter) exchange(ctx context.Context, message *dns.Msg) (*d } select { case response := <-callback: + response.Id = messageId return response, nil case <-conn.ctx.Done(): return nil, E.Errors(conn.err, conn.ctx.Err()) diff --git a/transport_https.go b/transport_https.go index fd332f9..885a053 100644 --- a/transport_https.go +++ b/transport_https.go @@ -69,8 +69,9 @@ func (t *HTTPSTransport) Raw() bool { } func (t *HTTPSTransport) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) { - message.Id = 0 - rawMessage, err := message.Pack() + exMessage := *message + exMessage.Id = 0 + rawMessage, err := exMessage.Pack() if err != nil { return nil, err }