mihomo/transport/tuic/v4/client.go

484 lines
11 KiB
Go
Raw Normal View History

2023-06-12 17:44:22 +08:00
package v4
2022-11-25 08:08:14 +08:00
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"net"
2022-11-26 23:53:59 +08:00
"runtime"
2022-11-25 08:08:14 +08:00
"sync"
2022-11-25 18:32:30 +08:00
"sync/atomic"
2022-11-25 08:08:14 +08:00
"time"
2023-05-10 09:36:06 +08:00
"unsafe"
2022-11-25 08:08:14 +08:00
2023-11-03 21:01:45 +08:00
atomic2 "github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/buf"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/pool"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/tuic/common"
"github.com/metacubex/quic-go"
2023-11-23 10:24:01 +08:00
"github.com/puzpuzpuz/xsync/v3"
"github.com/zhangyunhao116/fastrand"
2022-11-25 08:08:14 +08:00
)
type ClientOption struct {
TlsConfig *tls.Config
QuicConfig *quic.Config
Token [32]byte
UdpRelayMode common.UdpRelayMode
CongestionController string
ReduceRtt bool
2022-11-28 17:09:25 +08:00
RequestTimeout time.Duration
MaxUdpRelayPacketSize int
2022-11-26 21:14:56 +08:00
FastOpen bool
MaxOpenStreams int64
2023-06-18 00:47:26 +08:00
CWND int
2022-11-26 23:53:59 +08:00
}
type clientImpl struct {
2022-11-26 23:53:59 +08:00
*ClientOption
udp bool
2022-11-25 08:08:14 +08:00
quicConn quic.Connection
connMutex sync.Mutex
openStreams atomic.Int64
2022-11-28 17:09:25 +08:00
closed atomic.Bool
2022-11-25 18:32:30 +08:00
udpInputMap *xsync.MapOf[uint32, net.Conn]
2022-11-26 23:53:59 +08:00
// only ready for PoolClient
2022-12-22 09:53:11 +08:00
dialerRef C.Dialer
2023-06-12 17:44:22 +08:00
lastVisited atomic2.TypedValue[time.Time]
}
func (t *clientImpl) OpenStreams() int64 {
return t.openStreams.Load()
}
func (t *clientImpl) DialerRef() C.Dialer {
return t.dialerRef
}
func (t *clientImpl) LastVisited() time.Time {
return t.lastVisited.Load()
2022-11-25 08:08:14 +08:00
}
2023-06-12 17:44:22 +08:00
func (t *clientImpl) SetLastVisited(last time.Time) {
t.lastVisited.Store(last)
}
func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (quic.Connection, error) {
2022-11-25 08:08:14 +08:00
t.connMutex.Lock()
defer t.connMutex.Unlock()
if t.quicConn != nil {
return t.quicConn, nil
}
2023-06-03 16:45:35 +08:00
transport, addr, err := dialFn(ctx, dialer)
2022-11-25 08:08:14 +08:00
if err != nil {
return nil, err
}
var quicConn quic.Connection
if t.ReduceRtt {
2023-06-03 16:45:35 +08:00
quicConn, err = transport.DialEarly(ctx, addr, t.TlsConfig, t.QuicConfig)
2022-11-25 08:08:14 +08:00
} else {
2023-06-03 16:45:35 +08:00
quicConn, err = transport.Dial(ctx, addr, t.TlsConfig, t.QuicConfig)
2022-11-25 08:08:14 +08:00
}
if err != nil {
return nil, err
}
2023-06-18 00:47:26 +08:00
common.SetCongestionController(quicConn, t.CongestionController, t.CWND)
2022-11-28 17:09:25 +08:00
go func() {
_ = t.sendAuthentication(quicConn)
}()
if t.udp {
go func() {
switch t.UdpRelayMode {
case common.QUIC:
_ = t.handleUniStream(quicConn)
default: // native
_ = t.handleMessage(quicConn)
}
2022-11-25 08:08:14 +08:00
}()
}
2022-11-28 17:09:25 +08:00
t.quicConn = quicConn
t.openStreams.Store(0)
return quicConn, nil
}
func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) {
2022-11-28 17:09:25 +08:00
defer func() {
t.deferQuicConn(quicConn, err)
}()
stream, err := quicConn.OpenUniStream()
if err != nil {
return err
}
buf := pool.GetBuffer()
defer pool.PutBuffer(buf)
err = NewAuthenticate(t.Token).WriteTo(buf)
if err != nil {
return err
}
_, err = buf.WriteTo(stream)
if err != nil {
return err
}
err = stream.Close()
if err != nil {
return
}
return nil
}
func (t *clientImpl) handleUniStream(quicConn quic.Connection) (err error) {
2022-11-28 17:09:25 +08:00
defer func() {
t.deferQuicConn(quicConn, err)
}()
for {
var stream quic.ReceiveStream
stream, err = quicConn.AcceptUniStream(context.Background())
if err != nil {
return err
}
go func() (err error) {
var assocId uint32
defer func() {
t.deferQuicConn(quicConn, err)
if err != nil && assocId != 0 {
if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok {
if conn, ok := val.(net.Conn); ok {
_ = conn.Close()
2022-11-25 08:08:14 +08:00
}
}
}
stream.CancelRead(0)
}()
reader := bufio.NewReader(stream)
commandHead, err := ReadCommandHead(reader)
if err != nil {
return
}
switch commandHead.TYPE {
case PacketType:
var packet Packet
packet, err = ReadPacketWithHead(commandHead, reader)
2022-11-25 08:08:14 +08:00
if err != nil {
2022-11-28 17:09:25 +08:00
return
2022-11-25 08:08:14 +08:00
}
if t.udp && t.UdpRelayMode == common.QUIC {
assocId = packet.ASSOC_ID
if val, ok := t.udpInputMap.Load(assocId); ok {
if conn, ok := val.(net.Conn); ok {
writer := bufio.NewWriterSize(conn, packet.BytesLen())
_ = packet.WriteTo(writer)
_ = writer.Flush()
}
2022-11-25 08:08:14 +08:00
}
2022-11-28 17:09:25 +08:00
}
}
return
}()
}
}
func (t *clientImpl) handleMessage(quicConn quic.Connection) (err error) {
defer func() {
t.deferQuicConn(quicConn, err)
}()
for {
var message []byte
message, err = quicConn.ReceiveMessage(context.Background())
if err != nil {
return err
}
go func() (err error) {
var assocId uint32
defer func() {
t.deferQuicConn(quicConn, err)
if err != nil && assocId != 0 {
if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok {
if conn, ok := val.(net.Conn); ok {
_ = conn.Close()
2022-11-25 08:08:14 +08:00
}
}
}
}()
reader := bytes.NewBuffer(message)
commandHead, err := ReadCommandHead(reader)
if err != nil {
return
}
switch commandHead.TYPE {
case PacketType:
var packet Packet
packet, err = ReadPacketWithHead(commandHead, reader)
2022-11-28 17:09:25 +08:00
if err != nil {
return
}
if t.udp && t.UdpRelayMode == common.NATIVE {
assocId = packet.ASSOC_ID
if val, ok := t.udpInputMap.Load(assocId); ok {
if conn, ok := val.(net.Conn); ok {
_, _ = conn.Write(message)
}
2022-11-28 17:09:25 +08:00
}
}
}
return
}()
2022-11-25 17:15:45 +08:00
}
2022-11-25 08:08:14 +08:00
}
func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) {
2022-11-25 08:08:14 +08:00
var netError net.Error
if err != nil && errors.As(err, &netError) {
t.forceClose(quicConn, err)
2022-11-25 08:08:14 +08:00
}
}
func (t *clientImpl) forceClose(quicConn quic.Connection, err error) {
t.connMutex.Lock()
defer t.connMutex.Unlock()
if quicConn == nil {
quicConn = t.quicConn
2022-11-28 17:09:25 +08:00
}
2022-11-25 11:32:05 +08:00
if quicConn != nil {
if quicConn == t.quicConn {
t.quicConn = nil
}
2022-11-25 11:32:05 +08:00
}
errStr := ""
if err != nil {
errStr = err.Error()
}
if quicConn != nil {
_ = quicConn.CloseWithError(ProtocolError, errStr)
}
udpInputMap := t.udpInputMap
udpInputMap.Range(func(key uint32, value net.Conn) bool {
conn := value
_ = conn.Close()
udpInputMap.Delete(key)
return true
})
2022-11-25 11:32:05 +08:00
}
func (t *clientImpl) Close() {
2022-11-28 17:09:25 +08:00
t.closed.Store(true)
if t.openStreams.Load() == 0 {
2023-06-12 17:44:22 +08:00
t.forceClose(nil, common.ClientClosed)
2022-11-28 17:09:25 +08:00
}
}
2023-06-12 17:44:22 +08:00
func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) {
2022-12-22 09:53:11 +08:00
quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
2022-11-25 08:08:14 +08:00
if err != nil {
return nil, err
}
2022-11-26 23:53:59 +08:00
openStreams := t.openStreams.Add(1)
if openStreams >= t.MaxOpenStreams {
2022-11-26 23:53:59 +08:00
t.openStreams.Add(-1)
2023-06-12 17:44:22 +08:00
return nil, common.TooManyOpenStreams
2022-11-25 18:32:30 +08:00
}
2023-06-12 17:44:22 +08:00
stream, err := func() (stream net.Conn, err error) {
defer func() {
t.deferQuicConn(quicConn, err)
}()
buf := pool.GetBuffer()
defer pool.PutBuffer(buf)
err = NewConnect(NewAddress(metadata)).WriteTo(buf)
if err != nil {
return nil, err
}
2022-11-25 17:15:45 +08:00
quicStream, err := quicConn.OpenStream()
if err != nil {
return nil, err
}
2023-06-12 17:44:22 +08:00
stream = common.NewQuicStreamConn(
quicStream,
quicConn.LocalAddr(),
quicConn.RemoteAddr(),
func() {
2022-11-28 17:09:25 +08:00
time.AfterFunc(C.DefaultTCPTimeout, func() {
openStreams := t.openStreams.Add(-1)
if openStreams == 0 && t.closed.Load() {
2023-06-12 17:44:22 +08:00
t.forceClose(quicConn, common.ClientClosed)
2022-11-28 17:09:25 +08:00
}
})
},
2023-06-12 17:44:22 +08:00
)
_, err = buf.WriteTo(stream)
if err != nil {
_ = stream.Close()
return nil, err
}
return stream, err
2022-11-25 08:08:14 +08:00
}()
2022-11-25 12:10:33 +08:00
if err != nil {
return nil, err
}
2023-05-10 09:36:06 +08:00
bufConn := N.NewBufferedConn(stream)
conn := &earlyConn{ExtendedConn: bufConn, bufConn: bufConn, RequestTimeout: t.RequestTimeout}
2022-11-26 21:14:56 +08:00
if !t.FastOpen {
err = conn.Response()
if err != nil {
return nil, err
}
}
return conn, nil
}
type earlyConn struct {
2023-05-10 09:36:06 +08:00
N.ExtendedConn // only expose standard N.ExtendedConn function to outside
bufConn *N.BufferedConn
resOnce sync.Once
resErr error
2022-11-27 16:38:41 +08:00
2022-11-28 17:09:25 +08:00
RequestTimeout time.Duration
2022-11-26 21:14:56 +08:00
}
func (conn *earlyConn) response() error {
2022-11-27 16:38:41 +08:00
if conn.RequestTimeout > 0 {
2022-11-28 17:09:25 +08:00
_ = conn.SetReadDeadline(time.Now().Add(conn.RequestTimeout))
2022-11-27 16:38:41 +08:00
}
2023-05-10 09:36:06 +08:00
response, err := ReadResponse(conn.bufConn)
2022-11-25 08:08:14 +08:00
if err != nil {
2022-11-25 16:06:56 +08:00
_ = conn.Close()
2022-11-26 21:14:56 +08:00
return err
2022-11-25 08:08:14 +08:00
}
if response.IsFailed() {
2022-11-25 16:06:56 +08:00
_ = conn.Close()
2022-11-26 21:14:56 +08:00
return errors.New("connect failed")
2022-11-25 08:08:14 +08:00
}
2022-11-26 21:14:56 +08:00
_ = conn.SetReadDeadline(time.Time{})
return nil
}
func (conn *earlyConn) Response() error {
conn.resOnce.Do(func() {
conn.resErr = conn.response()
})
return conn.resErr
}
func (conn *earlyConn) Read(b []byte) (n int, err error) {
err = conn.Response()
if err != nil {
return 0, err
}
2023-05-10 09:36:06 +08:00
return conn.bufConn.Read(b)
2022-11-25 08:08:14 +08:00
}
2023-04-03 21:07:52 +08:00
func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) {
err = conn.Response()
if err != nil {
return err
}
2023-05-10 09:36:06 +08:00
return conn.bufConn.ReadBuffer(buffer)
}
func (conn *earlyConn) Upstream() any {
return conn.bufConn
}
func (conn *earlyConn) ReaderReplaceable() bool {
return atomic.LoadUint32((*uint32)(unsafe.Pointer(&conn.resOnce))) == 1 && conn.resErr == nil
}
func (conn *earlyConn) WriterReplaceable() bool {
return true
2023-04-03 21:07:52 +08:00
}
2023-06-12 17:44:22 +08:00
func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {
2022-12-22 09:53:11 +08:00
quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
2022-11-25 08:08:14 +08:00
if err != nil {
return nil, err
}
2022-11-26 23:53:59 +08:00
openStreams := t.openStreams.Add(1)
if openStreams >= t.MaxOpenStreams {
2022-11-26 23:53:59 +08:00
t.openStreams.Add(-1)
2023-06-12 17:44:22 +08:00
return nil, common.TooManyOpenStreams
2022-11-25 20:14:05 +08:00
}
2022-11-25 08:08:14 +08:00
pipe1, pipe2 := net.Pipe()
var connId uint32
for {
connId = fastrand.Uint32()
2022-11-25 08:08:14 +08:00
_, loaded := t.udpInputMap.LoadOrStore(connId, pipe1)
if !loaded {
break
}
}
pc := &quicStreamPacketConn{
2022-11-28 17:09:25 +08:00
connId: connId,
quicConn: quicConn,
inputConn: N.NewBufferedConn(pipe2),
udpRelayMode: t.UdpRelayMode,
maxUdpRelayPacketSize: t.MaxUdpRelayPacketSize,
deferQuicConnFn: t.deferQuicConn,
closeDeferFn: func() {
t.udpInputMap.Delete(connId)
time.AfterFunc(C.DefaultUDPTimeout, func() {
openStreams := t.openStreams.Add(-1)
if openStreams == 0 && t.closed.Load() {
2023-06-12 17:44:22 +08:00
t.forceClose(quicConn, common.ClientClosed)
2022-11-28 17:09:25 +08:00
}
})
},
2022-11-25 08:08:14 +08:00
}
return pc, nil
}
type Client struct {
*clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner
}
2023-06-12 17:44:22 +08:00
func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) {
2022-12-22 09:53:11 +08:00
conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn)
if err != nil {
return nil, err
}
return N.NewRefConn(conn, t), err
}
2023-06-12 17:44:22 +08:00
func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {
2022-12-22 09:53:11 +08:00
pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn)
if err != nil {
return nil, err
}
return N.NewRefPacketConn(pc, t), nil
}
func (t *Client) forceClose() {
2023-06-12 17:44:22 +08:00
t.clientImpl.forceClose(nil, common.ClientClosed)
}
2023-06-12 17:44:22 +08:00
func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client {
ci := &clientImpl{
2022-11-26 23:53:59 +08:00
ClientOption: clientOption,
udp: udp,
2023-06-12 17:44:22 +08:00
dialerRef: dialerRef,
2023-11-23 10:24:01 +08:00
udpInputMap: xsync.NewMapOf[uint32, net.Conn](),
2022-11-26 23:53:59 +08:00
}
c := &Client{ci}
2022-11-26 23:53:59 +08:00
runtime.SetFinalizer(c, closeClient)
2023-06-12 17:44:22 +08:00
log.Debugln("New TuicV4 Client at %p", c)
2022-11-26 23:53:59 +08:00
return c
}
func closeClient(client *Client) {
2023-06-12 17:44:22 +08:00
log.Debugln("Close TuicV4 Client at %p", client)
client.forceClose()
2022-11-26 23:53:59 +08:00
}