mihomo/component/dialer/dialer.go

352 lines
9.0 KiB
Go
Raw Normal View History

2020-02-09 17:02:48 +08:00
package dialer
import (
"context"
"fmt"
2020-02-09 17:02:48 +08:00
"net"
2022-04-20 01:52:51 +08:00
"net/netip"
"os"
2022-08-28 13:41:19 +08:00
"strings"
2022-04-27 21:37:20 +08:00
"sync"
"time"
"github.com/Dreamacro/clash/component/resolver"
2020-02-09 17:02:48 +08:00
)
2023-03-06 23:23:05 +08:00
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error)
2022-04-27 21:37:20 +08:00
var (
dialMux sync.Mutex
actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext
tcpConcurrent = false
fallbackTimeout = 300 * time.Millisecond
2022-04-27 21:37:20 +08:00
)
2022-12-22 09:53:11 +08:00
func applyOptions(options ...Option) *option {
opt := &option{
interfaceName: DefaultInterface.Load(),
routingMark: int(DefaultRoutingMark.Load()),
}
for _, o := range DefaultOptions {
o(opt)
}
for _, o := range options {
o(opt)
}
2022-11-25 08:08:14 +08:00
return opt
}
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
2022-12-22 09:53:11 +08:00
opt := applyOptions(options...)
2022-11-25 08:08:14 +08:00
2022-08-28 13:41:19 +08:00
if opt.network == 4 || opt.network == 6 {
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
network = fmt.Sprintf("%s%d", network, opt.network)
}
2023-03-06 23:23:05 +08:00
ips, port, err := parseAddr(ctx, network, address, opt.resolver)
if err != nil {
return nil, err
}
switch network {
case "tcp4", "tcp6", "udp4", "udp6":
2023-03-06 23:23:05 +08:00
return actualSingleStackDialContext(ctx, network, ips, port, opt)
case "tcp", "udp":
2023-03-06 23:23:05 +08:00
return actualDualStackDialContext(ctx, network, ips, port, opt)
default:
2022-08-28 13:41:19 +08:00
return nil, ErrorInvalidedNetworkStack
}
2020-02-09 17:02:48 +08:00
}
func ListenPacket(ctx context.Context, network, address string, options ...Option) (net.PacketConn, error) {
cfg := applyOptions(options...)
lc := &net.ListenConfig{}
if cfg.interfaceName != "" {
addr, err := bindIfaceToListenConfig(cfg.interfaceName, lc, network, address)
if err != nil {
return nil, err
}
address = addr
}
if cfg.addrReuse {
addrReuseToListenConfig(lc)
}
2021-11-08 16:59:48 +08:00
if cfg.routingMark != 0 {
bindMarkToListenConfig(cfg.routingMark, lc, network, address)
}
return lc.ListenPacket(ctx, network, address)
2020-02-09 17:02:48 +08:00
}
2023-03-06 23:23:05 +08:00
func SetTcpConcurrent(concurrent bool) {
2022-04-27 21:37:20 +08:00
dialMux.Lock()
2023-03-06 23:23:05 +08:00
defer dialMux.Unlock()
tcpConcurrent = concurrent
2022-04-27 21:37:20 +08:00
if concurrent {
actualSingleStackDialContext = concurrentSingleStackDialContext
2022-04-27 21:37:20 +08:00
actualDualStackDialContext = concurrentDualStackDialContext
} else {
actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext
2022-04-27 21:37:20 +08:00
}
}
2023-03-06 23:23:05 +08:00
func GetTcpConcurrent() bool {
dialMux.Lock()
defer dialMux.Unlock()
return tcpConcurrent
}
2022-04-20 01:52:51 +08:00
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
2023-03-07 09:30:51 +08:00
address := net.JoinHostPort(destination.String(), port)
netDialer := opt.netDialer
switch netDialer.(type) {
case nil:
netDialer = &net.Dialer{}
case *net.Dialer:
netDialer = &*netDialer.(*net.Dialer) // make a copy
default:
return netDialer.DialContext(ctx, network, address)
}
dialer := netDialer.(*net.Dialer)
if opt.interfaceName != "" {
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
}
}
2021-11-08 16:59:48 +08:00
if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination)
}
2023-02-24 13:53:44 +08:00
if opt.tfo {
return dialTFO(ctx, *dialer, network, address)
}
return dialer.DialContext(ctx, network, address)
}
2023-03-06 23:23:05 +08:00
func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
return serialDialContext(ctx, network, ips, port, opt)
}
2023-03-06 23:23:05 +08:00
func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
}
2023-03-06 23:23:05 +08:00
func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
return parallelDialContext(ctx, network, ips, port, opt)
}
2023-03-06 23:23:05 +08:00
func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt)
}
2023-03-06 23:23:05 +08:00
return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
}
2023-03-06 23:23:05 +08:00
func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
ipv4s, ipv6s := sortationAddr(ips)
preferIPVersion := opt.prefer
fallbackTicker := time.NewTicker(fallbackTimeout)
defer fallbackTicker.Stop()
results := make(chan dialResult)
returned := make(chan struct{})
defer close(returned)
2023-03-06 23:23:05 +08:00
racer := func(ips []netip.Addr, isPrimary bool) {
result := dialResult{isPrimary: isPrimary}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil && result.error == nil {
2022-04-20 01:52:51 +08:00
_ = result.Conn.Close()
}
}
}()
2023-03-06 23:23:05 +08:00
result.Conn, result.error = dialFn(ctx, network, ips, port, opt)
}
2023-03-06 23:23:05 +08:00
go racer(ipv4s, preferIPVersion != 6)
go racer(ipv6s, preferIPVersion != 4)
var fallback dialResult
2023-03-10 14:12:18 +08:00
var errs []error
for {
select {
case <-ctx.Done():
if fallback.error == nil && fallback.Conn != nil {
return fallback.Conn, nil
}
2023-03-10 14:12:18 +08:00
if res, ok := <-results; ok && res.error == nil {
return res.Conn, nil
}
return nil, errorsJoin(errs...)
case <-fallbackTicker.C:
if fallback.error == nil && fallback.Conn != nil {
return fallback.Conn, nil
}
case res := <-results:
if res.error == nil {
if res.isPrimary {
return res.Conn, nil
}
fallback = res
2023-02-26 21:01:44 +08:00
} else {
2023-03-10 14:12:18 +08:00
if res.isPrimary {
errs = append([]error{fmt.Errorf("connect failed: %w", res.error)}, errs...)
} else {
errs = append(errs, fmt.Errorf("connect failed: %w", res.error))
}
}
}
}
}
func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
2023-02-26 21:01:44 +08:00
if len(ips) == 0 {
2023-02-26 22:20:25 +08:00
return nil, ErrorNoIpAddress
2023-02-26 21:01:44 +08:00
}
results := make(chan dialResult)
returned := make(chan struct{})
defer close(returned)
2023-03-06 23:23:05 +08:00
racer := func(ctx context.Context, ip netip.Addr) {
result := dialResult{isPrimary: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil && result.error == nil {
2022-08-28 13:41:19 +08:00
_ = result.Conn.Close()
}
}
}()
result.ip = ip
2022-08-28 13:41:19 +08:00
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
}
for _, ip := range ips {
2023-03-06 23:23:05 +08:00
go racer(ctx, ip)
}
2023-03-10 14:12:18 +08:00
var errs []error
for {
select {
case <-ctx.Done():
2023-03-10 14:12:18 +08:00
if len(errs) > 0 {
return nil, errorsJoin(errs...)
}
if ctx.Err() == context.DeadlineExceeded {
return nil, os.ErrDeadlineExceeded
}
return nil, ctx.Err()
case res := <-results:
if res.error == nil {
return res.Conn, nil
2022-08-28 13:41:19 +08:00
}
2023-03-10 14:12:18 +08:00
errs = append(errs, res.error)
}
}
}
2022-04-27 21:37:20 +08:00
func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
2023-02-26 21:01:44 +08:00
if len(ips) == 0 {
2023-02-26 22:20:25 +08:00
return nil, ErrorNoIpAddress
2023-02-26 21:01:44 +08:00
}
2023-03-06 23:23:05 +08:00
var errs []error
for _, ip := range ips {
2023-03-06 23:23:05 +08:00
if conn, err := dialContext(ctx, network, ip, port, opt); err == nil {
return conn, nil
} else {
errs = append(errs, err)
}
2022-11-19 10:50:13 +08:00
}
2023-02-26 22:20:25 +08:00
return nil, errorsJoin(errs...)
}
2022-11-19 10:50:13 +08:00
type dialResult struct {
ip netip.Addr
net.Conn
error
isPrimary bool
2022-04-27 21:37:20 +08:00
}
func parseAddr(ctx context.Context, network, address string, preferResolver resolver.Resolver) ([]netip.Addr, string, error) {
2022-04-27 21:37:20 +08:00
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, "-1", err
2022-04-27 21:37:20 +08:00
}
var ips []netip.Addr
2022-04-27 21:37:20 +08:00
switch network {
case "tcp4", "udp4":
if preferResolver == nil {
ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host)
2022-04-27 21:37:20 +08:00
} else {
ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver)
2022-04-27 21:37:20 +08:00
}
case "tcp6", "udp6":
if preferResolver == nil {
ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host)
2022-04-27 21:37:20 +08:00
} else {
ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver)
}
default:
if preferResolver == nil {
2023-02-26 13:52:10 +08:00
ips, err = resolver.LookupIPProxyServerHost(ctx, host)
} else {
ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver)
2022-04-27 21:37:20 +08:00
}
}
2022-08-28 13:41:19 +08:00
if err != nil {
return nil, "-1", fmt.Errorf("dns resolve failed: %w", err)
2022-08-28 13:41:19 +08:00
}
for i, ip := range ips {
if ip.Is4In6() {
ips[i] = ip.Unmap()
}
}
return ips, port, nil
2022-08-28 13:41:19 +08:00
}
func sortationAddr(ips []netip.Addr) (ipv4s, ipv6s []netip.Addr) {
for _, v := range ips {
if v.Is4() { // 4in6 parse was in parseAddr
ipv4s = append(ipv4s, v)
} else {
ipv6s = append(ipv6s, v)
}
2022-04-27 21:37:20 +08:00
}
return
}
2022-12-19 21:34:07 +08:00
2022-12-22 09:53:11 +08:00
type Dialer struct {
2023-03-07 09:30:51 +08:00
opt option
2022-12-19 21:34:07 +08:00
}
2022-12-22 09:53:11 +08:00
func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2023-03-07 09:30:51 +08:00
return DialContext(ctx, network, address, WithOption(d.opt))
2022-12-19 21:34:07 +08:00
}
2022-12-22 09:53:11 +08:00
func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
2023-03-07 09:30:51 +08:00
opt := WithOption(d.opt)
if rAddrPort.Addr().Unmap().IsLoopback() {
// avoid "The requested address is not valid in its context."
opt = WithInterface("")
}
return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, opt)
2022-12-20 00:11:02 +08:00
}
2022-12-22 09:53:11 +08:00
func NewDialer(options ...Option) Dialer {
opt := applyOptions(options...)
2023-03-07 09:30:51 +08:00
return Dialer{opt: *opt}
2022-12-19 21:34:07 +08:00
}