Feature: dns resolve domain through nameserver-policy (#1406)

This commit is contained in:
gVisor bot 2021-05-19 11:17:35 +08:00
parent 1a435b79e6
commit a409e7f2aa
4 changed files with 62 additions and 3 deletions

View File

@ -23,7 +23,7 @@ type DomainTrie struct {
root *Node root *Node
} }
func validAndSplitDomain(domain string) ([]string, bool) { func ValidAndSplitDomain(domain string) ([]string, bool) {
if domain != "" && domain[len(domain)-1] == '.' { if domain != "" && domain[len(domain)-1] == '.' {
return nil, false return nil, false
} }
@ -54,7 +54,7 @@ func validAndSplitDomain(domain string) ([]string, bool) {
// 4. .example.com // 4. .example.com
// 5. +.example.com // 5. +.example.com
func (t *DomainTrie) Insert(domain string, data interface{}) error { func (t *DomainTrie) Insert(domain string, data interface{}) error {
parts, valid := validAndSplitDomain(domain) parts, valid := ValidAndSplitDomain(domain)
if !valid { if !valid {
return ErrInvalidDomain return ErrInvalidDomain
} }
@ -91,7 +91,7 @@ func (t *DomainTrie) insert(parts []string, data interface{}) {
// 2. wildcard domain // 2. wildcard domain
// 2. dot wildcard domain // 2. dot wildcard domain
func (t *DomainTrie) Search(domain string) *Node { func (t *DomainTrie) Search(domain string) *Node {
parts, valid := validAndSplitDomain(domain) parts, valid := ValidAndSplitDomain(domain)
if !valid || parts[0] == "" { if !valid || parts[0] == "" {
return nil return nil
} }

View File

@ -64,6 +64,7 @@ type DNS struct {
DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` DefaultNameserver []dns.NameServer `yaml:"default-nameserver"`
FakeIPRange *fakeip.Pool FakeIPRange *fakeip.Pool
Hosts *trie.DomainTrie Hosts *trie.DomainTrie
NameServerPolicy map[string]dns.NameServer
} }
// FallbackFilter config // FallbackFilter config
@ -106,6 +107,7 @@ type RawDNS struct {
FakeIPRange string `yaml:"fake-ip-range"` FakeIPRange string `yaml:"fake-ip-range"`
FakeIPFilter []string `yaml:"fake-ip-filter"` FakeIPFilter []string `yaml:"fake-ip-filter"`
DefaultNameserver []string `yaml:"default-nameserver"` DefaultNameserver []string `yaml:"default-nameserver"`
NameServerPolicy map[string]string `yaml:"nameserver-policy"`
} }
type RawFallbackFilter struct { type RawFallbackFilter struct {
@ -500,6 +502,23 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) {
return nameservers, nil return nameservers, nil
} }
func parseNameServerPolicy(nsPolicy map[string]string) (map[string]dns.NameServer, error) {
policy := map[string]dns.NameServer{}
for domain, server := range nsPolicy {
nameservers, err := parseNameServer([]string{server})
if err != nil {
return nil, err
}
if _, valid := trie.ValidAndSplitDomain(domain); !valid {
return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain)
}
policy[domain] = nameservers[0]
}
return policy, nil
}
func parseFallbackIPCIDR(ips []string) ([]*net.IPNet, error) { func parseFallbackIPCIDR(ips []string) ([]*net.IPNet, error) {
ipNets := []*net.IPNet{} ipNets := []*net.IPNet{}
@ -537,6 +556,10 @@ func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) {
return nil, err return nil, err
} }
if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy); err != nil {
return nil, err
}
if len(cfg.DefaultNameserver) == 0 { if len(cfg.DefaultNameserver) == 0 {
return nil, errors.New("default nameserver should have at least one nameserver") return nil, errors.New("default nameserver should have at least one nameserver")
} }

View File

@ -43,6 +43,7 @@ type Resolver struct {
fallbackIPFilters []fallbackIPFilter fallbackIPFilters []fallbackIPFilter
group singleflight.Group group singleflight.Group
lruCache *cache.LruCache lruCache *cache.LruCache
policy *trie.DomainTrie
} }
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
@ -131,6 +132,9 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
return r.ipExchange(m) return r.ipExchange(m)
} }
if matched := r.matchPolicy(m); len(matched) != 0 {
return r.batchExchange(matched, m)
}
return r.batchExchange(r.main, m) return r.batchExchange(r.main, m)
}) })
@ -172,6 +176,24 @@ func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err
return return
} }
func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
if r.policy == nil {
return nil
}
domain := r.msgToDomain(m)
if domain == "" {
return nil
}
record := r.policy.Search(domain)
if record == nil {
return nil
}
return record.Data.([]dnsClient)
}
func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
if r.fallback == nil || len(r.fallbackDomainFilters) == 0 { if r.fallback == nil || len(r.fallbackDomainFilters) == 0 {
return false return false
@ -194,6 +216,11 @@ func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) { func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
if matched := r.matchPolicy(m); len(matched) != 0 {
res := <-r.asyncExchange(matched, m)
return res.Msg, res.Error
}
onlyFallback := r.shouldOnlyQueryFallback(m) onlyFallback := r.shouldOnlyQueryFallback(m)
if onlyFallback { if onlyFallback {
@ -293,6 +320,7 @@ type Config struct {
FallbackFilter FallbackFilter FallbackFilter FallbackFilter
Pool *fakeip.Pool Pool *fakeip.Pool
Hosts *trie.DomainTrie Hosts *trie.DomainTrie
Policy map[string]NameServer
} }
func NewResolver(config Config) *Resolver { func NewResolver(config Config) *Resolver {
@ -312,6 +340,13 @@ func NewResolver(config Config) *Resolver {
r.fallback = transform(config.Fallback, defaultResolver) r.fallback = transform(config.Fallback, defaultResolver)
} }
if len(config.Policy) != 0 {
r.policy = trie.New()
for domain, nameserver := range config.Policy {
r.policy.Insert(domain, transform([]NameServer{nameserver}, defaultResolver))
}
}
fallbackIPFilters := []fallbackIPFilter{} fallbackIPFilters := []fallbackIPFilter{}
if config.FallbackFilter.GeoIP { if config.FallbackFilter.GeoIP {
fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{}) fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{})

View File

@ -128,6 +128,7 @@ func updateDNS(c *config.DNS) {
Domain: c.FallbackFilter.Domain, Domain: c.FallbackFilter.Domain,
}, },
Default: c.DefaultNameserver, Default: c.DefaultNameserver,
Policy: c.NameServerPolicy,
} }
r := dns.NewResolver(cfg) r := dns.NewResolver(cfg)