chore: read waiter for pipe

This commit is contained in:
wwqgtxx 2024-01-02 18:26:45 +08:00
parent 0404e35be8
commit 33bc7914e9
9 changed files with 229 additions and 8 deletions

View File

@ -0,0 +1,217 @@
package deadline
import (
"io"
"net"
"os"
"sync"
"time"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
type pipeAddr struct{}
func (pipeAddr) Network() string { return "pipe" }
func (pipeAddr) String() string { return "pipe" }
type pipe struct {
wrMu sync.Mutex // Serialize Write operations
// Used by local Read to interact with remote Write.
// Successful receive on rdRx is always followed by send on rdTx.
rdRx <-chan []byte
rdTx chan<- int
// Used by local Write to interact with remote Read.
// Successful send on wrTx is always followed by receive on wrRx.
wrTx chan<- []byte
wrRx <-chan int
once sync.Once // Protects closing localDone
localDone chan struct{}
remoteDone <-chan struct{}
readDeadline pipeDeadline
writeDeadline pipeDeadline
readWaitOptions N.ReadWaitOptions
}
// Pipe creates a synchronous, in-memory, full duplex
// network connection; both ends implement the Conn interface.
// Reads on one end are matched with writes on the other,
// copying data directly between the two; there is no internal
// buffering.
func Pipe() (net.Conn, net.Conn) {
cb1 := make(chan []byte)
cb2 := make(chan []byte)
cn1 := make(chan int)
cn2 := make(chan int)
done1 := make(chan struct{})
done2 := make(chan struct{})
p1 := &pipe{
rdRx: cb1, rdTx: cn1,
wrTx: cb2, wrRx: cn2,
localDone: done1, remoteDone: done2,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
}
p2 := &pipe{
rdRx: cb2, rdTx: cn2,
wrTx: cb1, wrRx: cn1,
localDone: done2, remoteDone: done1,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
}
return p1, p2
}
func (*pipe) LocalAddr() net.Addr { return pipeAddr{} }
func (*pipe) RemoteAddr() net.Addr { return pipeAddr{} }
func (p *pipe) Read(b []byte) (int, error) {
n, err := p.read(b)
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
err = &net.OpError{Op: "read", Net: "pipe", Err: err}
}
return n, err
}
func (p *pipe) read(b []byte) (n int, err error) {
switch {
case isClosedChan(p.localDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return 0, io.EOF
case isClosedChan(p.readDeadline.wait()):
return 0, os.ErrDeadlineExceeded
}
select {
case bw := <-p.rdRx:
nr := copy(b, bw)
p.rdTx <- nr
return nr, nil
case <-p.localDone:
return 0, io.ErrClosedPipe
case <-p.remoteDone:
return 0, io.EOF
case <-p.readDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
func (p *pipe) Write(b []byte) (int, error) {
n, err := p.write(b)
if err != nil && err != io.ErrClosedPipe {
err = &net.OpError{Op: "write", Net: "pipe", Err: err}
}
return n, err
}
func (p *pipe) write(b []byte) (n int, err error) {
switch {
case isClosedChan(p.localDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.writeDeadline.wait()):
return 0, os.ErrDeadlineExceeded
}
p.wrMu.Lock() // Ensure entirety of b is written together
defer p.wrMu.Unlock()
for once := true; once || len(b) > 0; once = false {
select {
case p.wrTx <- b:
nw := <-p.wrRx
b = b[nw:]
n += nw
case <-p.localDone:
return n, io.ErrClosedPipe
case <-p.remoteDone:
return n, io.ErrClosedPipe
case <-p.writeDeadline.wait():
return n, os.ErrDeadlineExceeded
}
}
return n, nil
}
func (p *pipe) SetDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.readDeadline.set(t)
p.writeDeadline.set(t)
return nil
}
func (p *pipe) SetReadDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.readDeadline.set(t)
return nil
}
func (p *pipe) SetWriteDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.writeDeadline.set(t)
return nil
}
func (p *pipe) Close() error {
p.once.Do(func() { close(p.localDone) })
return nil
}
var _ N.ReadWaiter = (*pipe)(nil)
func (p *pipe) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
p.readWaitOptions = options
return false
}
func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) {
buffer, err = p.waitReadBuffer()
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
err = &net.OpError{Op: "read", Net: "pipe", Err: err}
}
return
}
func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) {
switch {
case isClosedChan(p.localDone):
return nil, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return nil, io.EOF
case isClosedChan(p.readDeadline.wait()):
return nil, os.ErrDeadlineExceeded
}
select {
case bw := <-p.rdRx:
buffer = p.readWaitOptions.NewBuffer()
var nr int
nr, err = buffer.Write(bw)
if err != nil {
buffer.Release()
return
}
p.readWaitOptions.PostReturn(buffer)
p.rdTx <- nr
return
case <-p.localDone:
return nil, io.ErrClosedPipe
case <-p.remoteDone:
return nil, io.EOF
case <-p.readDeadline.wait():
return nil, os.ErrDeadlineExceeded
}
}

View File

@ -35,6 +35,8 @@ func NeedHandshake(conn any) bool {
type CountFunc = network.CountFunc type CountFunc = network.CountFunc
var Pipe = deadline.Pipe
// Relay copies between left and right bidirectionally. // Relay copies between left and right bidirectionally.
func Relay(leftConn, rightConn net.Conn) { func Relay(leftConn, rightConn net.Conn) {
defer runtime.KeepAlive(leftConn) defer runtime.KeepAlive(leftConn)

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/metacubex/mihomo/adapter/inbound" "github.com/metacubex/mihomo/adapter/inbound"
N "github.com/metacubex/mihomo/common/net"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
) )
@ -30,7 +31,7 @@ func newClient(srcConn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition)
return nil, socks5.ErrAddressNotSupported return nil, socks5.ErrAddressNotSupported
} }
left, right := net.Pipe() left, right := N.Pipe()
go tunnel.HandleTCPConn(inbound.NewHTTP(dstAddr, srcConn, right, additions...)) go tunnel.HandleTCPConn(inbound.NewHTTP(dstAddr, srcConn, right, additions...))

View File

@ -41,7 +41,7 @@ func handleUpgrade(conn net.Conn, request *http.Request, tunnel C.Tunnel, additi
return return
} }
left, right := net.Pipe() left, right := N.Pipe()
go tunnel.HandleTCPConn(inbound.NewHTTP(dstAddr, conn, right, additions...)) go tunnel.HandleTCPConn(inbound.NewHTTP(dstAddr, conn, right, additions...))

View File

@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"strconv" "strconv"
N "github.com/metacubex/mihomo/common/net"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
) )
@ -20,7 +21,7 @@ func HandleTcp(address string) (conn net.Conn, err error) {
return nil, errors.New("tcp uninitialized") return nil, errors.New("tcp uninitialized")
} }
// executor Parsed // executor Parsed
conn1, conn2 := net.Pipe() conn1, conn2 := N.Pipe()
metadata := &C.Metadata{} metadata := &C.Metadata{}
metadata.NetWork = C.TCP metadata.NetWork = C.TCP

View File

@ -364,7 +364,7 @@ func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Met
return nil, common.TooManyOpenStreams return nil, common.TooManyOpenStreams
} }
pipe1, pipe2 := net.Pipe() pipe1, pipe2 := N.Pipe()
var connId uint32 var connId uint32
for { for {
connId = fastrand.Uint32() connId = fastrand.Uint32()

View File

@ -348,7 +348,7 @@ func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Met
return nil, common.TooManyOpenStreams return nil, common.TooManyOpenStreams
} }
pipe1, pipe2 := net.Pipe() pipe1, pipe2 := N.Pipe()
var connId uint16 var connId uint16
for { for {
connId = uint16(fastrand.Intn(0xFFFF)) connId = uint16(fastrand.Intn(0xFFFF))

View File

@ -82,6 +82,6 @@ func closeAllLocalCoon(lAddr string) {
}) })
} }
func handleSocket(ctx C.ConnContext, outbound net.Conn) { func handleSocket(inbound, outbound net.Conn) {
N.Relay(ctx.Conn(), outbound) N.Relay(inbound, outbound)
} }

View File

@ -584,7 +584,7 @@ func handleTCPConn(connCtx C.ConnContext) {
peekMutex.Lock() peekMutex.Lock()
defer peekMutex.Unlock() defer peekMutex.Unlock()
_ = conn.SetReadDeadline(time.Time{}) // reset _ = conn.SetReadDeadline(time.Time{}) // reset
handleSocket(connCtx, remoteConn) handleSocket(conn, remoteConn)
} }
func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool { func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool {