diff --git a/adapter/outbound/reject.go b/adapter/outbound/reject.go index 43833238..d5a9c823 100644 --- a/adapter/outbound/reject.go +++ b/adapter/outbound/reject.go @@ -53,6 +53,9 @@ func (rw *nopConn) Read(b []byte) (int, error) { } func (rw *nopConn) Write(b []byte) (int, error) { + if len(b) == 0 { + return 0, nil + } return 0, io.EOF } diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index ae404eec..2ac1f234 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -103,9 +103,9 @@ func (ss *ShadowSocks) streamConn(c net.Conn, metadata *C.Metadata) (net.Conn, e } } if metadata.NetWork == C.UDP && ss.option.UDPOverTCP { - return ss.method.DialConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443")) + return ss.method.DialEarlyConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443")), nil } - return ss.method.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) + return ss.method.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil } // DialContext implements C.ProxyAdapter diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 5da8c8b1..e8220767 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -213,12 +213,12 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { } if metadata.NetWork == C.UDP { if v.option.XUDP { - return v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) + return v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil } else { - return v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) + return v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil } } else { - return v.client.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) + return v.client.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil } } @@ -289,9 +289,9 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o }(c) if v.option.XUDP { - c, err = v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) + c = v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) } else { - c, err = v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) + c = v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) } if err != nil { diff --git a/adapter/outboundgroup/fallback.go b/adapter/outboundgroup/fallback.go index 34365d0e..066e8a37 100644 --- a/adapter/outboundgroup/fallback.go +++ b/adapter/outboundgroup/fallback.go @@ -7,6 +7,7 @@ import ( "time" "github.com/Dreamacro/clash/adapter/outbound" + "github.com/Dreamacro/clash/common/callback" "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/constant/provider" @@ -30,11 +31,21 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts . c, err := proxy.DialContext(ctx, metadata, f.Base.DialOptions(opts...)...) if err == nil { c.AppendToChains(f) - f.onDialSuccess() } else { f.onDialFailed(proxy.Type(), err) } + c = &callback.FirstWriteCallBackConn{ + Conn: c, + Callback: func(err error) { + if err == nil { + f.onDialSuccess() + } else { + f.onDialFailed(proxy.Type(), err) + } + }, + } + return c, err } diff --git a/adapter/outboundgroup/loadbalance.go b/adapter/outboundgroup/loadbalance.go index 48bd4994..9a010cf9 100644 --- a/adapter/outboundgroup/loadbalance.go +++ b/adapter/outboundgroup/loadbalance.go @@ -10,6 +10,7 @@ import ( "github.com/Dreamacro/clash/adapter/outbound" "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/common/callback" "github.com/Dreamacro/clash/common/murmur3" "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" @@ -83,17 +84,24 @@ func jumpHash(key uint64, buckets int32) int32 { // DialContext implements C.ProxyAdapter func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) { proxy := lb.Unwrap(metadata, true) - - defer func() { - if err == nil { - c.AppendToChains(lb) - lb.onDialSuccess() - } else { - lb.onDialFailed(proxy.Type(), err) - } - }() - c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...) + + if err == nil { + c.AppendToChains(lb) + } else { + lb.onDialFailed(proxy.Type(), err) + } + + c = &callback.FirstWriteCallBackConn{ + Conn: c, + Callback: func(err error) { + if err == nil { + lb.onDialSuccess() + } else { + lb.onDialFailed(proxy.Type(), err) + } + }, + } return } diff --git a/adapter/outboundgroup/urltest.go b/adapter/outboundgroup/urltest.go index 27cef9c6..31eaf4a4 100644 --- a/adapter/outboundgroup/urltest.go +++ b/adapter/outboundgroup/urltest.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Dreamacro/clash/adapter/outbound" + "github.com/Dreamacro/clash/common/callback" "github.com/Dreamacro/clash/common/singledo" "github.com/Dreamacro/clash/component/dialer" C "github.com/Dreamacro/clash/constant" @@ -38,10 +39,20 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts .. c, err = proxy.DialContext(ctx, metadata, u.Base.DialOptions(opts...)...) if err == nil { c.AppendToChains(u) - u.onDialSuccess() } else { u.onDialFailed(proxy.Type(), err) } + + c = &callback.FirstWriteCallBackConn{ + Conn: c, + Callback: func(err error) { + if err == nil { + u.onDialSuccess() + } else { + u.onDialFailed(proxy.Type(), err) + } + }, + } return c, err } diff --git a/common/callback/callback.go b/common/callback/callback.go new file mode 100644 index 00000000..a0f1e717 --- /dev/null +++ b/common/callback/callback.go @@ -0,0 +1,25 @@ +package callback + +import ( + C "github.com/Dreamacro/clash/constant" +) + +type FirstWriteCallBackConn struct { + C.Conn + Callback func(error) + written bool +} + +func (c *FirstWriteCallBackConn) Write(b []byte) (n int, err error) { + defer func() { + if !c.written { + c.written = true + c.Callback(err) + } + }() + return c.Conn.Write(b) +} + +func (c *FirstWriteCallBackConn) Upstream() any { + return c.Conn +} diff --git a/common/net/bufconn.go b/common/net/bufconn.go index ba0ca026..54326cf9 100644 --- a/common/net/bufconn.go +++ b/common/net/bufconn.go @@ -12,13 +12,14 @@ var _ ExtendedConn = (*BufferedConn)(nil) type BufferedConn struct { r *bufio.Reader ExtendedConn + peeked bool } func NewBufferedConn(c net.Conn) *BufferedConn { if bc, ok := c.(*BufferedConn); ok { return bc } - return &BufferedConn{bufio.NewReader(c), NewExtendedConn(c)} + return &BufferedConn{bufio.NewReader(c), NewExtendedConn(c), false} } // Reader returns the internal bufio.Reader. @@ -26,11 +27,20 @@ func (c *BufferedConn) Reader() *bufio.Reader { return c.r } +func (c *BufferedConn) Peeked() bool { + return c.peeked +} + // Peek returns the next n bytes without advancing the reader. func (c *BufferedConn) Peek(n int) ([]byte, error) { + c.peeked = true return c.r.Peek(n) } +func (c *BufferedConn) Discard(n int) (discarded int, err error) { + return c.r.Discard(n) +} + func (c *BufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) } diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index f4511b97..97d448ce 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -36,12 +36,7 @@ type SnifferDispatcher struct { parsePureIp bool } -func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) { - bufConn, ok := conn.(*N.BufferedConn) - if !ok { - return - } - +func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) { if (metadata.Host == "" && sd.parsePureIp) || sd.forceDomain.Search(metadata.Host) != nil || (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping) { port, err := strconv.ParseUint(metadata.DstPort, 10, 16) if err != nil { @@ -74,7 +69,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) { } sd.rwMux.RUnlock() - if host, err := sd.sniffDomain(bufConn, metadata); err != nil { + if host, err := sd.sniffDomain(conn, metadata); err != nil { sd.cacheSniffFailed(metadata) log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort) return diff --git a/constant/context.go b/constant/context.go index e641ed14..da1e4155 100644 --- a/constant/context.go +++ b/constant/context.go @@ -3,6 +3,8 @@ package constant import ( "net" + N "github.com/Dreamacro/clash/common/net" + "github.com/gofrs/uuid" ) @@ -13,7 +15,7 @@ type PlainContext interface { type ConnContext interface { PlainContext Metadata() *Metadata - Conn() net.Conn + Conn() *N.BufferedConn } type PacketConnContext interface { diff --git a/context/conn.go b/context/conn.go index 08bbe3c7..b695ac4d 100644 --- a/context/conn.go +++ b/context/conn.go @@ -12,7 +12,7 @@ import ( type ConnContext struct { id uuid.UUID metadata *C.Metadata - conn net.Conn + conn *N.BufferedConn } func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext { @@ -36,6 +36,6 @@ func (c *ConnContext) Metadata() *C.Metadata { } // Conn implement C.ConnContext Conn -func (c *ConnContext) Conn() net.Conn { +func (c *ConnContext) Conn() *N.BufferedConn { return c.conn } diff --git a/transport/vless/conn.go b/transport/vless/conn.go index aceda463..e063d465 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -7,7 +7,6 @@ import ( "io" "net" "sync" - "time" "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" @@ -208,12 +207,12 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { } } - go func() { - select { - case <-c.handshake: - case <-time.After(200 * time.Millisecond): - c.sendRequest(nil) - } - }() + //go func() { + // select { + // case <-c.handshake: + // case <-time.After(200 * time.Millisecond): + // c.sendRequest(nil) + // } + //}() return c, nil } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 695f2945..b9d0e594 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -366,8 +366,20 @@ func handleTCPConn(connCtx C.ConnContext) { return } + conn := connCtx.Conn() if sniffer.Dispatcher.Enable() && sniffingEnable { - sniffer.Dispatcher.TCPSniff(connCtx.Conn(), metadata) + sniffer.Dispatcher.TCPSniff(conn, metadata) + } + + peekMutex := sync.Mutex{} + if !conn.Peeked() { + peekMutex.Lock() + go func() { + defer peekMutex.Unlock() + _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + _, _ = conn.Peek(1) + _ = conn.SetReadDeadline(time.Time{}) + }() } proxy, rule, err := resolveMetadata(connCtx, metadata) @@ -387,10 +399,26 @@ func handleTCPConn(connCtx C.ConnContext) { } } + var peekBytes []byte + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) defer cancel() remoteConn, err := retry(ctx, func(ctx context.Context) (C.Conn, error) { - return proxy.DialContext(ctx, dialMetadata) + remoteConn, err := proxy.DialContext(ctx, dialMetadata) + if err != nil { + return nil, err + } + peekMutex.Lock() + defer peekMutex.Unlock() + peekBytes, _ = conn.Peek(conn.Buffered()) + _, err = remoteConn.Write(peekBytes) + if err != nil { + return nil, err + } + if peekLen := len(peekBytes); peekLen > 0 { + _, _ = conn.Discard(peekLen) + } + return remoteConn, err }, func(err error) { if rule == nil { log.Warnln(