gVisor bot 97ea1e66ad Fix: handle parse socks5 udp address properly (#2220)
(cherry picked from commit bec4df7b122e6a4db0f831ed6176732b2a09fb63)
2022-10-01 23:45:06 +08:00

281 lines
5.5 KiB
Go

package snell
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/transport/shadowsocks/shadowaead"
"github.com/Dreamacro/clash/transport/socks5"
)
const (
Version1 = 1
Version2 = 2
Version3 = 3
DefaultSnellVersion = Version1
// max packet length
maxLength = 0x3FFF
)
const (
CommandPing byte = 0
CommandConnect byte = 1
CommandConnectV2 byte = 5
CommandUDP byte = 6
CommondUDPForward byte = 1
CommandTunnel byte = 0
CommandPong byte = 1
CommandError byte = 2
Version byte = 1
)
var endSignal = []byte{}
type Snell struct {
net.Conn
buffer [1]byte
reply bool
}
func (s *Snell) Read(b []byte) (int, error) {
if s.reply {
return s.Conn.Read(b)
}
s.reply = true
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return 0, err
}
if s.buffer[0] == CommandTunnel {
return s.Conn.Read(b)
} else if s.buffer[0] != CommandError {
return 0, errors.New("command not support")
}
// CommandError
// 1 byte error code
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return 0, err
}
errcode := int(s.buffer[0])
// 1 byte error message length
if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
return 0, err
}
length := int(s.buffer[0])
msg := make([]byte, length)
if _, err := io.ReadFull(s.Conn, msg); err != nil {
return 0, err
}
return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
}
func WriteHeader(conn net.Conn, host string, port uint, version int) error {
buf := pool.GetBuffer()
defer pool.PutBuffer(buf)
buf.WriteByte(Version)
if version == Version2 {
buf.WriteByte(CommandConnectV2)
} else {
buf.WriteByte(CommandConnect)
}
// clientID length & id
buf.WriteByte(0)
// host & port
buf.WriteByte(uint8(len(host)))
buf.WriteString(host)
binary.Write(buf, binary.BigEndian, uint16(port))
if _, err := conn.Write(buf.Bytes()); err != nil {
return err
}
return nil
}
func WriteUDPHeader(conn net.Conn, version int) error {
if version < Version3 {
return errors.New("unsupport UDP version")
}
// version, command, clientID length
_, err := conn.Write([]byte{Version, CommandUDP, 0x00})
return err
}
// HalfClose works only on version2
func HalfClose(conn net.Conn) error {
if _, err := conn.Write(endSignal); err != nil {
return err
}
if s, ok := conn.(*Snell); ok {
s.reply = false
}
return nil
}
func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
var cipher shadowaead.Cipher
if version != Version1 {
cipher = NewAES128GCM(psk)
} else {
cipher = NewChacha20Poly1305(psk)
}
return &Snell{Conn: shadowaead.NewConn(conn, cipher)}
}
func PacketConn(conn net.Conn) net.PacketConn {
return &packetConn{
Conn: conn,
}
}
func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
buf := pool.GetBuffer()
defer pool.PutBuffer(buf)
// compose snell UDP address format (refer: icpz/snell-server-reversed)
// a brand new wheel to replace socks5 address format, well done Yachen
buf.WriteByte(CommondUDPForward)
switch socks5Addr[0] {
case socks5.AtypDomainName:
hostLen := socks5Addr[1]
buf.Write(socks5Addr[1 : 1+1+hostLen+2])
case socks5.AtypIPv4:
buf.Write([]byte{0x00, 0x04})
buf.Write(socks5Addr[1 : 1+net.IPv4len+2])
case socks5.AtypIPv6:
buf.Write([]byte{0x00, 0x06})
buf.Write(socks5Addr[1 : 1+net.IPv6len+2])
}
buf.Write(payload)
_, err := w.Write(buf.Bytes())
if err != nil {
return 0, err
}
return len(payload), nil
}
func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
if len(payload) <= maxLength {
return writePacket(w, socks5Addr, payload)
}
offset := 0
total := len(payload)
for {
cursor := offset + maxLength
if cursor > total {
cursor = total
}
n, err := writePacket(w, socks5Addr, payload[offset:cursor])
if err != nil {
return offset + n, err
}
offset = cursor
if offset == total {
break
}
}
return total, nil
}
func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) {
buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf)
n, err := r.Read(buf)
headLen := 1
if err != nil {
return nil, 0, err
}
if n < headLen {
return nil, 0, errors.New("insufficient UDP length")
}
// parse snell UDP response address format
switch buf[0] {
case 0x04:
headLen += net.IPv4len + 2
if n < headLen {
err = errors.New("insufficient UDP length")
break
}
buf[0] = socks5.AtypIPv4
case 0x06:
headLen += net.IPv6len + 2
if n < headLen {
err = errors.New("insufficient UDP length")
break
}
buf[0] = socks5.AtypIPv6
default:
err = errors.New("ip version invalid")
}
if err != nil {
return nil, 0, err
}
addr := socks5.SplitAddr(buf[0:])
if addr == nil {
return nil, 0, errors.New("remote address invalid")
}
uAddr := addr.UDPAddr()
if uAddr == nil {
return nil, 0, errors.New("parse addr error")
}
length := len(payload)
if n-headLen < length {
length = n - headLen
}
copy(payload[:], buf[headLen:headLen+length])
return uAddr, length, nil
}
type packetConn struct {
net.Conn
rMux sync.Mutex
wMux sync.Mutex
}
func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
pc.wMux.Lock()
defer pc.wMux.Unlock()
return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
}
func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
pc.rMux.Lock()
defer pc.rMux.Unlock()
addr, n, err := ReadPacket(pc.Conn, b)
if err != nil {
return 0, nil, err
}
return n, addr, nil
}