From 4fa15c633494f6cf2fac2bef282667b4b0ee9db2 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Thu, 26 Sep 2024 11:21:07 +0800 Subject: [PATCH] chore: ensures packets can be sent without blocking the tunnel --- component/nat/table.go | 44 +++------ constant/adapters.go | 26 +++-- tunnel/connection.go | 75 +++++++++++++- tunnel/tunnel.go | 220 +++++++++++++++++++++-------------------- 4 files changed, 218 insertions(+), 147 deletions(-) diff --git a/component/nat/table.go b/component/nat/table.go index bb5ab755..66241fb4 100644 --- a/component/nat/table.go +++ b/component/nat/table.go @@ -10,47 +10,30 @@ import ( ) type Table struct { - mapping *xsync.MapOf[string, *Entry] - lockMap *xsync.MapOf[string, *sync.Cond] + mapping *xsync.MapOf[string, *entry] } -type Entry struct { - PacketConn C.PacketConn - WriteBackProxy C.WriteBackProxy +type entry struct { + PacketSender C.PacketSender LocalUDPConnMap *xsync.MapOf[string, *net.UDPConn] LocalLockMap *xsync.MapOf[string, *sync.Cond] } -func (t *Table) Set(key string, e C.PacketConn, w C.WriteBackProxy) { - t.mapping.Store(key, &Entry{ - PacketConn: e, - WriteBackProxy: w, - LocalUDPConnMap: xsync.NewMapOf[string, *net.UDPConn](), - LocalLockMap: xsync.NewMapOf[string, *sync.Cond](), +func (t *Table) GetOrCreate(key string, maker func() C.PacketSender) (C.PacketSender, bool) { + item, loaded := t.mapping.LoadOrCompute(key, func() *entry { + return &entry{ + PacketSender: maker(), + LocalUDPConnMap: xsync.NewMapOf[string, *net.UDPConn](), + LocalLockMap: xsync.NewMapOf[string, *sync.Cond](), + } }) -} - -func (t *Table) Get(key string) (C.PacketConn, C.WriteBackProxy) { - entry, exist := t.getEntry(key) - if !exist { - return nil, nil - } - return entry.PacketConn, entry.WriteBackProxy -} - -func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) { - item, loaded := t.lockMap.LoadOrCompute(key, makeLock) - return item, loaded + return item.PacketSender, loaded } func (t *Table) Delete(key string) { t.mapping.Delete(key) } -func (t *Table) DeleteLock(lockKey string) { - t.lockMap.Delete(lockKey) -} - func (t *Table) GetForLocalConn(lAddr, rAddr string) *net.UDPConn { entry, exist := t.getEntry(lAddr) if !exist { @@ -105,7 +88,7 @@ func (t *Table) DeleteLockForLocalConn(lAddr, key string) { entry.LocalLockMap.Delete(key) } -func (t *Table) getEntry(key string) (*Entry, bool) { +func (t *Table) getEntry(key string) (*entry, bool) { return t.mapping.Load(key) } @@ -116,7 +99,6 @@ func makeLock() *sync.Cond { // New return *Cache func New() *Table { return &Table{ - mapping: xsync.NewMapOf[string, *Entry](), - lockMap: xsync.NewMapOf[string, *sync.Cond](), + mapping: xsync.NewMapOf[string, *entry](), } } diff --git a/constant/adapters.go b/constant/adapters.go index cb213b3c..3731cd60 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -255,12 +255,16 @@ type UDPPacketInAddr interface { // PacketAdapter is a UDP Packet adapter for socks/redir/tun type PacketAdapter interface { UDPPacket + // Metadata returns destination metadata Metadata() *Metadata + // Key is a SNAT key + Key() string } type packetAdapter struct { UDPPacket metadata *Metadata + key string } // Metadata returns destination metadata @@ -268,10 +272,16 @@ func (s *packetAdapter) Metadata() *Metadata { return s.metadata } +// Key is a SNAT key +func (s *packetAdapter) Key() string { + return s.key +} + func NewPacketAdapter(packet UDPPacket, metadata *Metadata) PacketAdapter { return &packetAdapter{ packet, metadata, + packet.LocalAddr().String(), } } @@ -284,17 +294,19 @@ type WriteBackProxy interface { UpdateWriteBack(wb WriteBack) } +type PacketSender interface { + // Send will send PacketAdapter nonblocking + // the implement must call UDPPacket.Drop() inside Send + Send(PacketAdapter) + Process(PacketConn, WriteBackProxy) + Close() +} + type NatTable interface { - Set(key string, e PacketConn, w WriteBackProxy) - - Get(key string) (PacketConn, WriteBackProxy) - - GetOrCreateLock(key string) (*sync.Cond, bool) + GetOrCreate(key string, maker func() PacketSender) (PacketSender, bool) Delete(key string) - DeleteLock(key string) - GetForLocalConn(lAddr, rAddr string) *net.UDPConn AddForLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool diff --git a/tunnel/connection.go b/tunnel/connection.go index e96545e8..17e4efd0 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -1,6 +1,7 @@ package tunnel import ( + "context" "errors" "net" "net/netip" @@ -11,7 +12,78 @@ import ( "github.com/metacubex/mihomo/log" ) +type packetSender struct { + ctx context.Context + cancel context.CancelFunc + ch chan C.PacketAdapter +} + +// newPacketSender return a chan based C.PacketSender +// It ensures that packets can be sent sequentially and without blocking +func newPacketSender() C.PacketSender { + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan C.PacketAdapter, senderCapacity) + return &packetSender{ + ctx: ctx, + cancel: cancel, + ch: ch, + } +} + +func (s *packetSender) Process(pc C.PacketConn, proxy C.WriteBackProxy) { + for { + select { + case <-s.ctx.Done(): + return // sender closed + case packet := <-s.ch: + if proxy != nil { + proxy.UpdateWriteBack(packet) + } + _ = handleUDPToRemote(packet, pc, packet.Metadata()) + packet.Drop() + } + } +} + +func (s *packetSender) dropAll() { + for { + select { + case data := <-s.ch: + data.Drop() // drop all data still in chan + default: + return // no data, exit goroutine + } + } +} + +func (s *packetSender) Send(packet C.PacketAdapter) { + select { + case <-s.ctx.Done(): + packet.Drop() // sender closed before Send() + return + default: + } + + select { + case s.ch <- packet: + // put ok, so don't drop packet, will process by other side of chan + case <-s.ctx.Done(): + packet.Drop() // sender closed when putting data to chan + default: + packet.Drop() // chan is full + } +} + +func (s *packetSender) Close() { + s.cancel() + s.dropAll() +} + func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error { + if err := resolveUDP(metadata); err != nil { + return err + } + addr := metadata.UDPAddr() if addr == nil { return errors.New("udp addr invalid") @@ -26,8 +98,9 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata return nil } -func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) { +func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, sender C.PacketSender, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) { defer func() { + sender.Close() _ = pc.Close() closeAllLocalCoon(key) natTable.Delete(key) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index b6c61d76..af16e4ae 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -28,11 +28,14 @@ import ( "github.com/metacubex/mihomo/tunnel/statistic" ) -const queueSize = 200 +const ( + queueCapacity = 64 // chan capacity tcpQueue and udpQueue + senderCapacity = 128 // chan capacity of PacketSender +) var ( status = newAtomicStatus(Suspend) - tcpQueue = make(chan C.ConnContext, queueSize) + udpInit sync.Once udpQueues []chan C.PacketAdapter natTable = nat.New() rules []C.Rule @@ -43,6 +46,12 @@ var ( ruleProviders map[string]provider.RuleProvider configMux sync.RWMutex + // for compatibility, lazy init + tcpQueue chan C.ConnContext + tcpInOnce sync.Once + udpQueue chan C.PacketAdapter + udpInOnce sync.Once + // Outbound Rule mode = Rule @@ -70,15 +79,33 @@ func (t tunnel) HandleTCPConn(conn net.Conn, metadata *C.Metadata) { handleTCPConn(connCtx) } -func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) { - packetAdapter := C.NewPacketAdapter(packet, metadata) +func initUDP() { + numUDPWorkers := 4 + if num := runtime.GOMAXPROCS(0); num > numUDPWorkers { + numUDPWorkers = num + } - hash := utils.MapHash(metadata.SourceAddress() + "-" + metadata.RemoteAddress()) + udpQueues = make([]chan C.PacketAdapter, numUDPWorkers) + for i := 0; i < numUDPWorkers; i++ { + queue := make(chan C.PacketAdapter, queueCapacity) + udpQueues[i] = queue + go processUDP(queue) + } +} + +func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) { + udpInit.Do(initUDP) + + packetAdapter := C.NewPacketAdapter(packet, metadata) + key := packetAdapter.Key() + + hash := utils.MapHash(key) queueNo := uint(hash) % uint(len(udpQueues)) select { case udpQueues[queueNo] <- packetAdapter: default: + packet.Drop() } } @@ -134,21 +161,32 @@ func IsSniffing() bool { return sniffingEnable } -func init() { - go process() -} - // TCPIn return fan-in queue // Deprecated: using Tunnel instead func TCPIn() chan<- C.ConnContext { + tcpInOnce.Do(func() { + tcpQueue = make(chan C.ConnContext, queueCapacity) + go func() { + for connCtx := range tcpQueue { + go handleTCPConn(connCtx) + } + }() + }) return tcpQueue } // UDPIn return fan-in udp queue // Deprecated: using Tunnel instead func UDPIn() chan<- C.PacketAdapter { - // compatibility: first queue is always available for external callers - return udpQueues[0] + udpInOnce.Do(func() { + udpQueue = make(chan C.PacketAdapter, queueCapacity) + go func() { + for packet := range udpQueue { + Tunnel.HandleUDPPacket(packet, packet.Metadata()) + } + }() + }) + return udpQueue } // NatTable return nat table @@ -249,32 +287,6 @@ func isHandle(t C.Type) bool { return status == Running || (status == Inner && t == C.INNER) } -// processUDP starts a loop to handle udp packet -func processUDP(queue chan C.PacketAdapter) { - for conn := range queue { - handleUDPConn(conn) - } -} - -func process() { - numUDPWorkers := 4 - if num := runtime.GOMAXPROCS(0); num > numUDPWorkers { - numUDPWorkers = num - } - - udpQueues = make([]chan C.PacketAdapter, numUDPWorkers) - for i := 0; i < numUDPWorkers; i++ { - queue := make(chan C.PacketAdapter, queueSize) - udpQueues[i] = queue - go processUDP(queue) - } - - queue := tcpQueue - for conn := range queue { - go handleTCPConn(conn) - } -} - func needLookupIP(metadata *C.Metadata) bool { return resolver.MappingEnabled() && metadata.Host == "" && metadata.DstIP.IsValid() } @@ -334,6 +346,25 @@ func resolveMetadata(metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err erro return } +func resolveUDP(metadata *C.Metadata) error { + // local resolve UDP dns + if !metadata.Resolved() { + ip, err := resolver.ResolveIP(context.Background(), metadata.Host) + if err != nil { + return err + } + metadata.DstIP = ip + } + return nil +} + +// processUDP starts a loop to handle udp packet +func processUDP(queue chan C.PacketAdapter) { + for conn := range queue { + handleUDPConn(conn) + } +} + func handleUDPConn(packet C.PacketAdapter) { if !isHandle(packet.Metadata().Type) { packet.Drop() @@ -363,85 +394,58 @@ func handleUDPConn(packet C.PacketAdapter) { snifferDispatcher.UDPSniff(packet) } - // local resolve UDP dns - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(context.Background(), metadata.Host) - if err != nil { - return - } - metadata.DstIP = ip - } - - key := packet.LocalAddr().String() - - handle := func() bool { - pc, proxy := natTable.Get(key) - if pc != nil { - if proxy != nil { - proxy.UpdateWriteBack(packet) + key := packet.Key() + sender, loaded := natTable.GetOrCreate(key, newPacketSender) + if !loaded { + dial := func() (C.PacketConn, C.WriteBackProxy, error) { + if err := resolveUDP(metadata); err != nil { + log.Warnln("[UDP] Resolve Ip error: %s", err) + return nil, nil, err } - _ = handleUDPToRemote(packet, pc, metadata) - return true - } - return false - } - if handle() { - packet.Drop() - return - } + proxy, rule, err := resolveMetadata(metadata) + if err != nil { + log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) + return nil, nil, err + } - cond, loaded := natTable.GetOrCreateLock(key) + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout) + defer cancel() + rawPc, err := retry(ctx, func(ctx context.Context) (C.PacketConn, error) { + return proxy.ListenPacketContext(ctx, metadata.Pure()) + }, func(err error) { + logMetadataErr(metadata, rule, proxy, err) + }) + if err != nil { + return nil, nil, err + } + logMetadata(metadata, rule, rawPc) - go func() { - defer packet.Drop() + pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true) - if loaded { - cond.L.Lock() - cond.Wait() - handle() - cond.L.Unlock() - return + if rawPc.Chains().Last() == "REJECT-DROP" { + _ = pc.Close() + return nil, nil, errors.New("rejected drop packet") + } + + oAddrPort := metadata.AddrPort() + writeBackProxy := nat.NewWriteBackProxy(packet) + + go handleUDPToLocal(writeBackProxy, pc, sender, key, oAddrPort, fAddr) + return pc, writeBackProxy, nil } - defer func() { - natTable.DeleteLock(key) - cond.Broadcast() + go func() { + pc, proxy, err := dial() + if err != nil { + sender.Close() + natTable.Delete(key) + return + } + sender.Process(pc, proxy) }() - - proxy, rule, err := resolveMetadata(metadata) - if err != nil { - log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) - return - } - - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout) - defer cancel() - rawPc, err := retry(ctx, func(ctx context.Context) (C.PacketConn, error) { - return proxy.ListenPacketContext(ctx, metadata.Pure()) - }, func(err error) { - logMetadataErr(metadata, rule, proxy, err) - }) - if err != nil { - return - } - logMetadata(metadata, rule, rawPc) - - pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true) - - if rawPc.Chains().Last() == "REJECT-DROP" { - pc.Close() - return - } - - oAddrPort := metadata.AddrPort() - writeBackProxy := nat.NewWriteBackProxy(packet) - natTable.Set(key, pc, writeBackProxy) - - go handleUDPToLocal(writeBackProxy, pc, key, oAddrPort, fAddr) - - handle() - }() + } + sender.Send(packet) // nonblocking } func handleTCPConn(connCtx C.ConnContext) {