diff --git a/component/cidr/ipcidr_set.go b/component/cidr/ipcidr_set.go new file mode 100644 index 00000000..b8dec0ee --- /dev/null +++ b/component/cidr/ipcidr_set.go @@ -0,0 +1,89 @@ +package cidr + +import ( + "math/big" + "net" + "sort" +) + +type Range struct { + Start *big.Int + End *big.Int +} + +type IpCidrSet struct { + Ranges []Range +} + +func NewIpCidrSet() *IpCidrSet { + return &IpCidrSet{} +} + +func ipToBigInt(ip net.IP) *big.Int { + ipBigInt := big.NewInt(0) + ipBigInt.SetBytes(ip.To16()) + return ipBigInt +} + +func cidrToRange(cidr string) (Range, error) { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + return Range{}, err + } + firstIP, lastIP := networkRange(ipNet) + return Range{Start: ipToBigInt(firstIP), End: ipToBigInt(lastIP)}, nil +} + +func networkRange(network *net.IPNet) (net.IP, net.IP) { + firstIP := network.IP + lastIP := make(net.IP, len(firstIP)) + copy(lastIP, firstIP) + for i := range firstIP { + lastIP[i] |= ^network.Mask[i] + } + return firstIP, lastIP +} + +func (set *IpCidrSet) AddIpCidrForString(ipCidr string) error { + ipRange, err := cidrToRange(ipCidr) + if err != nil { + return err + } + set.Ranges = append(set.Ranges, ipRange) + sort.Slice(set.Ranges, func(i, j int) bool { + return set.Ranges[i].Start.Cmp(set.Ranges[j].Start) < 0 + }) + return nil +} + +func (set *IpCidrSet) AddIpCidr(ipCidr *net.IPNet) error { + return set.AddIpCidrForString(ipCidr.String()) +} + +func (set *IpCidrSet) IsContainForString(ipString string) bool { + ip := ipToBigInt(net.ParseIP(ipString)) + idx := sort.Search(len(set.Ranges), func(i int) bool { + return set.Ranges[i].End.Cmp(ip) >= 0 + }) + if idx < len(set.Ranges) && set.Ranges[idx].Start.Cmp(ip) <= 0 && set.Ranges[idx].End.Cmp(ip) >= 0 { + return true + } + return false +} + +func (set *IpCidrSet) IsContain(ip net.IP) bool { + if ip == nil { + return false + } + return set.IsContainForString(ip.String()) +} + +func (set *IpCidrSet) Merge() { + for i := 0; i < len(set.Ranges)-1; i++ { + if set.Ranges[i].End.Cmp(set.Ranges[i+1].Start) >= 0 { + set.Ranges[i].End = set.Ranges[i+1].End + set.Ranges = append(set.Ranges[:i+1], set.Ranges[i+2:]...) + i-- + } + } +} diff --git a/component/cidr/ipcidr_set_test.go b/component/cidr/ipcidr_set_test.go new file mode 100644 index 00000000..a8cfef7f --- /dev/null +++ b/component/cidr/ipcidr_set_test.go @@ -0,0 +1,105 @@ +package cidr + +import ( + "testing" +) + +func TestIpv4(t *testing.T) { + tests := []struct { + name string + ipCidr string + ip string + expected bool + }{ + { + name: "Test Case 1", + ipCidr: "149.154.160.0/20", + ip: "149.154.160.0", + expected: true, + }, + { + name: "Test Case 2", + ipCidr: "192.168.0.0/16", + ip: "10.0.0.1", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + set := &IpCidrSet{} + set.AddIpCidrForString(test.ipCidr) + + result := set.IsContainForString(test.ip) + if result != test.expected { + t.Errorf("Expected result: %v, got: %v", test.expected, result) + } + }) + } +} + +func TestIpv6(t *testing.T) { + tests := []struct { + name string + ipCidr string + ip string + expected bool + }{ + { + name: "Test Case 1", + ipCidr: "2409:8000::/20", + ip: "2409:8087:1e03:21::27", + expected: true, + }, + { + name: "Test Case 2", + ipCidr: "240e::/16", + ip: "240e:964:ea02:100:1800::71", + expected: true, + }, + } + // Add more test cases as needed + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + set := &IpCidrSet{} + set.AddIpCidrForString(test.ipCidr) + + result := set.IsContainForString(test.ip) + if result != test.expected { + t.Errorf("Expected result: %v, got: %v", test.expected, result) + } + }) + } +} + +func TestMerge(t *testing.T) { + tests := []struct { + name string + ipCidr1 string + ipCidr2 string + ipCidr3 string + expectedLen int + }{ + { + name: "Test Case 1", + ipCidr1: "2409:8000::/20", + ipCidr2: "2409:8000::/21", + ipCidr3: "2409:8000::/48", + expectedLen: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + set := &IpCidrSet{} + set.AddIpCidrForString(test.ipCidr1) + set.AddIpCidrForString(test.ipCidr2) + set.Merge() + + if len(set.Ranges) != test.expectedLen { + t.Errorf("Expected len: %v, got: %v", test.expectedLen, len(set.Ranges)) + } + }) + } +} diff --git a/rules/provider/ipcidr_strategy.go b/rules/provider/ipcidr_strategy.go index 321e901a..dad48305 100644 --- a/rules/provider/ipcidr_strategy.go +++ b/rules/provider/ipcidr_strategy.go @@ -1,7 +1,7 @@ package provider import ( - "github.com/metacubex/mihomo/component/trie" + "github.com/metacubex/mihomo/component/cidr" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" ) @@ -9,7 +9,8 @@ import ( type ipcidrStrategy struct { count int shouldResolveIP bool - trie *trie.IpCidrTrie + cidrSet *cidr.IpCidrSet + //trie *trie.IpCidrTrie } func (i *ipcidrStrategy) ShouldFindProcess() bool { @@ -17,7 +18,8 @@ func (i *ipcidrStrategy) ShouldFindProcess() bool { } func (i *ipcidrStrategy) Match(metadata *C.Metadata) bool { - return i.trie != nil && i.trie.IsContain(metadata.DstIP.AsSlice()) + // return i.trie != nil && i.trie.IsContain(metadata.DstIP.AsSlice()) + return i.cidrSet != nil && i.cidrSet.IsContain(metadata.DstIP.AsSlice()) } func (i *ipcidrStrategy) Count() int { @@ -29,13 +31,15 @@ func (i *ipcidrStrategy) ShouldResolveIP() bool { } func (i *ipcidrStrategy) Reset() { - i.trie = trie.NewIpCidrTrie() + // i.trie = trie.NewIpCidrTrie() + i.cidrSet = cidr.NewIpCidrSet() i.count = 0 i.shouldResolveIP = false } func (i *ipcidrStrategy) Insert(rule string) { - err := i.trie.AddIpCidrForString(rule) + //err := i.trie.AddIpCidrForString(rule) + err := i.cidrSet.AddIpCidrForString(rule) if err != nil { log.Warnln("invalid Ipcidr:[%s]", rule) } else { @@ -44,7 +48,9 @@ func (i *ipcidrStrategy) Insert(rule string) { } } -func (i *ipcidrStrategy) FinishInsert() {} +func (i *ipcidrStrategy) FinishInsert() { + i.cidrSet.Merge() +} func NewIPCidrStrategy() *ipcidrStrategy { return &ipcidrStrategy{}