Optimization: refactor picker

This commit is contained in:
Dreamacro 2019-07-02 19:18:03 +08:00
parent 0eff8516c0
commit 7c6c147a18
9 changed files with 123 additions and 104 deletions

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net" "net"
@ -99,7 +100,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
} }
// URLTest get the delay for the specified URL // URLTest get the delay for the specified URL
func (p *Proxy) URLTest(url string) (t uint16, err error) { func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
defer func() { defer func() {
p.alive = err == nil p.alive = err == nil
record := C.DelayHistory{Time: time.Now()} record := C.DelayHistory{Time: time.Now()}
@ -123,6 +124,13 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) {
return return
} }
defer instance.Close() defer instance.Close()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return
}
req = req.WithContext(ctx)
transport := &http.Transport{ transport := &http.Transport{
Dial: func(string, string) (net.Conn, error) { Dial: func(string, string) (net.Conn, error) {
return instance, nil return instance, nil
@ -133,8 +141,9 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) {
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
client := http.Client{Transport: transport} client := http.Client{Transport: transport}
resp, err := client.Get(url) resp, err := client.Do(req)
if err != nil { if err != nil {
return return
} }

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net" "net"
@ -90,7 +91,7 @@ func (f *Fallback) validTest() {
for _, p := range f.proxies { for _, p := range f.proxies {
go func(p C.Proxy) { go func(p C.Proxy) {
p.URLTest(f.rawURL) p.URLTest(context.Background(), f.rawURL)
wg.Done() wg.Done()
}(p) }(p)
} }

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net" "net"
@ -95,7 +96,7 @@ func (lb *LoadBalance) validTest() {
for _, p := range lb.proxies { for _, p := range lb.proxies {
go func(p C.Proxy) { go func(p C.Proxy) {
p.URLTest(lb.rawURL) p.URLTest(context.Background(), lb.rawURL)
wg.Done() wg.Done()
}(p) }(p)
} }

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"net" "net"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -103,35 +102,22 @@ func (u *URLTest) speedTest() {
} }
defer atomic.StoreInt32(&u.once, 0) defer atomic.StoreInt32(&u.once, 0)
wg := sync.WaitGroup{} ctx, cancel := context.WithTimeout(context.Background(), u.interval)
wg.Add(len(u.proxies)) defer cancel()
c := make(chan interface{}) picker, ctx := picker.WithContext(ctx)
fast := picker.SelectFast(context.Background(), c)
timer := time.NewTimer(u.interval)
for _, p := range u.proxies { for _, p := range u.proxies {
go func(p C.Proxy) { picker.Go(func() (interface{}, error) {
_, err := p.URLTest(u.rawURL) _, err := p.URLTest(ctx, u.rawURL)
if err == nil { if err != nil {
c <- p return nil, err
} }
wg.Done() return p, nil
}(p) })
} }
go func() { fast := picker.Wait()
wg.Wait() if fast != nil {
close(c) u.fast = fast.(C.Proxy)
}()
select {
case <-timer.C:
// Wait for fast to return or close.
<-fast
case p, open := <-fast:
if open {
u.fast = p.(C.Proxy)
}
} }
} }

View File

@ -1,22 +1,53 @@
package picker package picker
import "context" import (
"context"
"sync"
)
// Picker provides synchronization, and Context cancelation
// for groups of goroutines working on subtasks of a common task.
// Inspired by errGroup
type Picker struct {
cancel func()
wg sync.WaitGroup
once sync.Once
result interface{}
}
// WithContext returns a new Picker and an associated Context derived from ctx.
func WithContext(ctx context.Context) (*Picker, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Picker{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned,
// then returns the first nil error result (if any) from them.
func (p *Picker) Wait() interface{} {
p.wg.Wait()
if p.cancel != nil {
p.cancel()
}
return p.result
}
// Go calls the given function in a new goroutine.
// The first call to return a nil error cancels the group; its result will be returned by Wait.
func (p *Picker) Go(f func() (interface{}, error)) {
p.wg.Add(1)
func SelectFast(ctx context.Context, in <-chan interface{}) <-chan interface{} {
out := make(chan interface{})
go func() { go func() {
select { defer p.wg.Done()
case p, open := <-in:
if open {
out <- p
}
case <-ctx.Done():
}
close(out) if ret, err := f(); err == nil {
for range in { p.once.Do(func() {
p.result = ret
if p.cancel != nil {
p.cancel()
}
})
} }
}() }()
return out
} }

View File

@ -6,39 +6,37 @@ import (
"time" "time"
) )
func sleepAndSend(delay int, in chan<- interface{}, input interface{}) { func sleepAndSend(ctx context.Context, delay int, input interface{}) func() (interface{}, error) {
time.Sleep(time.Millisecond * time.Duration(delay)) return func() (interface{}, error) {
in <- input timer := time.NewTimer(time.Millisecond * time.Duration(delay))
} select {
case <-timer.C:
func sleepAndClose(delay int, in chan interface{}) { return input, nil
time.Sleep(time.Millisecond * time.Duration(delay)) case <-ctx.Done():
close(in) return nil, ctx.Err()
}
}
} }
func TestPicker_Basic(t *testing.T) { func TestPicker_Basic(t *testing.T) {
in := make(chan interface{}) picker, ctx := WithContext(context.Background())
fast := SelectFast(context.Background(), in) picker.Go(sleepAndSend(ctx, 30, 2))
go sleepAndSend(20, in, 1) picker.Go(sleepAndSend(ctx, 20, 1))
go sleepAndSend(30, in, 2)
go sleepAndClose(40, in)
number, exist := <-fast number := picker.Wait()
if !exist || number != 1 { if number != nil && number.(int) != 1 {
t.Error("should recv 1", exist, number) t.Error("should recv 1", number)
} }
} }
func TestPicker_Timeout(t *testing.T) { func TestPicker_Timeout(t *testing.T) {
in := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5)
defer cancel() defer cancel()
fast := SelectFast(ctx, in) picker, ctx := WithContext(ctx)
go sleepAndSend(20, in, 1) picker.Go(sleepAndSend(ctx, 20, 1))
go sleepAndClose(30, in)
_, exist := <-fast number := picker.Wait()
if exist { if number != nil {
t.Error("should recv false") t.Error("should recv nil")
} }
} }

View File

@ -1,6 +1,7 @@
package constant package constant
import ( import (
"context"
"net" "net"
"time" "time"
) )
@ -44,7 +45,7 @@ type Proxy interface {
Alive() bool Alive() bool
DelayHistory() []DelayHistory DelayHistory() []DelayHistory
LastDelay() uint16 LastDelay() uint16
URLTest(url string) (uint16, error) URLTest(ctx context.Context, url string) (uint16, error)
} }
// AdapterType is enum of adapter type // AdapterType is enum of adapter type

View File

@ -163,32 +163,22 @@ func (r *Resolver) IsFakeIP() bool {
} }
func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) { func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) {
in := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
fast := picker.SelectFast(ctx, in) fast, ctx := picker.WithContext(ctx)
wg := sync.WaitGroup{}
wg.Add(len(clients))
for _, r := range clients { for _, r := range clients {
go func(r resolver) { fast.Go(func() (interface{}, error) {
defer wg.Done()
msg, err := r.ExchangeContext(ctx, m) msg, err := r.ExchangeContext(ctx, m)
if err != nil || msg.Rcode != D.RcodeSuccess { if err != nil || msg.Rcode != D.RcodeSuccess {
return return nil, errors.New("resolve error")
} }
in <- msg return msg, nil
}(r) })
} }
// release in channel elm := fast.Wait()
go func() { if elm == nil {
wg.Wait()
close(in)
}()
elm, exist := <-fast
if !exist {
return nil, errors.New("All DNS requests failed") return nil, errors.New("All DNS requests failed")
} }

View File

@ -9,6 +9,7 @@ import (
"time" "time"
A "github.com/Dreamacro/clash/adapters/outbound" A "github.com/Dreamacro/clash/adapters/outbound"
"github.com/Dreamacro/clash/common/picker"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
T "github.com/Dreamacro/clash/tunnel" T "github.com/Dreamacro/clash/tunnel"
@ -110,27 +111,28 @@ func getProxyDelay(w http.ResponseWriter, r *http.Request) {
proxy := r.Context().Value(CtxKeyProxy).(C.Proxy) proxy := r.Context().Value(CtxKeyProxy).(C.Proxy)
sigCh := make(chan uint16) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout))
go func() { defer cancel()
t, err := proxy.URLTest(url) picker, ctx := picker.WithContext(ctx)
if err != nil { picker.Go(func() (interface{}, error) {
sigCh <- 0 return proxy.URLTest(ctx, url)
} })
sigCh <- t
}()
select { elm := picker.Wait()
case <-time.After(time.Millisecond * time.Duration(timeout)): if elm == nil {
render.Status(r, http.StatusRequestTimeout) render.Status(r, http.StatusRequestTimeout)
render.JSON(w, r, ErrRequestTimeout) render.JSON(w, r, ErrRequestTimeout)
case t := <-sigCh: return
if t == 0 {
render.Status(r, http.StatusServiceUnavailable)
render.JSON(w, r, newError("An error occurred in the delay test"))
} else {
render.JSON(w, r, render.M{
"delay": t,
})
}
} }
delay := elm.(uint16)
if delay == 0 {
render.Status(r, http.StatusServiceUnavailable)
render.JSON(w, r, newError("An error occurred in the delay test"))
return
}
render.JSON(w, r, render.M{
"delay": delay,
})
} }