Skip to content

Commit

Permalink
feat(dns): Support DNS queryStrategy config per NameServer.
Browse files Browse the repository at this point in the history
  • Loading branch information
cty123 authored and yuhan6665 committed Sep 22, 2023
1 parent cf575be commit 4f6042c
Show file tree
Hide file tree
Showing 10 changed files with 379 additions and 168 deletions.
183 changes: 98 additions & 85 deletions app/dns/config.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions app/dns/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ message NameServer {
repeated PriorityDomain prioritized_domain = 2;
repeated xray.app.router.GeoIP geoip = 3;
repeated OriginalRule original_rules = 4;
QueryStrategy query_strategy = 7;
}

enum DomainMatchingType {
Expand Down
46 changes: 37 additions & 9 deletions app/dns/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Client struct {
var errExpectedIPNonMatch = errors.New("expectIPs not match")

// NewServer creates a name server object according to the network destination url.
func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, error) {
func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (Server, error) {
if address := dest.Address; address.Family().IsDomain() {
u, err := url.Parse(address.Domain())
if err != nil {
Expand All @@ -45,15 +45,15 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, err
case strings.EqualFold(u.String(), "localhost"):
return NewLocalNameServer(), nil
case strings.EqualFold(u.Scheme, "https"): // DOH Remote mode
return NewDoHNameServer(u, dispatcher)
return NewDoHNameServer(u, dispatcher, queryStrategy)
case strings.EqualFold(u.Scheme, "https+local"): // DOH Local mode
return NewDoHLocalNameServer(u), nil
return NewDoHLocalNameServer(u, queryStrategy), nil
case strings.EqualFold(u.Scheme, "quic+local"): // DNS-over-QUIC Local mode
return NewQUICNameServer(u)
return NewQUICNameServer(u, queryStrategy)
case strings.EqualFold(u.Scheme, "tcp"): // DNS-over-TCP Remote mode
return NewTCPNameServer(u, dispatcher)
return NewTCPNameServer(u, dispatcher, queryStrategy)
case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode
return NewTCPLocalNameServer(u)
return NewTCPLocalNameServer(u, queryStrategy)
case strings.EqualFold(u.String(), "fakedns"):
return NewFakeDNSServer(), nil
}
Expand All @@ -68,12 +68,19 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, err
}

// NewClient creates a DNS client managing a name server with client IP, domain rules and expected IPs.
func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container router.GeoIPMatcherContainer, matcherInfos *[]*DomainMatcherInfo, updateDomainRule func(strmatcher.Matcher, int, []*DomainMatcherInfo) error) (*Client, error) {
func NewClient(
ctx context.Context,
ns *NameServer,
clientIP net.IP,
container router.GeoIPMatcherContainer,
matcherInfos *[]*DomainMatcherInfo,
updateDomainRule func(strmatcher.Matcher, int, []*DomainMatcherInfo) error,
) (*Client, error) {
client := &Client{}

err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
// Create a new server for each client for now
server, err := NewServer(ns.Address.AsDestination(), dispatcher)
server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy())
if err != nil {
return newError("failed to create nameserver").Base(err).AtWarning()
}
Expand Down Expand Up @@ -160,7 +167,7 @@ func NewClient(ctx context.Context, ns *NameServer, clientIP net.IP, container r
func NewSimpleClient(ctx context.Context, endpoint *net.Endpoint, clientIP net.IP) (*Client, error) {
client := &Client{}
err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
server, err := NewServer(endpoint.AsDestination(), dispatcher)
server, err := NewServer(endpoint.AsDestination(), dispatcher, QueryStrategy_USE_IP)
if err != nil {
return newError("failed to create nameserver").Base(err).AtWarning()
}
Expand Down Expand Up @@ -218,3 +225,24 @@ func (c *Client) MatchExpectedIPs(domain string, ips []net.IP) ([]net.IP, error)
newError("domain ", domain, " expectIPs ", newIps, " matched at server ", c.Name()).AtDebug().WriteToLog()
return newIps, nil
}

func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption) dns.IPOption {
switch queryStrategy {
case QueryStrategy_USE_IP:
return ipOption
case QueryStrategy_USE_IP4:
return dns.IPOption{
IPv4Enable: ipOption.IPv4Enable,
IPv6Enable: false,
FakeEnable: false,
}
case QueryStrategy_USE_IP6:
return dns.IPOption{
IPv4Enable: false,
IPv6Enable: ipOption.IPv6Enable,
FakeEnable: false,
}
default:
return ipOption
}
}
38 changes: 22 additions & 16 deletions app/dns/nameserver_doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,20 @@ import (
type DoHNameServer struct {
dispatcher routing.Dispatcher
sync.RWMutex
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
httpClient *http.Client
dohURL string
name string
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
httpClient *http.Client
dohURL string
name string
queryStrategy QueryStrategy
}

// NewDoHNameServer creates DOH server object for remote resolving.
func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher) (*DoHNameServer, error) {
func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (*DoHNameServer, error) {
newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog()
s := baseDOHNameServer(url, "DOH")
s := baseDOHNameServer(url, "DOH", queryStrategy)

s.dispatcher = dispatcher
tr := &http.Transport{
Expand Down Expand Up @@ -90,9 +91,9 @@ func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher) (*DoHNameServ
}

// NewDoHLocalNameServer creates DOH client object for local resolving
func NewDoHLocalNameServer(url *url.URL) *DoHNameServer {
func NewDoHLocalNameServer(url *url.URL, queryStrategy QueryStrategy) *DoHNameServer {
url.Scheme = "https"
s := baseDOHNameServer(url, "DOHL")
s := baseDOHNameServer(url, "DOHL", queryStrategy)
tr := &http.Transport{
IdleConnTimeout: 90 * time.Second,
ForceAttemptHTTP2: true,
Expand Down Expand Up @@ -122,12 +123,13 @@ func NewDoHLocalNameServer(url *url.URL) *DoHNameServer {
return s
}

func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer {
func baseDOHNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) *DoHNameServer {
s := &DoHNameServer{
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: prefix + "//" + url.Host,
dohURL: url.String(),
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: prefix + "//" + url.Host,
dohURL: url.String(),
queryStrategy: queryStrategy,
}
s.cleanup = &task.Periodic{
Interval: time.Minute,
Expand Down Expand Up @@ -353,6 +355,10 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt
// QueryIP implements Server.
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { // nolint: dupl
fqdn := Fqdn(domain)
option = ResolveIpOptionOverride(s.queryStrategy, option)
if !option.IPv4Enable && !option.IPv6Enable {
return nil, dns_feature.ErrEmptyResponse
}

if disableCache {
newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
Expand Down
50 changes: 48 additions & 2 deletions app/dns/nameserver_doh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestDOHNameServer(t *testing.T) {
url, err := url.Parse("https+local://1.1.1.1/dns-query")
common.Must(err)

s := NewDoHLocalNameServer(url)
s := NewDoHLocalNameServer(url, QueryStrategy_USE_IP)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
IPv4Enable: true,
Expand All @@ -34,7 +34,7 @@ func TestDOHNameServerWithCache(t *testing.T) {
url, err := url.Parse("https+local://1.1.1.1/dns-query")
common.Must(err)

s := NewDoHLocalNameServer(url)
s := NewDoHLocalNameServer(url, QueryStrategy_USE_IP)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
IPv4Enable: true,
Expand All @@ -57,3 +57,49 @@ func TestDOHNameServerWithCache(t *testing.T) {
t.Fatal(r)
}
}

func TestDOHNameServerWithIPv4Override(t *testing.T) {
url, err := url.Parse("https+local://1.1.1.1/dns-query")
common.Must(err)

s := NewDoHLocalNameServer(url, QueryStrategy_USE_IP4)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
IPv4Enable: true,
IPv6Enable: true,
}, false)
cancel()
common.Must(err)
if len(ips) == 0 {
t.Error("expect some ips, but got 0")
}

for _, ip := range ips {
if len(ip) != net.IPv4len {
t.Error("expect only IPv4 response from DNS query")
}
}
}

func TestDOHNameServerWithIPv6Override(t *testing.T) {
url, err := url.Parse("https+local://1.1.1.1/dns-query")
common.Must(err)

s := NewDoHLocalNameServer(url, QueryStrategy_USE_IP6)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
IPv4Enable: true,
IPv6Enable: true,
}, false)
cancel()
common.Must(err)
if len(ips) == 0 {
t.Error("expect some ips, but got 0")
}

for _, ip := range ips {
if len(ip) != net.IPv6len {
t.Error("expect only IPv6 response from DNS query")
}
}
}
30 changes: 18 additions & 12 deletions app/dns/nameserver_quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,18 @@ const handshakeTimeout = time.Second * 8
// QUICNameServer implemented DNS over QUIC
type QUICNameServer struct {
sync.RWMutex
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
name string
destination *net.Destination
connection quic.Connection
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
name string
destination *net.Destination
connection quic.Connection
queryStrategy QueryStrategy
}

// NewQUICNameServer creates DNS-over-QUIC client object for local resolving
func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
func NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServer, error) {
newError("DNS: created Local DNS-over-QUIC client for ", url.String()).AtInfo().WriteToLog()

var err error
Expand All @@ -55,10 +56,11 @@ func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port)

s := &QUICNameServer{
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: url.String(),
destination: &dest,
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: url.String(),
destination: &dest,
queryStrategy: queryStrategy,
}
s.cleanup = &task.Periodic{
Interval: time.Minute,
Expand Down Expand Up @@ -269,6 +271,10 @@ func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOp
// QueryIP is called from dns.Server->queryIPTimeout
func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
fqdn := Fqdn(domain)
option = ResolveIpOptionOverride(s.queryStrategy, option)
if !option.IPv4Enable && !option.IPv6Enable {
return nil, dns_feature.ErrEmptyResponse
}

if disableCache {
newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
Expand Down
48 changes: 47 additions & 1 deletion app/dns/nameserver_quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func TestQUICNameServer(t *testing.T) {
url, err := url.Parse("quic://dns.adguard.com")
common.Must(err)
s, err := NewQUICNameServer(url)
s, err := NewQUICNameServer(url, QueryStrategy_USE_IP)
common.Must(err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
Expand All @@ -40,3 +40,49 @@ func TestQUICNameServer(t *testing.T) {
t.Fatal(r)
}
}

func TestQUICNameServerWithIPv4Override(t *testing.T) {
url, err := url.Parse("quic://dns.adguard.com")
common.Must(err)
s, err := NewQUICNameServer(url, QueryStrategy_USE_IP4)
common.Must(err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
IPv4Enable: true,
IPv6Enable: true,
}, false)
cancel()
common.Must(err)
if len(ips) == 0 {
t.Error("expect some ips, but got 0")
}

for _, ip := range ips {
if len(ip) != net.IPv4len {
t.Error("expect only IPv4 response from DNS query")
}
}
}

func TestQUICNameServerWithIPv6Override(t *testing.T) {
url, err := url.Parse("quic://dns.adguard.com")
common.Must(err)
s, err := NewQUICNameServer(url, QueryStrategy_USE_IP6)
common.Must(err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns.IPOption{
IPv4Enable: true,
IPv6Enable: true,
}, false)
cancel()
common.Must(err)
if len(ips) == 0 {
t.Error("expect some ips, but got 0")
}

for _, ip := range ips {
if len(ip) != net.IPv6len {
t.Error("expect only IPv6 response from DNS query")
}
}
}
Loading

0 comments on commit 4f6042c

Please sign in to comment.