diff --git a/component/trie/set_test.go b/component/trie/set_test.go index 6672d885..62854a14 100644 --- a/component/trie/set_test.go +++ b/component/trie/set_test.go @@ -27,6 +27,7 @@ func TestDomainComplexWildcard(t *testing.T) { "+.baidu.com", "+.a.baidu.com", "www.baidu.com", + "+.bb.baidu.com", "test.a.net", "test.a.oc", "www.qq.com", @@ -40,8 +41,8 @@ func TestDomainComplexWildcard(t *testing.T) { func TestDomainWildcard(t *testing.T) { domainSet := []string{ - "*.baidu.com", - "www.baidu.com", + "*.*.*.baidu.com", + "www.baidu.*", "*.*.qq.com", "test.*.baidu.com", } diff --git a/component/trie/sskv.go b/component/trie/sskv.go index 268f6c18..6a661a85 100644 --- a/component/trie/sskv.go +++ b/component/trie/sskv.go @@ -77,7 +77,6 @@ func (ss *DomainSet) Has(key string) bool { if ss == nil { return false } - key = strings.TrimSpace(key) key = utils.Reverse(key) key = strings.ToLower(key) // no more labels in this node @@ -86,15 +85,16 @@ func (ss *DomainSet) Has(key string) bool { nodeId, bmIdx := 0, 0 type wildcardCursor struct { bmIdx, index int - find bool } - cursor := wildcardCursor{} + stack := make([]wildcardCursor, 0) for i := 0; i < len(key); i++ { RESTART: c := key[i] for ; ; bmIdx++ { if getBit(ss.labelBitmap, bmIdx) != 0 { - if cursor.find { + if len(stack) > 0 { + cursor := stack[len(stack)-1] + stack = stack[0 : len(stack)-1] // back wildcard and find next node nextNodeId := countZeros(ss.labelBitmap, ss.ranks, cursor.bmIdx+1) nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1 @@ -102,14 +102,17 @@ func (ss *DomainSet) Has(key string) bool { for ; j < len(key) && key[j] != domainStepByte; j++ { } if j == len(key) { - return getBit(ss.leaves, nextNodeId) != 0 + if getBit(ss.leaves, nextNodeId) != 0 { + return true + }else { + goto RESTART + } } for ; ; nextBmIdx++ { if ss.labels[nextBmIdx-nextNodeId] == domainStepByte { bmIdx = nextBmIdx nodeId = nextNodeId i = j - cursor.find = false goto RESTART } } @@ -120,13 +123,11 @@ func (ss *DomainSet) Has(key string) bool { if ss.labels[bmIdx-nodeId] == complexWildcardByte { return true } else if ss.labels[bmIdx-nodeId] == wildcardByte { - cursor.find = true + cursor := wildcardCursor{} cursor.bmIdx = bmIdx cursor.index = i + stack = append(stack, cursor) } else if ss.labels[bmIdx-nodeId] == c { - if ss.labels[bmIdx-nodeId] == domainStepByte { - cursor.find = false - } break } }