chore: support golang1.20's dialer.ControlContext

This commit is contained in:
wwqgtxx 2023-02-13 11:14:19 +08:00
parent ce8929d153
commit ae42d35184
9 changed files with 96 additions and 65 deletions

View File

@ -1,6 +1,7 @@
package dialer package dialer
import ( import (
"context"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
@ -10,16 +11,8 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type controlFn = func(network, address string, c syscall.RawConn) error func bindControl(ifaceIdx int) controlFn {
return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
func bindControl(ifaceIdx int, chain controlFn) controlFn {
return func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
addrPort, err := netip.ParseAddrPort(address) addrPort, err := netip.ParseAddrPort(address)
if err == nil && !addrPort.Addr().IsGlobalUnicast() { if err == nil && !addrPort.Addr().IsGlobalUnicast() {
return return
@ -49,7 +42,7 @@ func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.A
return err return err
} }
dialer.Control = bindControl(ifaceObj.Index, dialer.Control) addControlToDialer(dialer, bindControl(ifaceObj.Index))
return nil return nil
} }
@ -59,7 +52,7 @@ func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address
return "", err return "", err
} }
lc.Control = bindControl(ifaceObj.Index, lc.Control) addControlToListenConfig(lc, bindControl(ifaceObj.Index))
return address, nil return address, nil
} }

View File

@ -1,6 +1,7 @@
package dialer package dialer
import ( import (
"context"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
@ -8,16 +9,8 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type controlFn = func(network, address string, c syscall.RawConn) error func bindControl(ifaceName string) controlFn {
return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
func bindControl(ifaceName string, chain controlFn) controlFn {
return func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
addrPort, err := netip.ParseAddrPort(address) addrPort, err := netip.ParseAddrPort(address)
if err == nil && !addrPort.Addr().IsGlobalUnicast() { if err == nil && !addrPort.Addr().IsGlobalUnicast() {
return return
@ -37,13 +30,13 @@ func bindControl(ifaceName string, chain controlFn) controlFn {
} }
func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.Addr) error { func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.Addr) error {
dialer.Control = bindControl(ifaceName, dialer.Control) addControlToDialer(dialer, bindControl(ifaceName))
return nil return nil
} }
func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address string) (string, error) { func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address string) (string, error) {
lc.Control = bindControl(ifaceName, lc.Control) addControlToListenConfig(lc, bindControl(ifaceName))
return address, nil return address, nil
} }

View File

@ -1,6 +1,7 @@
package dialer package dialer
import ( import (
"context"
"encoding/binary" "encoding/binary"
"net" "net"
"net/netip" "net/netip"
@ -26,16 +27,8 @@ func bind6(handle syscall.Handle, ifaceIdx int) error {
return syscall.SetsockoptInt(handle, syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx) return syscall.SetsockoptInt(handle, syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx)
} }
type controlFn = func(network, address string, c syscall.RawConn) error func bindControl(ifaceIdx int) controlFn {
return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
func bindControl(ifaceIdx int, chain controlFn) controlFn {
return func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
addrPort, err := netip.ParseAddrPort(address) addrPort, err := netip.ParseAddrPort(address)
if err == nil && !addrPort.Addr().IsGlobalUnicast() { if err == nil && !addrPort.Addr().IsGlobalUnicast() {
return return
@ -69,7 +62,7 @@ func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ netip.A
return err return err
} }
dialer.Control = bindControl(ifaceObj.Index, dialer.Control) addControlToDialer(dialer, bindControl(ifaceObj.Index))
return nil return nil
} }
@ -79,7 +72,7 @@ func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address
return "", err return "", err
} }
lc.Control = bindControl(ifaceObj.Index, lc.Control) addControlToListenConfig(lc, bindControl(ifaceObj.Index))
return address, nil return address, nil
} }

View File

@ -0,0 +1,22 @@
package dialer
import (
"context"
"net"
"syscall"
)
type controlFn = func(ctx context.Context, network, address string, c syscall.RawConn) error
func addControlToListenConfig(lc *net.ListenConfig, fn controlFn) {
llc := *lc
lc.Control = func(network, address string, c syscall.RawConn) (err error) {
switch {
case llc.Control != nil:
if err = llc.Control(network, address, c); err != nil {
return
}
}
return fn(context.Background(), network, address, c)
}
}

View File

@ -0,0 +1,22 @@
//go:build !go1.20
package dialer
import (
"context"
"net"
"syscall"
)
func addControlToDialer(d *net.Dialer, fn controlFn) {
ld := *d
d.Control = func(network, address string, c syscall.RawConn) (err error) {
switch {
case ld.Control != nil:
if err = ld.Control(network, address, c); err != nil {
return
}
}
return fn(context.Background(), network, address, c)
}
}

View File

@ -0,0 +1,26 @@
//go:build go1.20
package dialer
import (
"context"
"net"
"syscall"
)
func addControlToDialer(d *net.Dialer, fn controlFn) {
ld := *d
d.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
switch {
case ld.ControlContext != nil:
if err = ld.ControlContext(ctx, network, address, c); err != nil {
return
}
case ld.Control != nil:
if err = ld.Control(network, address, c); err != nil {
return
}
}
return fn(ctx, network, address, c)
}
}

View File

@ -3,26 +3,22 @@
package dialer package dialer
import ( import (
"context"
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
) )
func bindMarkToDialer(mark int, dialer *net.Dialer, _ string, _ netip.Addr) { func bindMarkToDialer(mark int, dialer *net.Dialer, _ string, _ netip.Addr) {
dialer.Control = bindMarkToControl(mark, dialer.Control) addControlToDialer(dialer, bindMarkToControl(mark))
} }
func bindMarkToListenConfig(mark int, lc *net.ListenConfig, _, _ string) { func bindMarkToListenConfig(mark int, lc *net.ListenConfig, _, _ string) {
lc.Control = bindMarkToControl(mark, lc.Control) addControlToListenConfig(lc, bindMarkToControl(mark))
} }
func bindMarkToControl(mark int, chain controlFn) controlFn { func bindMarkToControl(mark int) controlFn {
return func(network, address string, c syscall.RawConn) (err error) { return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
addrPort, err := netip.ParseAddrPort(address) addrPort, err := netip.ParseAddrPort(address)
if err == nil && !addrPort.Addr().IsGlobalUnicast() { if err == nil && !addrPort.Addr().IsGlobalUnicast() {

View File

@ -3,6 +3,7 @@
package dialer package dialer
import ( import (
"context"
"net" "net"
"syscall" "syscall"
@ -10,18 +11,10 @@ import (
) )
func addrReuseToListenConfig(lc *net.ListenConfig) { func addrReuseToListenConfig(lc *net.ListenConfig) {
chain := lc.Control addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error {
lc.Control = func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
}) })
} })
} }

View File

@ -1,6 +1,7 @@
package dialer package dialer
import ( import (
"context"
"net" "net"
"syscall" "syscall"
@ -8,17 +9,9 @@ import (
) )
func addrReuseToListenConfig(lc *net.ListenConfig) { func addrReuseToListenConfig(lc *net.ListenConfig) {
chain := lc.Control addControlToListenConfig(lc, func(ctx context.Context, network, address string, c syscall.RawConn) error {
lc.Control = func(network, address string, c syscall.RawConn) (err error) {
defer func() {
if err == nil && chain != nil {
err = chain(network, address, c)
}
}()
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1) windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1)
}) })
} })
} }