From 05f877a324a31b21a6e4a30fd75d64dfb98aaff5 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Sun, 26 Feb 2023 10:07:03 +0800 Subject: [PATCH] refactor: tcp dial --- component/dialer/dialer.go | 469 +++++++++++++++---------------------- 1 file changed, 191 insertions(+), 278 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 9ac9d719..2a54536b 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -10,18 +10,17 @@ import ( "sync" "github.com/Dreamacro/clash/component/resolver" - - "go.uber.org/atomic" ) var ( - dialMux sync.Mutex - actualSingleDialContext = singleDialContext - actualDualStackDialContext = dualStackDialContext - tcpConcurrent = false - DisableIPv6 = false - ErrorInvalidedNetworkStack = errors.New("invalided network stack") - ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel") + dialMux sync.Mutex + actualSingleStackDialContext = serialSingleStackDialContext + actualDualStackDialContext = serialDualStackDialContext + tcpConcurrent = false + DisableIPv6 = false + ErrorInvalidedNetworkStack = errors.New("invalided network stack") + ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel") + ErrorConnTimeout = errors.New("connect timeout") ) func applyOptions(options ...Option) *option { @@ -56,7 +55,7 @@ func DialContext(ctx context.Context, network, address string, options ...Option switch network { case "tcp4", "tcp6", "udp4", "udp6": - return actualSingleDialContext(ctx, network, address, opt) + return actualSingleStackDialContext(ctx, network, address, opt) case "tcp", "udp": return actualDualStackDialContext(ctx, network, address, opt) default: @@ -89,11 +88,11 @@ func SetDial(concurrent bool) { dialMux.Lock() tcpConcurrent = concurrent if concurrent { - actualSingleDialContext = concurrentSingleDialContext + actualSingleStackDialContext = concurrentSingleStackDialContext actualDualStackDialContext = concurrentDualStackDialContext } else { - actualSingleDialContext = singleDialContext - actualDualStackDialContext = dualStackDialContext + actualSingleStackDialContext = serialSingleStackDialContext + actualDualStackDialContext = serialDualStackDialContext } dialMux.Unlock() @@ -125,289 +124,51 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po return dialer.DialContext(ctx, network, address) } -func singleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) +func serialSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { + ips, port, err := parseAddr(ctx, network, address, opt.resolver) if err != nil { return nil, err } - - var ip netip.Addr - switch network { - case "tcp4", "udp4": - if opt.resolver == nil { - ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host) - } else { - ip, err = resolver.ResolveIPv4WithResolver(ctx, host, opt.resolver) - } - default: - if opt.resolver == nil { - ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host) - } else { - ip, err = resolver.ResolveIPv6WithResolver(ctx, host, opt.resolver) - } - } - if err != nil { - err = fmt.Errorf("dns resolve failed:%w", err) - return nil, err - } - - return dialContext(ctx, network, ip, port, opt) + return serialDialContext(ctx, network, ips, port, opt) } -func dualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) +func serialDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { + ips, port, err := parseAddr(ctx, network, address, opt.resolver) + if err != nil { + return nil, err + } + return dualStackDial( + func() (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, + func() (net.Conn, error) { return serialDialContext(ctx, network, ips, port, opt) }, + opt.prefer == 4) +} + +func concurrentSingleStackDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { + ips, port, err := parseAddr(ctx, network, address, opt.resolver) if err != nil { return nil, err } - returned := make(chan struct{}) - defer close(returned) - - type dialResult struct { - net.Conn - error - resolved bool - ipv6 bool - done bool - } - results := make(chan dialResult) - var primary, fallback dialResult - - startRacer := func(ctx context.Context, network, host string, r resolver.Resolver, ipv6 bool) { - result := dialResult{ipv6: ipv6, done: true} - defer func() { - select { - case results <- result: - case <-returned: - if result.Conn != nil { - _ = result.Conn.Close() - } - } - }() - - var ip netip.Addr - if ipv6 { - if r == nil { - ip, result.error = resolver.ResolveIPv6ProxyServerHost(ctx, host) - } else { - ip, result.error = resolver.ResolveIPv6WithResolver(ctx, host, r) - } - } else { - if r == nil { - ip, result.error = resolver.ResolveIPv4ProxyServerHost(ctx, host) - } else { - ip, result.error = resolver.ResolveIPv4WithResolver(ctx, host, r) - } - } - if result.error != nil { - result.error = fmt.Errorf("dns resolve failed:%w", result.error) - return - } - result.resolved = true - - result.Conn, result.error = dialContext(ctx, network, ip, port, opt) - } - - go startRacer(ctx, network+"4", host, opt.resolver, false) - go startRacer(ctx, network+"6", host, opt.resolver, true) - - count := 2 - for i := 0; i < count; i++ { - select { - case res := <-results: - if res.error == nil { - return res.Conn, nil - } - - if !res.ipv6 { - primary = res - } else { - fallback = res - } - - if primary.done && fallback.done { - if primary.resolved { - return nil, primary.error - } else if fallback.resolved { - return nil, fallback.error - } else { - return nil, primary.error - } - } - case <-ctx.Done(): - err = ctx.Err() - break - } - } - - if err == nil { - err = fmt.Errorf("dual stack dial failed") + if conn, err := parallelDialContext(ctx, network, ips, port, opt); err != nil { + return nil, err } else { - err = fmt.Errorf("dual stack dial failed:%w", err) + return conn, nil } - return nil, err -} - -func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { - returned := make(chan struct{}) - defer close(returned) - - type dialResult struct { - ip netip.Addr - net.Conn - error - isPrimary bool - done bool - } - - preferCount := atomic.NewInt32(0) - results := make(chan dialResult) - tcpRacer := func(ctx context.Context, ip netip.Addr) { - result := dialResult{ip: ip, done: true} - - defer func() { - select { - case results <- result: - case <-returned: - if result.Conn != nil { - _ = result.Conn.Close() - } - } - }() - if strings.Contains(network, "tcp") { - network = "tcp" - } else { - network = "udp" - } - - if ip.Is6() { - network += "6" - if opt.prefer != 4 { - result.isPrimary = true - } - } - - if ip.Is4() { - network += "4" - if opt.prefer != 6 { - result.isPrimary = true - } - } - - if result.isPrimary { - preferCount.Add(1) - } - - result.Conn, result.error = dialContext(ctx, network, ip, port, opt) - } - - for _, ip := range ips { - go tcpRacer(ctx, ip) - } - - connCount := len(ips) - var fallback dialResult - var primaryError error - var finalError error - for i := 0; i < connCount; i++ { - select { - case res := <-results: - if res.error == nil { - if res.isPrimary { - return res.Conn, nil - } else { - if !fallback.done || fallback.error != nil { - fallback = res - } - } - } else { - if res.isPrimary { - primaryError = res.error - preferCount.Add(-1) - if preferCount.Load() == 0 && fallback.done && fallback.error == nil { - return fallback.Conn, nil - } - } - } - case <-ctx.Done(): - if fallback.done && fallback.error == nil { - return fallback.Conn, nil - } - finalError = ctx.Err() - break - } - } - - if fallback.done && fallback.error == nil { - return fallback.Conn, nil - } - - if primaryError != nil { - return nil, primaryError - } - - if fallback.error != nil { - return nil, fallback.error - } - - if finalError == nil { - finalError = fmt.Errorf("all ips %v tcp shake hands failed", ips) - } else { - finalError = fmt.Errorf("concurrent dial failed:%w", finalError) - } - - return nil, finalError -} - -func concurrentSingleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - return nil, err - } - - var ips []netip.Addr - switch network { - case "tcp4", "udp4": - if opt.resolver == nil { - ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host) - } else { - ips, err = resolver.LookupIPv4WithResolver(ctx, host, opt.resolver) - } - default: - if opt.resolver == nil { - ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host) - } else { - ips, err = resolver.LookupIPv6WithResolver(ctx, host, opt.resolver) - } - } - - if err != nil { - err = fmt.Errorf("dns resolve failed:%w", err) - return nil, err - } - - return concurrentDialContext(ctx, network, ips, port, opt) } func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) + ips, port, err := parseAddr(ctx, network, address, opt.resolver) if err != nil { return nil, err } - - var ips []netip.Addr - if opt.resolver != nil { - ips, err = resolver.LookupIPWithResolver(ctx, host, opt.resolver) - } else { - ips, err = resolver.LookupIPProxyServerHost(ctx, host) + if opt.prefer != 4 && opt.prefer != 6 { + return parallelDialContext(ctx, network, ips, port, opt) } - - if err != nil { - err = fmt.Errorf("dns resolve failed:%w", err) - return nil, err - } - - return concurrentDialContext(ctx, network, ips, port, opt) + ipv4s, ipv6s := sortationAddr(ips) + return dualStackDial( + func() (net.Conn, error) { return parallelDialContext(ctx, network, ipv4s, port, opt) }, + func() (net.Conn, error) { return parallelDialContext(ctx, network, ipv6s, port, opt) }, + opt.prefer == 4) } type Dialer struct { @@ -426,3 +187,155 @@ func NewDialer(options ...Option) Dialer { opt := applyOptions(options...) return Dialer{Opt: *opt} } + +func dualStackDial( + ipv4DialFn func() (net.Conn, error), + ipv6DialFn func() (net.Conn, error), + preferIPv4 bool) (net.Conn, error) { + results := make(chan dialResult) + returned := make(chan struct{}) + defer close(returned) + racer := func(dial func() (net.Conn, error), isPrimary bool) { + result := dialResult{isPrimary: isPrimary} + defer func() { + select { + case results <- result: + case <-returned: + if result.Conn != nil { + _ = result.Conn.Close() + } + } + }() + result.Conn, result.error = dial() + } + go racer(ipv4DialFn, preferIPv4) + go racer(ipv6DialFn, !preferIPv4) + var fallbackErr dialResult + var primaryErr dialResult + for res := range results { + if res.error == nil { + if res.isPrimary { + return res.Conn, nil + } + fallbackErr = res + } + if res.isPrimary { + primaryErr = res + } else { + fallbackErr = res + } + } + if primaryErr.error != nil { + return nil, primaryErr.error + } + return nil, fallbackErr.error +} + +func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { + results := make(chan dialResult) + returned := make(chan struct{}) + defer close(returned) + tcpRacer := func(ctx context.Context, ip netip.Addr, port string) { + result := dialResult{isPrimary: true} + defer func() { + select { + case results <- result: + case <-returned: + if result.Conn != nil { + _ = result.Conn.Close() + } + } + }() + result.ip = ip + result.Conn, result.error = dialContext(ctx, network, ip, port, opt) + } + + for _, ip := range ips { + go tcpRacer(ctx, ip, port) + } + var err error + for { + select { + case <-ctx.Done(): + if err != nil { + return nil, err + } + if ctx.Err() == context.DeadlineExceeded { + return nil, ErrorConnTimeout + } + return nil, ctx.Err() + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + err = res.error + } + } +} + +func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { + var ( + conn net.Conn + err error + errs []error + ) + for _, ip := range ips { + if conn, err = dialContext(ctx, network, ip, port, opt); err == nil { + return conn, nil + } else { + errs = append(errs, err) + } + } + return nil, errors.Join(errs...) +} + +type dialResult struct { + ip netip.Addr + net.Conn + error + isPrimary bool +} + +func parseAddr(ctx context.Context, network,address string, preferResolver resolver.Resolver) ([]netip.Addr, string, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, "-1", err + } + + var ips []netip.Addr + switch network { + case "tcp4", "udp4": + if preferResolver == nil { + ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host) + } else { + ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver) + } + case "tcp6", "udp6": + if preferResolver == nil { + ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host) + } else { + ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver) + } + default: + if preferResolver == nil { + ips, err = resolver.LookupIP(ctx, host) + } else { + ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver) + } + } + if err != nil { + return nil, "-1", fmt.Errorf("dns resolve failed: %w", err) + } + return ips, port, nil +} + +func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) { + for _, v := range ips { + if v.Is4() || v.Is4In6() { + ipv4s = append(ipv4s, v) + } else { + ipv6s = append(ipv6s, v) + } + } + return +}