chore: adjust logic

This commit is contained in:
Skyxim 2023-03-12 14:14:38 +08:00
parent 0273a6bb9d
commit a510909228
5 changed files with 82 additions and 81 deletions

View File

@ -7,8 +7,42 @@ import (
"strings" "strings"
"github.com/Dreamacro/clash/common/utils" "github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/component/trie"
) )
type Hosts struct {
*trie.DomainTrie[HostValue]
}
func NewHosts(hosts *trie.DomainTrie[HostValue]) Hosts {
return Hosts{
hosts,
}
}
func (h *Hosts) Search(domain string, isDomain bool) (*HostValue, bool) {
value := h.DomainTrie.Search(domain)
if value == nil {
return nil, false
}
hostValue := value.Data()
for {
if isDomain && hostValue.IsDomain {
return &hostValue, true
} else {
if node := h.DomainTrie.Search(hostValue.Domain); node != nil {
hostValue = node.Data()
} else {
break
}
}
}
if isDomain == hostValue.IsDomain {
return &hostValue, true
}
return &hostValue, false
}
type HostValue struct { type HostValue struct {
IsDomain bool IsDomain bool
IPs []netip.Addr IPs []netip.Addr

View File

@ -28,7 +28,7 @@ var (
DisableIPv6 = true DisableIPv6 = true
// DefaultHosts aim to resolve hosts // DefaultHosts aim to resolve hosts
DefaultHosts = trie.New[HostValue]() DefaultHosts = NewHosts(trie.New[HostValue]())
// DefaultDNSTimeout defined the default dns request timeout // DefaultDNSTimeout defined the default dns request timeout
DefaultDNSTimeout = time.Second * 5 DefaultDNSTimeout = time.Second * 5
@ -52,16 +52,12 @@ type Resolver interface {
// LookupIPv4WithResolver same as LookupIPv4, but with a resolver // LookupIPv4WithResolver same as LookupIPv4, but with a resolver
func LookupIPv4WithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { func LookupIPv4WithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil { if node, ok := DefaultHosts.Search(host, false); ok {
if value := node.Data(); !value.IsDomain { if addrs := utils.Filter(node.IPs, func(ip netip.Addr) bool {
if addrs := utils.Filter(value.IPs, func(ip netip.Addr) bool {
return ip.Is4() return ip.Is4()
}); len(addrs) > 0 { }); len(addrs) > 0 {
return addrs, nil return addrs, nil
} }
}else{
return LookupIPv4WithResolver(ctx,value.Domain,r)
}
} }
ip, err := netip.ParseAddr(host) ip, err := netip.ParseAddr(host)
@ -113,16 +109,12 @@ func LookupIPv6WithResolver(ctx context.Context, host string, r Resolver) ([]net
return nil, ErrIPv6Disabled return nil, ErrIPv6Disabled
} }
if node := DefaultHosts.Search(host); node != nil { if node, ok := DefaultHosts.Search(host, false); ok {
if value := node.Data(); !value.IsDomain { if addrs := utils.Filter(node.IPs, func(ip netip.Addr) bool {
if addrs := utils.Filter(value.IPs, func(ip netip.Addr) bool {
return ip.Is6() return ip.Is6()
}); len(addrs) > 0 { }); len(addrs) > 0 {
return addrs, nil return addrs, nil
} }
}else{
return LookupIPv6WithResolver(ctx,value.Domain,r)
}
} }
if ip, err := netip.ParseAddr(host); err == nil { if ip, err := netip.ParseAddr(host); err == nil {
@ -168,12 +160,8 @@ func ResolveIPv6(ctx context.Context, host string) (netip.Addr, error) {
// LookupIPWithResolver same as LookupIP, but with a resolver // LookupIPWithResolver same as LookupIP, but with a resolver
func LookupIPWithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { func LookupIPWithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil { if node, ok := DefaultHosts.Search(host, false); ok {
if !node.Data().IsDomain{ return node.IPs, nil
return node.Data().IPs, nil
}else{
return LookupIPWithResolver(ctx,node.Data().Domain,r)
}
} }
if r != nil { if r != nil {

View File

@ -8,8 +8,7 @@ import (
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/nnip" "github.com/Dreamacro/clash/common/nnip"
"github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/component/resolver" R "github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/context" "github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
@ -22,7 +21,7 @@ type (
middleware func(next handler) handler middleware func(next handler) handler
) )
func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCache[netip.Addr, string]) middleware { func withHosts(hosts R.Hosts, mapping *cache.LruCache[netip.Addr, string]) middleware {
return func(next handler) handler { return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
@ -33,15 +32,25 @@ func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCac
host := strings.TrimRight(q.Name, ".") host := strings.TrimRight(q.Name, ".")
record := hosts.Search(host) record, ok := hosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA)
if record == nil { if !ok {
if record != nil && record.IsDomain {
// replace request domain
newR := r.Copy()
newR.Question[0].Name = record.Domain+"."
resp, err := next(ctx, newR)
if err==nil{
resp.Id=r.Id
resp.Question=r.Question
}
return resp,err
}
return next(ctx,r) return next(ctx,r)
} }
hostValue := record.Data()
msg := r.Copy() msg := r.Copy()
handleIPs := func() { handleIPs := func() {
for _, ipAddr := range hostValue.IPs { for _, ipAddr := range record.IPs {
if ipAddr.Is4() && q.Qtype == D.TypeA { if ipAddr.Is4() && q.Qtype == D.TypeA {
rr := &D.A{} rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
@ -62,35 +71,16 @@ func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCac
} }
} }
} }
fillMsg := func() {
if !hostValue.IsDomain {
handleIPs()
} else {
for {
if hostValue.IsDomain {
if node := hosts.Search(hostValue.Domain); node != nil {
hostValue = node.Data()
} else {
break
}
}else{
break
}
}
if !hostValue.IsDomain {
handleIPs()
}
}
}
switch q.Qtype { switch q.Qtype {
case D.TypeA: case D.TypeA:
fillMsg() handleIPs()
case D.TypeAAAA: case D.TypeAAAA:
fillMsg() handleIPs()
case D.TypeCNAME: case D.TypeCNAME:
rr := &D.CNAME{} rr := &D.CNAME{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10}
rr.Target = hostValue.Domain + "." rr.Target = record.Domain + "."
msg.Answer = append(msg.Answer, rr) msg.Answer = append(msg.Answer, rr)
default: default:
return next(ctx, r) return next(ctx, r)
@ -100,7 +90,6 @@ func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCac
msg.SetRcode(r, D.RcodeSuccess) msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true msg.Authoritative = true
msg.RecursionAvailable = true msg.RecursionAvailable = true
return msg, nil return msg, nil
} }
} }
@ -185,6 +174,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
func withResolver(resolver *Resolver) handler { func withResolver(resolver *Resolver) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
ctx.SetType(context.DNSTypeRaw) ctx.SetType(context.DNSTypeRaw)
q := r.Question[0] q := r.Question[0]
// return a empty AAAA msg when ipv6 disabled // return a empty AAAA msg when ipv6 disabled
@ -219,7 +209,7 @@ func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
middlewares := []middleware{} middlewares := []middleware{}
if resolver.hosts != nil { if resolver.hosts != nil {
middlewares = append(middlewares, withHosts(resolver.hosts, mapper.mapping)) middlewares = append(middlewares, withHosts(R.NewHosts(resolver.hosts), mapper.mapping))
} }
if mapper.mode == C.DNSFakeIP { if mapper.mode == C.DNSFakeIP {

View File

@ -227,7 +227,7 @@ func updateDNS(c *config.DNS, generalIPv6 bool) {
} }
func updateHosts(tree *trie.DomainTrie[resolver.HostValue]) { func updateHosts(tree *trie.DomainTrie[resolver.HostValue]) {
resolver.DefaultHosts = tree resolver.DefaultHosts = resolver.NewHosts(tree)
} }
func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) { func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) {

View File

@ -3,7 +3,6 @@ package tunnel
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"net" "net"
"net/netip" "net/netip"
"path/filepath" "path/filepath"
@ -202,13 +201,9 @@ func preHandleMetadata(metadata *C.Metadata) error {
if resolver.FakeIPEnabled() { if resolver.FakeIPEnabled() {
metadata.DstIP = netip.Addr{} metadata.DstIP = netip.Addr{}
metadata.DNSMode = C.DNSFakeIP metadata.DNSMode = C.DNSFakeIP
} else if node := resolver.DefaultHosts.Search(host); node != nil { } else if node, ok := resolver.DefaultHosts.Search(host, false); ok {
// redir-host should lookup the hosts // redir-host should lookup the hosts
if !node.Data().IsDomain { metadata.DstIP, _ = node.RandIP()
metadata.DstIP,_ = node.Data().RandIP()
} else {
metadata.Host = node.Data().Domain
}
} }
} else if resolver.IsFakeIP(metadata.DstIP) { } else if resolver.IsFakeIP(metadata.DstIP) {
return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) return fmt.Errorf("fake DNS record %s missing", metadata.DstIP)
@ -397,17 +392,14 @@ func handleTCPConn(connCtx C.ConnContext) {
dialMetadata := metadata dialMetadata := metadata
if len(metadata.Host) > 0 { if len(metadata.Host) > 0 {
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok {
hostValue := node.Data() if dstIp, _ := node.RandIP(); !FakeIPRange().Contains(dstIp) {
if !hostValue.IsDomain {
if dstIp, _ := hostValue.RandIP(); !FakeIPRange().Contains(dstIp) {
dialMetadata.DstIP = dstIp dialMetadata.DstIP = dstIp
dialMetadata.DNSMode = C.DNSHosts dialMetadata.DNSMode = C.DNSHosts
dialMetadata = dialMetadata.Pure() dialMetadata = dialMetadata.Pure()
} }
} }
} }
}
var peekBytes []byte var peekBytes []byte
var peekLen int var peekLen int
@ -506,14 +498,11 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
processFound bool processFound bool
) )
if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok {
if !node.Data().IsDomain { metadata.DstIP, _ = node.RandIP()
metadata.DstIP = node.Data().IPs[rand.Intn(len(node.Data().IPs)-1)]
resolved = true resolved = true
} }
}
for _, rule := range getRules(metadata) { for _, rule := range getRules(metadata) {
if !resolved && shouldResolveIP(rule, metadata) { if !resolved && shouldResolveIP(rule, metadata) {
func() { func() {