package vmess import ( "bytes" "context" "crypto/sha1" "crypto/tls" "encoding/base64" "encoding/binary" "errors" "fmt" "io" "net" "net/http" "net/url" "strconv" "strings" "time" "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" tlsC "github.com/Dreamacro/clash/component/tls" "github.com/Dreamacro/clash/log" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/zhangyunhao116/fastrand" ) type websocketConn struct { net.Conn state ws.State reader *wsutil.Reader controlHandler wsutil.FrameHandlerFunc rawWriter N.ExtendedWriter } type websocketWithEarlyDataConn struct { net.Conn wsWriter N.ExtendedWriter underlay net.Conn closed bool dialed chan bool cancel context.CancelFunc ctx context.Context config *WebsocketConfig } type WebsocketConfig struct { Host string Port string Path string Headers http.Header TLS bool TLSConfig *tls.Config MaxEarlyData int EarlyDataHeaderName string ClientFingerprint string V2rayHttpUpgrade bool } // Read implements net.Conn.Read() // modify from gobwas/ws/wsutil.readData func (wsc *websocketConn) Read(b []byte) (n int, err error) { var header ws.Header for { n, err = wsc.reader.Read(b) // in gobwas/ws: "The error is io.EOF only if all of message bytes were read." // but maybe next frame still have data, so drop it if errors.Is(err, io.EOF) { err = nil } if !errors.Is(err, wsutil.ErrNoFrameAdvance) { return } header, err = wsc.reader.NextFrame() if err != nil { return } if header.OpCode.IsControl() { err = wsc.controlHandler(header, wsc.reader) if err != nil { return } continue } if header.OpCode&(ws.OpBinary|ws.OpText) == 0 { err = wsc.reader.Discard() if err != nil { return } continue } } } // Write implements io.Writer. func (wsc *websocketConn) Write(b []byte) (n int, err error) { err = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpBinary, b) if err != nil { return } n = len(b) return } func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error { var payloadBitLength int dataLen := buffer.Len() data := buffer.Bytes() if dataLen < 126 { payloadBitLength = 1 } else if dataLen < 65536 { payloadBitLength = 3 } else { payloadBitLength = 9 } var headerLen int headerLen += 1 // FIN / RSV / OPCODE headerLen += payloadBitLength if wsc.state.ClientSide() { headerLen += 4 // MASK KEY } header := buffer.ExtendHeader(headerLen) header[0] = byte(ws.OpBinary) | 0x80 if wsc.state.ClientSide() { header[1] = 1 << 7 } else { header[1] = 0 } if dataLen < 126 { header[1] |= byte(dataLen) } else if dataLen < 65536 { header[1] |= 126 binary.BigEndian.PutUint16(header[2:], uint16(dataLen)) } else { header[1] |= 127 binary.BigEndian.PutUint64(header[2:], uint64(dataLen)) } if wsc.state.ClientSide() { maskKey := fastrand.Uint32() binary.LittleEndian.PutUint32(header[1+payloadBitLength:], maskKey) N.MaskWebSocket(maskKey, data) } return wsc.rawWriter.WriteBuffer(buffer) } func (wsc *websocketConn) FrontHeadroom() int { return 14 } func (wsc *websocketConn) Upstream() any { return wsc.Conn } func (wsc *websocketConn) Close() error { _ = wsc.Conn.SetWriteDeadline(time.Now().Add(time.Second * 5)) _ = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, "")) _ = wsc.Conn.Close() return nil } func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { base64DataBuf := &bytes.Buffer{} base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf) earlyDataBuf := bytes.NewBuffer(earlyData) if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil { return fmt.Errorf("failed to encode early data: %w", err) } if errc := base64EarlyDataEncoder.Close(); errc != nil { return fmt.Errorf("failed to encode early data tail: %w", errc) } var err error if wsedc.Conn, err = streamWebsocketConn(wsedc.ctx, wsedc.underlay, wsedc.config, base64DataBuf); err != nil { wsedc.Close() return fmt.Errorf("failed to dial WebSocket: %w", err) } wsedc.dialed <- true wsedc.wsWriter = N.NewExtendedWriter(wsedc.Conn) if earlyDataBuf.Len() != 0 { _, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) } return err } func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) { if wsedc.closed { return 0, io.ErrClosedPipe } if wsedc.Conn == nil { if err := wsedc.Dial(b); err != nil { return 0, err } return len(b), nil } return wsedc.Conn.Write(b) } func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error { if wsedc.closed { return io.ErrClosedPipe } if wsedc.Conn == nil { if err := wsedc.Dial(buffer.Bytes()); err != nil { return err } return nil } return wsedc.wsWriter.WriteBuffer(buffer) } func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { if wsedc.closed { return 0, io.ErrClosedPipe } if wsedc.Conn == nil { select { case <-wsedc.ctx.Done(): return 0, io.ErrUnexpectedEOF case <-wsedc.dialed: } } return wsedc.Conn.Read(b) } func (wsedc *websocketWithEarlyDataConn) Close() error { wsedc.closed = true wsedc.cancel() if wsedc.Conn == nil { return nil } return wsedc.Conn.Close() } func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr { if wsedc.Conn == nil { return wsedc.underlay.LocalAddr() } return wsedc.Conn.LocalAddr() } func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr { if wsedc.Conn == nil { return wsedc.underlay.RemoteAddr() } return wsedc.Conn.RemoteAddr() } func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error { if err := wsedc.SetReadDeadline(t); err != nil { return err } return wsedc.SetWriteDeadline(t) } func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error { if wsedc.Conn == nil { return nil } return wsedc.Conn.SetReadDeadline(t) } func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error { if wsedc.Conn == nil { return nil } return wsedc.Conn.SetWriteDeadline(t) } func (wsedc *websocketWithEarlyDataConn) FrontHeadroom() int { return 14 } func (wsedc *websocketWithEarlyDataConn) Upstream() any { return wsedc.underlay } //func (wsedc *websocketWithEarlyDataConn) LazyHeadroom() bool { // return wsedc.Conn == nil //} // //func (wsedc *websocketWithEarlyDataConn) Upstream() any { // if wsedc.Conn == nil { // ensure return a nil interface not an interface with nil value // return nil // } // return wsedc.Conn //} func (wsedc *websocketWithEarlyDataConn) NeedHandshake() bool { return wsedc.Conn == nil } func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { ctx, cancel := context.WithCancel(context.Background()) conn = &websocketWithEarlyDataConn{ dialed: make(chan bool, 1), cancel: cancel, ctx: ctx, underlay: conn, config: c, } // websocketWithEarlyDataConn can't correct handle Deadline // it will not apply the already set Deadline after Dial() // so call N.NewDeadlineConn to add a safe wrapper return N.NewDeadlineConn(conn), nil } func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { u, err := url.Parse(c.Path) if err != nil { return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) } uri := url.URL{ Scheme: "ws", Host: net.JoinHostPort(c.Host, c.Port), Path: u.Path, RawQuery: u.RawQuery, } if c.TLS { uri.Scheme = "wss" config := c.TLSConfig if config == nil { // The config cannot be nil config = &tls.Config{NextProtos: []string{"http/1.1"}} } if config.ServerName == "" && !config.InsecureSkipVerify { // users must set either ServerName or InsecureSkipVerify in the config. config = config.Clone() config.ServerName = uri.Host } if len(c.ClientFingerprint) != 0 { if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists { utlsConn := tlsC.UClient(conn, config, fingerprint) if err = utlsConn.BuildWebsocketHandshakeState(); err != nil { return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) } conn = utlsConn } } else { conn = tls.Client(conn, config) } if tlsConn, ok := conn.(interface { HandshakeContext(ctx context.Context) error }); ok { if err = tlsConn.HandshakeContext(ctx); err != nil { return nil, err } } } request := &http.Request{ Method: http.MethodGet, URL: &uri, Header: c.Headers.Clone(), Host: c.Host, } request.Header.Set("Connection", "Upgrade") request.Header.Set("Upgrade", "websocket") if host := request.Header.Get("Host"); host != "" { // For client requests, Host optionally overrides the Host // header to send. If empty, the Request.Write method uses // the value of URL.Host. Host may contain an international // domain name. request.Host = host } request.Header.Del("Host") var nonce string if !c.V2rayHttpUpgrade { const nonceKeySize = 16 // NOTE: bts does not escape. bts := make([]byte, nonceKeySize) if _, err = fastrand.Read(bts); err != nil { return nil, fmt.Errorf("rand read error: %w", err) } nonce = base64.StdEncoding.EncodeToString(bts) request.Header.Set("Sec-WebSocket-Version", "13") request.Header.Set("Sec-WebSocket-Key", nonce) } if earlyData != nil { earlyDataString := earlyData.String() if c.EarlyDataHeaderName == "" { uri.Path += earlyDataString } else { request.Header.Set(c.EarlyDataHeaderName, earlyDataString) } } if ctx.Done() != nil { done := N.SetupContextForConn(ctx, conn) defer done(&err) } err = request.Write(conn) if err != nil { return nil, err } bufferedConn := N.NewBufferedConn(conn) response, err := http.ReadResponse(bufferedConn.Reader(), request) if err != nil { return nil, err } if response.StatusCode != http.StatusSwitchingProtocols || !strings.EqualFold(response.Header.Get("Connection"), "upgrade") || !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { return nil, fmt.Errorf("unexpected status: %s", response.Status) } if c.V2rayHttpUpgrade { return bufferedConn, nil } if log.Level() == log.DEBUG { // we might not check this for performance secAccept := response.Header.Get("Sec-Websocket-Accept") const acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size) if lenSecAccept := len(secAccept); lenSecAccept != acceptSize { return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept) } const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) p := make([]byte, nonceSize+len(magic)) copy(p[:nonceSize], nonce) copy(p[nonceSize:], magic) sum := sha1.Sum(p) if accept := base64.StdEncoding.EncodeToString(sum[:]); accept != secAccept { return nil, errors.New("unexpected Sec-Websocket-Accept") } } conn = newWebsocketConn(conn, ws.StateClientSide) // websocketConn can't correct handle ReadDeadline // so call N.NewDeadlineConn to add a safe wrapper return N.NewDeadlineConn(conn), nil } func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig) (net.Conn, error) { if u, err := url.Parse(c.Path); err == nil { if q := u.Query(); q.Get("ed") != "" { if ed, err := strconv.Atoi(q.Get("ed")); err == nil { c.MaxEarlyData = ed c.EarlyDataHeaderName = "Sec-WebSocket-Protocol" q.Del("ed") u.RawQuery = q.Encode() c.Path = u.String() } } } if c.MaxEarlyData > 0 { return streamWebsocketWithEarlyDataConn(conn, c) } return streamWebsocketConn(ctx, conn, c, nil) } func newWebsocketConn(conn net.Conn, state ws.State) *websocketConn { controlHandler := wsutil.ControlFrameHandler(conn, state) return &websocketConn{ Conn: conn, state: state, reader: &wsutil.Reader{ Source: conn, State: state, SkipHeaderCheck: true, CheckUTF8: false, OnIntermediate: controlHandler, }, controlHandler: controlHandler, rawWriter: N.NewExtendedWriter(conn), } } var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") func decodeEd(s string) ([]byte, error) { return base64.RawURLEncoding.DecodeString(replacer.Replace(s)) } func decodeXray0rtt(requestHeader http.Header) []byte { // read inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws if secProtocol := requestHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 { if edBuf, err := decodeEd(secProtocol); err == nil { // sure could base64 decode return edBuf } } return nil } func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Conn, error) { wsConn, rw, _, err := ws.UpgradeHTTP(r, w) if err != nil { return nil, err } // gobwas/ws will flush rw.Writer, so we only need warp rw.Reader wsConn = N.WarpConnWithBioReader(wsConn, rw.Reader) conn := newWebsocketConn(wsConn, ws.StateServerSide) if edBuf := decodeXray0rtt(r.Header); len(edBuf) > 0 { return N.NewDeadlineConn(&websocketWithReaderConn{conn, io.MultiReader(bytes.NewReader(edBuf), conn)}), nil } return N.NewDeadlineConn(conn), nil } type websocketWithReaderConn struct { *websocketConn reader io.Reader } func (ws *websocketWithReaderConn) Read(b []byte) (n int, err error) { return ws.reader.Read(b) }