From 7f588935eac0bcd2c7165d7d484c0a02a0c99df6 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sun, 12 Mar 2023 15:00:59 +0800 Subject: [PATCH] feta: add hosts support domain and mulitple ip (#439) * feat: host support domain and multiple ips * chore: append local address via `clash` * chore: update hosts demo * chore: unified parse mixed string and array * fix: flatten cname * chore: adjust logic * chore: reuse code * chore: use cname in tunnel * chore: try use domain mapping when normal dns * chore: format code --- common/utils/slice.go | 34 ++++++++++ component/resolver/host.go | 112 +++++++++++++++++++++++++++++++++ component/resolver/resolver.go | 23 ++++--- config/config.go | 74 +++++++++++++--------- dns/middleware.go | 79 +++++++++++++++-------- dns/resolver.go | 4 +- dns/util.go | 2 +- docs/config.yaml | 3 + hub/executor/executor.go | 4 +- hub/hub.go | 2 +- tunnel/tunnel.go | 17 +++-- 11 files changed, 278 insertions(+), 76 deletions(-) create mode 100644 common/utils/slice.go create mode 100644 component/resolver/host.go diff --git a/common/utils/slice.go b/common/utils/slice.go new file mode 100644 index 00000000..1b0fa494 --- /dev/null +++ b/common/utils/slice.go @@ -0,0 +1,34 @@ +package utils + +import ( + "errors" + "fmt" + "reflect" +) + +func Filter[T comparable](tSlice []T, filter func(t T) bool) []T { + result := make([]T, 0) + for _, t := range tSlice { + if filter(t) { + result = append(result, t) + } + } + return result +} + +func ToStringSlice(value any) ([]string, error) { + strArr := make([]string, 0) + switch reflect.TypeOf(value).Kind() { + case reflect.Slice, reflect.Array: + origin := reflect.ValueOf(value) + for i := 0; i < origin.Len(); i++ { + item := fmt.Sprintf("%v", origin.Index(i)) + strArr = append(strArr, item) + } + case reflect.String: + strArr = append(strArr, fmt.Sprintf("%v", value)) + default: + return nil, errors.New("value format error, must be string or array") + } + return strArr, nil +} diff --git a/component/resolver/host.go b/component/resolver/host.go new file mode 100644 index 00000000..ca90cd27 --- /dev/null +++ b/component/resolver/host.go @@ -0,0 +1,112 @@ +package resolver + +import ( + "errors" + "math/rand" + "net/netip" + "strings" + + "github.com/Dreamacro/clash/common/utils" + "github.com/Dreamacro/clash/component/trie" +) + +type Hosts struct { + *trie.DomainTrie[HostValue] +} + +func NewHosts(hosts *trie.DomainTrie[HostValue]) Hosts { + return Hosts{ + hosts, + } +} + +func (h *Hosts) Search(domain string, isDomain bool) (*HostValue, bool) { + value := h.DomainTrie.Search(domain) + if value == nil { + return nil, false + } + hostValue := value.Data() + for { + if isDomain && hostValue.IsDomain { + return &hostValue, true + } else { + if node := h.DomainTrie.Search(hostValue.Domain); node != nil { + hostValue = node.Data() + } else { + break + } + } + } + if isDomain == hostValue.IsDomain { + return &hostValue, true + } + return &hostValue, false +} + +type HostValue struct { + IsDomain bool + IPs []netip.Addr + Domain string +} + +func NewHostValue(value any) (HostValue, error) { + isDomain := true + ips := make([]netip.Addr, 0) + domain := "" + if valueArr, err := utils.ToStringSlice(value); err != nil { + return HostValue{}, err + } else { + if len(valueArr) > 1 { + isDomain = false + for _, str := range valueArr { + if ip, err := netip.ParseAddr(str); err == nil { + ips = append(ips, ip) + } else { + return HostValue{}, err + } + } + } else if len(valueArr) == 1 { + host := valueArr[0] + if ip, err := netip.ParseAddr(host); err == nil { + ips = append(ips, ip) + isDomain = false + } else { + domain = host + } + } + } + if isDomain { + return NewHostValueByDomain(domain) + } else { + return NewHostValueByIPs(ips) + } +} + +func NewHostValueByIPs(ips []netip.Addr) (HostValue, error) { + if len(ips) == 0 { + return HostValue{}, errors.New("ip list is empty") + } + return HostValue{ + IsDomain: false, + IPs: ips, + }, nil +} + +func NewHostValueByDomain(domain string) (HostValue, error) { + domain = strings.Trim(domain, ".") + item := strings.Split(domain, ".") + if len(item) < 2 { + return HostValue{}, errors.New("invaild domain") + } + return HostValue{ + IsDomain: true, + Domain: domain, + }, nil +} + +func (hv HostValue) RandIP() (netip.Addr, error) { + if hv.IsDomain { + return netip.Addr{}, errors.New("value type is error") + } + return hv.IPs[rand.Intn(len(hv.IPs)-1)], nil +} diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index 6ae2d7c2..f5872ad7 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/Dreamacro/clash/common/utils" "github.com/Dreamacro/clash/component/trie" "github.com/miekg/dns" @@ -27,7 +28,7 @@ var ( DisableIPv6 = true // DefaultHosts aim to resolve hosts - DefaultHosts = trie.New[netip.Addr]() + DefaultHosts = NewHosts(trie.New[HostValue]()) // DefaultDNSTimeout defined the default dns request timeout DefaultDNSTimeout = time.Second * 5 @@ -51,9 +52,11 @@ type Resolver interface { // LookupIPv4WithResolver same as LookupIPv4, but with a resolver func LookupIPv4WithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { - if node := DefaultHosts.Search(host); node != nil { - if ip := node.Data(); ip.Is4() { - return []netip.Addr{node.Data()}, nil + if node, ok := DefaultHosts.Search(host, false); ok { + if addrs := utils.Filter(node.IPs, func(ip netip.Addr) bool { + return ip.Is4() + }); len(addrs) > 0 { + return addrs, nil } } @@ -106,9 +109,11 @@ func LookupIPv6WithResolver(ctx context.Context, host string, r Resolver) ([]net return nil, ErrIPv6Disabled } - if node := DefaultHosts.Search(host); node != nil { - if ip := node.Data(); ip.Is6() { - return []netip.Addr{ip}, nil + if node, ok := DefaultHosts.Search(host, false); ok { + if addrs := utils.Filter(node.IPs, func(ip netip.Addr) bool { + return ip.Is6() + }); len(addrs) > 0 { + return addrs, nil } } @@ -155,8 +160,8 @@ func ResolveIPv6(ctx context.Context, host string) (netip.Addr, error) { // LookupIPWithResolver same as LookupIP, but with a resolver func LookupIPWithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { - if node := DefaultHosts.Search(host); node != nil { - return []netip.Addr{node.Data()}, nil + if node, ok := DefaultHosts.Search(host, false); ok { + return node.IPs, nil } if r != nil { diff --git a/config/config.go b/config/config.go index 817d0d64..45dd31f5 100644 --- a/config/config.go +++ b/config/config.go @@ -9,7 +9,6 @@ import ( "net/netip" "net/url" "os" - "reflect" "runtime" "strconv" "strings" @@ -26,6 +25,7 @@ import ( "github.com/Dreamacro/clash/component/geodata" "github.com/Dreamacro/clash/component/geodata/router" P "github.com/Dreamacro/clash/component/process" + "github.com/Dreamacro/clash/component/resolver" SNIFF "github.com/Dreamacro/clash/component/sniffer" tlsC "github.com/Dreamacro/clash/component/tls" "github.com/Dreamacro/clash/component/trie" @@ -100,7 +100,7 @@ type DNS struct { EnhancedMode C.DNSMode `yaml:"enhanced-mode"` DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` FakeIPRange *fakeip.Pool - Hosts *trie.DomainTrie[netip.Addr] + Hosts *trie.DomainTrie[resolver.HostValue] NameServerPolicy map[string][]dns.NameServer ProxyServerNameserver []dns.NameServer } @@ -154,7 +154,7 @@ type Config struct { IPTables *IPTables DNS *DNS Experimental *Experimental - Hosts *trie.DomainTrie[netip.Addr] + Hosts *trie.DomainTrie[resolver.HostValue] Profile *Profile Rules []C.Rule SubRules map[string][]C.Rule @@ -265,7 +265,7 @@ type RawConfig struct { Sniffer RawSniffer `yaml:"sniffer"` ProxyProvider map[string]map[string]any `yaml:"proxy-providers"` RuleProvider map[string]map[string]any `yaml:"rule-providers"` - Hosts map[string]string `yaml:"hosts"` + Hosts map[string]any `yaml:"hosts"` DNS RawDNS `yaml:"dns"` Tun RawTun `yaml:"tun"` TuicServer RawTuicServer `yaml:"tuic-server"` @@ -339,7 +339,7 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { UnifiedDelay: false, Authentication: []string{}, LogLevel: log.INFO, - Hosts: map[string]string{}, + Hosts: map[string]any{}, Rule: []string{}, Proxy: []map[string]any{}, ProxyGroup: []map[string]any{}, @@ -827,21 +827,47 @@ func parseRules(rulesConfig []string, proxies map[string]C.Proxy, subRules map[s return rules, nil } -func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) { - tree := trie.New[netip.Addr]() +func parseHosts(cfg *RawConfig) (*trie.DomainTrie[resolver.HostValue], error) { + tree := trie.New[resolver.HostValue]() // add default hosts - if err := tree.Insert("localhost", netip.AddrFrom4([4]byte{127, 0, 0, 1})); err != nil { + hostValue, _ := resolver.NewHostValueByIPs( + []netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1})}) + if err := tree.Insert("localhost", hostValue); err != nil { log.Errorln("insert localhost to host error: %s", err.Error()) } if len(cfg.Hosts) != 0 { - for domain, ipStr := range cfg.Hosts { - ip, err := netip.ParseAddr(ipStr) - if err != nil { - return nil, fmt.Errorf("%s is not a valid IP", ipStr) + for domain, anyValue := range cfg.Hosts { + if str, ok := anyValue.(string); ok && str == "clash" { + if addrs, err := net.InterfaceAddrs(); err != nil { + log.Errorln("insert clash to host error: %s", err) + } else { + ips := make([]netip.Addr, 0) + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ip, err := netip.ParseAddr(ipnet.IP.String()); err == nil { + ips = append(ips, ip) + } + } + } + anyValue = ips + } } - _ = tree.Insert(domain, ip) + value, err := resolver.NewHostValue(anyValue) + if err != nil { + return nil, fmt.Errorf("%s is not a valid value", anyValue) + } + if value.IsDomain { + node := tree.Search(value.Domain) + for node != nil && node.Data().IsDomain { + if node.Data().Domain == domain { + return nil, fmt.Errorf("%s, there is a cycle in domain name mapping", domain) + } + node = tree.Search(node.Data().Domain) + } + } + _ = tree.Insert(domain, value) } } tree.Optimize() @@ -961,24 +987,12 @@ func parseNameServerPolicy(nsPolicy map[string]any, preferH3 bool) (map[string][ policy := map[string][]dns.NameServer{} for domain, server := range nsPolicy { - var ( - nameservers []dns.NameServer - err error - ) - switch reflect.TypeOf(server).Kind() { - case reflect.Slice, reflect.Array: - origin := reflect.ValueOf(server) - servers := make([]string, 0) - for i := 0; i < origin.Len(); i++ { - servers = append(servers, fmt.Sprintf("%v", origin.Index(i))) - } - nameservers, err = parseNameServer(servers, preferH3) - case reflect.String: - nameservers, err = parseNameServer([]string{fmt.Sprintf("%v", server)}, preferH3) - default: - return nil, errors.New("server format error, must be string or array") + servers, err := utils.ToStringSlice(server) + if err != nil { + return nil, err } + nameservers, err := parseNameServer(servers, preferH3) if err != nil { return nil, err } @@ -1041,7 +1055,7 @@ func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]*router.DomainM return sites, nil } -func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.Rule) (*DNS, error) { +func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rules []C.Rule) (*DNS, error) { cfg := rawCfg.DNS if cfg.Enable && len(cfg.NameServer) == 0 { return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty") diff --git a/dns/middleware.go b/dns/middleware.go index 7dc9622d..f2dd9c96 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -8,7 +8,7 @@ import ( "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/nnip" "github.com/Dreamacro/clash/component/fakeip" - "github.com/Dreamacro/clash/component/trie" + R "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" "github.com/Dreamacro/clash/log" @@ -21,7 +21,7 @@ type ( middleware func(next handler) handler ) -func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip.Addr, string]) middleware { +func withHosts(hosts R.Hosts, mapping *cache.LruCache[netip.Addr, string]) middleware { return func(next handler) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] @@ -31,40 +31,68 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip } host := strings.TrimRight(q.Name, ".") - - record := hosts.Search(host) - if record == nil { + handleCName := func(resp *D.Msg, domain string) { + rr := &D.CNAME{} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10} + rr.Target = domain + "." + resp.Answer = append([]D.RR{rr}, resp.Answer...) + } + record, ok := hosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA) + if !ok { + if record != nil && record.IsDomain { + // replace request domain + newR := r.Copy() + newR.Question[0].Name = record.Domain + "." + resp, err := next(ctx, newR) + if err == nil { + resp.Id = r.Id + resp.Question = r.Question + handleCName(resp, record.Domain) + } + return resp, err + } return next(ctx, r) } - ip := record.Data() msg := r.Copy() - - if ip.Is4() && q.Qtype == D.TypeA { - rr := &D.A{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10} - rr.A = ip.AsSlice() - - msg.Answer = []D.RR{rr} - } else if q.Qtype == D.TypeAAAA { - rr := &D.AAAA{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10} - ip := ip.As16() - rr.AAAA = ip[:] - msg.Answer = []D.RR{rr} - } else { - return next(ctx, r) + handleIPs := func() { + for _, ipAddr := range record.IPs { + if ipAddr.Is4() && q.Qtype == D.TypeA { + rr := &D.A{} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10} + rr.A = ipAddr.AsSlice() + msg.Answer = append(msg.Answer, rr) + if mapping != nil { + mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10)) + } + } else if q.Qtype == D.TypeAAAA { + rr := &D.AAAA{} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10} + ip := ipAddr.As16() + rr.AAAA = ip[:] + msg.Answer = append(msg.Answer, rr) + if mapping != nil { + mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10)) + } + } + } } - if mapping != nil { - mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*10)) + switch q.Qtype { + case D.TypeA: + handleIPs() + case D.TypeAAAA: + handleIPs() + case D.TypeCNAME: + handleCName(r, record.Domain) + default: + return next(ctx, r) } ctx.SetType(context.DNSTypeHost) msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true msg.RecursionAvailable = true - return msg, nil } } @@ -149,6 +177,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { func withResolver(resolver *Resolver) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { ctx.SetType(context.DNSTypeRaw) + q := r.Question[0] // return a empty AAAA msg when ipv6 disabled @@ -183,7 +212,7 @@ func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { middlewares := []middleware{} if resolver.hosts != nil { - middlewares = append(middlewares, withHosts(resolver.hosts, mapper.mapping)) + middlewares = append(middlewares, withHosts(R.NewHosts(resolver.hosts), mapper.mapping)) } if mapper.mode == C.DNSFakeIP { diff --git a/dns/resolver.go b/dns/resolver.go index 69725870..57f581a5 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -43,7 +43,7 @@ type geositePolicyRecord struct { type Resolver struct { ipv6 bool ipv6Timeout time.Duration - hosts *trie.DomainTrie[netip.Addr] + hosts *trie.DomainTrie[resolver.HostValue] main []dnsClient fallback []dnsClient fallbackDomainFilters []fallbackDomainFilter @@ -430,7 +430,7 @@ type Config struct { EnhancedMode C.DNSMode FallbackFilter FallbackFilter Pool *fakeip.Pool - Hosts *trie.DomainTrie[netip.Addr] + Hosts *trie.DomainTrie[resolver.HostValue] Policy map[string][]NameServer } diff --git a/dns/util.go b/dns/util.go index 203ab615..4821195d 100644 --- a/dns/util.go +++ b/dns/util.go @@ -66,7 +66,7 @@ func setMsgTTL(msg *D.Msg, ttl uint32) { } func isIPRequest(q D.Question) bool { - return q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA) + return q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA || q.Qtype == D.TypeCNAME) } func transform(servers []NameServer, resolver *Resolver) []dnsClient { diff --git a/docs/config.yaml b/docs/config.yaml index b5af798b..b52cc63f 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -58,6 +58,9 @@ hosts: # '*.clash.dev': 127.0.0.1 # '.dev': 127.0.0.1 # 'alpha.clash.dev': '::1' +# test.com: [1.1.1.1, 2.2.2.2] +# clash.lan: clash # clash 为特别字段,将加入本地所有网卡的地址 +# baidu.com: google.com # 只允许配置一个别名 profile: # 存储 select 选择记录 store-selected: false diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 1bd22385..21a25ecd 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -226,8 +226,8 @@ func updateDNS(c *config.DNS, generalIPv6 bool) { dns.ReCreateServer(c.Listen, r, m) } -func updateHosts(tree *trie.DomainTrie[netip.Addr]) { - resolver.DefaultHosts = tree +func updateHosts(tree *trie.DomainTrie[resolver.HostValue]) { + resolver.DefaultHosts = resolver.NewHosts(tree) } func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) { diff --git a/hub/hub.go b/hub/hub.go index e4c414fa..bd228fad 100644 --- a/hub/hub.go +++ b/hub/hub.go @@ -44,7 +44,7 @@ func Parse(options ...Option) error { if cfg.General.ExternalController != "" { go route.Start(cfg.General.ExternalController, cfg.General.ExternalControllerTLS, - cfg.General.Secret, cfg.TLS.Certificate, cfg.TLS.PrivateKey,cfg.General.LogLevel==log.DEBUG) + cfg.General.Secret, cfg.TLS.Certificate, cfg.TLS.PrivateKey, cfg.General.LogLevel == log.DEBUG) } executor.ApplyConfig(cfg, true) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index d80893c9..b686eae6 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -201,13 +201,18 @@ func preHandleMetadata(metadata *C.Metadata) error { if resolver.FakeIPEnabled() { metadata.DstIP = netip.Addr{} metadata.DNSMode = C.DNSFakeIP - } else if node := resolver.DefaultHosts.Search(host); node != nil { + } else if node, ok := resolver.DefaultHosts.Search(host, false); ok { // redir-host should lookup the hosts - metadata.DstIP = node.Data() + metadata.DstIP, _ = node.RandIP() + } else if node != nil && node.IsDomain { + metadata.Host = node.Domain } } else if resolver.IsFakeIP(metadata.DstIP) { return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) } + } else if node, ok := resolver.DefaultHosts.Search(metadata.Host, true); ok { + // try use domain mapping + metadata.Host = node.Domain } return nil @@ -392,8 +397,8 @@ func handleTCPConn(connCtx C.ConnContext) { dialMetadata := metadata if len(metadata.Host) > 0 { - if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { - if dstIp := node.Data(); !FakeIPRange().Contains(dstIp) { + if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok { + if dstIp, _ := node.RandIP(); !FakeIPRange().Contains(dstIp) { dialMetadata.DstIP = dstIp dialMetadata.DNSMode = C.DNSHosts dialMetadata = dialMetadata.Pure() @@ -498,8 +503,8 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { processFound bool ) - if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { - metadata.DstIP = node.Data() + if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok { + metadata.DstIP, _ = node.RandIP() resolved = true }