diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index f96f2d6f..67cd9092 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -40,6 +40,7 @@ type WireGuard struct { startOnce sync.Once startErr error resolver *dns.Resolver + refP *refProxyAdapter } type WireGuardOption struct { @@ -100,6 +101,20 @@ func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr return cDialer.ListenPacket(ctx, "udp", "", destination.AddrPort()) } +type wgSingErrorHandler struct { + name string +} + +var _ E.Handler = (*wgSingErrorHandler)(nil) + +func (w wgSingErrorHandler) NewError(ctx context.Context, err error) { + if E.IsClosedOrCanceled(err) { + log.SingLogger.Debug(fmt.Sprintf("[WG](%s) connection closed: %s", w.name, err)) + return + } + log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", w.name, err)) +} + type wgNetDialer struct { tunDevice wireguard.Device } @@ -174,7 +189,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { connectAddr = option.Addr() } } - outbound.bind = wireguard.NewClientBind(context.Background(), outbound, outbound.dialer, isConnect, connectAddr, reserved) + outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, connectAddr, reserved) var localPrefixes []netip.Prefix @@ -312,13 +327,15 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { } } + refP := &refProxyAdapter{} + outbound.refP = refP if option.RemoteDnsResolve && len(option.Dns) > 0 { nss, err := dns.ParseNameServer(option.Dns) if err != nil { return nil, err } for i := range nss { - nss[i].ProxyAdapter = outbound + nss[i].ProxyAdapter = refP } outbound.resolver = dns.NewResolver(dns.Config{ Main: nss, @@ -329,14 +346,6 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { return outbound, nil } -func (w *WireGuard) NewError(ctx context.Context, err error) { - if E.IsClosedOrCanceled(err) { - log.SingLogger.Debug(fmt.Sprintf("[WG](%s) connection closed: %s", w.Name(), err)) - return - } - log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", w.Name(), err)) -} - func closeWireGuard(w *WireGuard) { if w.device != nil { w.device.Close() @@ -357,6 +366,8 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts if !metadata.Resolved() || w.resolver != nil { r := resolver.DefaultResolver if w.resolver != nil { + w.refP.SetProxyAdapter(w) + defer w.refP.ClearProxyAdapter() r = w.resolver } options = append(options, dialer.WithResolver(r)) @@ -391,6 +402,8 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" { r := resolver.DefaultResolver if w.resolver != nil { + w.refP.SetProxyAdapter(w) + defer w.refP.ClearProxyAdapter() r = w.resolver } ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) @@ -414,3 +427,139 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool { return true } + +type refProxyAdapter struct { + proxyAdapter C.ProxyAdapter + count int + mutex sync.Mutex +} + +func (r *refProxyAdapter) SetProxyAdapter(proxyAdapter C.ProxyAdapter) { + r.mutex.Lock() + defer r.mutex.Unlock() + r.proxyAdapter = proxyAdapter + r.count++ +} + +func (r *refProxyAdapter) ClearProxyAdapter() { + r.mutex.Lock() + defer r.mutex.Unlock() + r.count-- + if r.count == 0 { + r.proxyAdapter = nil + } +} + +func (r *refProxyAdapter) Name() string { + if r.proxyAdapter != nil { + return r.proxyAdapter.Name() + } + return "" +} + +func (r *refProxyAdapter) Type() C.AdapterType { + if r.proxyAdapter != nil { + return r.proxyAdapter.Type() + } + return C.AdapterType(0) +} + +func (r *refProxyAdapter) Addr() string { + if r.proxyAdapter != nil { + return r.proxyAdapter.Addr() + } + return "" +} + +func (r *refProxyAdapter) SupportUDP() bool { + if r.proxyAdapter != nil { + return r.proxyAdapter.SupportUDP() + } + return false +} + +func (r *refProxyAdapter) SupportXUDP() bool { + if r.proxyAdapter != nil { + return r.proxyAdapter.SupportXUDP() + } + return false +} + +func (r *refProxyAdapter) SupportTFO() bool { + if r.proxyAdapter != nil { + return r.proxyAdapter.SupportTFO() + } + return false +} + +func (r *refProxyAdapter) MarshalJSON() ([]byte, error) { + if r.proxyAdapter != nil { + return r.proxyAdapter.MarshalJSON() + } + return nil, C.ErrNotSupport +} + +func (r *refProxyAdapter) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { + if r.proxyAdapter != nil { + return r.proxyAdapter.StreamConn(c, metadata) + } + return nil, C.ErrNotSupport +} + +func (r *refProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { + if r.proxyAdapter != nil { + return r.proxyAdapter.DialContext(ctx, metadata, opts...) + } + return nil, C.ErrNotSupport +} + +func (r *refProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { + if r.proxyAdapter != nil { + return r.proxyAdapter.ListenPacketContext(ctx, metadata, opts...) + } + return nil, C.ErrNotSupport +} + +func (r *refProxyAdapter) SupportUOT() bool { + if r.proxyAdapter != nil { + return r.proxyAdapter.SupportUOT() + } + return false +} + +func (r *refProxyAdapter) SupportWithDialer() C.NetWork { + if r.proxyAdapter != nil { + return r.proxyAdapter.SupportWithDialer() + } + return C.InvalidNet +} + +func (r *refProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) { + if r.proxyAdapter != nil { + return r.proxyAdapter.DialContextWithDialer(ctx, dialer, metadata) + } + return nil, C.ErrNotSupport +} + +func (r *refProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) { + if r.proxyAdapter != nil { + return r.proxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata) + } + return nil, C.ErrNotSupport +} + +func (r *refProxyAdapter) IsL3Protocol(metadata *C.Metadata) bool { + if r.proxyAdapter != nil { + return r.proxyAdapter.IsL3Protocol(metadata) + } + return false +} + +func (r *refProxyAdapter) Unwrap(metadata *C.Metadata, touch bool) C.Proxy { + if r.proxyAdapter != nil { + return r.proxyAdapter.Unwrap(metadata, touch) + } + return nil +} + +var _ C.ProxyAdapter = (*refProxyAdapter)(nil)