From cd2d1c6bb0e88c8833fed7e335f3874f78ced4d6 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 27 Sep 2024 18:10:05 +0800 Subject: [PATCH] fix: `skip-auth-prefixes` not apply on listeners when `users` is unset --- component/auth/auth.go | 5 +++++ hub/executor/executor.go | 4 ++-- listener/auth/auth.go | 30 +++++++++++++++++++++++------- listener/http/proxy.go | 4 ++-- listener/http/server.go | 18 +++++++++--------- listener/inbound/auth.go | 6 +++--- listener/inbound/http.go | 2 +- listener/inbound/mixed.go | 2 +- listener/inbound/socks.go | 2 +- listener/mixed/mixed.go | 20 ++++++++++---------- listener/socks/tcp.go | 26 +++++++++++++------------- 11 files changed, 70 insertions(+), 49 deletions(-) diff --git a/component/auth/auth.go b/component/auth/auth.go index b52fa135..176b21d7 100644 --- a/component/auth/auth.go +++ b/component/auth/auth.go @@ -5,6 +5,11 @@ type Authenticator interface { Users() []string } +type AuthStore interface { + Authenticator() Authenticator + SetAuthenticator(Authenticator) +} + type AuthUser struct { User string Pass string diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 66bbc89b..39bf28d2 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -127,7 +127,7 @@ func initInnerTcp() { func GetGeneral() *config.General { ports := listener.GetPorts() var authenticator []string - if auth := authStore.Authenticator(); auth != nil { + if auth := authStore.Default.Authenticator(); auth != nil { authenticator = auth.Users() } @@ -422,7 +422,7 @@ func updateGeneral(general *config.General) { func updateUsers(users []auth.AuthUser) { authenticator := auth.NewAuthenticator(users) - authStore.SetAuthenticator(authenticator) + authStore.Default.SetAuthenticator(authenticator) if authenticator != nil { log.Infoln("Authentication of local server updated") } diff --git a/listener/auth/auth.go b/listener/auth/auth.go index 772be3bd..9e7632e8 100644 --- a/listener/auth/auth.go +++ b/listener/auth/auth.go @@ -4,14 +4,30 @@ import ( "github.com/metacubex/mihomo/component/auth" ) -var authenticator auth.Authenticator - -func Authenticator() auth.Authenticator { - return authenticator +type authStore struct { + authenticator auth.Authenticator } -func SetAuthenticator(au auth.Authenticator) { - authenticator = au +func (a *authStore) Authenticator() auth.Authenticator { + return a.authenticator } -func Nil() auth.Authenticator { return nil } +func (a *authStore) SetAuthenticator(authenticator auth.Authenticator) { + a.authenticator = authenticator +} + +func NewAuthStore(authenticator auth.Authenticator) auth.AuthStore { + return &authStore{authenticator} +} + +var Default auth.AuthStore = NewAuthStore(nil) + +type nilAuthStore struct{} + +func (a *nilAuthStore) Authenticator() auth.Authenticator { + return nil +} + +func (a *nilAuthStore) SetAuthenticator(authenticator auth.Authenticator) {} + +var Nil auth.AuthStore = (*nilAuthStore)(nil) // always return nil, even call SetAuthenticator() with a non-nil authenticator diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 04ab98eb..5c08cd45 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -30,7 +30,7 @@ func (b *bodyWrapper) Read(p []byte) (n int, err error) { return n, err } -func HandleConn(c net.Conn, tunnel C.Tunnel, getAuth func() auth.Authenticator, additions ...inbound.Addition) { +func HandleConn(c net.Conn, tunnel C.Tunnel, store auth.AuthStore, additions ...inbound.Addition) { additions = append(additions, inbound.Placeholder) // Add a placeholder for InUser inUserIdx := len(additions) - 1 client := newClient(c, tunnel, additions) @@ -41,7 +41,7 @@ func HandleConn(c net.Conn, tunnel C.Tunnel, getAuth func() auth.Authenticator, conn := N.NewBufferedConn(c) - authenticator := getAuth() + authenticator := store.Authenticator() keepAlive := true trusted := authenticator == nil // disable authenticate if lru is nil lastUser := "" diff --git a/listener/http/server.go b/listener/http/server.go index 04f32f4f..24f07e8b 100644 --- a/listener/http/server.go +++ b/listener/http/server.go @@ -32,20 +32,20 @@ func (l *Listener) Close() error { } func New(addr string, tunnel C.Tunnel, additions ...inbound.Addition) (*Listener, error) { - return NewWithAuthenticator(addr, tunnel, authStore.Authenticator, additions...) + return NewWithAuthenticator(addr, tunnel, authStore.Default, additions...) } // NewWithAuthenticate // never change type traits because it's used in CMFA func NewWithAuthenticate(addr string, tunnel C.Tunnel, authenticate bool, additions ...inbound.Addition) (*Listener, error) { - getAuth := authStore.Authenticator + store := authStore.Default if !authenticate { - getAuth = authStore.Nil + store = authStore.Default } - return NewWithAuthenticator(addr, tunnel, getAuth, additions...) + return NewWithAuthenticator(addr, tunnel, store, additions...) } -func NewWithAuthenticator(addr string, tunnel C.Tunnel, getAuth func() auth.Authenticator, additions ...inbound.Addition) (*Listener, error) { +func NewWithAuthenticator(addr string, tunnel C.Tunnel, store auth.AuthStore, additions ...inbound.Addition) (*Listener, error) { isDefault := false if len(additions) == 0 { isDefault = true @@ -74,17 +74,17 @@ func NewWithAuthenticator(addr string, tunnel C.Tunnel, getAuth func() auth.Auth continue } - getAuth := getAuth - if isDefault { // only apply on default listener + store := store + if isDefault || store == authStore.Default { // only apply on default listener if !inbound.IsRemoteAddrDisAllowed(conn.RemoteAddr()) { _ = conn.Close() continue } if inbound.SkipAuthRemoteAddr(conn.RemoteAddr()) { - getAuth = authStore.Nil + store = authStore.Nil } } - go HandleConn(conn, tunnel, getAuth, additions...) + go HandleConn(conn, tunnel, store, additions...) } }() diff --git a/listener/inbound/auth.go b/listener/inbound/auth.go index 41f18fc0..85e72494 100644 --- a/listener/inbound/auth.go +++ b/listener/inbound/auth.go @@ -12,7 +12,7 @@ type AuthUser struct { type AuthUsers []AuthUser -func (a AuthUsers) GetAuth() func() auth.Authenticator { +func (a AuthUsers) GetAuthStore() auth.AuthStore { if a != nil { // structure's Decode will ensure value not nil when input has value even it was set an empty array if len(a) == 0 { return authStore.Nil @@ -25,7 +25,7 @@ func (a AuthUsers) GetAuth() func() auth.Authenticator { } } authenticator := auth.NewAuthenticator(users) - return func() auth.Authenticator { return authenticator } + return authStore.NewAuthStore(authenticator) } - return authStore.Authenticator + return authStore.Default } diff --git a/listener/inbound/http.go b/listener/inbound/http.go index c78abefd..e20a9a23 100644 --- a/listener/inbound/http.go +++ b/listener/inbound/http.go @@ -45,7 +45,7 @@ func (h *HTTP) Address() string { // Listen implements constant.InboundListener func (h *HTTP) Listen(tunnel C.Tunnel) error { var err error - h.l, err = http.NewWithAuthenticator(h.RawAddress(), tunnel, h.config.Users.GetAuth(), h.Additions()...) + h.l, err = http.NewWithAuthenticator(h.RawAddress(), tunnel, h.config.Users.GetAuthStore(), h.Additions()...) if err != nil { return err } diff --git a/listener/inbound/mixed.go b/listener/inbound/mixed.go index 443a2564..1d79929a 100644 --- a/listener/inbound/mixed.go +++ b/listener/inbound/mixed.go @@ -53,7 +53,7 @@ func (m *Mixed) Address() string { // Listen implements constant.InboundListener func (m *Mixed) Listen(tunnel C.Tunnel) error { var err error - m.l, err = mixed.NewWithAuthenticator(m.RawAddress(), tunnel, m.config.Users.GetAuth(), m.Additions()...) + m.l, err = mixed.NewWithAuthenticator(m.RawAddress(), tunnel, m.config.Users.GetAuthStore(), m.Additions()...) if err != nil { return err } diff --git a/listener/inbound/socks.go b/listener/inbound/socks.go index cf6d1ce4..119eec82 100644 --- a/listener/inbound/socks.go +++ b/listener/inbound/socks.go @@ -71,7 +71,7 @@ func (s *Socks) Address() string { // Listen implements constant.InboundListener func (s *Socks) Listen(tunnel C.Tunnel) error { var err error - if s.stl, err = socks.NewWithAuthenticator(s.RawAddress(), tunnel, s.config.Users.GetAuth(), s.Additions()...); err != nil { + if s.stl, err = socks.NewWithAuthenticator(s.RawAddress(), tunnel, s.config.Users.GetAuthStore(), s.Additions()...); err != nil { return err } if s.udp { diff --git a/listener/mixed/mixed.go b/listener/mixed/mixed.go index ac3a0c58..5ac63011 100644 --- a/listener/mixed/mixed.go +++ b/listener/mixed/mixed.go @@ -37,10 +37,10 @@ func (l *Listener) Close() error { } func New(addr string, tunnel C.Tunnel, additions ...inbound.Addition) (*Listener, error) { - return NewWithAuthenticator(addr, tunnel, authStore.Authenticator, additions...) + return NewWithAuthenticator(addr, tunnel, authStore.Default, additions...) } -func NewWithAuthenticator(addr string, tunnel C.Tunnel, getAuth func() auth.Authenticator, additions ...inbound.Addition) (*Listener, error) { +func NewWithAuthenticator(addr string, tunnel C.Tunnel, store auth.AuthStore, additions ...inbound.Addition) (*Listener, error) { isDefault := false if len(additions) == 0 { isDefault = true @@ -68,24 +68,24 @@ func NewWithAuthenticator(addr string, tunnel C.Tunnel, getAuth func() auth.Auth } continue } - getAuth := getAuth - if isDefault { // only apply on default listener + store := store + if isDefault || store == authStore.Default { // only apply on default listener if !inbound.IsRemoteAddrDisAllowed(c.RemoteAddr()) { _ = c.Close() continue } if inbound.SkipAuthRemoteAddr(c.RemoteAddr()) { - getAuth = authStore.Nil + store = authStore.Nil } } - go handleConn(c, tunnel, getAuth, additions...) + go handleConn(c, tunnel, store, additions...) } }() return ml, nil } -func handleConn(conn net.Conn, tunnel C.Tunnel, getAuth func() auth.Authenticator, additions ...inbound.Addition) { +func handleConn(conn net.Conn, tunnel C.Tunnel, store auth.AuthStore, additions ...inbound.Addition) { bufConn := N.NewBufferedConn(conn) head, err := bufConn.Peek(1) if err != nil { @@ -94,10 +94,10 @@ func handleConn(conn net.Conn, tunnel C.Tunnel, getAuth func() auth.Authenticato switch head[0] { case socks4.Version: - socks.HandleSocks4(bufConn, tunnel, getAuth, additions...) + socks.HandleSocks4(bufConn, tunnel, store, additions...) case socks5.Version: - socks.HandleSocks5(bufConn, tunnel, getAuth, additions...) + socks.HandleSocks5(bufConn, tunnel, store, additions...) default: - http.HandleConn(bufConn, tunnel, getAuth, additions...) + http.HandleConn(bufConn, tunnel, store, additions...) } } diff --git a/listener/socks/tcp.go b/listener/socks/tcp.go index 950384c1..cc66613e 100644 --- a/listener/socks/tcp.go +++ b/listener/socks/tcp.go @@ -36,10 +36,10 @@ func (l *Listener) Close() error { } func New(addr string, tunnel C.Tunnel, additions ...inbound.Addition) (*Listener, error) { - return NewWithAuthenticator(addr, tunnel, authStore.Authenticator, additions...) + return NewWithAuthenticator(addr, tunnel, authStore.Default, additions...) } -func NewWithAuthenticator(addr string, tunnel C.Tunnel, getAuth func() auth.Authenticator, additions ...inbound.Addition) (*Listener, error) { +func NewWithAuthenticator(addr string, tunnel C.Tunnel, store auth.AuthStore, additions ...inbound.Addition) (*Listener, error) { isDefault := false if len(additions) == 0 { isDefault = true @@ -67,24 +67,24 @@ func NewWithAuthenticator(addr string, tunnel C.Tunnel, getAuth func() auth.Auth } continue } - getAuth := getAuth - if isDefault { // only apply on default listener + store := store + if isDefault || store == authStore.Default { // only apply on default listener if !inbound.IsRemoteAddrDisAllowed(c.RemoteAddr()) { _ = c.Close() continue } if inbound.SkipAuthRemoteAddr(c.RemoteAddr()) { - getAuth = authStore.Nil + store = authStore.Nil } } - go handleSocks(c, tunnel, getAuth, additions...) + go handleSocks(c, tunnel, store, additions...) } }() return sl, nil } -func handleSocks(conn net.Conn, tunnel C.Tunnel, getAuth func() auth.Authenticator, additions ...inbound.Addition) { +func handleSocks(conn net.Conn, tunnel C.Tunnel, store auth.AuthStore, additions ...inbound.Addition) { bufConn := N.NewBufferedConn(conn) head, err := bufConn.Peek(1) if err != nil { @@ -94,16 +94,16 @@ func handleSocks(conn net.Conn, tunnel C.Tunnel, getAuth func() auth.Authenticat switch head[0] { case socks4.Version: - HandleSocks4(bufConn, tunnel, getAuth, additions...) + HandleSocks4(bufConn, tunnel, store, additions...) case socks5.Version: - HandleSocks5(bufConn, tunnel, getAuth, additions...) + HandleSocks5(bufConn, tunnel, store, additions...) default: conn.Close() } } -func HandleSocks4(conn net.Conn, tunnel C.Tunnel, getAuth func() auth.Authenticator, additions ...inbound.Addition) { - authenticator := getAuth() +func HandleSocks4(conn net.Conn, tunnel C.Tunnel, store auth.AuthStore, additions ...inbound.Addition) { + authenticator := store.Authenticator() addr, _, user, err := socks4.ServerHandshake(conn, authenticator) if err != nil { conn.Close() @@ -113,8 +113,8 @@ func HandleSocks4(conn net.Conn, tunnel C.Tunnel, getAuth func() auth.Authentica tunnel.HandleTCPConn(inbound.NewSocket(socks5.ParseAddr(addr), conn, C.SOCKS4, additions...)) } -func HandleSocks5(conn net.Conn, tunnel C.Tunnel, getAuth func() auth.Authenticator, additions ...inbound.Addition) { - authenticator := getAuth() +func HandleSocks5(conn net.Conn, tunnel C.Tunnel, store auth.AuthStore, additions ...inbound.Addition) { + authenticator := store.Authenticator() target, command, user, err := socks5.ServerHandshake(conn, authenticator) if err != nil { conn.Close()