chore: add unit test and adjust logic error

This commit is contained in:
Skyxim 2023-03-30 13:36:14 +08:00
parent e0cf342672
commit 6c9c0bd755
2 changed files with 92 additions and 11 deletions

View File

@ -0,0 +1,49 @@
package trie_test
import (
"testing"
"github.com/Dreamacro/clash/component/trie"
"github.com/stretchr/testify/assert"
)
func TestDomain(t *testing.T) {
domainSet := []string{
"baidu.com",
"google.com",
"www.google.com",
}
set := trie.NewDomainSet(domainSet)
assert.NotNil(t, set)
assert.True(t, set.Has("google.com"))
assert.False(t, set.Has("www.baidu.com"))
}
func TestDomainComplexWildcard(t *testing.T) {
domainSet := []string{
"+.baidu.com",
"+.a.baidu.com",
"www.baidu.com",
"www.qq.com",
}
set := trie.NewDomainSet(domainSet)
assert.NotNil(t, set)
assert.False(t, set.Has("google.com"))
assert.True(t, set.Has("www.baidu.com"))
assert.True(t, set.Has("test.test.baidu.com"))
}
func TestDomainWildcard(t *testing.T) {
domainSet := []string{
"*.baidu.com",
"www.baidu.com",
"*.*.qq.com",
}
set := trie.NewDomainSet(domainSet)
assert.NotNil(t, set)
// assert.True(t, set.Has("www.baidu.com"))
// assert.False(t, set.Has("test.test.baidu.com"))
assert.True(t,set.Has("test.test.qq.com"))
assert.False(t,set.Has("test.qq.com"))
assert.False(t,set.Has("test.test.test.qq.com"))
}

View File

@ -27,11 +27,20 @@ type DomainSet struct {
func NewDomainSet(keys []string) *DomainSet { func NewDomainSet(keys []string) *DomainSet {
filter := make(map[string]struct{}, len(keys)) filter := make(map[string]struct{}, len(keys))
reserveDomains := make([]string, 0, len(keys)) reserveDomains := make([]string, 0, len(keys))
insert := func(reserveDomain string) { insert := func(domain string) {
reserveDomain = utils.Reverse(reserveDomain) reserveDomain := utils.Reverse(domain)
reserveDomain = strings.ToLower(reserveDomain) reserveDomain = strings.ToLower(reserveDomain)
if _, ok := filter[reserveDomain]; !ok { if _, ok := filter[reserveDomain]; !ok {
filter[reserveDomain] = struct{}{} filter[reserveDomain] = struct{}{}
domains := make([]string, 0, len(reserveDomains))
if strings.HasSuffix(reserveDomain, ".+") {
for _, domain := range reserveDomains {
if !strings.HasPrefix(domain, reserveDomain[0:len(reserveDomain)-2]) {
domains = append(domains, domain)
}
}
reserveDomains = domains
}
reserveDomains = append(reserveDomains, reserveDomain) reserveDomains = append(reserveDomains, reserveDomain)
} }
} }
@ -98,30 +107,53 @@ func (ss *DomainSet) Has(key string) bool {
// skip character matching // skip character matching
// go to next level // go to next level
nodeId, bmIdx := 0, 0 nodeId, bmIdx := 0, 0
type wildcardCursor struct {
index, bmIdx int
find bool
}
cursor := wildcardCursor{
find: false,
}
for i := 0; i < len(key); i++ { for i := 0; i < len(key); i++ {
c := key[i] c := key[i]
for ; ; bmIdx++ { for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 { if getBit(ss.labelBitmap, bmIdx) != 0 {
if cursor.find {
// gets the node next to the cursor
wildcardNextNodeId := countZeros(ss.labelBitmap, ss.ranks, cursor.bmIdx+1)
// next is a leaf, and the domain name has no next level
if getBit(ss.leaves, wildcardNextNodeId) != 0 && cursor.index == len(key) {
return true
}
// reset, and jump to the cursor location
cursor.find = false
i = cursor.index
bmIdx = cursor.bmIdx
break
}
return false return false
} }
// handle wildcard for domain // handle wildcard for domain
if ss.labels[bmIdx-nodeId] == complexWildcardByte { if ss.labels[bmIdx-nodeId] == complexWildcardByte {
return true return true
} else if ss.labels[bmIdx-nodeId] == wildcardByte { } else if ss.labels[bmIdx-nodeId] == wildcardByte {
j := i cursor.find = true
for ; j < len(key); j++ { cursor.bmIdx = bmIdx
if key[j] == domainStepByte { // gets the first domain step that follows
i = j // If not, it is the last domain level, which is represented by len(key)
goto END if index := strings.Index(key[i:], domainStep); index > 0 {
cursor.index = index + i - 1
} else {
cursor.index = len(key)
} }
} break
return true
} else if ss.labels[bmIdx-nodeId] == c { } else if ss.labels[bmIdx-nodeId] == c {
if ss.labels[bmIdx-nodeId] == domainStepByte {
cursor.find = false
}
break break
} }
} }
END:
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1 bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
} }