From 17fe860f7f397de0f8c881ba0b1143d94239bdf6 Mon Sep 17 00:00:00 2001 From: Mike Kosek Date: Sat, 17 Sep 2022 11:38:31 +0200 Subject: [PATCH 1/3] Put the connection right back in to allow the connection to be reused while requests are in flight --- upstream/upstream_dot.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/upstream/upstream_dot.go b/upstream/upstream_dot.go index 380478d34..daea925c2 100644 --- a/upstream/upstream_dot.go +++ b/upstream/upstream_dot.go @@ -54,6 +54,8 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { p.RLock() poolConn, err := p.pool.Get() + // Put the connection right back in to allow the connection to be reused while requests are in flight + p.pool.Put(poolConn) p.RUnlock() if err != nil { return nil, fmt.Errorf("getting connection to %s: %w", p.Address(), err) @@ -82,11 +84,6 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { logFinish(p.Address(), err) } - if err == nil { - p.RLock() - p.pool.Put(poolConn) - p.RUnlock() - } return reply, err } From 06c05563005b99f1560ec31fea41a844ad54b50a Mon Sep 17 00:00:00 2001 From: 42SK <42SK@users.noreply.github.com> Date: Fri, 30 Sep 2022 14:07:48 +0200 Subject: [PATCH 2/3] upstream: DoT: Add proper handling of out-of-order responses When processing multiple queries through a single DoT upstream, we might get a DNS ID mismatch if responses are received out of order (cf. PR #269). We fix this by storing all responses with unknown IDs in a map from which they can be retrieved later on. --- upstream/upstream_dot.go | 56 +++++++++++++++++++++++++--------- upstream/upstream_pool.go | 28 ++++++++++++----- upstream/upstream_pool_test.go | 22 ++++++------- 3 files changed, 73 insertions(+), 33 deletions(-) diff --git a/upstream/upstream_dot.go b/upstream/upstream_dot.go index daea925c2..7b1ebd4a9 100644 --- a/upstream/upstream_dot.go +++ b/upstream/upstream_dot.go @@ -2,8 +2,8 @@ package upstream import ( "fmt" - "net" "net/url" + "runtime" "sync" "github.com/AdguardTeam/golibs/errors" @@ -53,16 +53,16 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { } p.RLock() - poolConn, err := p.pool.Get() + poolConnAndStore, err := p.pool.Get() // Put the connection right back in to allow the connection to be reused while requests are in flight - p.pool.Put(poolConn) + p.pool.Put(poolConnAndStore) p.RUnlock() if err != nil { return nil, fmt.Errorf("getting connection to %s: %w", p.Address(), err) } logBegin(p.Address(), m) - reply, err = p.exchangeConn(poolConn, m) + reply, err = p.exchangeConn(poolConnAndStore, m) logFinish(p.Address(), err) if err != nil { log.Tracef("The TLS connection is expired due to %s", err) @@ -72,7 +72,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { // We are forcing creation of a new connection instead of calling Get() again // as there's no guarantee that other pooled connections are intact p.RLock() - poolConn, err = p.pool.Create() + poolConnAndStore, err = p.pool.Create() p.RUnlock() if err != nil { return nil, fmt.Errorf("creating new connection to %s: %w", p.Address(), err) @@ -80,36 +80,64 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (reply *dns.Msg, err error) { // Retry sending the DNS request logBegin(p.Address(), m) - reply, err = p.exchangeConn(poolConn, m) + reply, err = p.exchangeConn(poolConnAndStore, m) logFinish(p.Address(), err) } return reply, err } -func (p *dnsOverTLS) exchangeConn(conn net.Conn, m *dns.Msg) (reply *dns.Msg, err error) { +func (p *dnsOverTLS) exchangeConn(connAndStore *connAndStore, m *dns.Msg) (reply *dns.Msg, err error) { defer func() { if err == nil { return } - if cerr := conn.Close(); cerr != nil { + if cerr := connAndStore.conn.Close(); cerr != nil { err = &errors.Pair{Returned: err, Deferred: cerr} } }() - dnsConn := dns.Conn{Conn: conn} + dnsConn := dns.Conn{Conn: connAndStore.conn} err = dnsConn.WriteMsg(m) if err != nil { return nil, fmt.Errorf("sending request to %s: %w", p.Address(), err) } - reply, err = dnsConn.ReadMsg() - if err != nil { - return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err) - } else if reply.Id != m.Id { - err = dns.ErrId + // Since we might receive out-of-order responses when processing multiple queries through a single upstream (cf. + // PR #269), we will store all responses that don't match our DNS ID and retry until we find the response we are + // looking for (either by receiving it directly or by finding it in the stored responses). + responseFound := false + present := false + for !responseFound { + connAndStore.Lock() + + // has someone already received our response? + reply, present = connAndStore.store[m.Id] + if present { // matching response in store + log.Tracef("Found matching ID in store for request %d", m.Id) + delete(connAndStore.store, m.Id) // delete response from store + responseFound = true + } else { // no matching response in store + reply, err = dnsConn.ReadMsg() + if err != nil { + connAndStore.Unlock() + return nil, fmt.Errorf("reading response from %s: %w", p.Address(), err) + } else if reply.Id != m.Id { + // not the response we were looking for -> store it in the store + log.Tracef("Received unknown ID %d, storing in store for later use", reply.Id) + connAndStore.store[reply.Id] = reply + } else { + responseFound = true + } + } + connAndStore.Unlock() + + // yield to scheduler if we added something to the store + if !responseFound { + runtime.Gosched() + } } return reply, err diff --git a/upstream/upstream_pool.go b/upstream/upstream_pool.go index 9c3ddd3bc..5fc406a63 100644 --- a/upstream/upstream_pool.go +++ b/upstream/upstream_pool.go @@ -9,6 +9,7 @@ import ( "time" "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" ) // dialTimeout is the global timeout for establishing a TLS connection. @@ -36,16 +37,24 @@ type TLSPool struct { boot *bootstrapper // conns is the list of connections available in the pool. - conns []net.Conn + conns []*connAndStore // connsMutex protects conns. connsMutex sync.Mutex } +// connAndStore is a sturct that assigns a store for out-of-order responses to each connection. +// We need this to process multiple queries through a single upstream (cf. PR #269). +type connAndStore struct { + conn net.Conn + store map[uint16]*dns.Msg // needed to save out-of-order responses when reusing the connection + sync.Mutex // protects store +} + // Get gets a connection from the pool (if there's one available) or creates // a new TLS connection. -func (n *TLSPool) Get() (net.Conn, error) { +func (n *TLSPool) Get() (*connAndStore, error) { // Get the connection from the slice inside the lock. - var c net.Conn + var c *connAndStore n.connsMutex.Lock() num := len(n.conns) if num > 0 { @@ -57,11 +66,11 @@ func (n *TLSPool) Get() (net.Conn, error) { // If we got connection from the slice, update deadline and return it. if c != nil { - err := c.SetDeadline(time.Now().Add(dialTimeout)) + err := c.conn.SetDeadline(time.Now().Add(dialTimeout)) // If deadLine can't be updated it means that connection was already closed if err == nil { - log.Tracef("Returning existing connection to %s with updated deadLine", c.RemoteAddr()) + log.Tracef("Returning existing connection to %s with updated deadLine", c.conn.RemoteAddr()) return c, nil } } @@ -70,7 +79,7 @@ func (n *TLSPool) Get() (net.Conn, error) { } // Create creates a new connection for the pool (but not puts it there). -func (n *TLSPool) Create() (net.Conn, error) { +func (n *TLSPool) Create() (*connAndStore, error) { tlsConfig, dialContext, err := n.boot.get() if err != nil { return nil, err @@ -82,11 +91,14 @@ func (n *TLSPool) Create() (net.Conn, error) { return nil, fmt.Errorf("connecting to %s: %w", tlsConfig.ServerName, err) } - return conn, nil + // initialize the store + store := make(map[uint16]*dns.Msg) + + return &connAndStore{conn: conn, store: store}, nil } // Put returns the connection to the pool. -func (n *TLSPool) Put(c net.Conn) { +func (n *TLSPool) Put(c *connAndStore) { if c == nil { return } diff --git a/upstream/upstream_pool_test.go b/upstream/upstream_pool_test.go index 85b73ad3b..d2ca0dfa9 100644 --- a/upstream/upstream_pool_test.go +++ b/upstream/upstream_pool_test.go @@ -31,9 +31,9 @@ func TestTLSPoolReconnect(t *testing.T) { // Now let's close the pooled connection and return it back to the pool. p := u.(*dnsOverTLS) - conn, _ := p.pool.Get() - conn.Close() - p.pool.Put(conn) + connAndStore, _ := p.pool.Get() + connAndStore.conn.Close() + p.pool.Put(connAndStore) // Send the second test message. req = createTestMessage() @@ -72,42 +72,42 @@ func TestTLSPoolDeadLine(t *testing.T) { p := u.(*dnsOverTLS) // Now let's get connection from the pool and use it - conn, err := p.pool.Get() + connAndStore, err := p.pool.Get() if err != nil { t.Fatalf("couldn't get connection from pool: %s", err) } - response, err = p.exchangeConn(conn, req) + response, err = p.exchangeConn(connAndStore, req) if err != nil { t.Fatalf("first DNS message failed: %s", err) } requireResponse(t, req, response) // Update connection's deadLine and put it back to the pool - err = conn.SetDeadline(time.Now().Add(10 * time.Hour)) + err = connAndStore.conn.SetDeadline(time.Now().Add(10 * time.Hour)) if err != nil { t.Fatalf("can't set new deadLine for connection. Looks like it's already closed: %s", err) } - p.pool.Put(conn) + p.pool.Put(connAndStore) // Get connection from the pool and reuse it - conn, err = p.pool.Get() + connAndStore, err = p.pool.Get() if err != nil { t.Fatalf("couldn't get connection from pool: %s", err) } - response, err = p.exchangeConn(conn, req) + response, err = p.exchangeConn(connAndStore, req) if err != nil { t.Fatalf("first DNS message failed: %s", err) } requireResponse(t, req, response) // Set connection's deadLine to the past and try to reuse it - err = conn.SetDeadline(time.Now().Add(-10 * time.Hour)) + err = connAndStore.conn.SetDeadline(time.Now().Add(-10 * time.Hour)) if err != nil { t.Fatalf("can't set new deadLine for connection. Looks like it's already closed: %s", err) } // Connection with expired deadLine can't be used - response, err = p.exchangeConn(conn, req) + response, err = p.exchangeConn(connAndStore, req) if err == nil { t.Fatalf("this connection should be already closed, got response %s", response) } From 51583cb0b7752e5526d3936c7a52e2cb7984d5ea Mon Sep 17 00:00:00 2001 From: 42SK <42SK@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:31:52 +0200 Subject: [PATCH 3/3] upstream: DoT: Match QNAME, QCLASS, and QTYPE fields from response to query In order to deal with response reordering, RFC 7766 [1] requires that we match the QNAME, QCLASS, and QTYPE fields in the response to the query. [1] https://www.rfc-editor.org/rfc/rfc7766#section-7 --- upstream/upstream_dot.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/upstream/upstream_dot.go b/upstream/upstream_dot.go index 7b1ebd4a9..42a0e3543 100644 --- a/upstream/upstream_dot.go +++ b/upstream/upstream_dot.go @@ -140,5 +140,22 @@ func (p *dnsOverTLS) exchangeConn(connAndStore *connAndStore, m *dns.Msg) (reply } } + // Match response QNAME, QCLASS, and QTYPE to query according to RFC 7766 + // (https://www.rfc-editor.org/rfc/rfc7766#section-7) + if len(reply.Question) != 0 && len(m.Question) != 0 { + if reply.Question[0].Name != m.Question[0].Name { + err = fmt.Errorf("Query and response QNAME do not match; received %s, expected %s", reply.Question[0].Name, m.Question[0].Name) + return reply, err + } + if reply.Question[0].Qtype != m.Question[0].Qtype { + err = fmt.Errorf("Query and response QTYPE do not match; received %d, expected %d", reply.Question[0].Qtype, m.Question[0].Qtype) + return reply, err + } + if reply.Question[0].Qclass != m.Question[0].Qclass { + err = fmt.Errorf("Query and response QCLASS do not match; received %d, expected %d", reply.Question[0].Qclass, m.Question[0].Qclass) + return reply, err + } + } + return reply, err }