From 7c4a359a2b611175542004fb46c81a1489e75dc6 Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Sat, 12 Oct 2019 23:55:39 +0800 Subject: [PATCH] Fix: dial tcp with context to avoid margin of error --- adapters/outbound/base.go | 10 ++++++++-- adapters/outbound/direct.go | 5 +++-- adapters/outbound/fallback.go | 4 ++-- adapters/outbound/http.go | 5 +++-- adapters/outbound/loadbalance.go | 6 +++--- adapters/outbound/reject.go | 3 ++- adapters/outbound/selector.go | 5 +++-- adapters/outbound/shadowsocks.go | 5 +++-- adapters/outbound/snell.go | 5 +++-- adapters/outbound/socks5.go | 9 ++++++--- adapters/outbound/urltest.go | 4 ++-- adapters/outbound/util.go | 4 +--- adapters/outbound/vmess.go | 9 ++++++--- constant/adapters.go | 3 ++- 14 files changed, 47 insertions(+), 30 deletions(-) diff --git a/adapters/outbound/base.go b/adapters/outbound/base.go index b560fc71..505489b4 100644 --- a/adapters/outbound/base.go +++ b/adapters/outbound/base.go @@ -91,7 +91,13 @@ func (p *Proxy) Alive() bool { } func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { - conn, err := p.ProxyAdapter.Dial(metadata) + ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) + defer cancel() + return p.DialContext(ctx, metadata) +} + +func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + conn, err := p.ProxyAdapter.DialContext(ctx, metadata) if err != nil { p.alive = false } @@ -157,7 +163,7 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { } start := time.Now() - instance, err := p.Dial(&addr) + instance, err := p.DialContext(ctx, &addr) if err != nil { return } diff --git a/adapters/outbound/direct.go b/adapters/outbound/direct.go index 2b0a2a47..22a4171c 100644 --- a/adapters/outbound/direct.go +++ b/adapters/outbound/direct.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "net" C "github.com/Dreamacro/clash/constant" @@ -10,13 +11,13 @@ type Direct struct { *Base } -func (d *Direct) Dial(metadata *C.Metadata) (C.Conn, error) { +func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { address := net.JoinHostPort(metadata.Host, metadata.DstPort) if metadata.DstIP != nil { address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort) } - c, err := dialTimeout("tcp", address, tcpTimeout) + c, err := dialContext(ctx, "tcp", address) if err != nil { return nil, err } diff --git a/adapters/outbound/fallback.go b/adapters/outbound/fallback.go index 9b44edeb..3e43e63c 100644 --- a/adapters/outbound/fallback.go +++ b/adapters/outbound/fallback.go @@ -31,9 +31,9 @@ func (f *Fallback) Now() string { return proxy.Name() } -func (f *Fallback) Dial(metadata *C.Metadata) (C.Conn, error) { +func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { proxy := f.findAliveProxy() - c, err := proxy.Dial(metadata) + c, err := proxy.DialContext(ctx, metadata) if err == nil { c.AppendToChains(f) } diff --git a/adapters/outbound/http.go b/adapters/outbound/http.go index 357ed5df..77b835b6 100644 --- a/adapters/outbound/http.go +++ b/adapters/outbound/http.go @@ -3,6 +3,7 @@ package adapters import ( "bufio" "bytes" + "context" "crypto/tls" "encoding/base64" "errors" @@ -35,8 +36,8 @@ type HttpOption struct { SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` } -func (h *Http) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", h.addr, tcpTimeout) +func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", h.addr) if err == nil && h.tls { cc := tls.Client(c, h.tlsConfig) err = cc.Handshake() diff --git a/adapters/outbound/loadbalance.go b/adapters/outbound/loadbalance.go index c719e8b8..d27beece 100644 --- a/adapters/outbound/loadbalance.go +++ b/adapters/outbound/loadbalance.go @@ -54,7 +54,7 @@ func jumpHash(key uint64, buckets int32) int32 { return int32(b) } -func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) { +func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) { defer func() { if err == nil { c.AppendToChains(lb) @@ -67,11 +67,11 @@ func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) { idx := jumpHash(key, buckets) proxy := lb.proxies[idx] if proxy.Alive() { - c, err = proxy.Dial(metadata) + c, err = proxy.DialContext(ctx, metadata) return } } - c, err = lb.proxies[0].Dial(metadata) + c, err = lb.proxies[0].DialContext(ctx, metadata) return } diff --git a/adapters/outbound/reject.go b/adapters/outbound/reject.go index de395d58..65ab1192 100644 --- a/adapters/outbound/reject.go +++ b/adapters/outbound/reject.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "io" "net" "time" @@ -12,7 +13,7 @@ type Reject struct { *Base } -func (r *Reject) Dial(metadata *C.Metadata) (C.Conn, error) { +func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { return newConn(&NopConn{}, r), nil } diff --git a/adapters/outbound/selector.go b/adapters/outbound/selector.go index 31d5a0a5..b7ed661c 100644 --- a/adapters/outbound/selector.go +++ b/adapters/outbound/selector.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "encoding/json" "errors" "net" @@ -20,8 +21,8 @@ type SelectorOption struct { Proxies []string `proxy:"proxies"` } -func (s *Selector) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := s.selected.Dial(metadata) +func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := s.selected.DialContext(ctx, metadata) if err == nil { c.AppendToChains(s) } diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index c46f8fc9..22d160ba 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -57,8 +58,8 @@ type v2rayObfsOption struct { Mux bool `obfs:"mux,omitempty"` } -func (ss *ShadowSocks) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", ss.server, tcpTimeout) +func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", ss.server) if err != nil { return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error()) } diff --git a/adapters/outbound/snell.go b/adapters/outbound/snell.go index b4131199..6b95aace 100644 --- a/adapters/outbound/snell.go +++ b/adapters/outbound/snell.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "fmt" "net" "strconv" @@ -26,8 +27,8 @@ type SnellOption struct { ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"` } -func (s *Snell) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", s.server, tcpTimeout) +func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", s.server) if err != nil { return nil, fmt.Errorf("%s connect error: %s", s.server, err.Error()) } diff --git a/adapters/outbound/socks5.go b/adapters/outbound/socks5.go index d99daa93..9355cff1 100644 --- a/adapters/outbound/socks5.go +++ b/adapters/outbound/socks5.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "crypto/tls" "fmt" "io" @@ -33,8 +34,8 @@ type Socks5Option struct { SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` } -func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", ss.addr, tcpTimeout) +func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", ss.addr) if err == nil && ss.tls { cc := tls.Client(c, ss.tlsConfig) @@ -60,7 +61,9 @@ func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) { } func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err error) { - c, err := dialTimeout("tcp", ss.addr, tcpTimeout) + ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) + defer cancel() + c, err := dialContext(ctx, "tcp", ss.addr) if err != nil { err = fmt.Errorf("%s connect error", ss.addr) return diff --git a/adapters/outbound/urltest.go b/adapters/outbound/urltest.go index 60b4beda..2bdb872d 100644 --- a/adapters/outbound/urltest.go +++ b/adapters/outbound/urltest.go @@ -33,9 +33,9 @@ func (u *URLTest) Now() string { return u.fast.Name() } -func (u *URLTest) Dial(metadata *C.Metadata) (c C.Conn, err error) { +func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) { for i := 0; i < 3; i++ { - c, err = u.fast.Dial(metadata) + c, err = u.fast.DialContext(ctx, metadata) if err == nil { c.AppendToChains(u) return diff --git a/adapters/outbound/util.go b/adapters/outbound/util.go index 46c4581a..22b2d953 100644 --- a/adapters/outbound/util.go +++ b/adapters/outbound/util.go @@ -86,15 +86,13 @@ func serializesSocksAddr(metadata *C.Metadata) []byte { return bytes.Join(buf, nil) } -func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { +func dialContext(ctx context.Context, network, address string) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err } dialer := net.Dialer{} - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() returned := make(chan struct{}) defer close(returned) diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index 5b5337a0..d61172e5 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "fmt" "net" "strconv" @@ -31,8 +32,8 @@ type VmessOption struct { SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` } -func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", v.server, tcpTimeout) +func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", v.server) if err != nil { return nil, fmt.Errorf("%s connect error", v.server) } @@ -42,7 +43,9 @@ func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) { } func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { - c, err := dialTimeout("tcp", v.server, tcpTimeout) + ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) + defer cancel() + c, err := dialContext(ctx, "tcp", v.server) if err != nil { return nil, nil, fmt.Errorf("%s connect error", v.server) } diff --git a/constant/adapters.go b/constant/adapters.go index 2e155ac8..97d65a50 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -58,7 +58,7 @@ type PacketConn interface { type ProxyAdapter interface { Name() string Type() AdapterType - Dial(metadata *Metadata) (Conn, error) + DialContext(ctx context.Context, metadata *Metadata) (Conn, error) DialUDP(metadata *Metadata) (PacketConn, net.Addr, error) SupportUDP() bool Destroy() @@ -74,6 +74,7 @@ type Proxy interface { ProxyAdapter Alive() bool DelayHistory() []DelayHistory + Dial(metadata *Metadata) (Conn, error) LastDelay() uint16 URLTest(ctx context.Context, url string) (uint16, error) }