feat: host support domain and multiple ips

This commit is contained in:
Skyxim 2023-03-12 10:53:38 +08:00
parent ae4d114802
commit 2f69b64d82
10 changed files with 188 additions and 49 deletions

11
common/utils/slice.go Normal file
View File

@ -0,0 +1,11 @@
package utils
func Filter[T comparable](tSlice []T, filter func(t T) bool) []T {
result := make([]T, 0)
for _, t := range tSlice {
if filter(t) {
result = append(result, t)
}
}
return result
}

View File

@ -0,0 +1,78 @@
package resolver
import (
"errors"
"fmt"
"math/rand"
"net/netip"
"reflect"
"strings"
)
type HostValue struct {
IsDomain bool
IPs []netip.Addr
Domain string
}
func NewHostValue(value any) (HostValue, error) {
isDomain := true
ips := make([]netip.Addr, 0)
domain := ""
switch reflect.TypeOf(value).Kind() {
case reflect.Slice, reflect.Array:
isDomain = false
origin := reflect.ValueOf(value)
for i := 0; i < origin.Len(); i++ {
if ip, err := netip.ParseAddr(fmt.Sprintf("%v", origin.Index(i))); err == nil {
ips = append(ips, ip)
} else {
return HostValue{}, err
}
}
case reflect.String:
host := fmt.Sprintf("%v", value)
if ip, err := netip.ParseAddr(host); err == nil {
ips = append(ips, ip)
isDomain = false
} else {
domain = host
}
default:
return HostValue{}, errors.New("value format error, must be string or array")
}
if isDomain {
return NewHostValueByDomain(domain)
} else {
return NewHostValueByIPs(ips)
}
}
func NewHostValueByIPs(ips []netip.Addr) (HostValue, error) {
if len(ips) == 0 {
return HostValue{}, errors.New("ip list is empty")
}
return HostValue{
IsDomain: false,
IPs: ips,
}, nil
}
func NewHostValueByDomain(domain string) (HostValue, error) {
domain = strings.Trim(domain, ".")
item := strings.Split(domain, ".")
if len(item) < 2 {
return HostValue{}, errors.New("invaild domain")
}
return HostValue{
IsDomain: true,
Domain: domain,
}, nil
}
func (hv HostValue) RandIP() (netip.Addr, error) {
if hv.IsDomain {
return netip.Addr{}, errors.New("value type is error")
}
return hv.IPs[rand.Intn(len(hv.IPs)-1)], nil
}

View File

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/component/trie"
"github.com/miekg/dns"
@ -27,7 +28,7 @@ var (
DisableIPv6 = true
// DefaultHosts aim to resolve hosts
DefaultHosts = trie.New[netip.Addr]()
DefaultHosts = trie.New[HostValue]()
// DefaultDNSTimeout defined the default dns request timeout
DefaultDNSTimeout = time.Second * 5
@ -52,8 +53,14 @@ type Resolver interface {
// LookupIPv4WithResolver same as LookupIPv4, but with a resolver
func LookupIPv4WithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data(); ip.Is4() {
return []netip.Addr{node.Data()}, nil
if value := node.Data(); !value.IsDomain {
if addrs := utils.Filter(value.IPs, func(ip netip.Addr) bool {
return ip.Is4()
}); len(addrs) > 0 {
return addrs, nil
}
}else{
return LookupIPv4WithResolver(ctx,value.Domain,r)
}
}
@ -107,8 +114,14 @@ func LookupIPv6WithResolver(ctx context.Context, host string, r Resolver) ([]net
}
if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data(); ip.Is6() {
return []netip.Addr{ip}, nil
if value := node.Data(); !value.IsDomain {
if addrs := utils.Filter(value.IPs, func(ip netip.Addr) bool {
return ip.Is6()
}); len(addrs) > 0 {
return addrs, nil
}
}else{
return LookupIPv6WithResolver(ctx,value.Domain,r)
}
}
@ -156,7 +169,11 @@ func ResolveIPv6(ctx context.Context, host string) (netip.Addr, error) {
// LookupIPWithResolver same as LookupIP, but with a resolver
func LookupIPWithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil {
return []netip.Addr{node.Data()}, nil
if !node.Data().IsDomain{
return node.Data().IPs, nil
}else{
return LookupIPWithResolver(ctx,node.Data().Domain,r)
}
}
if r != nil {

View File

@ -26,6 +26,7 @@ import (
"github.com/Dreamacro/clash/component/geodata"
"github.com/Dreamacro/clash/component/geodata/router"
P "github.com/Dreamacro/clash/component/process"
"github.com/Dreamacro/clash/component/resolver"
SNIFF "github.com/Dreamacro/clash/component/sniffer"
tlsC "github.com/Dreamacro/clash/component/tls"
"github.com/Dreamacro/clash/component/trie"
@ -100,7 +101,7 @@ type DNS struct {
EnhancedMode C.DNSMode `yaml:"enhanced-mode"`
DefaultNameserver []dns.NameServer `yaml:"default-nameserver"`
FakeIPRange *fakeip.Pool
Hosts *trie.DomainTrie[netip.Addr]
Hosts *trie.DomainTrie[resolver.HostValue]
NameServerPolicy map[string][]dns.NameServer
ProxyServerNameserver []dns.NameServer
}
@ -154,7 +155,7 @@ type Config struct {
IPTables *IPTables
DNS *DNS
Experimental *Experimental
Hosts *trie.DomainTrie[netip.Addr]
Hosts *trie.DomainTrie[resolver.HostValue]
Profile *Profile
Rules []C.Rule
SubRules map[string][]C.Rule
@ -265,7 +266,7 @@ type RawConfig struct {
Sniffer RawSniffer `yaml:"sniffer"`
ProxyProvider map[string]map[string]any `yaml:"proxy-providers"`
RuleProvider map[string]map[string]any `yaml:"rule-providers"`
Hosts map[string]string `yaml:"hosts"`
Hosts map[string]any `yaml:"hosts"`
DNS RawDNS `yaml:"dns"`
Tun RawTun `yaml:"tun"`
TuicServer RawTuicServer `yaml:"tuic-server"`
@ -339,7 +340,7 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
UnifiedDelay: false,
Authentication: []string{},
LogLevel: log.INFO,
Hosts: map[string]string{},
Hosts: map[string]any{},
Rule: []string{},
Proxy: []map[string]any{},
ProxyGroup: []map[string]any{},
@ -827,21 +828,32 @@ func parseRules(rulesConfig []string, proxies map[string]C.Proxy, subRules map[s
return rules, nil
}
func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) {
tree := trie.New[netip.Addr]()
func parseHosts(cfg *RawConfig) (*trie.DomainTrie[resolver.HostValue], error) {
tree := trie.New[resolver.HostValue]()
// add default hosts
if err := tree.Insert("localhost", netip.AddrFrom4([4]byte{127, 0, 0, 1})); err != nil {
hostValue, _ := resolver.NewHostValueByIPs(
[]netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1})})
if err := tree.Insert("localhost", hostValue); err != nil {
log.Errorln("insert localhost to host error: %s", err.Error())
}
if len(cfg.Hosts) != 0 {
for domain, ipStr := range cfg.Hosts {
ip, err := netip.ParseAddr(ipStr)
for domain, valueStr := range cfg.Hosts {
value, err := resolver.NewHostValue(valueStr)
if err != nil {
return nil, fmt.Errorf("%s is not a valid IP", ipStr)
return nil, fmt.Errorf("%s is not a valid value", valueStr)
}
_ = tree.Insert(domain, ip)
if value.IsDomain {
node := tree.Search(value.Domain)
for node != nil && node.Data().IsDomain {
if node.Data().Domain == domain {
return nil, fmt.Errorf("%s, there is a cycle in domain name mapping", domain)
}
node = tree.Search(node.Data().Domain)
}
}
_ = tree.Insert(domain, value)
}
}
tree.Optimize()
@ -1041,7 +1053,7 @@ func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]*router.DomainM
return sites, nil
}
func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.Rule) (*DNS, error) {
func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rules []C.Rule) (*DNS, error) {
cfg := rawCfg.DNS
if cfg.Enable && len(cfg.NameServer) == 0 {
return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty")

View File

@ -124,4 +124,4 @@ const (
HTTPVersion2 HTTPVersion = "h2"
// HTTPVersion3 is HTTP/3.
HTTPVersion3 HTTPVersion = "h3"
)
)

View File

@ -8,6 +8,7 @@ import (
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/nnip"
"github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context"
@ -21,7 +22,7 @@ type (
middleware func(next handler) handler
)
func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip.Addr, string]) middleware {
func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCache[netip.Addr, string]) middleware {
return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
@ -37,29 +38,38 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip
return next(ctx, r)
}
ip := record.Data()
hostValue := record.Data()
msg := r.Copy()
if ip.Is4() && q.Qtype == D.TypeA {
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
rr.A = ip.AsSlice()
msg.Answer = []D.RR{rr}
} else if q.Qtype == D.TypeAAAA {
rr := &D.AAAA{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10}
ip := ip.As16()
rr.AAAA = ip[:]
msg.Answer = []D.RR{rr}
if !hostValue.IsDomain {
for _, ipAddr := range hostValue.IPs {
if ipAddr.Is4() && q.Qtype == D.TypeA {
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
rr.A = ipAddr.AsSlice()
msg.Answer = append(msg.Answer, rr)
if mapping != nil {
mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
}
} else if q.Qtype == D.TypeAAAA {
rr := &D.AAAA{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10}
ip := ipAddr.As16()
rr.AAAA = ip[:]
msg.Answer = append(msg.Answer, rr)
if mapping != nil {
mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
}
}
}
} else if q.Qtype == D.TypeCNAME {
rr := &D.CNAME{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10}
rr.Target = hostValue.Domain+"."
msg.Answer = append(msg.Answer, rr)
} else {
return next(ctx, r)
}
if mapping != nil {
mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*10))
}
ctx.SetType(context.DNSTypeHost)
msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true

View File

@ -43,7 +43,7 @@ type geositePolicyRecord struct {
type Resolver struct {
ipv6 bool
ipv6Timeout time.Duration
hosts *trie.DomainTrie[netip.Addr]
hosts *trie.DomainTrie[resolver.HostValue]
main []dnsClient
fallback []dnsClient
fallbackDomainFilters []fallbackDomainFilter
@ -430,7 +430,7 @@ type Config struct {
EnhancedMode C.DNSMode
FallbackFilter FallbackFilter
Pool *fakeip.Pool
Hosts *trie.DomainTrie[netip.Addr]
Hosts *trie.DomainTrie[resolver.HostValue]
Policy map[string][]NameServer
}

View File

@ -66,7 +66,7 @@ func setMsgTTL(msg *D.Msg, ttl uint32) {
}
func isIPRequest(q D.Question) bool {
return q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA)
return q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA||q.Qtype==D.TypeCNAME)
}
func transform(servers []NameServer, resolver *Resolver) []dnsClient {

View File

@ -226,7 +226,7 @@ func updateDNS(c *config.DNS, generalIPv6 bool) {
dns.ReCreateServer(c.Listen, r, m)
}
func updateHosts(tree *trie.DomainTrie[netip.Addr]) {
func updateHosts(tree *trie.DomainTrie[resolver.HostValue]) {
resolver.DefaultHosts = tree
}

View File

@ -3,6 +3,7 @@ package tunnel
import (
"context"
"fmt"
"math/rand"
"net"
"net/netip"
"path/filepath"
@ -203,7 +204,11 @@ func preHandleMetadata(metadata *C.Metadata) error {
metadata.DNSMode = C.DNSFakeIP
} else if node := resolver.DefaultHosts.Search(host); node != nil {
// redir-host should lookup the hosts
metadata.DstIP = node.Data()
if !node.Data().IsDomain {
metadata.DstIP,_ = node.Data().RandIP()
} else {
metadata.Host = node.Data().Domain
}
}
} else if resolver.IsFakeIP(metadata.DstIP) {
return fmt.Errorf("fake DNS record %s missing", metadata.DstIP)
@ -393,10 +398,13 @@ func handleTCPConn(connCtx C.ConnContext) {
dialMetadata := metadata
if len(metadata.Host) > 0 {
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
if dstIp := node.Data(); !FakeIPRange().Contains(dstIp) {
dialMetadata.DstIP = dstIp
dialMetadata.DNSMode = C.DNSHosts
dialMetadata = dialMetadata.Pure()
hostValue := node.Data()
if !hostValue.IsDomain {
if dstIp, _ := hostValue.RandIP(); !FakeIPRange().Contains(dstIp) {
dialMetadata.DstIP = dstIp
dialMetadata.DNSMode = C.DNSHosts
dialMetadata = dialMetadata.Pure()
}
}
}
}
@ -499,8 +507,11 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
)
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
metadata.DstIP = node.Data()
resolved = true
if !node.Data().IsDomain {
metadata.DstIP = node.Data().IPs[rand.Intn(len(node.Data().IPs)-1)]
resolved = true
}
}
for _, rule := range getRules(metadata) {