From c510cd92cf69df3cdc224b930490274ce1b6e2ac Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 16 Apr 2022 07:59:23 -0700 Subject: [PATCH] net: permit use of Resolver.PreferGo, netgo on Windows and Plan 9 Fixes #33097 Change-Id: I2e55c7c113683814521f2068e0922b63c62ea5d8 Reviewed-on: https://go-review.googlesource.com/c/go/+/400654 Run-TryBot: Brad Fitzpatrick Reviewed-by: Damien Neil Auto-Submit: Brad Fitzpatrick Reviewed-by: Ian Lance Taylor TryBot-Result: Gopher Robot Reviewed-by: Ian Lance Taylor --- src/net/addrselect.go | 2 - src/net/cgo_stub.go | 2 - src/net/conf.go | 36 +++- src/net/dnsclient_unix.go | 24 ++- src/net/dnsconfig.go | 43 ++++ src/net/dnsconfig_unix.go | 36 +--- src/net/dnsconfig_windows.go | 58 ++++++ src/net/lookup.go | 226 +++++++++++++++++++++ src/net/lookup_plan9.go | 52 ++++- src/net/lookup_unix.go | 205 +------------------ src/net/lookup_windows.go | 93 ++++++--- src/net/net.go | 9 +- src/net/net_fake.go | 6 + src/net/netgo.go | 9 + src/net/nss.go | 2 - src/net/resolverdialfunc_test.go | 328 +++++++++++++++++++++++++++++++ 16 files changed, 838 insertions(+), 293 deletions(-) create mode 100644 src/net/dnsconfig.go create mode 100644 src/net/dnsconfig_windows.go create mode 100644 src/net/netgo.go create mode 100644 src/net/resolverdialfunc_test.go diff --git a/src/net/addrselect.go b/src/net/addrselect.go index 8accdb89e14f4..59380b94868fa 100644 --- a/src/net/addrselect.go +++ b/src/net/addrselect.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix - // Minimal RFC 6724 address selection. package net diff --git a/src/net/cgo_stub.go b/src/net/cgo_stub.go index cc84ca47aed0b..298d829f6fa6d 100644 --- a/src/net/cgo_stub.go +++ b/src/net/cgo_stub.go @@ -8,8 +8,6 @@ package net import "context" -func init() { netGo = true } - type addrinfoErrno int func (eai addrinfoErrno) Error() string { return "" } diff --git a/src/net/conf.go b/src/net/conf.go index 9d4752173e1af..b08bbc7d7a1c8 100644 --- a/src/net/conf.go +++ b/src/net/conf.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix +//go:build !js package net @@ -21,7 +21,7 @@ type conf struct { forceCgoLookupHost bool netGo bool // go DNS resolution forced - netCgo bool // cgo DNS resolution forced + netCgo bool // non-go DNS resolution forced (cgo, or win32) // machine has an /etc/mdns.allow file hasMDNSAllow bool @@ -49,9 +49,23 @@ func initConfVal() { confVal.dnsDebugLevel = debugLevel confVal.netGo = netGo || dnsMode == "go" confVal.netCgo = netCgo || dnsMode == "cgo" + if !confVal.netGo && !confVal.netCgo && (runtime.GOOS == "windows" || runtime.GOOS == "plan9") { + // Neither of these platforms actually use cgo. + // + // The meaning of "cgo" mode in the net package is + // really "the native OS way", which for libc meant + // cgo on the original platforms that motivated + // PreferGo support before Windows and Plan9 got support, + // at which time the GODEBUG=netdns=go and GODEBUG=netdns=cgo + // names were already kinda locked in. + confVal.netCgo = true + } if confVal.dnsDebugLevel > 0 { defer func() { + if confVal.dnsDebugLevel > 1 { + println("go package net: confVal.netCgo =", confVal.netCgo, " netGo =", confVal.netGo) + } switch { case confVal.netGo: if netGo { @@ -75,6 +89,10 @@ func initConfVal() { return } + if runtime.GOOS == "windows" || runtime.GOOS == "plan9" { + return + } + // If any environment-specified resolver options are specified, // force cgo. Note that LOCALDOMAIN can change behavior merely // by being specified with the empty string. @@ -129,7 +147,19 @@ func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrde } fallbackOrder := hostLookupCgo if c.netGo || r.preferGo() { - fallbackOrder = hostLookupFilesDNS + switch c.goos { + case "windows": + // TODO(bradfitz): implement files-based + // lookup on Windows too? I guess /etc/hosts + // kinda exists on Windows. But for now, only + // do DNS. + fallbackOrder = hostLookupDNS + default: + fallbackOrder = hostLookupFilesDNS + } + } + if c.goos == "windows" || c.goos == "plan9" { + return fallbackOrder } if c.forceCgoLookupHost || c.resolv.unknownOpt || c.goos == "android" { return fallbackOrder diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index 15dbc25830276..9ade767ace6dd 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix +//go:build !js // DNS client: see RFC 1035. // Has to be linked into package net for Dial. @@ -20,6 +20,7 @@ import ( "internal/itoa" "io" "os" + "runtime" "sync" "time" @@ -378,12 +379,21 @@ func (conf *resolverConfig) tryUpdate(name string) { } conf.lastChecked = now - var mtime time.Time - if fi, err := os.Stat(name); err == nil { - mtime = fi.ModTime() - } - if mtime.Equal(conf.dnsConfig.mtime) { - return + switch runtime.GOOS { + case "windows": + // There's no file on disk, so don't bother checking + // and failing. + // + // The Windows implementation of dnsReadConfig (called + // below) ignores the name. + default: + var mtime time.Time + if fi, err := os.Stat(name); err == nil { + mtime = fi.ModTime() + } + if mtime.Equal(conf.dnsConfig.mtime) { + return + } } dnsConf := dnsReadConfig(name) diff --git a/src/net/dnsconfig.go b/src/net/dnsconfig.go new file mode 100644 index 0000000000000..091b5483013f5 --- /dev/null +++ b/src/net/dnsconfig.go @@ -0,0 +1,43 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "sync/atomic" + "time" +) + +var ( + defaultNS = []string{"127.0.0.1:53", "[::1]:53"} + getHostname = os.Hostname // variable for testing +) + +type dnsConfig struct { + servers []string // server addresses (in host:port form) to use + search []string // rooted suffixes to append to local name + ndots int // number of dots in name to trigger absolute lookup + timeout time.Duration // wait before giving up on a query, including retries + attempts int // lost packets before giving up on server + rotate bool // round robin among servers + unknownOpt bool // anything unknown was encountered + lookup []string // OpenBSD top-level database "lookup" order + err error // any error that occurs during open of resolv.conf + mtime time.Time // time of resolv.conf modification + soffset uint32 // used by serverOffset + singleRequest bool // use sequential A and AAAA queries instead of parallel queries + useTCP bool // force usage of TCP for DNS resolutions +} + +// serverOffset returns an offset that can be used to determine +// indices of servers in c.servers when making queries. +// When the rotate option is enabled, this offset increases. +// Otherwise it is always 0. +func (c *dnsConfig) serverOffset() uint32 { + if c.rotate { + return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start + } + return 0 +} diff --git a/src/net/dnsconfig_unix.go b/src/net/dnsconfig_unix.go index 7552bc51e653a..94cd09ec71066 100644 --- a/src/net/dnsconfig_unix.go +++ b/src/net/dnsconfig_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix +//go:build !js && !windows // Read system DNS config from /etc/resolv.conf @@ -10,32 +10,9 @@ package net import ( "internal/bytealg" - "os" - "sync/atomic" "time" ) -var ( - defaultNS = []string{"127.0.0.1:53", "[::1]:53"} - getHostname = os.Hostname // variable for testing -) - -type dnsConfig struct { - servers []string // server addresses (in host:port form) to use - search []string // rooted suffixes to append to local name - ndots int // number of dots in name to trigger absolute lookup - timeout time.Duration // wait before giving up on a query, including retries - attempts int // lost packets before giving up on server - rotate bool // round robin among servers - unknownOpt bool // anything unknown was encountered - lookup []string // OpenBSD top-level database "lookup" order - err error // any error that occurs during open of resolv.conf - mtime time.Time // time of resolv.conf modification - soffset uint32 // used by serverOffset - singleRequest bool // use sequential A and AAAA queries instead of parallel queries - useTCP bool // force usage of TCP for DNS resolutions -} - // See resolv.conf(5) on a Linux machine. func dnsReadConfig(filename string) *dnsConfig { conf := &dnsConfig{ @@ -156,17 +133,6 @@ func dnsReadConfig(filename string) *dnsConfig { return conf } -// serverOffset returns an offset that can be used to determine -// indices of servers in c.servers when making queries. -// When the rotate option is enabled, this offset increases. -// Otherwise it is always 0. -func (c *dnsConfig) serverOffset() uint32 { - if c.rotate { - return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start - } - return 0 -} - func dnsDefaultSearch() []string { hn, err := getHostname() if err != nil { diff --git a/src/net/dnsconfig_windows.go b/src/net/dnsconfig_windows.go new file mode 100644 index 0000000000000..5d640da1d740c --- /dev/null +++ b/src/net/dnsconfig_windows.go @@ -0,0 +1,58 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "syscall" + "time" +) + +func dnsReadConfig(ignoredFilename string) (conf *dnsConfig) { + conf = &dnsConfig{ + ndots: 1, + timeout: 5 * time.Second, + attempts: 2, + } + defer func() { + if len(conf.servers) == 0 { + conf.servers = defaultNS + } + }() + aas, err := adapterAddresses() + if err != nil { + return + } + // TODO(bradfitz): this just collects all the DNS servers on all + // the interfaces in some random order. It should order it by + // default route, or only use the default route(s) instead. + // In practice, however, it mostly works. + for _, aa := range aas { + for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next { + sa, err := dns.Address.Sockaddr.Sockaddr() + if err != nil { + continue + } + var ip IP + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + ip = IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) + case *syscall.SockaddrInet6: + ip = make(IP, IPv6len) + copy(ip, sa.Addr[:]) + if ip[0] == 0xfe && ip[1] == 0xc0 { + // Ignore these fec0/10 ones. Windows seems to + // populate them as defaults on its misc rando + // interfaces. + continue + } + default: + // Unexpected type. + continue + } + conf.servers = append(conf.servers, JoinHostPort(ip.String(), "53")) + } + } + return conf +} diff --git a/src/net/lookup.go b/src/net/lookup.go index 6fa90f354d466..7f3d20126c902 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -10,6 +10,8 @@ import ( "internal/singleflight" "net/netip" "sync" + + "golang.org/x/net/dns/dnsmessage" ) // protocols contains minimal mappings between internet protocol @@ -665,3 +667,227 @@ func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error // method receives DNS records which contain invalid DNS names. This may be returned alongside // results which have had the malformed records filtered out. var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names" + +// dial makes a new connection to the provided server (which must be +// an IP address) with the provided network type, using either r.Dial +// (if both r and r.Dial are non-nil) or else Dialer.DialContext. +func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) { + // Calling Dial here is scary -- we have to be sure not to + // dial a name that will require a DNS lookup, or Dial will + // call back here to translate it. The DNS config parser has + // already checked that all the cfg.servers are IP + // addresses, which Dial will use without a DNS lookup. + var c Conn + var err error + if r != nil && r.Dial != nil { + c, err = r.Dial(ctx, network, server) + } else { + var d Dialer + c, err = d.DialContext(ctx, network, server) + } + if err != nil { + return nil, mapErr(err) + } + return c, nil +} + +// goLookupSRV returns the SRV records for a target name, built either +// from its component service ("sip"), protocol ("tcp"), and name +// ("example.com."), or from name directly (if service and proto are +// both empty). +// +// In either case, the returned target name ("_sip._tcp.example.com.") +// is also returned on success. +// +// The records are sorted by weight. +func (r *Resolver) goLookupSRV(ctx context.Context, service, proto, name string) (target string, srvs []*SRV, err error) { + if service == "" && proto == "" { + target = name + } else { + target = "_" + service + "._" + proto + "." + name + } + p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV) + if err != nil { + return "", nil, err + } + var cname dnsmessage.Name + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return "", nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + if h.Type != dnsmessage.TypeSRV { + if err := p.SkipAnswer(); err != nil { + return "", nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + continue + } + if cname.Length == 0 && h.Name.Length != 0 { + cname = h.Name + } + srv, err := p.SRVResource() + if err != nil { + return "", nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight}) + } + byPriorityWeight(srvs).sort() + return cname.String(), srvs, nil +} + +// goLookupMX returns the MX records for name. +func (r *Resolver) goLookupMX(ctx context.Context, name string) ([]*MX, error) { + p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX) + if err != nil { + return nil, err + } + var mxs []*MX + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + if h.Type != dnsmessage.TypeMX { + if err := p.SkipAnswer(); err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + continue + } + mx, err := p.MXResource() + if err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref}) + + } + byPref(mxs).sort() + return mxs, nil +} + +// goLookupNS returns the NS records for name. +func (r *Resolver) goLookupNS(ctx context.Context, name string) ([]*NS, error) { + p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS) + if err != nil { + return nil, err + } + var nss []*NS + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + if h.Type != dnsmessage.TypeNS { + if err := p.SkipAnswer(); err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + continue + } + ns, err := p.NSResource() + if err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + nss = append(nss, &NS{Host: ns.NS.String()}) + } + return nss, nil +} + +// goLookupTXT returns the TXT records from name. +func (r *Resolver) goLookupTXT(ctx context.Context, name string) ([]string, error) { + p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT) + if err != nil { + return nil, err + } + var txts []string + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + if h.Type != dnsmessage.TypeTXT { + if err := p.SkipAnswer(); err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + continue + } + txt, err := p.TXTResource() + if err != nil { + return nil, &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + // Multiple strings in one TXT record need to be + // concatenated without separator to be consistent + // with previous Go resolver. + n := 0 + for _, s := range txt.TXT { + n += len(s) + } + txtJoin := make([]byte, 0, n) + for _, s := range txt.TXT { + txtJoin = append(txtJoin, s...) + } + if len(txts) == 0 { + txts = make([]string, 0, 1) + } + txts = append(txts, string(txtJoin)) + } + return txts, nil +} diff --git a/src/net/lookup_plan9.go b/src/net/lookup_plan9.go index d43a03b778d47..445b1294e352d 100644 --- a/src/net/lookup_plan9.go +++ b/src/net/lookup_plan9.go @@ -179,7 +179,27 @@ loop: return } -func (r *Resolver) lookupIP(ctx context.Context, _, host string) (addrs []IPAddr, err error) { +// preferGoOverPlan9 reports whether the resolver should use the +// "PreferGo" implementation rather than asking plan9 services +// for the answers. +func (r *Resolver) preferGoOverPlan9() bool { + conf := systemConf() + order := conf.hostLookupOrder(r, "") // name is unused + + // TODO(bradfitz): for now we only permit use of the PreferGo + // implementation when there's a non-nil Resolver with a + // non-nil Dialer. This is a sign that they the code is trying + // to use their DNS-speaking net.Conn (such as an in-memory + // DNS cache) and they don't want to actually hit the network. + // Once we add support for looking the default DNS servers + // from plan9, though, then we can relax this. + return order != hostLookupCgo && r != nil && r.Dial != nil +} + +func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) { + if r.preferGoOverPlan9() { + return r.goLookupIP(ctx, network, host) + } lits, err := r.lookupHost(ctx, host) if err != nil { return @@ -223,7 +243,10 @@ func (*Resolver) lookupPort(ctx context.Context, network, service string) (port return 0, unknownPortError } -func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) { +func (r *Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) { + if r.preferGoOverPlan9() { + return r.goLookupCNAME(ctx, name) + } lines, err := queryDNS(ctx, name, "cname") if err != nil { if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") { @@ -240,7 +263,10 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, er return "", errors.New("bad response from ndb/dns") } -func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { +func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { + if r.preferGoOverPlan9() { + return r.goLookupSRV(ctx, service, proto, name) + } var target string if service == "" && proto == "" { target = name @@ -269,7 +295,10 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cn return } -func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) { +func (r *Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) { + if r.preferGoOverPlan9() { + return r.goLookupMX(ctx, name) + } lines, err := queryDNS(ctx, name, "mx") if err != nil { return @@ -287,7 +316,10 @@ func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error return } -func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) { +func (r *Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) { + if r.preferGoOverPlan9() { + return r.goLookupNS(ctx, name) + } lines, err := queryDNS(ctx, name, "ns") if err != nil { return @@ -302,7 +334,10 @@ func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error return } -func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) { +func (r *Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) { + if r.preferGoOverPlan9() { + return r.goLookupTXT(ctx, name) + } lines, err := queryDNS(ctx, name, "txt") if err != nil { return @@ -315,7 +350,10 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err return } -func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) { +func (r *Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) { + if r.preferGoOverPlan9() { + return r.goLookupPTR(ctx, addr) + } arpa, err := reverseaddr(addr) if err != nil { return diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go index ad4164d86517a..4b885e938a06f 100644 --- a/src/net/lookup_unix.go +++ b/src/net/lookup_unix.go @@ -11,8 +11,6 @@ import ( "internal/bytealg" "sync" "syscall" - - "golang.org/x/net/dns/dnsmessage" ) var onceReadProtocols sync.Once @@ -55,26 +53,6 @@ func lookupProtocol(_ context.Context, name string) (int, error) { return lookupProtocolMap(name) } -func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) { - // Calling Dial here is scary -- we have to be sure not to - // dial a name that will require a DNS lookup, or Dial will - // call back here to translate it. The DNS config parser has - // already checked that all the cfg.servers are IP - // addresses, which Dial will use without a DNS lookup. - var c Conn - var err error - if r != nil && r.Dial != nil { - c, err = r.Dial(ctx, network, server) - } else { - var d Dialer - c, err = d.DialContext(ctx, network, server) - } - if err != nil { - return nil, mapErr(err) - } - return c, nil -} - func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { order := systemConf().hostLookupOrder(r, host) if !r.preferGo() && order == hostLookupCgo { @@ -129,194 +107,19 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) } func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { - var target string - if service == "" && proto == "" { - target = name - } else { - target = "_" + service + "._" + proto + "." + name - } - p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV) - if err != nil { - return "", nil, err - } - var srvs []*SRV - var cname dnsmessage.Name - for { - h, err := p.AnswerHeader() - if err == dnsmessage.ErrSectionDone { - break - } - if err != nil { - return "", nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - if h.Type != dnsmessage.TypeSRV { - if err := p.SkipAnswer(); err != nil { - return "", nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - continue - } - if cname.Length == 0 && h.Name.Length != 0 { - cname = h.Name - } - srv, err := p.SRVResource() - if err != nil { - return "", nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight}) - } - byPriorityWeight(srvs).sort() - return cname.String(), srvs, nil + return r.goLookupSRV(ctx, service, proto, name) } func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { - p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX) - if err != nil { - return nil, err - } - var mxs []*MX - for { - h, err := p.AnswerHeader() - if err == dnsmessage.ErrSectionDone { - break - } - if err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - if h.Type != dnsmessage.TypeMX { - if err := p.SkipAnswer(); err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - continue - } - mx, err := p.MXResource() - if err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref}) - - } - byPref(mxs).sort() - return mxs, nil + return r.goLookupMX(ctx, name) } func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { - p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS) - if err != nil { - return nil, err - } - var nss []*NS - for { - h, err := p.AnswerHeader() - if err == dnsmessage.ErrSectionDone { - break - } - if err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - if h.Type != dnsmessage.TypeNS { - if err := p.SkipAnswer(); err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - continue - } - ns, err := p.NSResource() - if err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - nss = append(nss, &NS{Host: ns.NS.String()}) - } - return nss, nil + return r.goLookupNS(ctx, name) } func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { - p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT) - if err != nil { - return nil, err - } - var txts []string - for { - h, err := p.AnswerHeader() - if err == dnsmessage.ErrSectionDone { - break - } - if err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - if h.Type != dnsmessage.TypeTXT { - if err := p.SkipAnswer(); err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - continue - } - txt, err := p.TXTResource() - if err != nil { - return nil, &DNSError{ - Err: "cannot unmarshal DNS message", - Name: name, - Server: server, - } - } - // Multiple strings in one TXT record need to be - // concatenated without separator to be consistent - // with previous Go resolver. - n := 0 - for _, s := range txt.TXT { - n += len(s) - } - txtJoin := make([]byte, 0, n) - for _, s := range txt.TXT { - txtJoin = append(txtJoin, s...) - } - if len(txts) == 0 { - txts = make([]string, 0, 1) - } - txts = append(txts, string(txtJoin)) - } - return txts, nil + return r.goLookupTXT(ctx, name) } func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) { diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go index 27e5f86910e0f..051f47da392c3 100644 --- a/src/net/lookup_windows.go +++ b/src/net/lookup_windows.go @@ -82,7 +82,19 @@ func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error return addrs, nil } +// preferGoOverWindows reports whether the resolver should use the +// pure Go implementation rather than making win32 calls to ask the +// kernel for its answer. +func (r *Resolver) preferGoOverWindows() bool { + conf := systemConf() + order := conf.hostLookupOrder(r, "") // name is unused + return order != hostLookupCgo +} + func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) { + if r.preferGoOverWindows() { + return r.goLookupIP(ctx, network, name) + } // TODO(bradfitz,brainman): use ctx more. See TODO below. var family int32 = syscall.AF_UNSPEC @@ -169,7 +181,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr } func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) { - if r.preferGo() { + if r.preferGoOverWindows() { return lookupPortMap(network, service) } @@ -217,12 +229,15 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service} } -func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { +func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { + if r.preferGoOverWindows() { + return r.goLookupCNAME(ctx, name) + } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() - var r *syscall.DNSRecord - e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) + var rec *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &rec, nil) // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS { // if there are no aliases, the canonical name is the input name @@ -231,14 +246,17 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { if e != nil { return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name} } - defer syscall.DnsRecordListFree(r, 1) + defer syscall.DnsRecordListFree(rec, 1) - resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r) + resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), rec) cname := windows.UTF16PtrToString(resolved) return absDomainName(cname), nil } -func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { +func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { + if r.preferGoOverWindows() { + return r.goLookupSRV(ctx, service, proto, name) + } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() @@ -248,15 +266,15 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st } else { target = "_" + service + "._" + proto + "." + name } - var r *syscall.DNSRecord - e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil) + var rec *syscall.DNSRecord + e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil) if e != nil { return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target} } - defer syscall.DnsRecordListFree(r, 1) + defer syscall.DnsRecordListFree(rec, 1) srvs := make([]*SRV, 0, 10) - for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) { + for _, p := range validRecs(rec, syscall.DNS_TYPE_SRV, target) { v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight}) } @@ -264,19 +282,22 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st return absDomainName(target), srvs, nil } -func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { +func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { + if r.preferGoOverWindows() { + return r.goLookupMX(ctx, name) + } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() - var r *syscall.DNSRecord - e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil) + var rec *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil) if e != nil { return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} } - defer syscall.DnsRecordListFree(r, 1) + defer syscall.DnsRecordListFree(rec, 1) mxs := make([]*MX, 0, 10) - for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) { + for _, p := range validRecs(rec, syscall.DNS_TYPE_MX, name) { v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0])) mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference}) } @@ -284,38 +305,44 @@ func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { return mxs, nil } -func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { +func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { + if r.preferGoOverWindows() { + return r.goLookupNS(ctx, name) + } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() - var r *syscall.DNSRecord - e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil) + var rec *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil) if e != nil { return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} } - defer syscall.DnsRecordListFree(r, 1) + defer syscall.DnsRecordListFree(rec, 1) nss := make([]*NS, 0, 10) - for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) { + for _, p := range validRecs(rec, syscall.DNS_TYPE_NS, name) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))}) } return nss, nil } -func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { +func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { + if r.preferGoOverWindows() { + return r.lookupTXT(ctx, name) + } // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() - var r *syscall.DNSRecord - e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil) + var rec *syscall.DNSRecord + e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil) if e != nil { return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} } - defer syscall.DnsRecordListFree(r, 1) + defer syscall.DnsRecordListFree(rec, 1) txts := make([]string, 0, 10) - for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) { + for _, p := range validRecs(rec, syscall.DNS_TYPE_TEXT, name) { d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0])) s := "" for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount:d.StringCount] { @@ -326,7 +353,11 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { return txts, nil } -func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) { +func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) { + if r.preferGoOverWindows() { + return r.goLookupPTR(ctx, addr) + } + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() @@ -334,15 +365,15 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) if err != nil { return nil, err } - var r *syscall.DNSRecord - e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil) + var rec *syscall.DNSRecord + e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil) if e != nil { return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr} } - defer syscall.DnsRecordListFree(r, 1) + defer syscall.DnsRecordListFree(rec, 1) ptrs := make([]string, 0, 10) - for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) { + for _, p := range validRecs(rec, syscall.DNS_TYPE_PTR, arpa) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host))) } diff --git a/src/net/net.go b/src/net/net.go index 7a97b9dcfd2a6..ff56c31c56343 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -61,7 +61,7 @@ The resolver decision can be overridden by setting the netdns value of the GODEBUG environment variable (see package runtime) to go or cgo, as in: export GODEBUG=netdns=go # force pure Go resolver - export GODEBUG=netdns=cgo # force cgo resolver + export GODEBUG=netdns=cgo # force native resolver (cgo, win32) The decision can also be forced while building the Go source tree by setting the netgo or netcgo build tag. @@ -73,7 +73,8 @@ join the two settings by a plus sign, as in GODEBUG=netdns=go+1. On Plan 9, the resolver always accesses /net/cs and /net/dns. -On Windows, the resolver always uses C library functions, such as GetAddrInfo and DnsQuery. +On Windows, in Go 1.18.x and earlier, the resolver always used C +library functions, such as GetAddrInfo and DnsQuery. */ package net @@ -588,7 +589,9 @@ func (e InvalidAddrError) Temporary() bool { return false } // // TODO(iant): We could consider changing this to os.ErrDeadlineExceeded // in the future, if we make -// errors.Is(os.ErrDeadlineExceeded, context.DeadlineExceeded) +// +// errors.Is(os.ErrDeadlineExceeded, context.DeadlineExceeded) +// // return true. var errTimeout error = &timeoutError{} diff --git a/src/net/net_fake.go b/src/net/net_fake.go index ee5644c67f087..6d07d6297a4c5 100644 --- a/src/net/net_fake.go +++ b/src/net/net_fake.go @@ -16,6 +16,8 @@ import ( "sync" "syscall" "time" + + "golang.org/x/net/dns/dnsmessage" ) var listenersMu sync.Mutex @@ -314,3 +316,7 @@ func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob func (fd *netFD) dup() (f *os.File, err error) { return nil, syscall.ENOSYS } + +func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { + panic("unreachable") +} diff --git a/src/net/netgo.go b/src/net/netgo.go new file mode 100644 index 0000000000000..f91c91b614858 --- /dev/null +++ b/src/net/netgo.go @@ -0,0 +1,9 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build netgo + +package net + +func init() { netGo = true } diff --git a/src/net/nss.go b/src/net/nss.go index 5df71dc268fad..c4c608fb61864 100644 --- a/src/net/nss.go +++ b/src/net/nss.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build unix - package net import ( diff --git a/src/net/resolverdialfunc_test.go b/src/net/resolverdialfunc_test.go new file mode 100644 index 0000000000000..034c636eb6937 --- /dev/null +++ b/src/net/resolverdialfunc_test.go @@ -0,0 +1,328 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !js + +// Test that Resolver.Dial can be a func returning an in-memory net.Conn +// speaking DNS. + +package net + +import ( + "bytes" + "context" + "errors" + "fmt" + "reflect" + "sort" + "testing" + "time" + + "golang.org/x/net/dns/dnsmessage" +) + +func TestResolverDialFunc(t *testing.T) { + r := &Resolver{ + PreferGo: true, + Dial: newResolverDialFunc(&resolverDialHandler{ + StartDial: func(network, address string) error { + t.Logf("StartDial(%q, %q) ...", network, address) + return nil + }, + Question: func(h dnsmessage.Header, q dnsmessage.Question) { + t.Logf("Header: %+v for %q (type=%v, class=%v)", h, + q.Name.String(), q.Type, q.Class) + }, + // TODO: add test without HandleA* hooks specified at all, that Go + // doesn't issue retries; map to something terminal. + HandleA: func(w AWriter, name string) error { + w.AddIP([4]byte{1, 2, 3, 4}) + w.AddIP([4]byte{5, 6, 7, 8}) + return nil + }, + HandleAAAA: func(w AAAAWriter, name string) error { + w.AddIP([16]byte{1: 1, 15: 15}) + w.AddIP([16]byte{2: 2, 14: 14}) + return nil + }, + HandleSRV: func(w SRVWriter, name string) error { + w.AddSRV(1, 2, 80, "foo.bar.") + w.AddSRV(2, 3, 81, "bar.baz.") + return nil + }, + }), + } + ctx := context.Background() + const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld." + + t.Run("LookupIP", func(t *testing.T) { + ips, err := r.LookupIP(ctx, "ip", fakeDomain) + if err != nil { + t.Fatal(err) + } + if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) { + t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want) + } + }) + + t.Run("LookupSRV", func(t *testing.T) { + _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain) + if err != nil { + t.Fatal(err) + } + want := []*SRV{ + { + Target: "foo.bar.", + Port: 80, + Priority: 1, + Weight: 2, + }, + { + Target: "bar.baz.", + Port: 81, + Priority: 2, + Weight: 3, + }, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("wrong result. got:") + for _, r := range got { + t.Logf(" - %+v", r) + } + } + }) +} + +func sortedIPStrings(ips []IP) []string { + ret := make([]string, len(ips)) + for i, ip := range ips { + ret[i] = ip.String() + } + sort.Strings(ret) + return ret +} + +func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) { + return func(ctx context.Context, network, address string) (Conn, error) { + a := &resolverFuncConn{ + h: h, + network: network, + address: address, + ttl: 10, // 10 second default if unset + } + if h.StartDial != nil { + if err := h.StartDial(network, address); err != nil { + return nil, err + } + } + return a, nil + } +} + +type resolverDialHandler struct { + // StartDial, if non-nil, is called when Go first calls Resolver.Dial. + // Any error returned aborts the dial and is returned unwrapped. + StartDial func(network, address string) error + + Question func(dnsmessage.Header, dnsmessage.Question) + + // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2). + // A nil error means success. + HandleA func(w AWriter, name string) error + HandleAAAA func(w AAAAWriter, name string) error + HandleSRV func(w SRVWriter, name string) error +} + +type ResponseWriter struct{ a *resolverFuncConn } + +func (w ResponseWriter) header() dnsmessage.ResourceHeader { + q := w.a.q + return dnsmessage.ResourceHeader{ + Name: q.Name, + Type: q.Type, + Class: q.Class, + TTL: w.a.ttl, + } +} + +// SetTTL sets the TTL for subsequent written resources. +// Once a resource has been written, SetTTL calls are no-ops. +// That is, it can only be called at most once, before anything +// else is written. +func (w ResponseWriter) SetTTL(seconds uint32) { + // ... intention is last one wins and mutates all previously + // written records too, but that's a little annoying. + // But it's also annoying if the requirement is it needs to be set + // last. + // And it's also annoying if it's possible for users to set + // different TTLs per Answer. + if w.a.wrote { + return + } + w.a.ttl = seconds + +} + +type AWriter struct{ ResponseWriter } + +func (w AWriter) AddIP(v4 [4]byte) { + w.a.wrote = true + err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4}) + if err != nil { + panic(err) + } +} + +type AAAAWriter struct{ ResponseWriter } + +func (w AAAAWriter) AddIP(v6 [16]byte) { + w.a.wrote = true + err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6}) + if err != nil { + panic(err) + } +} + +type SRVWriter struct{ ResponseWriter } + +// AddSRV adds a SRV record. The target name must end in a period and +// be 63 bytes or fewer. +func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error { + targetName, err := dnsmessage.NewName(target) + if err != nil { + return err + } + w.a.wrote = true + err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{ + Priority: priority, + Weight: weight, + Port: port, + Target: targetName, + }) + if err != nil { + panic(err) // internal fault, not user + } + return nil +} + +var ( + ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN + ErrRefused = errors.New("refused") // maps to RCode5, REFUSED +) + +type resolverFuncConn struct { + h *resolverDialHandler + ctx context.Context + network string + address string + builder *dnsmessage.Builder + q dnsmessage.Question + ttl uint32 + wrote bool + + rbuf bytes.Buffer +} + +func (*resolverFuncConn) Close() error { return nil } +func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} } +func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} } +func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil } +func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil } +func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil } + +func (a *resolverFuncConn) Read(p []byte) (n int, err error) { + return a.rbuf.Read(p) +} + +func (a *resolverFuncConn) Write(packet []byte) (n int, err error) { + if len(packet) < 2 { + return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet)) + } + reqLen := int(packet[0])<<8 | int(packet[1]) + req := packet[2:] + if len(req) != reqLen { + return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req)) + } + + var parser dnsmessage.Parser + h, err := parser.Start(req) + if err != nil { + // TODO: hook + return 0, err + } + q, err := parser.Question() + hadQ := (err == nil) + if err == nil && a.h.Question != nil { + a.h.Question(h, q) + } + if err != nil && err != dnsmessage.ErrSectionDone { + return 0, err + } + + resh := h + resh.Response = true + resh.Authoritative = true + if hadQ { + resh.RCode = dnsmessage.RCodeSuccess + } else { + resh.RCode = dnsmessage.RCodeNotImplemented + } + a.rbuf.Grow(514) + a.rbuf.WriteByte('X') // reserved header for beu16 length + a.rbuf.WriteByte('Y') // reserved header for beu16 length + builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh) + a.builder = &builder + if hadQ { + a.q = q + a.builder.StartQuestions() + err := a.builder.Question(q) + if err != nil { + return 0, fmt.Errorf("Question: %w", err) + } + a.builder.StartAnswers() + switch q.Type { + case dnsmessage.TypeA: + if a.h.HandleA != nil { + resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String())) + } + case dnsmessage.TypeAAAA: + if a.h.HandleAAAA != nil { + resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String())) + } + case dnsmessage.TypeSRV: + if a.h.HandleSRV != nil { + resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String())) + } + } + } + tcpRes, err := builder.Finish() + if err != nil { + return 0, fmt.Errorf("Finish: %w", err) + } + + n = len(tcpRes) - 2 + tcpRes[0] = byte(n >> 8) + tcpRes[1] = byte(n) + a.rbuf.Write(tcpRes[2:]) + + return len(packet), nil +} + +type someaddr struct{} + +func (someaddr) Network() string { return "unused" } +func (someaddr) String() string { return "unused-someaddr" } + +func mapRCode(err error) dnsmessage.RCode { + switch err { + case nil: + return dnsmessage.RCodeSuccess + case ErrNotExist: + return dnsmessage.RCodeNameError + case ErrRefused: + return dnsmessage.RCodeRefused + default: + return dnsmessage.RCodeServerFailure + } +}