chore: ensures packets can be sent without blocking the tunnel

This commit is contained in:
wwqgtxx 2024-09-26 11:21:07 +08:00
parent 5812a7bdeb
commit 4fa15c6334
4 changed files with 218 additions and 147 deletions

View File

@ -10,47 +10,30 @@ import (
) )
type Table struct { type Table struct {
mapping *xsync.MapOf[string, *Entry] mapping *xsync.MapOf[string, *entry]
lockMap *xsync.MapOf[string, *sync.Cond]
} }
type Entry struct { type entry struct {
PacketConn C.PacketConn PacketSender C.PacketSender
WriteBackProxy C.WriteBackProxy
LocalUDPConnMap *xsync.MapOf[string, *net.UDPConn] LocalUDPConnMap *xsync.MapOf[string, *net.UDPConn]
LocalLockMap *xsync.MapOf[string, *sync.Cond] LocalLockMap *xsync.MapOf[string, *sync.Cond]
} }
func (t *Table) Set(key string, e C.PacketConn, w C.WriteBackProxy) { func (t *Table) GetOrCreate(key string, maker func() C.PacketSender) (C.PacketSender, bool) {
t.mapping.Store(key, &Entry{ item, loaded := t.mapping.LoadOrCompute(key, func() *entry {
PacketConn: e, return &entry{
WriteBackProxy: w, PacketSender: maker(),
LocalUDPConnMap: xsync.NewMapOf[string, *net.UDPConn](), LocalUDPConnMap: xsync.NewMapOf[string, *net.UDPConn](),
LocalLockMap: xsync.NewMapOf[string, *sync.Cond](), LocalLockMap: xsync.NewMapOf[string, *sync.Cond](),
})
}
func (t *Table) Get(key string) (C.PacketConn, C.WriteBackProxy) {
entry, exist := t.getEntry(key)
if !exist {
return nil, nil
} }
return entry.PacketConn, entry.WriteBackProxy })
} return item.PacketSender, loaded
func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
item, loaded := t.lockMap.LoadOrCompute(key, makeLock)
return item, loaded
} }
func (t *Table) Delete(key string) { func (t *Table) Delete(key string) {
t.mapping.Delete(key) t.mapping.Delete(key)
} }
func (t *Table) DeleteLock(lockKey string) {
t.lockMap.Delete(lockKey)
}
func (t *Table) GetForLocalConn(lAddr, rAddr string) *net.UDPConn { func (t *Table) GetForLocalConn(lAddr, rAddr string) *net.UDPConn {
entry, exist := t.getEntry(lAddr) entry, exist := t.getEntry(lAddr)
if !exist { if !exist {
@ -105,7 +88,7 @@ func (t *Table) DeleteLockForLocalConn(lAddr, key string) {
entry.LocalLockMap.Delete(key) entry.LocalLockMap.Delete(key)
} }
func (t *Table) getEntry(key string) (*Entry, bool) { func (t *Table) getEntry(key string) (*entry, bool) {
return t.mapping.Load(key) return t.mapping.Load(key)
} }
@ -116,7 +99,6 @@ func makeLock() *sync.Cond {
// New return *Cache // New return *Cache
func New() *Table { func New() *Table {
return &Table{ return &Table{
mapping: xsync.NewMapOf[string, *Entry](), mapping: xsync.NewMapOf[string, *entry](),
lockMap: xsync.NewMapOf[string, *sync.Cond](),
} }
} }

View File

@ -255,12 +255,16 @@ type UDPPacketInAddr interface {
// PacketAdapter is a UDP Packet adapter for socks/redir/tun // PacketAdapter is a UDP Packet adapter for socks/redir/tun
type PacketAdapter interface { type PacketAdapter interface {
UDPPacket UDPPacket
// Metadata returns destination metadata
Metadata() *Metadata Metadata() *Metadata
// Key is a SNAT key
Key() string
} }
type packetAdapter struct { type packetAdapter struct {
UDPPacket UDPPacket
metadata *Metadata metadata *Metadata
key string
} }
// Metadata returns destination metadata // Metadata returns destination metadata
@ -268,10 +272,16 @@ func (s *packetAdapter) Metadata() *Metadata {
return s.metadata return s.metadata
} }
// Key is a SNAT key
func (s *packetAdapter) Key() string {
return s.key
}
func NewPacketAdapter(packet UDPPacket, metadata *Metadata) PacketAdapter { func NewPacketAdapter(packet UDPPacket, metadata *Metadata) PacketAdapter {
return &packetAdapter{ return &packetAdapter{
packet, packet,
metadata, metadata,
packet.LocalAddr().String(),
} }
} }
@ -284,17 +294,19 @@ type WriteBackProxy interface {
UpdateWriteBack(wb WriteBack) UpdateWriteBack(wb WriteBack)
} }
type PacketSender interface {
// Send will send PacketAdapter nonblocking
// the implement must call UDPPacket.Drop() inside Send
Send(PacketAdapter)
Process(PacketConn, WriteBackProxy)
Close()
}
type NatTable interface { type NatTable interface {
Set(key string, e PacketConn, w WriteBackProxy) GetOrCreate(key string, maker func() PacketSender) (PacketSender, bool)
Get(key string) (PacketConn, WriteBackProxy)
GetOrCreateLock(key string) (*sync.Cond, bool)
Delete(key string) Delete(key string)
DeleteLock(key string)
GetForLocalConn(lAddr, rAddr string) *net.UDPConn GetForLocalConn(lAddr, rAddr string) *net.UDPConn
AddForLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool AddForLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool

View File

@ -1,6 +1,7 @@
package tunnel package tunnel
import ( import (
"context"
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
@ -11,7 +12,78 @@ import (
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
) )
type packetSender struct {
ctx context.Context
cancel context.CancelFunc
ch chan C.PacketAdapter
}
// newPacketSender return a chan based C.PacketSender
// It ensures that packets can be sent sequentially and without blocking
func newPacketSender() C.PacketSender {
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan C.PacketAdapter, senderCapacity)
return &packetSender{
ctx: ctx,
cancel: cancel,
ch: ch,
}
}
func (s *packetSender) Process(pc C.PacketConn, proxy C.WriteBackProxy) {
for {
select {
case <-s.ctx.Done():
return // sender closed
case packet := <-s.ch:
if proxy != nil {
proxy.UpdateWriteBack(packet)
}
_ = handleUDPToRemote(packet, pc, packet.Metadata())
packet.Drop()
}
}
}
func (s *packetSender) dropAll() {
for {
select {
case data := <-s.ch:
data.Drop() // drop all data still in chan
default:
return // no data, exit goroutine
}
}
}
func (s *packetSender) Send(packet C.PacketAdapter) {
select {
case <-s.ctx.Done():
packet.Drop() // sender closed before Send()
return
default:
}
select {
case s.ch <- packet:
// put ok, so don't drop packet, will process by other side of chan
case <-s.ctx.Done():
packet.Drop() // sender closed when putting data to chan
default:
packet.Drop() // chan is full
}
}
func (s *packetSender) Close() {
s.cancel()
s.dropAll()
}
func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error { func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error {
if err := resolveUDP(metadata); err != nil {
return err
}
addr := metadata.UDPAddr() addr := metadata.UDPAddr()
if addr == nil { if addr == nil {
return errors.New("udp addr invalid") return errors.New("udp addr invalid")
@ -26,8 +98,9 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata
return nil return nil
} }
func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) { func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, sender C.PacketSender, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) {
defer func() { defer func() {
sender.Close()
_ = pc.Close() _ = pc.Close()
closeAllLocalCoon(key) closeAllLocalCoon(key)
natTable.Delete(key) natTable.Delete(key)

View File

@ -28,11 +28,14 @@ import (
"github.com/metacubex/mihomo/tunnel/statistic" "github.com/metacubex/mihomo/tunnel/statistic"
) )
const queueSize = 200 const (
queueCapacity = 64 // chan capacity tcpQueue and udpQueue
senderCapacity = 128 // chan capacity of PacketSender
)
var ( var (
status = newAtomicStatus(Suspend) status = newAtomicStatus(Suspend)
tcpQueue = make(chan C.ConnContext, queueSize) udpInit sync.Once
udpQueues []chan C.PacketAdapter udpQueues []chan C.PacketAdapter
natTable = nat.New() natTable = nat.New()
rules []C.Rule rules []C.Rule
@ -43,6 +46,12 @@ var (
ruleProviders map[string]provider.RuleProvider ruleProviders map[string]provider.RuleProvider
configMux sync.RWMutex configMux sync.RWMutex
// for compatibility, lazy init
tcpQueue chan C.ConnContext
tcpInOnce sync.Once
udpQueue chan C.PacketAdapter
udpInOnce sync.Once
// Outbound Rule // Outbound Rule
mode = Rule mode = Rule
@ -70,15 +79,33 @@ func (t tunnel) HandleTCPConn(conn net.Conn, metadata *C.Metadata) {
handleTCPConn(connCtx) handleTCPConn(connCtx)
} }
func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) { func initUDP() {
packetAdapter := C.NewPacketAdapter(packet, metadata) numUDPWorkers := 4
if num := runtime.GOMAXPROCS(0); num > numUDPWorkers {
numUDPWorkers = num
}
hash := utils.MapHash(metadata.SourceAddress() + "-" + metadata.RemoteAddress()) udpQueues = make([]chan C.PacketAdapter, numUDPWorkers)
for i := 0; i < numUDPWorkers; i++ {
queue := make(chan C.PacketAdapter, queueCapacity)
udpQueues[i] = queue
go processUDP(queue)
}
}
func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) {
udpInit.Do(initUDP)
packetAdapter := C.NewPacketAdapter(packet, metadata)
key := packetAdapter.Key()
hash := utils.MapHash(key)
queueNo := uint(hash) % uint(len(udpQueues)) queueNo := uint(hash) % uint(len(udpQueues))
select { select {
case udpQueues[queueNo] <- packetAdapter: case udpQueues[queueNo] <- packetAdapter:
default: default:
packet.Drop()
} }
} }
@ -134,21 +161,32 @@ func IsSniffing() bool {
return sniffingEnable return sniffingEnable
} }
func init() {
go process()
}
// TCPIn return fan-in queue // TCPIn return fan-in queue
// Deprecated: using Tunnel instead // Deprecated: using Tunnel instead
func TCPIn() chan<- C.ConnContext { func TCPIn() chan<- C.ConnContext {
tcpInOnce.Do(func() {
tcpQueue = make(chan C.ConnContext, queueCapacity)
go func() {
for connCtx := range tcpQueue {
go handleTCPConn(connCtx)
}
}()
})
return tcpQueue return tcpQueue
} }
// UDPIn return fan-in udp queue // UDPIn return fan-in udp queue
// Deprecated: using Tunnel instead // Deprecated: using Tunnel instead
func UDPIn() chan<- C.PacketAdapter { func UDPIn() chan<- C.PacketAdapter {
// compatibility: first queue is always available for external callers udpInOnce.Do(func() {
return udpQueues[0] udpQueue = make(chan C.PacketAdapter, queueCapacity)
go func() {
for packet := range udpQueue {
Tunnel.HandleUDPPacket(packet, packet.Metadata())
}
}()
})
return udpQueue
} }
// NatTable return nat table // NatTable return nat table
@ -249,32 +287,6 @@ func isHandle(t C.Type) bool {
return status == Running || (status == Inner && t == C.INNER) return status == Running || (status == Inner && t == C.INNER)
} }
// processUDP starts a loop to handle udp packet
func processUDP(queue chan C.PacketAdapter) {
for conn := range queue {
handleUDPConn(conn)
}
}
func process() {
numUDPWorkers := 4
if num := runtime.GOMAXPROCS(0); num > numUDPWorkers {
numUDPWorkers = num
}
udpQueues = make([]chan C.PacketAdapter, numUDPWorkers)
for i := 0; i < numUDPWorkers; i++ {
queue := make(chan C.PacketAdapter, queueSize)
udpQueues[i] = queue
go processUDP(queue)
}
queue := tcpQueue
for conn := range queue {
go handleTCPConn(conn)
}
}
func needLookupIP(metadata *C.Metadata) bool { func needLookupIP(metadata *C.Metadata) bool {
return resolver.MappingEnabled() && metadata.Host == "" && metadata.DstIP.IsValid() return resolver.MappingEnabled() && metadata.Host == "" && metadata.DstIP.IsValid()
} }
@ -334,6 +346,25 @@ func resolveMetadata(metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err erro
return return
} }
func resolveUDP(metadata *C.Metadata) error {
// local resolve UDP dns
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(context.Background(), metadata.Host)
if err != nil {
return err
}
metadata.DstIP = ip
}
return nil
}
// processUDP starts a loop to handle udp packet
func processUDP(queue chan C.PacketAdapter) {
for conn := range queue {
handleUDPConn(conn)
}
}
func handleUDPConn(packet C.PacketAdapter) { func handleUDPConn(packet C.PacketAdapter) {
if !isHandle(packet.Metadata().Type) { if !isHandle(packet.Metadata().Type) {
packet.Drop() packet.Drop()
@ -363,56 +394,19 @@ func handleUDPConn(packet C.PacketAdapter) {
snifferDispatcher.UDPSniff(packet) snifferDispatcher.UDPSniff(packet)
} }
// local resolve UDP dns key := packet.Key()
if !metadata.Resolved() { sender, loaded := natTable.GetOrCreate(key, newPacketSender)
ip, err := resolver.ResolveIP(context.Background(), metadata.Host) if !loaded {
if err != nil { dial := func() (C.PacketConn, C.WriteBackProxy, error) {
return if err := resolveUDP(metadata); err != nil {
log.Warnln("[UDP] Resolve Ip error: %s", err)
return nil, nil, err
} }
metadata.DstIP = ip
}
key := packet.LocalAddr().String()
handle := func() bool {
pc, proxy := natTable.Get(key)
if pc != nil {
if proxy != nil {
proxy.UpdateWriteBack(packet)
}
_ = handleUDPToRemote(packet, pc, metadata)
return true
}
return false
}
if handle() {
packet.Drop()
return
}
cond, loaded := natTable.GetOrCreateLock(key)
go func() {
defer packet.Drop()
if loaded {
cond.L.Lock()
cond.Wait()
handle()
cond.L.Unlock()
return
}
defer func() {
natTable.DeleteLock(key)
cond.Broadcast()
}()
proxy, rule, err := resolveMetadata(metadata) proxy, rule, err := resolveMetadata(metadata)
if err != nil { if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return return nil, nil, err
} }
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout) ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
@ -423,25 +417,35 @@ func handleUDPConn(packet C.PacketAdapter) {
logMetadataErr(metadata, rule, proxy, err) logMetadataErr(metadata, rule, proxy, err)
}) })
if err != nil { if err != nil {
return return nil, nil, err
} }
logMetadata(metadata, rule, rawPc) logMetadata(metadata, rule, rawPc)
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true) pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true)
if rawPc.Chains().Last() == "REJECT-DROP" { if rawPc.Chains().Last() == "REJECT-DROP" {
pc.Close() _ = pc.Close()
return return nil, nil, errors.New("rejected drop packet")
} }
oAddrPort := metadata.AddrPort() oAddrPort := metadata.AddrPort()
writeBackProxy := nat.NewWriteBackProxy(packet) writeBackProxy := nat.NewWriteBackProxy(packet)
natTable.Set(key, pc, writeBackProxy)
go handleUDPToLocal(writeBackProxy, pc, key, oAddrPort, fAddr) go handleUDPToLocal(writeBackProxy, pc, sender, key, oAddrPort, fAddr)
return pc, writeBackProxy, nil
}
handle() go func() {
pc, proxy, err := dial()
if err != nil {
sender.Close()
natTable.Delete(key)
return
}
sender.Process(pc, proxy)
}() }()
}
sender.Send(packet) // nonblocking
} }
func handleTCPConn(connCtx C.ConnContext) { func handleTCPConn(connCtx C.ConnContext) {