mihomo/component/trie/domain.go

181 lines
3.6 KiB
Go

package trie
import (
"errors"
"strings"
"unicode"
"unicode/utf8"
)
const (
wildcard = "*"
dotWildcard = ""
complexWildcard = "+"
domainStep = "."
)
// ErrInvalidDomain means insert domain is invalid
var ErrInvalidDomain = errors.New("invalid domain")
// DomainTrie contains the main logic for adding and searching nodes for domain segments.
// support wildcard domain (e.g *.google.com)
type DomainTrie[T any] struct {
root *Node[T]
}
func ValidAndSplitDomain(domain string) ([]string, bool) {
if domain != "" && domain[len(domain)-1] == '.' {
return nil, false
}
if domain != "" {
if r, _ := utf8.DecodeRuneInString(domain); unicode.IsSpace(r) {
return nil, false
}
if r, _ := utf8.DecodeLastRuneInString(domain); unicode.IsSpace(r) {
return nil, false
}
}
domain = strings.ToLower(domain)
parts := strings.Split(domain, domainStep)
if len(parts) == 1 {
if parts[0] == "" {
return nil, false
}
return parts, true
}
for _, part := range parts[1:] {
if part == "" {
return nil, false
}
}
return parts, true
}
// Insert adds a node to the trie.
// Support
// 1. www.example.com
// 2. *.example.com
// 3. subdomain.*.example.com
// 4. .example.com
// 5. +.example.com
func (t *DomainTrie[T]) Insert(domain string, data T) error {
parts, valid := ValidAndSplitDomain(domain)
if !valid {
return ErrInvalidDomain
}
if parts[0] == complexWildcard {
t.insert(parts[1:], data)
parts[0] = dotWildcard
t.insert(parts, data)
} else {
t.insert(parts, data)
}
return nil
}
func (t *DomainTrie[T]) insert(parts []string, data T) {
node := t.root
// reverse storage domain part to save space
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]
node = node.getOrNewChild(part)
}
node.setData(data)
}
// Search is the most important part of the Trie.
// Priority as:
// 1. static part
// 2. wildcard domain
// 2. dot wildcard domain
func (t *DomainTrie[T]) Search(domain string) *Node[T] {
parts, valid := ValidAndSplitDomain(domain)
if !valid || parts[0] == "" {
return nil
}
n := t.search(t.root, parts)
if n.isEmpty() {
return nil
}
return n
}
func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
if len(parts) == 0 {
return node
}
if c := node.getChild(parts[len(parts)-1]); c != nil {
if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
return n
}
}
if c := node.getChild(wildcard); c != nil {
if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
return n
}
}
return node.getChild(dotWildcard)
}
func (t *DomainTrie[T]) Optimize() {
t.root.optimize()
}
func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) {
for key, data := range t.root.getChildren() {
recursion([]string{key}, data, fn)
if !data.isEmpty() {
if !fn(joinDomain([]string{key}), data.data) {
return
}
}
}
}
func (t *DomainTrie[T]) IsEmpty() bool {
if t == nil || t.root == nil {
return true
}
return len(t.root.getChildren()) == 0
}
func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool {
for key, data := range node.getChildren() {
newItems := append([]string{key}, items...)
if !data.isEmpty() {
domain := joinDomain(newItems)
if domain[0] == domainStepByte {
domain = complexWildcard + domain
}
if !fn(domain, data.Data()) {
return false
}
}
if !recursion(newItems, data, fn) {
return false
}
}
return true
}
func joinDomain(items []string) string {
return strings.Join(items, domainStep)
}
// New returns a new, empty Trie.
func New[T any]() *DomainTrie[T] {
return &DomainTrie[T]{root: newNode[T]()}
}