mihomo/tunnel/tunnel.go

740 lines
18 KiB
Go

package tunnel
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/loopback"
"github.com/metacubex/mihomo/component/nat"
P "github.com/metacubex/mihomo/component/process"
"github.com/metacubex/mihomo/component/resolver"
"github.com/metacubex/mihomo/component/slowdown"
"github.com/metacubex/mihomo/component/sniffer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/features"
"github.com/metacubex/mihomo/constant/provider"
icontext "github.com/metacubex/mihomo/context"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/tunnel/statistic"
)
const (
queueCapacity = 64 // chan capacity tcpQueue and udpQueue
senderCapacity = 128 // chan capacity of PacketSender
)
var (
status = newAtomicStatus(Suspend)
udpInit sync.Once
udpQueues []chan C.PacketAdapter
natTable = nat.New()
rules []C.Rule
listeners = make(map[string]C.InboundListener)
subRules map[string][]C.Rule
proxies = make(map[string]C.Proxy)
providers map[string]provider.ProxyProvider
ruleProviders map[string]provider.RuleProvider
configMux sync.RWMutex
// for compatibility, lazy init
tcpQueue chan C.ConnContext
tcpInOnce sync.Once
udpQueue chan C.PacketAdapter
udpInOnce sync.Once
// Outbound Rule
mode = Rule
// default timeout for UDP session
udpTimeout = 60 * time.Second
findProcessMode P.FindProcessMode
fakeIPRange netip.Prefix
snifferDispatcher *sniffer.Dispatcher
sniffingEnable = false
ruleUpdateCallback = utils.NewCallback[provider.RuleProvider]()
)
type tunnel struct{}
var Tunnel = tunnel{}
var _ C.Tunnel = Tunnel
var _ provider.Tunnel = Tunnel
func (t tunnel) HandleTCPConn(conn net.Conn, metadata *C.Metadata) {
connCtx := icontext.NewConnContext(conn, metadata)
handleTCPConn(connCtx)
}
func initUDP() {
numUDPWorkers := 4
if num := runtime.GOMAXPROCS(0); num > numUDPWorkers {
numUDPWorkers = num
}
udpQueues = make([]chan C.PacketAdapter, numUDPWorkers)
for i := 0; i < numUDPWorkers; i++ {
queue := make(chan C.PacketAdapter, queueCapacity)
udpQueues[i] = queue
go processUDP(queue)
}
}
func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) {
udpInit.Do(initUDP)
packetAdapter := C.NewPacketAdapter(packet, metadata)
key := packetAdapter.Key()
hash := utils.MapHash(key)
queueNo := uint(hash) % uint(len(udpQueues))
select {
case udpQueues[queueNo] <- packetAdapter:
default:
packet.Drop()
}
}
func (t tunnel) NatTable() C.NatTable {
return natTable
}
func (t tunnel) Providers() map[string]provider.ProxyProvider {
return providers
}
func (t tunnel) RuleProviders() map[string]provider.RuleProvider {
return ruleProviders
}
func (t tunnel) RuleUpdateCallback() *utils.Callback[provider.RuleProvider] {
return ruleUpdateCallback
}
func OnSuspend() {
status.Store(Suspend)
}
func OnInnerLoading() {
status.Store(Inner)
}
func OnRunning() {
status.Store(Running)
}
func Status() TunnelStatus {
return status.Load()
}
func SetFakeIPRange(p netip.Prefix) {
fakeIPRange = p
}
func FakeIPRange() netip.Prefix {
return fakeIPRange
}
func SetSniffing(b bool) {
if snifferDispatcher.Enable() {
configMux.Lock()
sniffingEnable = b
configMux.Unlock()
}
}
func IsSniffing() bool {
return sniffingEnable
}
// TCPIn return fan-in queue
// Deprecated: using Tunnel instead
func TCPIn() chan<- C.ConnContext {
tcpInOnce.Do(func() {
tcpQueue = make(chan C.ConnContext, queueCapacity)
go func() {
for connCtx := range tcpQueue {
go handleTCPConn(connCtx)
}
}()
})
return tcpQueue
}
// UDPIn return fan-in udp queue
// Deprecated: using Tunnel instead
func UDPIn() chan<- C.PacketAdapter {
udpInOnce.Do(func() {
udpQueue = make(chan C.PacketAdapter, queueCapacity)
go func() {
for packet := range udpQueue {
Tunnel.HandleUDPPacket(packet, packet.Metadata())
}
}()
})
return udpQueue
}
// NatTable return nat table
func NatTable() C.NatTable {
return natTable
}
// Rules return all rules
func Rules() []C.Rule {
return rules
}
func Listeners() map[string]C.InboundListener {
return listeners
}
// UpdateRules handle update rules
func UpdateRules(newRules []C.Rule, newSubRule map[string][]C.Rule, rp map[string]provider.RuleProvider) {
configMux.Lock()
rules = newRules
ruleProviders = rp
subRules = newSubRule
configMux.Unlock()
}
// Proxies return all proxies
func Proxies() map[string]C.Proxy {
return proxies
}
func ProxiesWithProviders() map[string]C.Proxy {
allProxies := make(map[string]C.Proxy)
for name, proxy := range proxies {
allProxies[name] = proxy
}
for _, p := range providers {
for _, proxy := range p.Proxies() {
name := proxy.Name()
allProxies[name] = proxy
}
}
return allProxies
}
// Providers return all compatible providers
func Providers() map[string]provider.ProxyProvider {
return providers
}
// RuleProviders return all loaded rule providers
func RuleProviders() map[string]provider.RuleProvider {
return ruleProviders
}
// UpdateProxies handle update proxies
func UpdateProxies(newProxies map[string]C.Proxy, newProviders map[string]provider.ProxyProvider) {
configMux.Lock()
proxies = newProxies
providers = newProviders
configMux.Unlock()
}
func UpdateListeners(newListeners map[string]C.InboundListener) {
configMux.Lock()
defer configMux.Unlock()
listeners = newListeners
}
func UpdateSniffer(dispatcher *sniffer.Dispatcher) {
configMux.Lock()
snifferDispatcher = dispatcher
sniffingEnable = dispatcher.Enable()
configMux.Unlock()
}
// Mode return current mode
func Mode() TunnelMode {
return mode
}
// SetMode change the mode of tunnel
func SetMode(m TunnelMode) {
mode = m
}
func FindProcessMode() P.FindProcessMode {
return findProcessMode
}
// SetFindProcessMode replace SetAlwaysFindProcess
// always find process info if legacyAlways = true or mode.Always() = true, may be increase many memory
func SetFindProcessMode(mode P.FindProcessMode) {
findProcessMode = mode
}
func isHandle(t C.Type) bool {
status := status.Load()
return status == Running || (status == Inner && t == C.INNER)
}
func needLookupIP(metadata *C.Metadata) bool {
return resolver.MappingEnabled() && metadata.Host == "" && metadata.DstIP.IsValid()
}
func preHandleMetadata(metadata *C.Metadata) error {
// handle IP string on host
if ip, err := netip.ParseAddr(metadata.Host); err == nil {
metadata.DstIP = ip
metadata.Host = ""
}
// preprocess enhanced-mode metadata
if needLookupIP(metadata) {
host, exist := resolver.FindHostByIP(metadata.DstIP)
if exist {
metadata.Host = host
metadata.DNSMode = C.DNSMapping
if resolver.FakeIPEnabled() {
metadata.DstIP = netip.Addr{}
metadata.DNSMode = C.DNSFakeIP
} else if node, ok := resolver.DefaultHosts.Search(host, false); ok {
// redir-host should lookup the hosts
metadata.DstIP, _ = node.RandIP()
} else if node != nil && node.IsDomain {
metadata.Host = node.Domain
}
} else if resolver.IsFakeIP(metadata.DstIP) {
return fmt.Errorf("fake DNS record %s missing", metadata.DstIP)
}
} else if node, ok := resolver.DefaultHosts.Search(metadata.Host, true); ok {
// try use domain mapping
metadata.Host = node.Domain
}
return nil
}
func resolveMetadata(metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) {
if metadata.SpecialProxy != "" {
var exist bool
proxy, exist = proxies[metadata.SpecialProxy]
if !exist {
err = fmt.Errorf("proxy %s not found", metadata.SpecialProxy)
}
return
}
switch mode {
case Direct:
proxy = proxies["DIRECT"]
case Global:
proxy = proxies["GLOBAL"]
// Rule
default:
proxy, rule, err = match(metadata)
}
return
}
func resolveUDP(metadata *C.Metadata) error {
// local resolve UDP dns
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(context.Background(), metadata.Host)
if err != nil {
return err
}
metadata.DstIP = ip
}
return nil
}
// processUDP starts a loop to handle udp packet
func processUDP(queue chan C.PacketAdapter) {
for conn := range queue {
handleUDPConn(conn)
}
}
func handleUDPConn(packet C.PacketAdapter) {
if !isHandle(packet.Metadata().Type) {
packet.Drop()
return
}
metadata := packet.Metadata()
if !metadata.Valid() {
packet.Drop()
log.Warnln("[Metadata] not valid: %#v", metadata)
return
}
// make a fAddr if request ip is fakeip
var fAddr netip.Addr
if resolver.IsExistFakeIP(metadata.DstIP) {
fAddr = metadata.DstIP
}
if err := preHandleMetadata(metadata); err != nil {
packet.Drop()
log.Debugln("[Metadata PreHandle] error: %s", err)
return
}
if sniffingEnable && snifferDispatcher.Enable() {
snifferDispatcher.UDPSniff(packet)
}
key := packet.Key()
sender, loaded := natTable.GetOrCreate(key, newPacketSender)
if !loaded {
dial := func() (C.PacketConn, C.WriteBackProxy, error) {
if err := resolveUDP(metadata); err != nil {
log.Warnln("[UDP] Resolve Ip error: %s", err)
return nil, nil, err
}
proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return nil, nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
defer cancel()
rawPc, err := retry(ctx, func(ctx context.Context) (C.PacketConn, error) {
return proxy.ListenPacketContext(ctx, metadata.Pure())
}, func(err error) {
logMetadataErr(metadata, rule, proxy, err)
})
if err != nil {
return nil, nil, err
}
logMetadata(metadata, rule, rawPc)
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true)
if rawPc.Chains().Last() == "REJECT-DROP" {
_ = pc.Close()
return nil, nil, errors.New("rejected drop packet")
}
oAddrPort := metadata.AddrPort()
writeBackProxy := nat.NewWriteBackProxy(packet)
go handleUDPToLocal(writeBackProxy, pc, sender, key, oAddrPort, fAddr)
return pc, writeBackProxy, nil
}
go func() {
pc, proxy, err := dial()
if err != nil {
sender.Close()
natTable.Delete(key)
return
}
sender.Process(pc, proxy)
}()
}
sender.Send(packet) // nonblocking
}
func handleTCPConn(connCtx C.ConnContext) {
if !isHandle(connCtx.Metadata().Type) {
_ = connCtx.Conn().Close()
return
}
defer func(conn net.Conn) {
_ = conn.Close()
}(connCtx.Conn())
metadata := connCtx.Metadata()
if !metadata.Valid() {
log.Warnln("[Metadata] not valid: %#v", metadata)
return
}
preHandleFailed := false
if err := preHandleMetadata(metadata); err != nil {
log.Debugln("[Metadata PreHandle] error: %s", err)
preHandleFailed = true
}
conn := connCtx.Conn()
conn.ResetPeeked() // reset before sniffer
if sniffingEnable && snifferDispatcher.Enable() {
// Try to sniff a domain when `preHandleMetadata` failed, this is usually
// caused by a "Fake DNS record missing" error when enhanced-mode is fake-ip.
if snifferDispatcher.TCPSniff(conn, metadata) {
// we now have a domain name
preHandleFailed = false
}
}
// If both trials have failed, we can do nothing but give up
if preHandleFailed {
log.Debugln("[Metadata PreHandle] failed to sniff a domain for connection %s --> %s, give up",
metadata.SourceDetail(), metadata.RemoteAddress())
return
}
peekMutex := sync.Mutex{}
if !conn.Peeked() {
peekMutex.Lock()
go func() {
defer peekMutex.Unlock()
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
_, _ = conn.Peek(1)
_ = conn.SetReadDeadline(time.Time{})
}()
}
proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[Metadata] parse failed: %s", err.Error())
return
}
dialMetadata := metadata
if len(metadata.Host) > 0 {
if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok {
if dstIp, _ := node.RandIP(); !FakeIPRange().Contains(dstIp) {
dialMetadata.DstIP = dstIp
dialMetadata.DNSMode = C.DNSHosts
dialMetadata = dialMetadata.Pure()
}
}
}
var peekBytes []byte
var peekLen int
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
remoteConn, err := retry(ctx, func(ctx context.Context) (remoteConn C.Conn, err error) {
remoteConn, err = proxy.DialContext(ctx, dialMetadata)
if err != nil {
return
}
if N.NeedHandshake(remoteConn) {
defer func() {
for _, chain := range remoteConn.Chains() {
if chain == "REJECT" {
err = nil
return
}
}
if err != nil {
remoteConn = nil
}
}()
peekMutex.Lock()
defer peekMutex.Unlock()
peekBytes, _ = conn.Peek(conn.Buffered())
_, err = remoteConn.Write(peekBytes)
if err != nil {
return
}
if peekLen = len(peekBytes); peekLen > 0 {
_, _ = conn.Discard(peekLen)
}
}
return
}, func(err error) {
logMetadataErr(metadata, rule, proxy, err)
})
if err != nil {
return
}
logMetadata(metadata, rule, remoteConn)
remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule, 0, int64(peekLen), true)
defer func(remoteConn C.Conn) {
_ = remoteConn.Close()
}(remoteConn)
_ = conn.SetReadDeadline(time.Now()) // stop unfinished peek
peekMutex.Lock()
defer peekMutex.Unlock()
_ = conn.SetReadDeadline(time.Time{}) // reset
handleSocket(conn, remoteConn)
}
func logMetadataErr(metadata *C.Metadata, rule C.Rule, proxy C.ProxyAdapter, err error) {
if rule == nil {
log.Warnln("[%s] dial %s %s --> %s error: %s", strings.ToUpper(metadata.NetWork.String()), proxy.Name(), metadata.SourceDetail(), metadata.RemoteAddress(), err.Error())
} else {
log.Warnln("[%s] dial %s (match %s/%s) %s --> %s error: %s", strings.ToUpper(metadata.NetWork.String()), proxy.Name(), rule.RuleType().String(), rule.Payload(), metadata.SourceDetail(), metadata.RemoteAddress(), err.Error())
}
}
func logMetadata(metadata *C.Metadata, rule C.Rule, remoteConn C.Connection) {
switch {
case metadata.SpecialProxy != "":
log.Infoln("[%s] %s --> %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), metadata.SpecialProxy)
case rule != nil:
if rule.Payload() != "" {
log.Infoln("[%s] %s --> %s match %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), fmt.Sprintf("%s(%s)", rule.RuleType().String(), rule.Payload()), remoteConn.Chains().String())
} else {
log.Infoln("[%s] %s --> %s match %s using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), rule.RuleType().String(), remoteConn.Chains().String())
}
case mode == Global:
log.Infoln("[%s] %s --> %s using GLOBAL", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress())
case mode == Direct:
log.Infoln("[%s] %s --> %s using DIRECT", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress())
default:
log.Infoln("[%s] %s --> %s doesn't match any rule using %s", strings.ToUpper(metadata.NetWork.String()), metadata.SourceDetail(), metadata.RemoteAddress(), remoteConn.Chains().Last())
}
}
func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool {
return rule.ShouldResolveIP() && metadata.Host != "" && !metadata.DstIP.IsValid()
}
func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
configMux.RLock()
defer configMux.RUnlock()
var (
resolved bool
attemptProcessLookup = metadata.Type != C.INNER
)
if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok {
metadata.DstIP, _ = node.RandIP()
resolved = true
}
for _, rule := range getRules(metadata) {
if !resolved && shouldResolveIP(rule, metadata) {
func() {
ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout)
defer cancel()
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error())
} else {
log.Debugln("[DNS] %s --> %s", metadata.Host, ip.String())
metadata.DstIP = ip
}
resolved = true
}()
}
if attemptProcessLookup && !findProcessMode.Off() && (findProcessMode.Always() || rule.ShouldFindProcess()) {
attemptProcessLookup = false
if !features.CMFA {
// normal check for process
uid, path, err := P.FindProcessName(metadata.NetWork.String(), metadata.SrcIP, int(metadata.SrcPort))
if err != nil {
log.Debugln("[Process] find process %s error: %v", metadata.String(), err)
} else {
metadata.Process = filepath.Base(path)
metadata.ProcessPath = path
metadata.Uid = uid
if pkg, err := P.FindPackageName(metadata); err == nil { // for android (not CMFA) package names
metadata.Process = pkg
}
}
} else {
// check package names
pkg, err := P.FindPackageName(metadata)
if err != nil {
log.Debugln("[Process] find process %s error: %v", metadata.String(), err)
} else {
metadata.Process = pkg
}
}
}
if matched, ada := rule.Match(metadata); matched {
adapter, ok := proxies[ada]
if !ok {
continue
}
// parse multi-layer nesting
passed := false
for adapter := adapter; adapter != nil; adapter = adapter.Unwrap(metadata, false) {
if adapter.Type() == C.Pass {
passed = true
break
}
}
if passed {
log.Debugln("%s match Pass rule", adapter.Name())
continue
}
if metadata.NetWork == C.UDP && !adapter.SupportUDP() {
log.Debugln("%s UDP is not supported", adapter.Name())
continue
}
return adapter, rule, nil
}
}
return proxies["DIRECT"], nil, nil
}
func getRules(metadata *C.Metadata) []C.Rule {
if sr, ok := subRules[metadata.SpecialRules]; ok {
log.Debugln("[Rule] use %s rules", metadata.SpecialRules)
return sr
} else {
log.Debugln("[Rule] use default rules")
return rules
}
}
func shouldStopRetry(err error) bool {
if errors.Is(err, resolver.ErrIPNotFound) {
return true
}
if errors.Is(err, resolver.ErrIPVersion) {
return true
}
if errors.Is(err, resolver.ErrIPv6Disabled) {
return true
}
if errors.Is(err, loopback.ErrReject) {
return true
}
return false
}
func retry[T any](ctx context.Context, ft func(context.Context) (T, error), fe func(err error)) (t T, err error) {
s := slowdown.New()
for i := 0; i < 10; i++ {
t, err = ft(ctx)
if err != nil {
if fe != nil {
fe(err)
}
if shouldStopRetry(err) {
return
}
if s.Wait(ctx) == nil {
continue
} else {
return
}
} else {
break
}
}
return
}