Wintun: use new swdevice-based API for upcoming Wintun 0.14

This commit is contained in:
yaling888 2021-11-03 15:10:31 +08:00
parent 78cef7df59
commit c824ace2d7
37 changed files with 632 additions and 5770 deletions

View File

@ -25,12 +25,13 @@ jobs:
restore-keys: |
${{ runner.os }}-go-
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: Get dependencies, run test
run: |
# install python3.9
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt install -y python3.9 python3.9-dev
# fetch python cross compile source files
mkdir -p bin/python/
cd bin/python/

11
go.mod
View File

@ -20,13 +20,14 @@ require (
go.etcd.io/bbolt v1.3.6
go.uber.org/atomic v1.9.0
go.uber.org/automaxprocs v1.4.0
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5
golang.org/x/net v0.0.0-20210903162142-ad29c8ab022f
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
golang.org/x/net v0.0.0-20211029160332-540bb53d3b2e
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34
golang.org/x/sys v0.0.0-20211029165221-6e7872819dc8
golang.zx2c4.com/wireguard/windows v0.5.1
google.golang.org/protobuf v1.27.1
gopkg.in/yaml.v2 v2.4.0
gvisor.dev/gvisor v0.0.0-20210922003438-b39716d116fd
gvisor.dev/gvisor v0.0.0-20211102011804-04d474e716e0
)
require (
@ -35,7 +36,7 @@ require (
github.com/oschwald/maxminddb-golang v1.8.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/u-root/uio v0.0.0-20210528114334-82958018845c // indirect
golang.org/x/text v0.3.6 // indirect
golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

936
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -191,9 +191,8 @@ func updateGeneral(general *config.General, force bool) {
if err == nil {
if autoDetectInterfaceName != "" && autoDetectInterfaceName != "<nil>" {
general.Interface = autoDetectInterfaceName
log.Infoln("Use auto detect interface: %s", general.Interface)
} else {
log.Debugln("Auto detect interface is empty.")
log.Debugln("Auto detect interface name is empty.")
}
} else {
log.Debugln("Can not find auto detect interface. %s", err.Error())
@ -202,6 +201,7 @@ func updateGeneral(general *config.General, force bool) {
if general.Interface != "" {
dialer.DefaultOptions = []dialer.Option{dialer.WithInterface(general.Interface)}
log.Infoln("Use interface name: %s", general.Interface)
} else {
dialer.DefaultOptions = nil
}

View File

@ -114,7 +114,6 @@ var sockaddrCtlSize uintptr = 32
// OpenTunDevice return a TunDevice according a URL
func OpenTunDevice(tunAddress string, autoRoute bool) (TunDevice, error) {
name := "utun"
// TODO: configure the MTU
mtu := 9000
ifIndex := -1
@ -212,8 +211,7 @@ func CreateTUNFromFile(file *os.File, mtu int, tunAddress string, autoRoute bool
}
// This address doesn't mean anything here. NIC just net an IP address to set route upon.
// TODO: maybe let user config it. And I'm doubt whether we really need it.
p2pAddress := net.ParseIP("198.18.0.1")
p2pAddress := net.ParseIP(tunAddress)
err = tun.setTunAddress(p2pAddress)
if err != nil {
tun.Close()

View File

@ -7,20 +7,14 @@
package dev
import (
"bytes"
"errors"
"fmt"
"net"
"os"
"sort"
"sync"
"sync/atomic"
"time"
_ "unsafe"
"github.com/Dreamacro/clash/listener/tun/dev/winipcfg"
"github.com/Dreamacro/clash/listener/tun/dev/wintun"
"github.com/Dreamacro/clash/log"
"golang.org/x/sys/windows"
)
@ -28,8 +22,6 @@ const (
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
spinloopRateThreshold = 800000000 / 8 // 800mbps
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
messageTransportHeaderSize = 0 // size of data preceding content in transport message
)
type rateJuggler struct {
@ -39,24 +31,7 @@ type rateJuggler struct {
changing int32
}
type tunWindows struct {
wt *wintun.Adapter
handle windows.Handle
close int32
running sync.WaitGroup
forcedMTU int
rate rateJuggler
session wintun.Session
readWait windows.Handle
stopOnce sync.Once
url string
name string
tunAddress string
autoRoute bool
}
var WintunPool, _ = wintun.MakePool("Clash")
var WintunTunnelType = "Clash"
var WintunStaticRequestedGUID *windows.GUID
//go:linkname procyield runtime.procyield
@ -65,128 +40,18 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
// OpenTunDevice return a TunDevice according a URL
func OpenTunDevice(tunAddress string, autoRoute bool) (TunDevice, error) {
requestedGUID, err := windows.GUIDFromString("{330EAEF8-7578-5DF2-D97B-8DADC0EA85CB}")
if err == nil {
WintunStaticRequestedGUID = &requestedGUID
log.Debugln("Generate GUID: %s", WintunStaticRequestedGUID.String())
} else {
log.Warnln("Error parese GUID from string: %v", err)
}
interfaceName := "Clash"
mtu := 9000
tun, err := CreateTUN(interfaceName, mtu, tunAddress, autoRoute)
if err != nil {
return nil, err
}
return tun, nil
}
//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
//
func CreateTUN(ifname string, mtu int, tunAddress string, autoRoute bool) (TunDevice, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu, tunAddress, autoRoute)
}
//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int, tunAddress string, autoRoute bool) (TunDevice, error) {
var err error
var wt *wintun.Adapter
// Does an interface with this name already exist?
wt, err = WintunPool.OpenAdapter(ifname)
if err == nil {
// If so, we delete it, in case it has weird residual configuration.
_, err = wt.Delete(false)
if err != nil {
return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
}
}
wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err)
}
if rebootRequired {
log.Infoln("Windows indicated a reboot is required.")
}
forcedMTU := 1420
if mtu > 0 {
forcedMTU = mtu
}
tun := &tunWindows{
wt: wt,
handle: windows.InvalidHandle,
forcedMTU: forcedMTU,
tunAddress: tunAddress,
autoRoute: autoRoute,
}
// config tun ip
err = tun.configureInterface()
if err != nil {
tun.wt.Delete(false)
return nil, fmt.Errorf("Error configure interface: %w", err)
}
realInterfaceName, err2 := wt.Name()
if err2 == nil {
ifname = realInterfaceName
tun.name = realInterfaceName
}
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
tun.wt.Delete(false)
return nil, fmt.Errorf("Error starting session: %w", err)
}
tun.readWait = tun.session.ReadWaitEvent()
return tun, nil
}
func (tun *tunWindows) getName() (string, error) {
tun.running.Add(1)
defer tun.running.Done()
if atomic.LoadInt32(&tun.close) == 1 {
return "", os.ErrClosed
}
return tun.wt.Name()
}
func (tun *tunWindows) IsClose() bool {
return atomic.LoadInt32(&tun.close) == 1
}
func (tun *tunWindows) Close() error {
tun.stopOnce.Do(func() {
var err error
tun.closeOnce.Do(func() {
atomic.StoreInt32(&tun.close, 1)
windows.SetEvent(tun.readWait)
//tun.running.Wait()
tun.session.End()
if tun.wt != nil {
forceCloseSessions := false
rebootRequired, err := tun.wt.Delete(forceCloseSessions)
if rebootRequired {
log.Infoln("Remove Wintun failure, Windows indicated a reboot is required.")
} else {
log.Infoln("Remove Wintun adapter success.")
}
if err != nil {
log.Errorln("Close Wintun Sessions failure: %v", err)
}
err = tun.wt.Close()
}
})
return nil
return err
}
func (tun *tunWindows) MTU() (int, error) {
@ -198,13 +63,9 @@ func (tun *tunWindows) ForceMTU(mtu int) {
tun.forcedMTU = mtu
}
func (tun *tunWindows) Read(buff []byte) (int, error) {
return tun.ReadO(buff, messageTransportHeaderSize)
}
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *tunWindows) ReadO(buff []byte, offset int) (int, error) {
func (tun *tunWindows) Read0(buff []byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
retry:
@ -245,20 +106,13 @@ func (tun *tunWindows) Flush() error {
return nil
}
func (tun *tunWindows) Write(buff []byte) (int, error) {
return tun.WriteO(buff, messageTransportHeaderSize)
}
func (tun *tunWindows) WriteO(buff []byte, offset int) (int, error) {
func (tun *tunWindows) Write0(buff []byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
if atomic.LoadInt32(&tun.close) == 1 {
return 0, os.ErrClosed
}
if len(buff) == 0 {
return 0, nil
}
packetSize := len(buff) - offset
tun.rate.update(uint64(packetSize))
@ -306,257 +160,3 @@ func (rate *rateJuggler) update(packetLen uint64) {
atomic.StoreInt32(&rate.changing, 0)
}
}
func (tun *tunWindows) Name() string {
return tun.name
}
func (t *tunWindows) URL() string {
return fmt.Sprintf("dev://%s", t.Name())
}
func (tun *tunWindows) configureInterface() error {
luid := winipcfg.LUID(tun.LUID())
log.Infoln("[wintun]: tun adapter LUID: %d", luid)
mtu, err := tun.MTU()
if err != nil {
return errors.New("unable to get device mtu")
}
family := winipcfg.AddressFamily(windows.AF_INET)
familyV6 := winipcfg.AddressFamily(windows.AF_INET6)
tunAddress := winipcfg.ParseIPCidr("198.18.0.1/16")
addresses := []net.IPNet{tunAddress.IPNet()}
err = luid.FlushIPAddresses(familyV6)
if err != nil {
return err
}
err = luid.FlushDNS(family)
if err != nil {
return err
}
err = luid.FlushDNS(familyV6)
if err != nil {
return err
}
err = luid.FlushRoutes(familyV6)
//if err != nil {
// return err
//}
err = luid.SetIPAddressesForFamily(family, addresses)
if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
cleanupAddressesOnDisconnectedInterfaces(family, addresses)
err = luid.SetIPAddressesForFamily(family, addresses)
}
if err != nil {
return fmt.Errorf("unable to set ips %+v: %w", addresses, err)
}
foundDefault4 := false
foundDefault6 := false
if tun.autoRoute {
allowedIPs := []*winipcfg.IPCidr{
//winipcfg.ParseIPCidr("0.0.0.0/0"),
winipcfg.ParseIPCidr("1.0.0.0/8"),
winipcfg.ParseIPCidr("2.0.0.0/7"),
winipcfg.ParseIPCidr("4.0.0.0/6"),
winipcfg.ParseIPCidr("8.0.0.0/5"),
winipcfg.ParseIPCidr("16.0.0.0/4"),
winipcfg.ParseIPCidr("32.0.0.0/3"),
winipcfg.ParseIPCidr("64.0.0.0/2"),
winipcfg.ParseIPCidr("128.0.0.0/1"),
//winipcfg.ParseIPCidr("198.18.0.0/16"),
//winipcfg.ParseIPCidr("198.18.0.1/32"),
//winipcfg.ParseIPCidr("198.18.255.255/32"),
winipcfg.ParseIPCidr("224.0.0.0/4"),
winipcfg.ParseIPCidr("255.255.255.255/32"),
}
estimatedRouteCount := len(allowedIPs)
routes := make([]winipcfg.RouteData, 0, estimatedRouteCount)
var haveV4Address, haveV6Address bool = true, false
for _, allowedip := range allowedIPs {
allowedip.MaskSelf()
if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) {
continue
}
route := winipcfg.RouteData{
Destination: allowedip.IPNet(),
Metric: 0,
}
if allowedip.Bits() == 32 {
if allowedip.Cidr == 0 {
foundDefault4 = true
}
route.NextHop = net.IPv4zero
} else if allowedip.Bits() == 128 {
if allowedip.Cidr == 0 {
foundDefault6 = true
}
route.NextHop = net.IPv6zero
}
routes = append(routes, route)
}
deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
sort.Slice(routes, func(i, j int) bool {
if routes[i].Metric != routes[j].Metric {
return routes[i].Metric < routes[j].Metric
}
if c := bytes.Compare(routes[i].NextHop, routes[j].NextHop); c != 0 {
return c < 0
}
if c := bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP); c != 0 {
return c < 0
}
if c := bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask); c != 0 {
return c < 0
}
return false
})
for i := 0; i < len(routes); i++ {
if i > 0 && routes[i].Metric == routes[i-1].Metric &&
bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) &&
bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
continue
}
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
if err != nil {
return fmt.Errorf("unable to set routes %+v: %w", deduplicatedRoutes, err)
}
}
ipif, err := luid.IPInterface(family)
if err != nil {
return err
}
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
ipif.DadTransmits = 0
ipif.ManagedAddressConfigurationSupported = false
ipif.OtherStatefulConfigurationSupported = false
ipif.NLMTU = uint32(mtu)
if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
err = ipif.Set()
if err != nil {
return fmt.Errorf("unable to set metric and MTU: %w", err)
}
ipif6, err := luid.IPInterface(familyV6)
ipif6.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
ipif6.DadTransmits = 0
ipif6.ManagedAddressConfigurationSupported = false
ipif6.OtherStatefulConfigurationSupported = false
if err != nil {
return err
}
err = ipif6.Set()
if err != nil {
return fmt.Errorf("unable to set v6 metric and MTU: %w", err)
}
err = luid.SetDNS(family, []net.IP{net.ParseIP("198.18.0.2")}, nil)
if err != nil {
return fmt.Errorf("unable to set DNS %s %s: %w", "198.18.0.2", "nil", err)
}
return nil
}
func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) {
if len(addresses) == 0 {
return
}
includedInAddresses := func(a net.IPNet) bool {
// TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer!
for _, addr := range addresses {
ip := addr.IP
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
mA, _ := addr.Mask.Size()
mB, _ := a.Mask.Size()
if bytes.Equal(ip, a.IP) && mA == mB {
return true
}
}
return false
}
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
if err != nil {
return
}
for _, iface := range interfaces {
if iface.OperStatus == winipcfg.IfOperStatusUp {
continue
}
for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
ip := address.Address.IP()
ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))}
if includedInAddresses(ipnet) {
log.Infoln("[Wintun] Cleaning up stale address %s from interface %s", ipnet.String(), iface.FriendlyName())
iface.LUID.DeleteIPAddress(ipnet)
}
}
}
}
// GetAutoDetectInterface get ethernet interface
func GetAutoDetectInterface() (string, error) {
ifname, err := getAutoDetectInterfaceByFamily(winipcfg.AddressFamily(windows.AF_INET))
if err == nil {
return ifname, err
}
return getAutoDetectInterfaceByFamily(winipcfg.AddressFamily(windows.AF_INET6))
}
func getAutoDetectInterfaceByFamily(family winipcfg.AddressFamily) (string, error) {
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeGateways)
if err != nil {
return "", fmt.Errorf("find ethernet interface failure. %w", err)
}
for _, iface := range interfaces {
if iface.OperStatus != winipcfg.IfOperStatusUp {
continue
}
ifname := iface.FriendlyName()
if ifname == "Clash" {
continue
}
for gatewayAddress := iface.FirstGatewayAddress; gatewayAddress != nil; gatewayAddress = gatewayAddress.Next {
nextHop := gatewayAddress.Address.IP()
var ipnet net.IPNet
if family == windows.AF_INET {
ipnet = net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
} else {
ipnet = net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
}
if _, err = iface.LUID.Route(ipnet, nextHop); err == nil {
return ifname, nil
}
}
}
return "", errors.New("ethernet interface not found")
}

View File

@ -0,0 +1,398 @@
//go:build windows
// +build windows
package dev
import (
"bytes"
"errors"
"fmt"
"net"
"sort"
"sync"
"sync/atomic"
"time"
"github.com/Dreamacro/clash/listener/tun/dev/wintun"
"github.com/Dreamacro/clash/log"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
const messageTransportHeaderSize = 0 // size of data preceding content in transport message
type tunWindows struct {
wt *wintun.Adapter
handle windows.Handle
close int32
running sync.WaitGroup
forcedMTU int
rate rateJuggler
session wintun.Session
readWait windows.Handle
closeOnce sync.Once
url string
name string
tunAddress string
autoRoute bool
}
// OpenTunDevice return a TunDevice according a URL
func OpenTunDevice(tunAddress string, autoRoute bool) (TunDevice, error) {
requestedGUID, err := windows.GUIDFromString("{330EAEF8-7578-5DF2-D97B-8DADC0EA85CB}")
if err == nil {
WintunStaticRequestedGUID = &requestedGUID
log.Debugln("Generate GUID: %s", WintunStaticRequestedGUID.String())
} else {
log.Warnln("Error parese GUID from string: %v", err)
}
interfaceName := "Clash"
mtu := 9000
tun, err := CreateTUN(interfaceName, mtu, tunAddress, autoRoute)
if err != nil {
return nil, err
}
return tun, nil
}
//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
//
func CreateTUN(ifname string, mtu int, tunAddress string, autoRoute bool) (TunDevice, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu, tunAddress, autoRoute)
}
//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int, tunAddress string, autoRoute bool) (TunDevice, error) {
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err)
}
forcedMTU := 1420
if mtu > 0 {
forcedMTU = mtu
}
tun := &tunWindows{
name: ifname,
wt: wt,
handle: windows.InvalidHandle,
forcedMTU: forcedMTU,
tunAddress: tunAddress,
autoRoute: autoRoute,
}
// config tun ip
err = tun.configureInterface()
if err != nil {
tun.wt.Close()
return nil, fmt.Errorf("error configure interface: %w", err)
}
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
tun.wt.Close()
return nil, fmt.Errorf("error starting session: %w", err)
}
tun.readWait = tun.session.ReadWaitEvent()
return tun, nil
}
func (tun *tunWindows) Name() string {
return tun.name
}
func (tun *tunWindows) IsClose() bool {
return atomic.LoadInt32(&tun.close) == 1
}
func (tun *tunWindows) Read(buff []byte) (int, error) {
return tun.Read0(buff, messageTransportHeaderSize)
}
func (tun *tunWindows) Write(buff []byte) (int, error) {
return tun.Write0(buff, messageTransportHeaderSize)
}
func (tun *tunWindows) URL() string {
return fmt.Sprintf("dev://%s", tun.Name())
}
func (tun *tunWindows) configureInterface() error {
retryOnFailure := wintun.StartedAtBoot()
tryTimes := 0
startOver:
var err error
if tryTimes > 0 {
log.Infoln("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
time.Sleep(time.Second)
retryOnFailure = retryOnFailure && tryTimes < 15
}
tryTimes++
luid := winipcfg.LUID(tun.LUID())
log.Infoln("[wintun]: tun adapter LUID: %d", luid)
mtu, err := tun.MTU()
if err != nil {
return errors.New("unable to get device mtu")
}
family := winipcfg.AddressFamily(windows.AF_INET)
familyV6 := winipcfg.AddressFamily(windows.AF_INET6)
tunAddress := wintun.ParseIPCidr(tun.tunAddress + "/16")
addresses := []net.IPNet{tunAddress.IPNet()}
err = luid.FlushIPAddresses(familyV6)
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return err
}
err = luid.FlushDNS(family)
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return err
}
err = luid.FlushDNS(familyV6)
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return err
}
err = luid.FlushRoutes(familyV6)
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return err
}
foundDefault4 := false
foundDefault6 := false
if tun.autoRoute {
allowedIPs := []*wintun.IPCidr{
//wintun.ParseIPCidr("0.0.0.0/0"),
wintun.ParseIPCidr("1.0.0.0/8"),
wintun.ParseIPCidr("2.0.0.0/7"),
wintun.ParseIPCidr("4.0.0.0/6"),
wintun.ParseIPCidr("8.0.0.0/5"),
wintun.ParseIPCidr("16.0.0.0/4"),
wintun.ParseIPCidr("32.0.0.0/3"),
wintun.ParseIPCidr("64.0.0.0/2"),
wintun.ParseIPCidr("128.0.0.0/1"),
wintun.ParseIPCidr("224.0.0.0/4"),
wintun.ParseIPCidr("255.255.255.255/32"),
}
estimatedRouteCount := len(allowedIPs)
routes := make([]winipcfg.RouteData, 0, estimatedRouteCount)
var haveV4Address, haveV6Address bool = true, false
for _, allowedip := range allowedIPs {
allowedip.MaskSelf()
if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) {
continue
}
route := winipcfg.RouteData{
Destination: allowedip.IPNet(),
Metric: 0,
}
if allowedip.Bits() == 32 {
if allowedip.Cidr == 0 {
foundDefault4 = true
}
route.NextHop = net.IPv4zero
} else if allowedip.Bits() == 128 {
if allowedip.Cidr == 0 {
foundDefault6 = true
}
route.NextHop = net.IPv6zero
}
routes = append(routes, route)
}
deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
sort.Slice(routes, func(i, j int) bool {
if routes[i].Metric != routes[j].Metric {
return routes[i].Metric < routes[j].Metric
}
if c := bytes.Compare(routes[i].NextHop, routes[j].NextHop); c != 0 {
return c < 0
}
if c := bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP); c != 0 {
return c < 0
}
if c := bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask); c != 0 {
return c < 0
}
return false
})
for i := 0; i < len(routes); i++ {
if i > 0 && routes[i].Metric == routes[i-1].Metric &&
bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) &&
bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
continue
}
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return fmt.Errorf("unable to set routes: %w", err)
}
}
err = luid.SetIPAddressesForFamily(family, addresses)
if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
cleanupAddressesOnDisconnectedInterfaces(family, addresses)
err = luid.SetIPAddressesForFamily(family, addresses)
}
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return fmt.Errorf("unable to set ips: %w", err)
}
var ipif *winipcfg.MibIPInterfaceRow
ipif, err = luid.IPInterface(family)
if err != nil {
return err
}
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
ipif.DadTransmits = 0
ipif.ManagedAddressConfigurationSupported = false
ipif.OtherStatefulConfigurationSupported = false
if mtu > 0 {
ipif.NLMTU = uint32(mtu)
}
if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
err = ipif.Set()
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return fmt.Errorf("unable to set metric and MTU: %w", err)
}
var ipif6 *winipcfg.MibIPInterfaceRow
ipif6, err = luid.IPInterface(familyV6)
if err != nil {
return err
}
ipif6.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
ipif6.DadTransmits = 0
ipif6.ManagedAddressConfigurationSupported = false
ipif6.OtherStatefulConfigurationSupported = false
err = ipif6.Set()
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return fmt.Errorf("unable to set v6 metric and MTU: %w", err)
}
err = luid.SetDNS(family, []net.IP{net.ParseIP("198.18.0.2")}, nil)
if err == windows.ERROR_NOT_FOUND && retryOnFailure {
goto startOver
} else if err != nil {
return fmt.Errorf("unable to set DNS %s %s: %w", "198.18.0.2", "nil", err)
}
return nil
}
func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) {
if len(addresses) == 0 {
return
}
addrToStr := func(ip *net.IP) string {
if ip4 := ip.To4(); ip4 != nil {
return string(ip4)
}
return string(*ip)
}
addrHash := make(map[string]bool, len(addresses))
for i := range addresses {
addrHash[addrToStr(&addresses[i].IP)] = true
}
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
if err != nil {
return
}
for _, iface := range interfaces {
if iface.OperStatus == winipcfg.IfOperStatusUp {
continue
}
for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
ip := address.Address.IP()
if addrHash[addrToStr(&ip)] {
ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))}
log.Infoln("Cleaning up stale address %s from interface %s", ipnet.String(), iface.FriendlyName())
iface.LUID.DeleteIPAddress(ipnet)
}
}
}
}
// GetAutoDetectInterface get ethernet interface
func GetAutoDetectInterface() (string, error) {
ifname, err := getAutoDetectInterfaceByFamily(winipcfg.AddressFamily(windows.AF_INET))
if err == nil {
return ifname, err
}
return getAutoDetectInterfaceByFamily(winipcfg.AddressFamily(windows.AF_INET6))
}
func getAutoDetectInterfaceByFamily(family winipcfg.AddressFamily) (string, error) {
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagIncludeGateways)
if err != nil {
return "", fmt.Errorf("get ethernet interface failure. %w", err)
}
for _, iface := range interfaces {
if iface.OperStatus != winipcfg.IfOperStatusUp {
continue
}
ifname := iface.FriendlyName()
if ifname == "Clash" {
continue
}
for gatewayAddress := iface.FirstGatewayAddress; gatewayAddress != nil; gatewayAddress = gatewayAddress.Next {
nextHop := gatewayAddress.Address.IP()
var ipnet net.IPNet
if family == windows.AF_INET {
ipnet = net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
} else {
ipnet = net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
}
if _, err = iface.LUID.Route(ipnet, nextHop); err == nil {
return ifname, nil
}
}
}
return "", errors.New("ethernet interface not found")
}

View File

@ -1,91 +0,0 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"sync"
"golang.org/x/sys/windows"
)
// InterfaceChangeCallback structure allows interface change callback handling.
type InterfaceChangeCallback struct {
cb func(notificationType MibNotificationType, iface *MibIPInterfaceRow)
wait sync.WaitGroup
}
var (
interfaceChangeAddRemoveMutex = sync.Mutex{}
interfaceChangeMutex = sync.Mutex{}
interfaceChangeCallbacks = make(map[*InterfaceChangeCallback]bool)
interfaceChangeHandle = windows.Handle(0)
)
// RegisterInterfaceChangeCallback registers a new InterfaceChangeCallback. If this particular callback is already
// registered, the function will silently return. Returned InterfaceChangeCallback.Unregister method should be used
// to unregister.
func RegisterInterfaceChangeCallback(callback func(notificationType MibNotificationType, iface *MibIPInterfaceRow)) (*InterfaceChangeCallback, error) {
s := &InterfaceChangeCallback{cb: callback}
interfaceChangeAddRemoveMutex.Lock()
defer interfaceChangeAddRemoveMutex.Unlock()
interfaceChangeMutex.Lock()
defer interfaceChangeMutex.Unlock()
interfaceChangeCallbacks[s] = true
if interfaceChangeHandle == 0 {
err := notifyIPInterfaceChange(windows.AF_UNSPEC, windows.NewCallback(interfaceChanged), 0, false, &interfaceChangeHandle)
if err != nil {
delete(interfaceChangeCallbacks, s)
interfaceChangeHandle = 0
return nil, err
}
}
return s, nil
}
// Unregister unregisters the callback.
func (callback *InterfaceChangeCallback) Unregister() error {
interfaceChangeAddRemoveMutex.Lock()
defer interfaceChangeAddRemoveMutex.Unlock()
interfaceChangeMutex.Lock()
delete(interfaceChangeCallbacks, callback)
removeIt := len(interfaceChangeCallbacks) == 0 && interfaceChangeHandle != 0
interfaceChangeMutex.Unlock()
callback.wait.Wait()
if removeIt {
err := cancelMibChangeNotify2(interfaceChangeHandle)
if err != nil {
return err
}
interfaceChangeHandle = 0
}
return nil
}
func interfaceChanged(callerContext uintptr, row *MibIPInterfaceRow, notificationType MibNotificationType) uintptr {
rowCopy := *row
interfaceChangeMutex.Lock()
for cb := range interfaceChangeCallbacks {
cb.wait.Add(1)
go func(cb *InterfaceChangeCallback) {
cb.cb(notificationType, &rowCopy)
cb.wait.Done()
}(cb)
}
interfaceChangeMutex.Unlock()
return 0
}

View File

@ -1,404 +0,0 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"errors"
"fmt"
"net"
"strings"
"golang.org/x/sys/windows"
)
// LUID represents a network interface.
type LUID uint64
// IPInterface method retrieves IP information for the specified interface on the local computer.
func (luid LUID) IPInterface(family AddressFamily) (*MibIPInterfaceRow, error) {
row := &MibIPInterfaceRow{}
row.Init()
row.InterfaceLUID = luid
row.Family = family
err := row.get()
if err != nil {
return nil, err
}
return row, nil
}
// Interface method retrieves information for the specified adapter on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getifentry2
func (luid LUID) Interface() (*MibIfRow2, error) {
row := &MibIfRow2{}
row.InterfaceLUID = luid
err := row.get()
if err != nil {
return nil, err
}
return row, nil
}
// GUID method converts a locally unique identifier (LUID) for a network interface to a globally unique identifier (GUID) for the interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceluidtoguid
func (luid LUID) GUID() (*windows.GUID, error) {
guid := &windows.GUID{}
err := convertInterfaceLUIDToGUID(&luid, guid)
if err != nil {
return nil, err
}
return guid, nil
}
// LUIDFromGUID function converts a globally unique identifier (GUID) for a network interface to the locally unique identifier (LUID) for the interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceguidtoluid
func LUIDFromGUID(guid *windows.GUID) (LUID, error) {
var luid LUID
err := convertInterfaceGUIDToLUID(guid, &luid)
if err != nil {
return 0, err
}
return luid, nil
}
// LUIDFromIndex function converts a local index for a network interface to the locally unique identifier (LUID) for the interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceindextoluid
func LUIDFromIndex(index uint32) (LUID, error) {
var luid LUID
err := convertInterfaceIndexToLUID(index, &luid)
if err != nil {
return 0, err
}
return luid, nil
}
// IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry)
func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) {
row := &MibUnicastIPAddressRow{InterfaceLUID: luid}
err := row.Address.SetIP(ip, 0)
if err != nil {
return nil, err
}
err = row.get()
if err != nil {
return nil, err
}
return row, nil
}
// AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
func (luid LUID) AddIPAddress(address net.IPNet) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
row.DadState = DadStatePreferred
row.ValidLifetime = 0xffffffff
row.PreferredLifetime = 0xffffffff
err := row.Address.SetIP(address.IP, 0)
if err != nil {
return err
}
ones, _ := address.Mask.Size()
row.OnLinkPrefixLength = uint8(ones)
return row.Create()
}
// AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
func (luid LUID) AddIPAddresses(addresses []net.IPNet) error {
for i := range addresses {
err := luid.AddIPAddress(addresses[i])
if err != nil {
return err
}
}
return nil
}
// SetIPAddresses method sets new unicast IP addresses to the interface.
func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
err := luid.FlushIPAddresses(windows.AF_UNSPEC)
if err != nil {
return err
}
return luid.AddIPAddresses(addresses)
}
// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface.
func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error {
err := luid.FlushIPAddresses(family)
if err != nil {
return err
}
for i := range addresses {
asV4 := addresses[i].IP.To4()
if asV4 == nil && family == windows.AF_INET {
continue
} else if asV4 != nil && family == windows.AF_INET6 {
continue
}
err := luid.AddIPAddress(addresses[i])
if err != nil {
return err
}
}
return nil
}
// DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry).
func (luid LUID) DeleteIPAddress(address net.IPNet) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
err := row.Address.SetIP(address.IP, 0)
if err != nil {
return err
}
// Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry().
ones, _ := address.Mask.Size()
row.OnLinkPrefixLength = uint8(ones)
return row.Delete()
}
// FlushIPAddresses method deletes all interface's unicast IP addresses.
func (luid LUID) FlushIPAddresses(family AddressFamily) error {
var tab *mibUnicastIPAddressTable
err := getUnicastIPAddressTable(family, &tab)
if err != nil {
return err
}
t := tab.get()
for i := range t {
if t[i].InterfaceLUID == luid {
t[i].Delete()
}
}
tab.free()
return nil
}
// Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2).
// NOTE: If the corresponding route isn't found, the method will return error.
func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2, error) {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
row.ValidLifetime = 0xffffffff
row.PreferredLifetime = 0xffffffff
err := row.DestinationPrefix.SetIPNet(destination)
if err != nil {
return nil, err
}
err = row.NextHop.SetIP(nextHop, 0)
if err != nil {
return nil, err
}
err = row.get()
if err != nil {
return nil, err
}
return row, nil
}
// AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature.
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2)
func (luid LUID) AddRoute(destination net.IPNet, nextHop net.IP, metric uint32) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
err := row.DestinationPrefix.SetIPNet(destination)
if err != nil {
return fmt.Errorf("AddRoute1: %w", err)
}
err = row.NextHop.SetIP(nextHop, 0)
if err != nil {
return fmt.Errorf("AddRoute2: %w", err)
}
row.Metric = metric
err = row.Create()
if err != nil {
return fmt.Errorf("AddRoute3: %w", err)
}
return nil
}
// AddRoutes method adds multiple routes to the interface.
func (luid LUID) AddRoutes(routesData []*RouteData) error {
for _, rd := range routesData {
err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric)
if err != nil {
return err
}
}
return nil
}
// SetRoutes method sets (flush than add) multiple routes to the interface.
func (luid LUID) SetRoutes(routesData []*RouteData) error {
err := luid.FlushRoutes(windows.AF_UNSPEC)
if err != nil {
return err
}
return luid.AddRoutes(routesData)
}
// SetRoutesForFamily method sets (flush than add) multiple routes for a specific family to the interface.
func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteData) error {
//err := luid.FlushRoutes(family)
//if err != nil {
// return err
//}
_ = luid.FlushRoutes(family)
for _, rd := range routesData {
asV4 := rd.Destination.IP.To4()
if asV4 == nil && family == windows.AF_INET {
continue
} else if asV4 != nil && family == windows.AF_INET6 {
continue
}
err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric)
if err != nil {
return err
}
}
return nil
}
// DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2).
func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
err := row.DestinationPrefix.SetIPNet(destination)
if err != nil {
return err
}
err = row.NextHop.SetIP(nextHop, 0)
if err != nil {
return err
}
err = row.get()
if err != nil {
return err
}
return row.Delete()
}
// FlushRoutes method deletes all interface's routes.
// It continues on failures, and returns the last error afterwards.
func (luid LUID) FlushRoutes(family AddressFamily) error {
var tab *mibIPforwardTable2
err := getIPForwardTable2(family, &tab)
if err != nil {
return fmt.Errorf("FlushRoutes1: %w", err)
}
t := tab.get()
for i := range t {
if t[i].InterfaceLUID == luid {
err2 := t[i].Delete()
if err2 != nil {
err = err2
}
}
}
tab.free()
if err != nil {
return fmt.Errorf("FlushRoutes2: %w", err)
}
return nil
}
// DNS method returns all DNS server addresses associated with the adapter.
func (luid LUID) DNS() ([]net.IP, error) {
addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault)
if err != nil {
return nil, err
}
r := make([]net.IP, 0, len(addresses))
for _, addr := range addresses {
if addr.LUID == luid {
for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next {
if ip := dns.Address.IP(); ip != nil {
r = append(r, ip)
} else {
return nil, windows.ERROR_INVALID_PARAMETER
}
}
}
}
return r, nil
}
// SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family.
func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string) error {
if family != windows.AF_INET && family != windows.AF_INET6 {
return windows.ERROR_PROTOCOL_UNREACHABLE
}
var filteredServers []string
for _, server := range servers {
if v4 := server.To4(); v4 != nil && family == windows.AF_INET {
filteredServers = append(filteredServers, v4.String())
} else if v6 := server.To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
filteredServers = append(filteredServers, v6.String())
}
}
servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ","))
if err != nil {
return err
}
domains16, err := windows.UTF16PtrFromString(strings.Join(domains, ","))
if err != nil {
return err
}
guid, err := luid.GUID()
if err != nil {
return err
}
dnsInterfaceSettings := &DnsInterfaceSettings{
Version: DnsInterfaceSettingsVersion1,
Flags: DnsInterfaceSettingsFlagNameserver | DnsInterfaceSettingsFlagSearchList,
NameServer: servers16,
SearchList: domains16,
}
if family == windows.AF_INET6 {
dnsInterfaceSettings.Flags |= DnsInterfaceSettingsFlagIPv6
}
// For >= Windows 10 1809
err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings)
if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
return err
}
// For < Windows 10 1809
err = luid.fallbackSetDNSForFamily(family, servers)
if err != nil {
return err
}
if len(domains) > 0 {
return luid.fallbackSetDNSDomain(domains[0])
} else {
return luid.fallbackSetDNSDomain("")
}
}
// FlushDNS method clears all DNS servers associated with the adapter.
func (luid LUID) FlushDNS(family AddressFamily) error {
return luid.SetDNS(family, nil, nil)
}

View File

@ -1,11 +0,0 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zwinipcfg_windows.go winipcfg.go

View File

@ -1,111 +0,0 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"os/exec"
"path/filepath"
"strings"
"syscall"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
func runNetsh(cmds []string) error {
system32, err := windows.GetSystemDirectory()
if err != nil {
return err
}
cmd := exec.Command(filepath.Join(system32, "netsh.exe"))
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("runNetsh stdin pipe - %w", err)
}
go func() {
defer stdin.Close()
io.WriteString(stdin, strings.Join(append(cmds, "exit\r\n"), "\r\n"))
}()
output, err := cmd.CombinedOutput()
// Horrible kludges, sorry.
cleaned := bytes.ReplaceAll(output, []byte{'\r', '\n'}, []byte{'\n'})
cleaned = bytes.ReplaceAll(cleaned, []byte("netsh>"), []byte{})
cleaned = bytes.ReplaceAll(cleaned, []byte("There are no Domain Name Servers (DNS) configured on this computer."), []byte{})
cleaned = bytes.TrimSpace(cleaned)
if len(cleaned) != 0 && err == nil {
return fmt.Errorf("netsh: %#q", string(cleaned))
} else if err != nil {
return fmt.Errorf("netsh: %v: %#q", err, string(cleaned))
}
return nil
}
const (
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
)
func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) error {
var templateFlush string
if family == windows.AF_INET {
templateFlush = netshCmdTemplateFlush4
} else if family == windows.AF_INET6 {
templateFlush = netshCmdTemplateFlush6
}
cmds := make([]string, 0, 1+len(dnses))
ipif, err := luid.IPInterface(family)
if err != nil {
return err
}
cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex))
for i := 0; i < len(dnses); i++ {
if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET {
cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String()))
} else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String()))
}
}
return runNetsh(cmds)
}
func (luid LUID) fallbackSetDNSDomain(domain string) error {
guid, err := luid.GUID()
if err != nil {
return fmt.Errorf("Error converting luid to guid: %w", err)
}
key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE)
if err != nil {
return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %w", err)
}
paths, _, err := key.GetStringsValue("IpConfig")
key.Close()
if err != nil {
return fmt.Errorf("Error reading IpConfig registry key: %w", err)
}
if len(paths) == 0 {
return errors.New("No TCP/IP interfaces found on adapter")
}
key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE)
if err != nil {
return fmt.Errorf("Unable to open TCP/IP network registry key: %w", err)
}
err = key.SetStringValue("Domain", domain)
key.Close()
return err
}

View File

@ -1,7 +0,0 @@
//go:build windows
// +build windows
// Modified from: https://git.zx2c4.com/wireguard-windows/tree/tunnel/winipcfg
// License: MIT
package winipcfg

View File

@ -1,91 +0,0 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"sync"
"golang.org/x/sys/windows"
)
// RouteChangeCallback structure allows route change callback handling.
type RouteChangeCallback struct {
cb func(notificationType MibNotificationType, route *MibIPforwardRow2)
wait sync.WaitGroup
}
var (
routeChangeAddRemoveMutex = sync.Mutex{}
routeChangeMutex = sync.Mutex{}
routeChangeCallbacks = make(map[*RouteChangeCallback]bool)
routeChangeHandle = windows.Handle(0)
)
// RegisterRouteChangeCallback registers a new RouteChangeCallback. If this particular callback is already
// registered, the function will silently return. Returned RouteChangeCallback.Unregister method should be used
// to unregister.
func RegisterRouteChangeCallback(callback func(notificationType MibNotificationType, route *MibIPforwardRow2)) (*RouteChangeCallback, error) {
s := &RouteChangeCallback{cb: callback}
routeChangeAddRemoveMutex.Lock()
defer routeChangeAddRemoveMutex.Unlock()
routeChangeMutex.Lock()
defer routeChangeMutex.Unlock()
routeChangeCallbacks[s] = true
if routeChangeHandle == 0 {
err := notifyRouteChange2(windows.AF_UNSPEC, windows.NewCallback(routeChanged), 0, false, &routeChangeHandle)
if err != nil {
delete(routeChangeCallbacks, s)
routeChangeHandle = 0
return nil, err
}
}
return s, nil
}
// Unregister unregisters the callback.
func (callback *RouteChangeCallback) Unregister() error {
routeChangeAddRemoveMutex.Lock()
defer routeChangeAddRemoveMutex.Unlock()
routeChangeMutex.Lock()
delete(routeChangeCallbacks, callback)
removeIt := len(routeChangeCallbacks) == 0 && routeChangeHandle != 0
routeChangeMutex.Unlock()
callback.wait.Wait()
if removeIt {
err := cancelMibChangeNotify2(routeChangeHandle)
if err != nil {
return err
}
routeChangeHandle = 0
}
return nil
}
func routeChanged(callerContext uintptr, row *MibIPforwardRow2, notificationType MibNotificationType) uintptr {
rowCopy := *row
routeChangeMutex.Lock()
for cb := range routeChangeCallbacks {
cb.wait.Add(1)
go func(cb *RouteChangeCallback) {
cb.cb(notificationType, &rowCopy)
cb.wait.Done()
}(cb)
}
routeChangeMutex.Unlock()
return 0
}

File diff suppressed because it is too large Load Diff

View File

@ -1,234 +0,0 @@
//go:build windows && (386 || arm)
// +build windows
// +build 386 arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"golang.org/x/sys/windows"
)
// IPAdapterWINSServerAddress structure stores a single Windows Internet Name Service (WINS) server address in a linked list of WINS server addresses for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_wins_server_address_lh
type IPAdapterWINSServerAddress struct {
Length uint32
_ uint32
Next *IPAdapterWINSServerAddress
Address windows.SocketAddress
_ [4]byte
}
// IPAdapterGatewayAddress structure stores a single gateway address in a linked list of gateway addresses for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_gateway_address_lh
type IPAdapterGatewayAddress struct {
Length uint32
_ uint32
Next *IPAdapterGatewayAddress
Address windows.SocketAddress
_ [4]byte
}
// IPAdapterAddresses structure is the header node for a linked list of addresses for a particular adapter. This structure can simultaneously be used as part of a linked list of IP_ADAPTER_ADDRESSES structures.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_addresses_lh
// This is a modified and extended version of windows.IpAdapterAddresses.
type IPAdapterAddresses struct {
Length uint32
IfIndex uint32
Next *IPAdapterAddresses
adapterName *byte
FirstUnicastAddress *windows.IpAdapterUnicastAddress
FirstAnycastAddress *windows.IpAdapterAnycastAddress
FirstMulticastAddress *windows.IpAdapterMulticastAddress
FirstDNSServerAddress *windows.IpAdapterDnsServerAdapter
dnsSuffix *uint16
description *uint16
friendlyName *uint16
physicalAddress [windows.MAX_ADAPTER_ADDRESS_LENGTH]byte
physicalAddressLength uint32
Flags IPAAFlags
MTU uint32
IfType IfType
OperStatus IfOperStatus
IPv6IfIndex uint32
ZoneIndices [16]uint32
FirstPrefix *windows.IpAdapterPrefix
TransmitLinkSpeed uint64
ReceiveLinkSpeed uint64
FirstWINSServerAddress *IPAdapterWINSServerAddress
FirstGatewayAddress *IPAdapterGatewayAddress
Ipv4Metric uint32
Ipv6Metric uint32
LUID LUID
DHCPv4Server windows.SocketAddress
CompartmentID uint32
NetworkGUID windows.GUID
ConnectionType NetIfConnectionType
TunnelType TunnelType
DHCPv6Server windows.SocketAddress
dhcpv6ClientDUID [maxDHCPv6DUIDLength]byte
dhcpv6ClientDUIDLength uint32
DHCPv6IAID uint32
FirstDNSSuffix *IPAdapterDNSSuffix
_ [4]byte
}
// MibIPInterfaceRow structure stores interface management information for a particular IP address family on a network interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_row
type MibIPInterfaceRow struct {
Family AddressFamily
_ [4]byte
InterfaceLUID LUID
InterfaceIndex uint32
MaxReassemblySize uint32
InterfaceIdentifier uint64
MinRouterAdvertisementInterval uint32
MaxRouterAdvertisementInterval uint32
AdvertisingEnabled bool
ForwardingEnabled bool
WeakHostSend bool
WeakHostReceive bool
UseAutomaticMetric bool
UseNeighborUnreachabilityDetection bool
ManagedAddressConfigurationSupported bool
OtherStatefulConfigurationSupported bool
AdvertiseDefaultRoute bool
RouterDiscoveryBehavior RouterDiscoveryBehavior
DadTransmits uint32
BaseReachableTime uint32
RetransmitTime uint32
PathMTUDiscoveryTimeout uint32
LinkLocalAddressBehavior LinkLocalAddressBehavior
LinkLocalAddressTimeout uint32
ZoneIndices [ScopeLevelCount]uint32
SitePrefixLength uint32
Metric uint32
NLMTU uint32
Connected bool
SupportsWakeUpPatterns bool
SupportsNeighborDiscovery bool
SupportsRouterDiscovery bool
ReachableTime uint32
TransmitOffload OffloadRod
ReceiveOffload OffloadRod
DisableDefaultRoutes bool
}
// mibIPInterfaceTable structure contains a table of IP interface entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_table
type mibIPInterfaceTable struct {
numEntries uint32
_ [4]byte
table [anySize]MibIPInterfaceRow
}
// MibIfRow2 structure stores information about a particular interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2
type MibIfRow2 struct {
InterfaceLUID LUID
InterfaceIndex uint32
InterfaceGUID windows.GUID
alias [ifMaxStringSize + 1]uint16
description [ifMaxStringSize + 1]uint16
physicalAddressLength uint32
physicalAddress [ifMaxPhysAddressLength]byte
permanentPhysicalAddress [ifMaxPhysAddressLength]byte
MTU uint32
Type IfType
TunnelType TunnelType
MediaType NdisMedium
PhysicalMediumType NdisPhysicalMedium
AccessType NetIfAccessType
DirectionType NetIfDirectionType
InterfaceAndOperStatusFlags InterfaceAndOperStatusFlags
OperStatus IfOperStatus
AdminStatus NetIfAdminStatus
MediaConnectState NetIfMediaConnectState
NetworkGUID windows.GUID
ConnectionType NetIfConnectionType
_ [4]byte
TransmitLinkSpeed uint64
ReceiveLinkSpeed uint64
InOctets uint64
InUcastPkts uint64
InNUcastPkts uint64
InDiscards uint64
InErrors uint64
InUnknownProtos uint64
InUcastOctets uint64
InMulticastOctets uint64
InBroadcastOctets uint64
OutOctets uint64
OutUcastPkts uint64
OutNUcastPkts uint64
OutDiscards uint64
OutErrors uint64
OutUcastOctets uint64
OutMulticastOctets uint64
OutBroadcastOctets uint64
OutQLen uint64
}
// mibIfTable2 structure contains a table of logical and physical interface entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_table2
type mibIfTable2 struct {
numEntries uint32
_ [4]byte
table [anySize]MibIfRow2
}
// MibUnicastIPAddressRow structure stores information about a unicast IP address.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_row
type MibUnicastIPAddressRow struct {
Address RawSockaddrInet
_ [4]byte
InterfaceLUID LUID
InterfaceIndex uint32
PrefixOrigin PrefixOrigin
SuffixOrigin SuffixOrigin
ValidLifetime uint32
PreferredLifetime uint32
OnLinkPrefixLength uint8
SkipAsSource bool
DadState DadState
ScopeID uint32
CreationTimeStamp int64
}
// mibUnicastIPAddressTable structure contains a table of unicast IP address entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_table
type mibUnicastIPAddressTable struct {
numEntries uint32
_ [4]byte
table [anySize]MibUnicastIPAddressRow
}
// MibAnycastIPAddressRow structure stores information about an anycast IP address.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_anycastipaddress_row
type MibAnycastIPAddressRow struct {
Address RawSockaddrInet
_ [4]byte
InterfaceLUID LUID
InterfaceIndex uint32
ScopeID uint32
}
// mibAnycastIPAddressTable structure contains a table of anycast IP address entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-mib_anycastipaddress_table
type mibAnycastIPAddressTable struct {
numEntries uint32
_ [4]byte
table [anySize]MibAnycastIPAddressRow
}
// mibIPforwardTable2 structure contains a table of IP route entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipforward_table2
type mibIPforwardTable2 struct {
numEntries uint32
_ [4]byte
table [anySize]MibIPforwardRow2
}

View File

@ -1,222 +0,0 @@
//go:build windows && (amd64 || arm64)
// +build windows
// +build amd64 arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"golang.org/x/sys/windows"
)
// IPAdapterWINSServerAddress structure stores a single Windows Internet Name Service (WINS) server address in a linked list of WINS server addresses for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_wins_server_address_lh
type IPAdapterWINSServerAddress struct {
Length uint32
_ uint32
Next *IPAdapterWINSServerAddress
Address windows.SocketAddress
}
// IPAdapterGatewayAddress structure stores a single gateway address in a linked list of gateway addresses for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_gateway_address_lh
type IPAdapterGatewayAddress struct {
Length uint32
_ uint32
Next *IPAdapterGatewayAddress
Address windows.SocketAddress
}
// IPAdapterAddresses structure is the header node for a linked list of addresses for a particular adapter. This structure can simultaneously be used as part of a linked list of IP_ADAPTER_ADDRESSES structures.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_addresses_lh
// This is a modified and extended version of windows.IpAdapterAddresses.
type IPAdapterAddresses struct {
Length uint32
IfIndex uint32
Next *IPAdapterAddresses
adapterName *byte
FirstUnicastAddress *windows.IpAdapterUnicastAddress
FirstAnycastAddress *windows.IpAdapterAnycastAddress
FirstMulticastAddress *windows.IpAdapterMulticastAddress
FirstDNSServerAddress *windows.IpAdapterDnsServerAdapter
dnsSuffix *uint16
description *uint16
friendlyName *uint16
physicalAddress [windows.MAX_ADAPTER_ADDRESS_LENGTH]byte
physicalAddressLength uint32
Flags IPAAFlags
MTU uint32
IfType IfType
OperStatus IfOperStatus
IPv6IfIndex uint32
ZoneIndices [16]uint32
FirstPrefix *windows.IpAdapterPrefix
TransmitLinkSpeed uint64
ReceiveLinkSpeed uint64
FirstWINSServerAddress *IPAdapterWINSServerAddress
FirstGatewayAddress *IPAdapterGatewayAddress
Ipv4Metric uint32
Ipv6Metric uint32
LUID LUID
DHCPv4Server windows.SocketAddress
CompartmentID uint32
NetworkGUID windows.GUID
ConnectionType NetIfConnectionType
TunnelType TunnelType
DHCPv6Server windows.SocketAddress
dhcpv6ClientDUID [maxDHCPv6DUIDLength]byte
dhcpv6ClientDUIDLength uint32
DHCPv6IAID uint32
FirstDNSSuffix *IPAdapterDNSSuffix
}
// MibIPInterfaceRow structure stores interface management information for a particular IP address family on a network interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_row
type MibIPInterfaceRow struct {
Family AddressFamily
InterfaceLUID LUID
InterfaceIndex uint32
MaxReassemblySize uint32
InterfaceIdentifier uint64
MinRouterAdvertisementInterval uint32
MaxRouterAdvertisementInterval uint32
AdvertisingEnabled bool
ForwardingEnabled bool
WeakHostSend bool
WeakHostReceive bool
UseAutomaticMetric bool
UseNeighborUnreachabilityDetection bool
ManagedAddressConfigurationSupported bool
OtherStatefulConfigurationSupported bool
AdvertiseDefaultRoute bool
RouterDiscoveryBehavior RouterDiscoveryBehavior
DadTransmits uint32
BaseReachableTime uint32
RetransmitTime uint32
PathMTUDiscoveryTimeout uint32
LinkLocalAddressBehavior LinkLocalAddressBehavior
LinkLocalAddressTimeout uint32
ZoneIndices [ScopeLevelCount]uint32
SitePrefixLength uint32
Metric uint32
NLMTU uint32
Connected bool
SupportsWakeUpPatterns bool
SupportsNeighborDiscovery bool
SupportsRouterDiscovery bool
ReachableTime uint32
TransmitOffload OffloadRod
ReceiveOffload OffloadRod
DisableDefaultRoutes bool
}
// mibIPInterfaceTable structure contains a table of IP interface entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipinterface_table
type mibIPInterfaceTable struct {
numEntries uint32
table [anySize]MibIPInterfaceRow
}
// MibIfRow2 structure stores information about a particular interface.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_row2
type MibIfRow2 struct {
InterfaceLUID LUID
InterfaceIndex uint32
InterfaceGUID windows.GUID
alias [ifMaxStringSize + 1]uint16
description [ifMaxStringSize + 1]uint16
physicalAddressLength uint32
physicalAddress [ifMaxPhysAddressLength]byte
permanentPhysicalAddress [ifMaxPhysAddressLength]byte
MTU uint32
Type IfType
TunnelType TunnelType
MediaType NdisMedium
PhysicalMediumType NdisPhysicalMedium
AccessType NetIfAccessType
DirectionType NetIfDirectionType
InterfaceAndOperStatusFlags InterfaceAndOperStatusFlags
OperStatus IfOperStatus
AdminStatus NetIfAdminStatus
MediaConnectState NetIfMediaConnectState
NetworkGUID windows.GUID
ConnectionType NetIfConnectionType
TransmitLinkSpeed uint64
ReceiveLinkSpeed uint64
InOctets uint64
InUcastPkts uint64
InNUcastPkts uint64
InDiscards uint64
InErrors uint64
InUnknownProtos uint64
InUcastOctets uint64
InMulticastOctets uint64
InBroadcastOctets uint64
OutOctets uint64
OutUcastPkts uint64
OutNUcastPkts uint64
OutDiscards uint64
OutErrors uint64
OutUcastOctets uint64
OutMulticastOctets uint64
OutBroadcastOctets uint64
OutQLen uint64
}
// mibIfTable2 structure contains a table of logical and physical interface entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_if_table2
type mibIfTable2 struct {
numEntries uint32
table [anySize]MibIfRow2
}
// MibUnicastIPAddressRow structure stores information about a unicast IP address.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_row
type MibUnicastIPAddressRow struct {
Address RawSockaddrInet
InterfaceLUID LUID
InterfaceIndex uint32
PrefixOrigin PrefixOrigin
SuffixOrigin SuffixOrigin
ValidLifetime uint32
PreferredLifetime uint32
OnLinkPrefixLength uint8
SkipAsSource bool
DadState DadState
ScopeID uint32
CreationTimeStamp int64
}
// mibUnicastIPAddressTable structure contains a table of unicast IP address entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_unicastipaddress_table
type mibUnicastIPAddressTable struct {
numEntries uint32
table [anySize]MibUnicastIPAddressRow
}
// MibAnycastIPAddressRow structure stores information about an anycast IP address.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_anycastipaddress_row
type MibAnycastIPAddressRow struct {
Address RawSockaddrInet
InterfaceLUID LUID
InterfaceIndex uint32
ScopeID uint32
}
// mibAnycastIPAddressTable structure contains a table of anycast IP address entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-mib_anycastipaddress_table
type mibAnycastIPAddressTable struct {
numEntries uint32
table [anySize]MibAnycastIPAddressRow
}
// mibIPforwardTable2 structure contains a table of IP route entries.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_mib_ipforward_table2
type mibIPforwardTable2 struct {
numEntries uint32
table [anySize]MibIPforwardRow2
}

View File

@ -1,91 +0,0 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"sync"
"golang.org/x/sys/windows"
)
// UnicastAddressChangeCallback structure allows unicast address change callback handling.
type UnicastAddressChangeCallback struct {
cb func(notificationType MibNotificationType, unicastAddress *MibUnicastIPAddressRow)
wait sync.WaitGroup
}
var (
unicastAddressChangeAddRemoveMutex = sync.Mutex{}
unicastAddressChangeMutex = sync.Mutex{}
unicastAddressChangeCallbacks = make(map[*UnicastAddressChangeCallback]bool)
unicastAddressChangeHandle = windows.Handle(0)
)
// RegisterUnicastAddressChangeCallback registers a new UnicastAddressChangeCallback. If this particular callback is already
// registered, the function will silently return. Returned UnicastAddressChangeCallback.Unregister method should be used
// to unregister.
func RegisterUnicastAddressChangeCallback(callback func(notificationType MibNotificationType, unicastAddress *MibUnicastIPAddressRow)) (*UnicastAddressChangeCallback, error) {
s := &UnicastAddressChangeCallback{cb: callback}
unicastAddressChangeAddRemoveMutex.Lock()
defer unicastAddressChangeAddRemoveMutex.Unlock()
unicastAddressChangeMutex.Lock()
defer unicastAddressChangeMutex.Unlock()
unicastAddressChangeCallbacks[s] = true
if unicastAddressChangeHandle == 0 {
err := notifyUnicastIPAddressChange(windows.AF_UNSPEC, windows.NewCallback(unicastAddressChanged), 0, false, &unicastAddressChangeHandle)
if err != nil {
delete(unicastAddressChangeCallbacks, s)
unicastAddressChangeHandle = 0
return nil, err
}
}
return s, nil
}
// Unregister unregisters the callback.
func (callback *UnicastAddressChangeCallback) Unregister() error {
unicastAddressChangeAddRemoveMutex.Lock()
defer unicastAddressChangeAddRemoveMutex.Unlock()
unicastAddressChangeMutex.Lock()
delete(unicastAddressChangeCallbacks, callback)
removeIt := len(unicastAddressChangeCallbacks) == 0 && unicastAddressChangeHandle != 0
unicastAddressChangeMutex.Unlock()
callback.wait.Wait()
if removeIt {
err := cancelMibChangeNotify2(unicastAddressChangeHandle)
if err != nil {
return err
}
unicastAddressChangeHandle = 0
}
return nil
}
func unicastAddressChanged(callerContext uintptr, row *MibUnicastIPAddressRow, notificationType MibNotificationType) uintptr {
rowCopy := *row
unicastAddressChangeMutex.Lock()
for cb := range unicastAddressChangeCallbacks {
cb.wait.Add(1)
go func(cb *UnicastAddressChangeCallback) {
cb.cb(notificationType, &rowCopy)
cb.wait.Done()
}(cb)
}
unicastAddressChangeMutex.Unlock()
return 0
}

View File

@ -1,199 +0,0 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"runtime"
"unsafe"
"golang.org/x/sys/windows"
)
//
// Common functions
//
//sys freeMibTable(memory unsafe.Pointer) = iphlpapi.FreeMibTable
//
// Interface-related functions
//
//sys initializeIPInterfaceEntry(row *MibIPInterfaceRow) = iphlpapi.InitializeIpInterfaceEntry
//sys getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) = iphlpapi.GetIpInterfaceTable
//sys getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) = iphlpapi.GetIpInterfaceEntry
//sys setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) = iphlpapi.SetIpInterfaceEntry
//sys getIfEntry2(row *MibIfRow2) (ret error) = iphlpapi.GetIfEntry2
//sys getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) = iphlpapi.GetIfTable2Ex
//sys convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
//sys convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceGuidToLuid
//sys convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceIndexToLuid
// GetAdaptersAddresses function retrieves the addresses associated with the adapters on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses
func GetAdaptersAddresses(family AddressFamily, flags GAAFlags) ([]*IPAdapterAddresses, error) {
var b []byte
size := uint32(15000)
for {
b = make([]byte, size)
err := windows.GetAdaptersAddresses(uint32(family), uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size)
if err == nil {
break
}
if err != windows.ERROR_BUFFER_OVERFLOW || size <= uint32(len(b)) {
return nil, err
}
}
result := make([]*IPAdapterAddresses, 0, uintptr(size)/unsafe.Sizeof(IPAdapterAddresses{}))
for wtiaa := (*IPAdapterAddresses)(unsafe.Pointer(&b[0])); wtiaa != nil; wtiaa = wtiaa.Next {
result = append(result, wtiaa)
}
return result, nil
}
// GetIPInterfaceTable function retrieves the IP interface entries on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipinterfacetable
func GetIPInterfaceTable(family AddressFamily) ([]MibIPInterfaceRow, error) {
var tab *mibIPInterfaceTable
err := getIPInterfaceTable(family, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibIPInterfaceRow, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
// GetIfTable2Ex function retrieves the MIB-II interface table.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getiftable2ex
func GetIfTable2Ex(level MibIfEntryLevel) ([]MibIfRow2, error) {
var tab *mibIfTable2
err := getIfTable2Ex(level, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibIfRow2, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
//
// Unicast IP address-related functions
//
//sys getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) = iphlpapi.GetUnicastIpAddressTable
//sys initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) = iphlpapi.InitializeUnicastIpAddressEntry
//sys getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.GetUnicastIpAddressEntry
//sys setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.SetUnicastIpAddressEntry
//sys createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.CreateUnicastIpAddressEntry
//sys deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) = iphlpapi.DeleteUnicastIpAddressEntry
// GetUnicastIPAddressTable function retrieves the unicast IP address table on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddresstable
func GetUnicastIPAddressTable(family AddressFamily) ([]MibUnicastIPAddressRow, error) {
var tab *mibUnicastIPAddressTable
err := getUnicastIPAddressTable(family, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibUnicastIPAddressRow, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
//
// Anycast IP address-related functions
//
//sys getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) = iphlpapi.GetAnycastIpAddressTable
//sys getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.GetAnycastIpAddressEntry
//sys createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.CreateAnycastIpAddressEntry
//sys deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) = iphlpapi.DeleteAnycastIpAddressEntry
// GetAnycastIPAddressTable function retrieves the anycast IP address table on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getanycastipaddresstable
func GetAnycastIPAddressTable(family AddressFamily) ([]MibAnycastIPAddressRow, error) {
var tab *mibAnycastIPAddressTable
err := getAnycastIPAddressTable(family, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibAnycastIPAddressRow, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
//
// Routing-related functions
//
//sys getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) = iphlpapi.GetIpForwardTable2
//sys initializeIPForwardEntry(route *MibIPforwardRow2) = iphlpapi.InitializeIpForwardEntry
//sys getIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.GetIpForwardEntry2
//sys setIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.SetIpForwardEntry2
//sys createIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.CreateIpForwardEntry2
//sys deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) = iphlpapi.DeleteIpForwardEntry2
// GetIPForwardTable2 function retrieves the IP route entries on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardtable2
func GetIPForwardTable2(family AddressFamily) ([]MibIPforwardRow2, error) {
var tab *mibIPforwardTable2
err := getIPForwardTable2(family, &tab)
if err != nil {
return nil, err
}
t := append(make([]MibIPforwardRow2, 0, tab.numEntries), tab.get()...)
tab.free()
return t, nil
}
//
// Notifications-related functions
//
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyipinterfacechange
//sys notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyIpInterfaceChange
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyunicastipaddresschange
//sys notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyUnicastIpAddressChange
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-notifyroutechange2
//sys notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) = iphlpapi.NotifyRouteChange2
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-cancelmibchangenotify2
//sys cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) = iphlpapi.CancelMibChangeNotify2
//
// DNS-related functions
//
//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
// The GUID is passed by value, not by reference, which means different
// things on different calling conventions. On amd64, this means it's
// passed by reference anyway, while on arm, arm64, and 386, it's split
// into words.
func SetInterfaceDnsSettings(guid windows.GUID, settings *DnsInterfaceSettings) error {
words := (*[4]uintptr)(unsafe.Pointer(&guid))
switch runtime.GOARCH {
case "amd64":
return setInterfaceDnsSettingsByPtr(&guid, settings)
case "arm64":
return setInterfaceDnsSettingsByQwords(words[0], words[1], settings)
case "arm", "386":
return setInterfaceDnsSettingsByDwords(words[0], words[1], words[2], words[3], settings)
default:
panic("unknown calling convention")
}
}

View File

@ -1,350 +0,0 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winipcfg
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
procConvertInterfaceGuidToLuid = modiphlpapi.NewProc("ConvertInterfaceGuidToLuid")
procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid")
procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
procCreateAnycastIpAddressEntry = modiphlpapi.NewProc("CreateAnycastIpAddressEntry")
procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
procCreateUnicastIpAddressEntry = modiphlpapi.NewProc("CreateUnicastIpAddressEntry")
procDeleteAnycastIpAddressEntry = modiphlpapi.NewProc("DeleteAnycastIpAddressEntry")
procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
procDeleteUnicastIpAddressEntry = modiphlpapi.NewProc("DeleteUnicastIpAddressEntry")
procFreeMibTable = modiphlpapi.NewProc("FreeMibTable")
procGetAnycastIpAddressEntry = modiphlpapi.NewProc("GetAnycastIpAddressEntry")
procGetAnycastIpAddressTable = modiphlpapi.NewProc("GetAnycastIpAddressTable")
procGetIfEntry2 = modiphlpapi.NewProc("GetIfEntry2")
procGetIfTable2Ex = modiphlpapi.NewProc("GetIfTable2Ex")
procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
procGetIpForwardTable2 = modiphlpapi.NewProc("GetIpForwardTable2")
procGetIpInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry")
procGetIpInterfaceTable = modiphlpapi.NewProc("GetIpInterfaceTable")
procGetUnicastIpAddressEntry = modiphlpapi.NewProc("GetUnicastIpAddressEntry")
procGetUnicastIpAddressTable = modiphlpapi.NewProc("GetUnicastIpAddressTable")
procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry")
procInitializeIpInterfaceEntry = modiphlpapi.NewProc("InitializeIpInterfaceEntry")
procInitializeUnicastIpAddressEntry = modiphlpapi.NewProc("InitializeUnicastIpAddressEntry")
procNotifyIpInterfaceChange = modiphlpapi.NewProc("NotifyIpInterfaceChange")
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
procNotifyUnicastIpAddressChange = modiphlpapi.NewProc("NotifyUnicastIpAddressChange")
procSetInterfaceDnsSettings = modiphlpapi.NewProc("SetInterfaceDnsSettings")
procSetIpForwardEntry2 = modiphlpapi.NewProc("SetIpForwardEntry2")
procSetIpInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry")
procSetUnicastIpAddressEntry = modiphlpapi.NewProc("SetUnicastIpAddressEntry")
)
func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) {
r0, _, _ := syscall.Syscall(procCancelMibChangeNotify2.Addr(), 1, uintptr(notificationHandle), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceGuidToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceIndexToLuid.Addr(), 2, uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procCreateAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
r0, _, _ := syscall.Syscall(procCreateIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procCreateUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procDeleteAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
r0, _, _ := syscall.Syscall(procDeleteIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procDeleteUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func freeMibTable(memory unsafe.Pointer) {
syscall.Syscall(procFreeMibTable.Addr(), 1, uintptr(memory), 0, 0)
return
}
func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procGetAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) {
r0, _, _ := syscall.Syscall(procGetAnycastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIfEntry2(row *MibIfRow2) (ret error) {
r0, _, _ := syscall.Syscall(procGetIfEntry2.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) {
r0, _, _ := syscall.Syscall(procGetIfTable2Ex.Addr(), 2, uintptr(level), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
r0, _, _ := syscall.Syscall(procGetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) {
r0, _, _ := syscall.Syscall(procGetIpForwardTable2.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) {
r0, _, _ := syscall.Syscall(procGetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) {
r0, _, _ := syscall.Syscall(procGetIpInterfaceTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procGetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) {
r0, _, _ := syscall.Syscall(procGetUnicastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func initializeIPForwardEntry(route *MibIPforwardRow2) {
syscall.Syscall(procInitializeIpForwardEntry.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
return
}
func initializeIPInterfaceEntry(row *MibIPInterfaceRow) {
syscall.Syscall(procInitializeIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
return
}
func initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) {
syscall.Syscall(procInitializeUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
return
}
func notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
var _p0 uint32
if initialNotification {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procNotifyIpInterfaceChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
var _p0 uint32
if initialNotification {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procNotifyRouteChange2.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
var _p0 uint32
if initialNotification {
_p0 = 1
}
r0, _, _ := syscall.Syscall6(procNotifyUnicastIpAddressChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
}
r0, _, _ := syscall.Syscall6(procSetInterfaceDnsSettings.Addr(), 5, uintptr(guid1), uintptr(guid2), uintptr(guid3), uintptr(guid4), uintptr(unsafe.Pointer(settings)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
}
r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings)))
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
}
r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) {
r0, _, _ := syscall.Syscall(procSetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
r0, _, _ := syscall.Syscall(procSetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}

View File

@ -0,0 +1,41 @@
//go:build windows
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"errors"
"log"
"sync"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
)
var (
startedAtBoot bool
startedAtBootOnce sync.Once
)
func StartedAtBoot() bool {
startedAtBootOnce.Do(func() {
if isService, err := svc.IsWindowsService(); err == nil && !isService {
return
}
if reason, err := svc.DynamicStartReason(); err == nil {
startedAtBoot = (reason&svc.StartReasonAuto) != 0 || (reason&svc.StartReasonDelayedAuto) != 0
} else if errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
// TODO: Below this line is Windows 7 compatibility code, which hopefully we can delete at some point.
startedAtBoot = windows.DurationSinceBoot() < time.Minute*10
} else {
log.Printf("Unable to determine service start reason: %v", err)
}
})
return startedAtBoot
}

View File

@ -1,7 +1,7 @@
//go:build windows
// +build windows
package winipcfg
package wintun
import (
"fmt"

View File

@ -1,52 +0,0 @@
//go:build !load_wintun_from_rsrc
// +build !load_wintun_from_rsrc
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
C "github.com/Dreamacro/clash/constant"
"golang.org/x/sys/windows"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
//const (
// LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
// LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
//)
//module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
module, err := windows.LoadLibraryEx(C.Path.GetAssetLocation(d.Name), 0, windows.LOAD_WITH_ALTERED_SEARCH_PATH)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}

View File

@ -1,57 +0,0 @@
//go:build load_wintun_from_rsrc
// +build load_wintun_from_rsrc
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"github.com/Dreamacro/clash/listener/tun/dev/wintun/memmod"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module *memmod.Module
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != nil {
return nil
}
const ourModule windows.Handle = 0
resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA)
if err != nil {
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
}
data, err := windows.LoadResourceData(ourModule, resInfo)
if err != nil {
return fmt.Errorf("Unable to load resource: %w", err)
}
module, err := memmod.LoadLibrary(data)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return p.dll.module.ProcAddressByName(p.Name)
}

View File

@ -1,3 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
@ -5,6 +10,9 @@ import (
"sync"
"sync/atomic"
"unsafe"
C "github.com/Dreamacro/clash/constant"
"golang.org/x/sys/windows"
)
func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
@ -52,3 +60,71 @@ func (p *lazyProc) Addr() uintptr {
}
return p.addr
}
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
//const (
// LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
// LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
//)
//module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
module, err := windows.LoadLibraryEx(C.Path.GetAssetLocation(d.Name), 0, windows.LOAD_WITH_ALTERED_SEARCH_PATH)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}
// Version returns the version of the Wintun DLL.
func Version() string {
if modwintun.Load() != nil {
return "unknown"
}
resInfo, err := windows.FindResource(modwintun.module, windows.ResourceID(1), windows.RT_VERSION)
if err != nil {
return "unknown"
}
data, err := windows.LoadResourceData(modwintun.module, resInfo)
if err != nil {
return "unknown"
}
var fixedInfo *windows.VS_FIXEDFILEINFO
fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
if err != nil {
return "unknown"
}
version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff)
if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 {
version += fmt.Sprintf(".%d", nextNibble)
}
if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 {
version += fmt.Sprintf(".%d", nextNibble)
}
return version
}

View File

@ -1,635 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type addressList struct {
next *addressList
address uintptr
}
func (head *addressList) free() {
for node := head; node != nil; node = node.next {
windows.VirtualFree(node.address, 0, windows.MEM_RELEASE)
}
}
type Module struct {
headers *IMAGE_NT_HEADERS
codeBase uintptr
modules []windows.Handle
initialized bool
isDLL bool
isRelocated bool
nameExports map[string]uint16
entry uintptr
blockedMemory *addressList
}
func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY {
return &module.headers.OptionalHeader.DataDirectory[idx]
}
func (module *Module) copySections(address uintptr, size uintptr, oldHeaders *IMAGE_NT_HEADERS) error {
sections := module.headers.Sections()
for i := range sections {
if sections[i].SizeOfRawData == 0 {
// Section doesn't contain data in the dll itself, but may define uninitialized data.
sectionSize := oldHeaders.OptionalHeader.SectionAlignment
if sectionSize == 0 {
continue
}
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
uintptr(sectionSize),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating section: %w", err)
}
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
dest = module.codeBase + uintptr(sections[i].VirtualAddress)
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
var dst []byte
unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
for j := range dst {
dst[j] = 0
}
continue
}
if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) {
return errors.New("Incomplete section")
}
// Commit memory block and copy data from dll.
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
uintptr(sections[i].SizeOfRawData),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating memory block: %w", err)
}
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
memcpy(
module.codeBase+uintptr(sections[i].VirtualAddress),
address+uintptr(sections[i].PointerToRawData),
uintptr(sections[i].SizeOfRawData))
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
}
return nil
}
func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr {
size := section.SizeOfRawData
if size != 0 {
return uintptr(size)
}
if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 {
return uintptr(module.headers.OptionalHeader.SizeOfInitializedData)
}
if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 {
return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData)
}
return 0
}
type sectionFinalizeData struct {
address uintptr
alignedAddress uintptr
size uintptr
characteristics uint32
last bool
}
func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error {
if sectionData.size == 0 {
return nil
}
if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 {
// Section is not needed any more and can safely be freed.
if sectionData.address == sectionData.alignedAddress &&
(sectionData.last ||
(sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) {
// Only allowed to decommit whole pages.
windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT)
}
return nil
}
// determine protection flags based on characteristics
var ProtectionFlags = [8]uint32{
windows.PAGE_NOACCESS, // not writeable, not readable, not executable
windows.PAGE_EXECUTE, // not writeable, not readable, executable
windows.PAGE_READONLY, // not writeable, readable, not executable
windows.PAGE_EXECUTE_READ, // not writeable, readable, executable
windows.PAGE_WRITECOPY, // writeable, not readable, not executable
windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable
windows.PAGE_READWRITE, // writeable, readable, not executable
windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable
}
protect := ProtectionFlags[sectionData.characteristics>>29]
if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 {
protect |= windows.PAGE_NOCACHE
}
// Change memory access flags.
var oldProtect uint32
err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect)
if err != nil {
return fmt.Errorf("Error protecting memory page: %w", err)
}
return nil
}
var rtlAddFunctionTable = windows.NewLazySystemDLL("ntdll.dll").NewProc("RtlAddFunctionTable")
func (module *Module) registerExceptionHandlers() {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXCEPTION)
if directory.Size == 0 || directory.VirtualAddress == 0 {
return
}
rtlAddFunctionTable.Call(module.codeBase+uintptr(directory.VirtualAddress), uintptr(directory.Size)/unsafe.Sizeof(IMAGE_RUNTIME_FUNCTION_ENTRY{}), module.codeBase)
}
func (module *Module) finalizeSections() error {
sections := module.headers.Sections()
imageOffset := module.headers.OptionalHeader.imageOffset()
sectionData := sectionFinalizeData{}
sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset
sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment))
sectionData.size = module.realSectionSize(&sections[0])
sections[0].SetVirtualSize(uint32(sectionData.size))
sectionData.characteristics = sections[0].Characteristics
// Loop through all sections and change access flags.
for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ {
sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset
alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment))
sectionSize := module.realSectionSize(&sections[i])
sections[i].SetVirtualSize(uint32(sectionSize))
// Combine access flags of all sections that share a page.
// TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized.
if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress {
// Section shares page with previous.
if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 {
sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE
} else {
sectionData.characteristics |= sections[i].Characteristics
}
sectionData.size = sectionAddress + sectionSize - sectionData.address
continue
}
err := module.finalizeSection(&sectionData)
if err != nil {
return fmt.Errorf("Error finalizing section: %w", err)
}
sectionData.address = sectionAddress
sectionData.alignedAddress = alignedAddress
sectionData.size = sectionSize
sectionData.characteristics = sections[i].Characteristics
}
sectionData.last = true
err := module.finalizeSection(&sectionData)
if err != nil {
return fmt.Errorf("Error finalizing section: %w", err)
}
return nil
}
func (module *Module) executeTLS() {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS)
if directory.VirtualAddress == 0 {
return
}
tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
callback := tls.AddressOfCallbacks
if callback != 0 {
for {
f := *(*uintptr)(a2p(callback))
if f == 0 {
break
}
syscall.Syscall(f, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), uintptr(0))
callback += unsafe.Sizeof(f)
}
}
}
func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC)
if directory.Size == 0 {
return delta == 0, nil
}
relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
for relocationHdr.VirtualAddress > 0 {
dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
var relInfos []uint16
unsafeSlice(
unsafe.Pointer(&relInfos),
a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
for _, relInfo := range relInfos {
// The upper 4 bits define the type of relocation.
relType := relInfo >> 12
// The lower 12 bits define the offset.
relOffset := uintptr(relInfo & 0xfff)
switch relType {
case IMAGE_REL_BASED_ABSOLUTE:
// Skip relocation.
case IMAGE_REL_BASED_LOW:
*(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff)
break
case IMAGE_REL_BASED_HIGH:
*(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16)
break
case IMAGE_REL_BASED_HIGHLOW:
*(*uint32)(a2p(dest + relOffset)) += uint32(delta)
case IMAGE_REL_BASED_DIR64:
*(*uint64)(a2p(dest + relOffset)) += uint64(delta)
case IMAGE_REL_BASED_THUMB_MOV32:
inst := *(*uint32)(a2p(dest + relOffset))
imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
if (inst & 0x8000fbf0) != 0x0000f240 {
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst)
}
imm16 += uint32(delta) & 0xffff
hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16)
*(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) +
((imm16 >> 12) & 0x000f) +
((imm16 << 20) & 0x70000000) +
((imm16 << 16) & 0xff0000)
if hiDelta != 0 {
inst = *(*uint32)(a2p(dest + relOffset + 4))
imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
if (inst & 0x8000fbf0) != 0x0000f2c0 {
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst)
}
imm16 += hiDelta
if imm16 > 0xffff {
return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16)
}
*(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) +
((imm16 >> 1) & 0x0400) +
((imm16 >> 12) & 0x000f) +
((imm16 << 20) & 0x70000000) +
((imm16 << 16) & 0xff0000)
}
default:
return false, fmt.Errorf("Unsupported relocation: %v", relType)
}
}
// Advance to next relocation block.
relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock)))
}
return true, nil
}
func (module *Module) buildImportTable() error {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT)
if directory.Size == 0 {
return nil
}
module.modules = make([]windows.Handle, 0, 16)
importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
for importDesc.Name != 0 {
handle, err := windows.LoadLibraryEx(windows.BytePtrToString((*byte)(a2p(module.codeBase+uintptr(importDesc.Name)))), 0, windows.LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Error loading module: %w", err)
}
var thunkRef, funcRef *uintptr
if importDesc.OriginalFirstThunk() != 0 {
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk())))
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
} else {
// No hint table.
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
}
for *thunkRef != 0 {
if IMAGE_SNAP_BY_ORDINAL(*thunkRef) {
*funcRef, err = windows.GetProcAddressByOrdinal(handle, IMAGE_ORDINAL(*thunkRef))
} else {
thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef))
*funcRef, err = windows.GetProcAddress(handle, windows.BytePtrToString(&thunkData.Name[0]))
}
if err != nil {
windows.FreeLibrary(handle)
return fmt.Errorf("Error getting function address: %w", err)
}
thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef)))
funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef)))
}
module.modules = append(module.modules, handle)
importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc)))
}
return nil
}
func (module *Module) buildNameExports() error {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 {
return errors.New("No functions exported")
}
if exports.NumberOfNames == 0 {
return errors.New("No functions exported by name")
}
var nameRefs []uint32
unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
var ordinals []uint16
unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
module.nameExports = make(map[string]uint16)
for i := range nameRefs {
nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
module.nameExports[nameArray] = ordinals[i]
}
return nil
}
// LoadLibrary loads module image to memory.
func LoadLibrary(data []byte) (module *Module, err error) {
addr := uintptr(unsafe.Pointer(&data[0]))
size := uintptr(len(data))
if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) {
return nil, errors.New("Incomplete IMAGE_DOS_HEADER")
}
dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr))
if dosHeader.E_magic != IMAGE_DOS_SIGNATURE {
return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE)
}
if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) {
return nil, errors.New("Incomplete IMAGE_NT_HEADERS")
}
oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew)))
if oldHeader.Signature != IMAGE_NT_SIGNATURE {
return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE)
}
if oldHeader.FileHeader.Machine != imageFileProcess {
return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess)
}
if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 {
return nil, errors.New("Unaligned section")
}
lastSectionEnd := uintptr(0)
sections := oldHeader.Sections()
optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment
for i := range sections {
var endOfSection uintptr
if sections[i].SizeOfRawData == 0 {
// Section without data in the DLL
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize)
} else {
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData)
}
if endOfSection > lastSectionEnd {
lastSectionEnd = endOfSection
}
}
alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment))
if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) {
return nil, errors.New("Section is not page-aligned")
}
module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0}
defer func() {
if err != nil {
module.Free()
module = nil
}
}()
// Reserve memory for image of library.
// TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't.
module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
// Try to allocate memory at arbitrary position.
module.codeBase, err = windows.VirtualAlloc(0,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
err = fmt.Errorf("Error allocating code: %w", err)
return
}
}
err = module.check4GBBoundaries(alignedImageSize)
if err != nil {
err = fmt.Errorf("Error reallocating code: %w", err)
return
}
if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) {
err = errors.New("Incomplete headers")
return
}
// Commit memory for headers.
headers, err := windows.VirtualAlloc(module.codeBase,
uintptr(oldHeader.OptionalHeader.SizeOfHeaders),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
err = fmt.Errorf("Error allocating headers: %w", err)
return
}
// Copy PE header to code.
memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders))
module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew)))
// Update position.
module.headers.OptionalHeader.ImageBase = module.codeBase
// Copy sections from DLL file block to new memory location.
err = module.copySections(addr, size, oldHeader)
if err != nil {
err = fmt.Errorf("Error copying sections: %w", err)
return
}
// Adjust base address of imported data.
locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase
if locationDelta != 0 {
module.isRelocated, err = module.performBaseRelocation(locationDelta)
if err != nil {
err = fmt.Errorf("Error relocating module: %w", err)
return
}
} else {
module.isRelocated = true
}
// Load required dlls and adjust function table of imports.
err = module.buildImportTable()
if err != nil {
err = fmt.Errorf("Error building import table: %w", err)
return
}
// Mark memory pages depending on section headers and release sections that are marked as "discardable".
err = module.finalizeSections()
if err != nil {
err = fmt.Errorf("Error finalizing sections: %w", err)
return
}
// Register exception tables, if they exist.
module.registerExceptionHandlers()
// TLS callbacks are executed BEFORE the main loading.
module.executeTLS()
// Get entry point of loaded module.
if module.headers.OptionalHeader.AddressOfEntryPoint != 0 {
module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint)
if module.isDLL {
// Notify library about attaching to process.
r0, _, _ := syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), 0)
successful := r0 != 0
if !successful {
err = windows.ERROR_DLL_INIT_FAILED
return
}
module.initialized = true
}
}
module.buildNameExports()
return
}
// Free releases module resources and unloads it.
func (module *Module) Free() {
if module.initialized {
// Notify library about detaching from process.
syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_DETACH), 0)
module.initialized = false
}
if module.modules != nil {
// Free previously opened libraries.
for _, handle := range module.modules {
windows.FreeLibrary(handle)
}
module.modules = nil
}
if module.codeBase != 0 {
windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE)
module.codeBase = 0
}
if module.blockedMemory != nil {
module.blockedMemory.free()
module.blockedMemory = nil
}
}
// ProcAddressByName returns function address by exported name.
func (module *Module) ProcAddressByName(name string) (uintptr, error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return 0, errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if module.nameExports == nil {
return 0, errors.New("No functions exported by name")
}
if idx, ok := module.nameExports[name]; ok {
if uint32(idx) > exports.NumberOfFunctions {
return 0, errors.New("Ordinal number too high")
}
// AddressOfFunctions contains the RVAs to the "real" functions.
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
}
return 0, errors.New("Function not found by name")
}
// ProcAddressByOrdinal returns function address by exported ordinal.
func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return 0, errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if uint32(ordinal) < exports.Base {
return 0, errors.New("Ordinal number too low")
}
idx := ordinal - uint16(exports.Base)
if uint32(idx) > exports.NumberOfFunctions {
return 0, errors.New("Ordinal number too high")
}
// AddressOfFunctions contains the RVAs to the "real" functions.
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
}
func alignDown(value, alignment uintptr) uintptr {
return value & ^(alignment - 1)
}
func alignUp(value, alignment uintptr) uintptr {
return (value + alignment - 1) & ^(alignment - 1)
}
func a2p(addr uintptr) unsafe.Pointer {
return unsafe.Pointer(addr)
}
func memcpy(dst, src, size uintptr) {
var d, s []byte
unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
copy(d, s)
}
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@ -1,18 +0,0 @@
//go:build windows && (386 || arm)
// +build windows
// +build 386 arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
return 0
}
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
return
}

View File

@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_I386

View File

@ -1,38 +0,0 @@
//go:build windows && (amd64 || arm64)
// +build windows
// +build amd64 arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import (
"fmt"
"golang.org/x/sys/windows"
)
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
return uintptr(opthdr.ImageBase & 0xffffffff00000000)
}
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) {
node := &addressList{
next: module.blockedMemory,
address: module.codeBase,
}
module.blockedMemory = node
module.codeBase, err = windows.VirtualAlloc(0,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating memory block: %w", err)
}
}
return
}

View File

@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_AMD64

View File

@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT

View File

@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_ARM64

View File

@ -1,398 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
import "unsafe"
const (
IMAGE_DOS_SIGNATURE = 0x5A4D // MZ
IMAGE_OS2_SIGNATURE = 0x454E // NE
IMAGE_OS2_SIGNATURE_LE = 0x454C // LE
IMAGE_VXD_SIGNATURE = 0x454C // LE
IMAGE_NT_SIGNATURE = 0x00004550 // PE00
)
// DOS .EXE header
type IMAGE_DOS_HEADER struct {
E_magic uint16 // Magic number
E_cblp uint16 // Bytes on last page of file
E_cp uint16 // Pages in file
E_crlc uint16 // Relocations
E_cparhdr uint16 // Size of header in paragraphs
E_minalloc uint16 // Minimum extra paragraphs needed
E_maxalloc uint16 // Maximum extra paragraphs needed
E_ss uint16 // Initial (relative) SS value
E_sp uint16 // Initial SP value
E_csum uint16 // Checksum
E_ip uint16 // Initial IP value
E_cs uint16 // Initial (relative) CS value
E_lfarlc uint16 // File address of relocation table
E_ovno uint16 // Overlay number
E_res [4]uint16 // Reserved words
E_oemid uint16 // OEM identifier (for e_oeminfo)
E_oeminfo uint16 // OEM information; e_oemid specific
E_res2 [10]uint16 // Reserved words
E_lfanew int32 // File address of new exe header
}
// File header format
type IMAGE_FILE_HEADER struct {
Machine uint16
NumberOfSections uint16
TimeDateStamp uint32
PointerToSymbolTable uint32
NumberOfSymbols uint32
SizeOfOptionalHeader uint16
Characteristics uint16
}
const (
IMAGE_SIZEOF_FILE_HEADER = 20
IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file.
IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references).
IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file.
IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file.
IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set
IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses
IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed.
IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine.
IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file
IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file.
IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file.
IMAGE_FILE_SYSTEM = 0x1000 // System File.
IMAGE_FILE_DLL = 0x2000 // File is a DLL.
IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine
IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed.
IMAGE_FILE_MACHINE_UNKNOWN = 0
IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest.
IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386.
IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian
IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian
IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian
IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2
IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP
IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian
IMAGE_FILE_MACHINE_SH3DSP = 0x01a3
IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian
IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian
IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5
IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian
IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian
IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian
IMAGE_FILE_MACHINE_AM33 = 0x01d3
IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian
IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1
IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64
IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS
IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64
IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS
IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS
IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64
IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon
IMAGE_FILE_MACHINE_CEF = 0x0CEF
IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code
IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8)
IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian
IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian
IMAGE_FILE_MACHINE_CEE = 0xC0EE
)
// Directory format
type IMAGE_DATA_DIRECTORY struct {
VirtualAddress uint32
Size uint32
}
const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16
type IMAGE_NT_HEADERS struct {
Signature uint32
FileHeader IMAGE_FILE_HEADER
OptionalHeader IMAGE_OPTIONAL_HEADER
}
func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER {
return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer(
(uintptr)(unsafe.Pointer(ntheader)) +
unsafe.Offsetof(ntheader.OptionalHeader) +
uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections]
}
const (
IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory
IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory
IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory
IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory
IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory
IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table
IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory
IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage)
IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data
IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP
IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory
IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory
IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers
IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table
IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors
IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor
)
const IMAGE_SIZEOF_SHORT_NAME = 8
// Section header format
type IMAGE_SECTION_HEADER struct {
Name [IMAGE_SIZEOF_SHORT_NAME]byte
physicalAddressOrVirtualSize uint32
VirtualAddress uint32
SizeOfRawData uint32
PointerToRawData uint32
PointerToRelocations uint32
PointerToLinenumbers uint32
NumberOfRelocations uint16
NumberOfLinenumbers uint16
Characteristics uint32
}
func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 {
return ishdr.physicalAddressOrVirtualSize
}
func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) {
ishdr.physicalAddressOrVirtualSize = addr
}
func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 {
return ishdr.physicalAddressOrVirtualSize
}
func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) {
ishdr.physicalAddressOrVirtualSize = addr
}
const (
// Dll characteristics.
IMAGE_DLL_CHARACTERISTICS_HIGH_ENTROPY_VA = 0x0020
IMAGE_DLL_CHARACTERISTICS_DYNAMIC_BASE = 0x0040
IMAGE_DLL_CHARACTERISTICS_FORCE_INTEGRITY = 0x0080
IMAGE_DLL_CHARACTERISTICS_NX_COMPAT = 0x0100
IMAGE_DLL_CHARACTERISTICS_NO_ISOLATION = 0x0200
IMAGE_DLL_CHARACTERISTICS_NO_SEH = 0x0400
IMAGE_DLL_CHARACTERISTICS_NO_BIND = 0x0800
IMAGE_DLL_CHARACTERISTICS_APPCONTAINER = 0x1000
IMAGE_DLL_CHARACTERISTICS_WDM_DRIVER = 0x2000
IMAGE_DLL_CHARACTERISTICS_GUARD_CF = 0x4000
IMAGE_DLL_CHARACTERISTICS_TERMINAL_SERVER_AWARE = 0x8000
)
const (
// Section characteristics.
IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved.
IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved.
IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved.
IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved.
IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved.
IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved.
IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code.
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data.
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data.
IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved.
IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information.
IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved.
IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image.
IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat.
IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete.
IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section.
IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP
IMAGE_SCN_MEM_FARDATA = 0x00008000
IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete.
IMAGE_SCN_MEM_PURGEABLE = 0x00020000
IMAGE_SCN_MEM_16BIT = 0x00020000
IMAGE_SCN_MEM_LOCKED = 0x00040000
IMAGE_SCN_MEM_PRELOAD = 0x00080000
IMAGE_SCN_ALIGN_1BYTES = 0x00100000 //
IMAGE_SCN_ALIGN_2BYTES = 0x00200000 //
IMAGE_SCN_ALIGN_4BYTES = 0x00300000 //
IMAGE_SCN_ALIGN_8BYTES = 0x00400000 //
IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified.
IMAGE_SCN_ALIGN_32BYTES = 0x00600000 //
IMAGE_SCN_ALIGN_64BYTES = 0x00700000 //
IMAGE_SCN_ALIGN_128BYTES = 0x00800000 //
IMAGE_SCN_ALIGN_256BYTES = 0x00900000 //
IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 //
IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 //
IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 //
IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 //
IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 //
IMAGE_SCN_ALIGN_MASK = 0x00F00000
IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations.
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded.
IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable.
IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable.
IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable.
IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable.
IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable.
IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable.
// TLS Characteristic Flags
IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled.
)
// Based relocation format
type IMAGE_BASE_RELOCATION struct {
VirtualAddress uint32
SizeOfBlock uint32
}
const (
IMAGE_REL_BASED_ABSOLUTE = 0
IMAGE_REL_BASED_HIGH = 1
IMAGE_REL_BASED_LOW = 2
IMAGE_REL_BASED_HIGHLOW = 3
IMAGE_REL_BASED_HIGHADJ = 4
IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5
IMAGE_REL_BASED_RESERVED = 6
IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7
IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8
IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9
IMAGE_REL_BASED_DIR64 = 10
IMAGE_REL_BASED_IA64_IMM64 = 9
IMAGE_REL_BASED_MIPS_JMPADDR = 5
IMAGE_REL_BASED_MIPS_JMPADDR16 = 9
IMAGE_REL_BASED_ARM_MOV32 = 5
IMAGE_REL_BASED_THUMB_MOV32 = 7
)
// Export Format
type IMAGE_EXPORT_DIRECTORY struct {
Characteristics uint32
TimeDateStamp uint32
MajorVersion uint16
MinorVersion uint16
Name uint32
Base uint32
NumberOfFunctions uint32
NumberOfNames uint32
AddressOfFunctions uint32 // RVA from base of image
AddressOfNames uint32 // RVA from base of image
AddressOfNameOrdinals uint32 // RVA from base of image
}
type IMAGE_IMPORT_BY_NAME struct {
Hint uint16
Name [1]byte
}
func IMAGE_ORDINAL(ordinal uintptr) uintptr {
return ordinal & 0xffff
}
func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool {
return (ordinal & IMAGE_ORDINAL_FLAG) != 0
}
// Thread Local Storage
type IMAGE_TLS_DIRECTORY struct {
StartAddressOfRawData uintptr
EndAddressOfRawData uintptr
AddressOfIndex uintptr // PDWORD
AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *;
SizeOfZeroFill uint32
Characteristics uint32
}
type IMAGE_IMPORT_DESCRIPTOR struct {
characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor
// RVA to original unbound IAT (PIMAGE_THUNK_DATA)
TimeDateStamp uint32 // 0 if not bound,
// -1 if bound, and real date\time stamp
// in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND)
// O.W. date/time stamp of DLL bound to (Old BIND)
ForwarderChain uint32 // -1 if no forwarders
Name uint32
FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses)
}
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 {
return imgimpdesc.characteristicsOrOriginalFirstThunk
}
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 {
return imgimpdesc.characteristicsOrOriginalFirstThunk
}
type IMAGE_DELAYLOAD_DESCRIPTOR struct {
Attributes uint32
DllNameRVA uint32
ModuleHandleRVA uint32
ImportAddressTableRVA uint32
ImportNameTableRVA uint32
BoundImportAddressTableRVA uint32
UnloadInformationTableRVA uint32
TimeDateStamp uint32
}
type IMAGE_LOAD_CONFIG_CODE_INTEGRITY struct {
Flags uint16
Catalog uint16
CatalogOffset uint32
Reserved uint32
}
const (
IMAGE_GUARD_CF_INSTRUMENTED = 0x00000100
IMAGE_GUARD_CFW_INSTRUMENTED = 0x00000200
IMAGE_GUARD_CF_FUNCTION_TABLE_PRESENT = 0x00000400
IMAGE_GUARD_SECURITY_COOKIE_UNUSED = 0x00000800
IMAGE_GUARD_PROTECT_DELAYLOAD_IAT = 0x00001000
IMAGE_GUARD_DELAYLOAD_IAT_IN_ITS_OWN_SECTION = 0x00002000
IMAGE_GUARD_CF_EXPORT_SUPPRESSION_INFO_PRESENT = 0x00004000
IMAGE_GUARD_CF_ENABLE_EXPORT_SUPPRESSION = 0x00008000
IMAGE_GUARD_CF_LONGJUMP_TABLE_PRESENT = 0x00010000
IMAGE_GUARD_RF_INSTRUMENTED = 0x00020000
IMAGE_GUARD_RF_ENABLE = 0x00040000
IMAGE_GUARD_RF_STRICT = 0x00080000
IMAGE_GUARD_RETPOLINE_PRESENT = 0x00100000
IMAGE_GUARD_EH_CONTINUATION_TABLE_PRESENT = 0x00400000
IMAGE_GUARD_XFG_ENABLED = 0x00800000
IMAGE_GUARD_CF_FUNCTION_TABLE_SIZE_MASK = 0xF0000000
IMAGE_GUARD_CF_FUNCTION_TABLE_SIZE_SHIFT = 28
)
type IMAGE_RUNTIME_FUNCTION_ENTRY struct {
BeginAddress uint32
EndAddress uint32
UnwindInfoAddress uint32
}
const (
DLL_PROCESS_ATTACH = 1
DLL_THREAD_ATTACH = 2
DLL_THREAD_DETACH = 3
DLL_PROCESS_DETACH = 0
)
type SYSTEM_INFO struct {
ProcessorArchitecture uint16
Reserved uint16
PageSize uint32
MinimumApplicationAddress uintptr
MaximumApplicationAddress uintptr
ActiveProcessorMask uintptr
NumberOfProcessors uint32
ProcessorType uint32
AllocationGranularity uint32
ProcessorLevel uint16
ProcessorRevision uint16
}

View File

@ -1,98 +0,0 @@
//go:build windows && (386 || arm)
// +build windows
// +build 386 arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
// Optional header format
type IMAGE_OPTIONAL_HEADER struct {
Magic uint16
MajorLinkerVersion uint8
MinorLinkerVersion uint8
SizeOfCode uint32
SizeOfInitializedData uint32
SizeOfUninitializedData uint32
AddressOfEntryPoint uint32
BaseOfCode uint32
BaseOfData uint32
ImageBase uintptr
SectionAlignment uint32
FileAlignment uint32
MajorOperatingSystemVersion uint16
MinorOperatingSystemVersion uint16
MajorImageVersion uint16
MinorImageVersion uint16
MajorSubsystemVersion uint16
MinorSubsystemVersion uint16
Win32VersionValue uint32
SizeOfImage uint32
SizeOfHeaders uint32
CheckSum uint32
Subsystem uint16
DllCharacteristics uint16
SizeOfStackReserve uintptr
SizeOfStackCommit uintptr
SizeOfHeapReserve uintptr
SizeOfHeapCommit uintptr
LoaderFlags uint32
NumberOfRvaAndSizes uint32
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
}
const IMAGE_ORDINAL_FLAG uintptr = 0x80000000
type IMAGE_LOAD_CONFIG_DIRECTORY struct {
Size uint32
TimeDateStamp uint32
MajorVersion uint16
MinorVersion uint16
GlobalFlagsClear uint32
GlobalFlagsSet uint32
CriticalSectionDefaultTimeout uint32
DeCommitFreeBlockThreshold uint32
DeCommitTotalFreeThreshold uint32
LockPrefixTable uint32
MaximumAllocationSize uint32
VirtualMemoryThreshold uint32
ProcessHeapFlags uint32
ProcessAffinityMask uint32
CSDVersion uint16
DependentLoadFlags uint16
EditList uint32
SecurityCookie uint32
SEHandlerTable uint32
SEHandlerCount uint32
GuardCFCheckFunctionPointer uint32
GuardCFDispatchFunctionPointer uint32
GuardCFFunctionTable uint32
GuardCFFunctionCount uint32
GuardFlags uint32
CodeIntegrity IMAGE_LOAD_CONFIG_CODE_INTEGRITY
GuardAddressTakenIatEntryTable uint32
GuardAddressTakenIatEntryCount uint32
GuardLongJumpTargetTable uint32
GuardLongJumpTargetCount uint32
DynamicValueRelocTable uint32
CHPEMetadataPointer uint32
GuardRFFailureRoutine uint32
GuardRFFailureRoutineFunctionPointer uint32
DynamicValueRelocTableOffset uint32
DynamicValueRelocTableSection uint16
Reserved2 uint16
GuardRFVerifyStackPointerFunctionPointer uint32
HotPatchTableOffset uint32
Reserved3 uint32
EnclaveConfigurationPointer uint32
VolatileMetadataPointer uint32
GuardEHContinuationTable uint32
GuardEHContinuationCount uint32
GuardXFGCheckFunctionPointer uint32
GuardXFGDispatchFunctionPointer uint32
GuardXFGTableDispatchFunctionPointer uint32
CastGuardOsDeterminedFailureMode uint32
}

View File

@ -1,97 +0,0 @@
//go:build windows && (amd64 || arm64)
// +build windows
// +build amd64 arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package memmod
// Optional header format
type IMAGE_OPTIONAL_HEADER struct {
Magic uint16
MajorLinkerVersion uint8
MinorLinkerVersion uint8
SizeOfCode uint32
SizeOfInitializedData uint32
SizeOfUninitializedData uint32
AddressOfEntryPoint uint32
BaseOfCode uint32
ImageBase uintptr
SectionAlignment uint32
FileAlignment uint32
MajorOperatingSystemVersion uint16
MinorOperatingSystemVersion uint16
MajorImageVersion uint16
MinorImageVersion uint16
MajorSubsystemVersion uint16
MinorSubsystemVersion uint16
Win32VersionValue uint32
SizeOfImage uint32
SizeOfHeaders uint32
CheckSum uint32
Subsystem uint16
DllCharacteristics uint16
SizeOfStackReserve uintptr
SizeOfStackCommit uintptr
SizeOfHeapReserve uintptr
SizeOfHeapCommit uintptr
LoaderFlags uint32
NumberOfRvaAndSizes uint32
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
}
const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000
type IMAGE_LOAD_CONFIG_DIRECTORY struct {
Size uint32
TimeDateStamp uint32
MajorVersion uint16
MinorVersion uint16
GlobalFlagsClear uint32
GlobalFlagsSet uint32
CriticalSectionDefaultTimeout uint32
DeCommitFreeBlockThreshold uint64
DeCommitTotalFreeThreshold uint64
LockPrefixTable uint64
MaximumAllocationSize uint64
VirtualMemoryThreshold uint64
ProcessAffinityMask uint64
ProcessHeapFlags uint32
CSDVersion uint16
DependentLoadFlags uint16
EditList uint64
SecurityCookie uint64
SEHandlerTable uint64
SEHandlerCount uint64
GuardCFCheckFunctionPointer uint64
GuardCFDispatchFunctionPointer uint64
GuardCFFunctionTable uint64
GuardCFFunctionCount uint64
GuardFlags uint32
CodeIntegrity IMAGE_LOAD_CONFIG_CODE_INTEGRITY
GuardAddressTakenIatEntryTable uint64
GuardAddressTakenIatEntryCount uint64
GuardLongJumpTargetTable uint64
GuardLongJumpTargetCount uint64
DynamicValueRelocTable uint64
CHPEMetadataPointer uint64
GuardRFFailureRoutine uint64
GuardRFFailureRoutineFunctionPointer uint64
DynamicValueRelocTableOffset uint32
DynamicValueRelocTableSection uint16
Reserved2 uint16
GuardRFVerifyStackPointerFunctionPointer uint64
HotPatchTableOffset uint32
Reserved3 uint32
EnclaveConfigurationPointer uint64
VolatileMetadataPointer uint64
GuardEHContinuationTable uint64
GuardEHContinuationCount uint64
GuardXFGCheckFunctionPointer uint64
GuardXFGDispatchFunctionPointer uint64
GuardXFGTableDispatchFunctionPointer uint64
CastGuardOsDeterminedFailureMode uint64
}

View File

@ -1,3 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
@ -62,7 +67,7 @@ func (session Session) ReceivePacket() (packet []byte, err error) {
err = e1
return
}
unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize))
packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
return
}
@ -76,28 +81,10 @@ func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err er
err = e1
return
}
unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize))
packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize)
return
}
func (session Session) SendPacket(packet []byte) {
syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
}
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@ -1,7 +1,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"errors"
"runtime"
"syscall"
"unsafe"
@ -18,193 +22,121 @@ const (
logErr
)
const (
PoolNameMax = 256
AdapterNameMax = 128
)
const AdapterNameMax = 128
type Pool [PoolNameMax]uint16
type Adapter struct {
handle uintptr
}
var (
modwintun = newLazyDLL("wintun.dll", setupLogger)
modwintun = newLazyDLL("wintun.dll", setupLogger)
procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter")
procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter")
procWintunDeletePoolDriver = modwintun.NewProc("WintunDeletePoolDriver")
procWintunEnumAdapters = modwintun.NewProc("WintunEnumAdapters")
procWintunFreeAdapter = modwintun.NewProc("WintunFreeAdapter")
procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter")
procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter")
procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver")
procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID")
procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName")
procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName")
)
func logMessage(level loggerLevel, _ uint64, msg *uint16) int {
var lv log.LogLevel
switch level {
case logInfo:
lv = log.INFO
case logWarn:
lv = log.WARNING
case logErr:
lv = log.ERROR
default:
lv = log.INFO
}
log.PrintLog(lv, "[Wintun] %s", windows.UTF16PtrToString(msg))
return 0
}
func setupLogger(dll *lazyDLL) {
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int {
var lv log.LogLevel
switch level {
case logInfo:
lv = log.INFO
case logWarn:
lv = log.WARNING
case logErr:
lv = log.ERROR
default:
lv = log.INFO
}
log.PrintLog(lv, "[Wintun] %s", windows.UTF16PtrToString(msg))
return 0
}), 0, 0)
var callback uintptr
if runtime.GOARCH == "386" {
callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
})
} else if runtime.GOARCH == "arm" {
callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int {
return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
})
} else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
callback = windows.NewCallback(logMessage)
}
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0)
}
func MakePool(poolName string) (pool *Pool, err error) {
poolName16, err := windows.UTF16FromString(poolName)
func closeAdapter(wintun *Adapter) {
syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
}
// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter.
// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is
// the GUID of the created network adapter, which then influences NLA generation
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
// and hence a new NLA entry is created for each new adapter.
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {
return
}
if len(poolName16) > PoolNameMax {
err = errors.New("Pool name too long")
var tunnelType16 *uint16
tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
if err != nil {
return
}
pool = &Pool{}
copy(pool[:], poolName16)
return
}
func (pool *Pool) String() string {
return windows.UTF16ToString(pool[:])
}
func freeAdapter(wintun *Adapter) {
syscall.Syscall(procWintunFreeAdapter.Addr(), 1, uintptr(wintun.handle), 0, 0)
}
// OpenAdapter finds a Wintun adapter by its name. This function returns the adapter if found, or
// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a Wintun-class or a
// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be
// released after use.
func (pool *Pool) OpenAdapter(ifname string) (wintun *Adapter, err error) {
ifname16, err := windows.UTF16PtrFromString(ifname)
if err != nil {
return nil, err
}
r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0)
r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
if r0 == 0 {
err = e1
return
}
wintun = &Adapter{r0}
runtime.SetFinalizer(wintun, freeAdapter)
wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, closeAdapter)
return
}
// CreateAdapter creates a Wintun adapter. ifname is the requested name of the adapter, while
// requestedGUID is the GUID of the created network adapter, which then influences NLA generation
// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a
// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it
// uses is completely undocumented, and so there could be minor interesting complications with its
// usage. This function returns the network adapter ID and a flag if reboot is required.
func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wintun *Adapter, rebootRequired bool, err error) {
var ifname16 *uint16
ifname16, err = windows.UTF16PtrFromString(ifname)
// OpenAdapter opens an existing Wintun adapter by name.
func OpenAdapter(name string) (wintun *Adapter, err error) {
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {
return
}
var _p0 uint32
r0, _, e1 := syscall.Syscall6(procWintunCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0)
rebootRequired = _p0 != 0
r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0)
if r0 == 0 {
err = e1
return
}
wintun = &Adapter{r0}
runtime.SetFinalizer(wintun, freeAdapter)
wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, closeAdapter)
return
}
// Delete deletes a Wintun adapter. This function succeeds if the adapter was not found. It returns
// a bool indicating whether a reboot is required.
func (wintun *Adapter) Delete(forceCloseSessions bool) (rebootRequired bool, err error) {
var _p0 uint32
if forceCloseSessions {
_p0 = 1
}
var _p1 uint32
r1, _, e1 := syscall.Syscall(procWintunDeleteAdapter.Addr(), 3, uintptr(wintun.handle), uintptr(_p0), uintptr(unsafe.Pointer(&_p1)))
rebootRequired = _p1 != 0
// Close closes a Wintun adapter.
func (wintun *Adapter) Close() (err error) {
runtime.SetFinalizer(wintun, nil)
r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
if r1 == 0 {
err = e1
}
return
}
// DeleteMatchingAdapters deletes all Wintun adapters, which match
// given criteria, and returns which ones it deleted, whether a reboot
// is required after, and which errors occurred during the process.
func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool, forceCloseSessions bool) (rebootRequired bool, errors []error) {
cb := func(handle uintptr, _ uintptr) int {
adapter := &Adapter{handle}
if !matches(adapter) {
return 1
}
rebootRequired2, err := adapter.Delete(forceCloseSessions)
if err != nil {
errors = append(errors, err)
return 1
}
rebootRequired = rebootRequired || rebootRequired2
return 1
}
r1, _, e1 := syscall.Syscall(procWintunEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0)
if r1 == 0 {
errors = append(errors, e1)
}
return
}
// Name returns the name of the Wintun adapter.
func (wintun *Adapter) Name() (ifname string, err error) {
var ifname16 [AdapterNameMax]uint16
r1, _, e1 := syscall.Syscall(procWintunGetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
if r1 == 0 {
err = e1
return
}
ifname = windows.UTF16ToString(ifname16[:])
return
}
// DeleteDriver deletes all Wintun adapters in a pool and if there are no more adapters in any other
// pools, also removes Wintun from the driver store, usually called by uninstallers.
func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) {
var _p0 uint32
r1, _, e1 := syscall.Syscall(procWintunDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0)
rebootRequired = _p0 != 0
if r1 == 0 {
err = e1
}
return
}
// SetName sets name of the Wintun adapter.
func (wintun *Adapter) SetName(ifname string) (err error) {
ifname16, err := windows.UTF16FromString(ifname)
if err != nil {
return err
}
r1, _, e1 := syscall.Syscall(procWintunSetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
// Uninstall removes the driver from the system if no drivers are currently in use.
func Uninstall() (err error) {
r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0)
if r1 == 0 {
err = e1
}
return
}
// RunningVersion returns the version of the running Wintun driver.
// RunningVersion returns the version of the loaded driver.
func RunningVersion() (version uint32, err error) {
r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
version = uint32(r0)

View File

@ -41,7 +41,7 @@ func New(conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.Pack
} else if strings.EqualFold(stack, "gvisor") {
tunAdapter, err = gvisor.NewAdapter(device, conf, tunAddress, tcpIn, udpIn)
} else {
err = fmt.Errorf("can not support tun ip stack: %s, only support \"system\" and \"gvisor\"", stack)
err = fmt.Errorf("can not support tun ip stack: %s, only support \"lwip\" \"system\" and \"gvisor\"", stack)
}
if err != nil {