diff --git a/component/trie/set_test.go b/component/trie/set_test.go new file mode 100644 index 00000000..b3bb46e0 --- /dev/null +++ b/component/trie/set_test.go @@ -0,0 +1,49 @@ +package trie_test + +import ( + "testing" + + "github.com/Dreamacro/clash/component/trie" + "github.com/stretchr/testify/assert" +) + +func TestDomain(t *testing.T) { + domainSet := []string{ + "baidu.com", + "google.com", + "www.google.com", + } + set := trie.NewDomainSet(domainSet) + assert.NotNil(t, set) + assert.True(t, set.Has("google.com")) + assert.False(t, set.Has("www.baidu.com")) +} + +func TestDomainComplexWildcard(t *testing.T) { + domainSet := []string{ + "+.baidu.com", + "+.a.baidu.com", + "www.baidu.com", + "www.qq.com", + } + set := trie.NewDomainSet(domainSet) + assert.NotNil(t, set) + assert.False(t, set.Has("google.com")) + assert.True(t, set.Has("www.baidu.com")) + assert.True(t, set.Has("test.test.baidu.com")) +} + +func TestDomainWildcard(t *testing.T) { + domainSet := []string{ + "*.baidu.com", + "www.baidu.com", + "*.*.qq.com", + } + set := trie.NewDomainSet(domainSet) + assert.NotNil(t, set) + // assert.True(t, set.Has("www.baidu.com")) + // assert.False(t, set.Has("test.test.baidu.com")) + assert.True(t,set.Has("test.test.qq.com")) + assert.False(t,set.Has("test.qq.com")) + assert.False(t,set.Has("test.test.test.qq.com")) +} diff --git a/component/trie/sskv.go b/component/trie/sskv.go index 231ff09a..410015a1 100644 --- a/component/trie/sskv.go +++ b/component/trie/sskv.go @@ -27,11 +27,20 @@ type DomainSet struct { func NewDomainSet(keys []string) *DomainSet { filter := make(map[string]struct{}, len(keys)) reserveDomains := make([]string, 0, len(keys)) - insert := func(reserveDomain string) { - reserveDomain = utils.Reverse(reserveDomain) + insert := func(domain string) { + reserveDomain := utils.Reverse(domain) reserveDomain = strings.ToLower(reserveDomain) if _, ok := filter[reserveDomain]; !ok { filter[reserveDomain] = struct{}{} + domains := make([]string, 0, len(reserveDomains)) + if strings.HasSuffix(reserveDomain, ".+") { + for _, domain := range reserveDomains { + if !strings.HasPrefix(domain, reserveDomain[0:len(reserveDomain)-2]) { + domains = append(domains, domain) + } + } + reserveDomains = domains + } reserveDomains = append(reserveDomains, reserveDomain) } } @@ -98,30 +107,53 @@ func (ss *DomainSet) Has(key string) bool { // skip character matching // go to next level nodeId, bmIdx := 0, 0 - + type wildcardCursor struct { + index, bmIdx int + find bool + } + cursor := wildcardCursor{ + find: false, + } for i := 0; i < len(key); i++ { c := key[i] for ; ; bmIdx++ { if getBit(ss.labelBitmap, bmIdx) != 0 { + if cursor.find { + // gets the node next to the cursor + wildcardNextNodeId := countZeros(ss.labelBitmap, ss.ranks, cursor.bmIdx+1) + // next is a leaf, and the domain name has no next level + if getBit(ss.leaves, wildcardNextNodeId) != 0 && cursor.index == len(key) { + return true + } + // reset, and jump to the cursor location + cursor.find = false + i = cursor.index + bmIdx = cursor.bmIdx + break + } return false } // handle wildcard for domain if ss.labels[bmIdx-nodeId] == complexWildcardByte { return true } else if ss.labels[bmIdx-nodeId] == wildcardByte { - j := i - for ; j < len(key); j++ { - if key[j] == domainStepByte { - i = j - goto END - } + cursor.find = true + cursor.bmIdx = bmIdx + // gets the first domain step that follows + // If not, it is the last domain level, which is represented by len(key) + if index := strings.Index(key[i:], domainStep); index > 0 { + cursor.index = index + i - 1 + } else { + cursor.index = len(key) } - return true + break } else if ss.labels[bmIdx-nodeId] == c { + if ss.labels[bmIdx-nodeId] == domainStepByte { + cursor.find = false + } break } } - END: nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1 }