mihomo/common/structure/structure.go

605 lines
16 KiB
Go
Raw Normal View History

2018-09-30 16:30:11 +08:00
package structure
// references: https://github.com/mitchellh/mapstructure
2018-09-30 16:30:11 +08:00
import (
"encoding"
"encoding/base64"
2018-09-30 16:30:11 +08:00
"fmt"
"reflect"
"strconv"
"strings"
)
// Option is the configuration that is used to create a new decoder
type Option struct {
TagName string
WeaklyTypedInput bool
KeyReplacer *strings.Replacer
2018-09-30 16:30:11 +08:00
}
2022-12-04 21:53:13 +08:00
var DefaultKeyReplacer = strings.NewReplacer("_", "-")
2018-09-30 16:30:11 +08:00
// Decoder is the core of structure
type Decoder struct {
option *Option
}
// NewDecoder return a Decoder by Option
func NewDecoder(option Option) *Decoder {
if option.TagName == "" {
option.TagName = "structure"
}
return &Decoder{option: &option}
}
2022-03-16 12:10:13 +08:00
// Decode transform a map[string]any to a struct
func (d *Decoder) Decode(src map[string]any, dst any) error {
2018-09-30 16:30:11 +08:00
if reflect.TypeOf(dst).Kind() != reflect.Ptr {
return fmt.Errorf("decode must recive a ptr struct")
2018-09-30 16:30:11 +08:00
}
t := reflect.TypeOf(dst).Elem()
v := reflect.ValueOf(dst).Elem()
for idx := 0; idx < v.NumField(); idx++ {
field := t.Field(idx)
2021-11-08 13:29:37 +08:00
if field.Anonymous {
if err := d.decodeStruct(field.Name, src, v.Field(idx)); err != nil {
return err
}
continue
}
2018-09-30 16:30:11 +08:00
tag := field.Tag.Get(d.option.TagName)
2022-03-16 12:10:13 +08:00
key, omitKey, found := strings.Cut(tag, ",")
omitempty := found && omitKey == "omitempty"
2018-09-30 16:30:11 +08:00
value, ok := src[key]
if !ok {
if d.option.KeyReplacer != nil {
key = d.option.KeyReplacer.Replace(key)
}
for _strKey := range src {
strKey := _strKey
if d.option.KeyReplacer != nil {
strKey = d.option.KeyReplacer.Replace(strKey)
}
if strings.EqualFold(key, strKey) {
value = src[_strKey]
ok = true
break
}
}
}
2019-03-25 20:42:20 +08:00
if !ok || value == nil {
2018-09-30 16:30:11 +08:00
if omitempty {
continue
}
2019-03-25 20:42:20 +08:00
return fmt.Errorf("key '%s' missing", key)
2018-09-30 16:30:11 +08:00
}
err := d.decode(key, value, v.Field(idx))
if err != nil {
return err
}
}
return nil
}
// isNil returns true if the input is nil or a typed nil pointer.
func isNil(input any) bool {
if input == nil {
return true
}
val := reflect.ValueOf(input)
return val.Kind() == reflect.Pointer && val.IsNil()
}
2022-03-16 12:10:13 +08:00
func (d *Decoder) decode(name string, data any, val reflect.Value) error {
if isNil(data) {
// If the data is nil, then we don't set anything
// Maybe we should set to zero value?
return nil
}
if !reflect.ValueOf(data).IsValid() {
// If the input value is invalid, then we just set the value
// to be the zero value.
val.Set(reflect.Zero(val.Type()))
return nil
}
for {
kind := val.Kind()
if kind == reflect.Pointer && val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
if ok, err := d.decodeTextUnmarshaller(name, data, val); ok {
return err
}
switch {
case isInt(kind):
return d.decodeInt(name, data, val)
case isUint(kind):
return d.decodeUint(name, data, val)
case isFloat(kind):
return d.decodeFloat(name, data, val)
}
switch kind {
case reflect.Pointer:
val = val.Elem()
continue
case reflect.String:
return d.decodeString(name, data, val)
case reflect.Bool:
return d.decodeBool(name, data, val)
case reflect.Slice:
return d.decodeSlice(name, data, val)
case reflect.Map:
return d.decodeMap(name, data, val)
case reflect.Interface:
return d.setInterface(name, data, val)
case reflect.Struct:
return d.decodeStruct(name, data, val)
default:
return fmt.Errorf("type %s not support", val.Kind().String())
}
2018-09-30 16:30:11 +08:00
}
}
2022-12-13 21:13:31 +08:00
func isInt(kind reflect.Kind) bool {
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
default:
return false
}
}
func isUint(kind reflect.Kind) bool {
switch kind {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
default:
return false
}
}
func isFloat(kind reflect.Kind) bool {
switch kind {
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
2022-03-16 12:10:13 +08:00
func (d *Decoder) decodeInt(name string, data any, val reflect.Value) (err error) {
2018-09-30 16:30:11 +08:00
dataVal := reflect.ValueOf(data)
kind := dataVal.Kind()
switch {
2022-12-13 21:13:31 +08:00
case isInt(kind):
2018-09-30 16:30:11 +08:00
val.SetInt(dataVal.Int())
2022-12-13 21:13:31 +08:00
case isUint(kind) && d.option.WeaklyTypedInput:
val.SetInt(int64(dataVal.Uint()))
2022-12-13 21:13:31 +08:00
case isFloat(kind) && d.option.WeaklyTypedInput:
val.SetInt(int64(dataVal.Float()))
2018-09-30 16:30:11 +08:00
case kind == reflect.String && d.option.WeaklyTypedInput:
var i int64
i, err = strconv.ParseInt(dataVal.String(), 0, val.Type().Bits())
if err == nil {
val.SetInt(i)
} else {
err = fmt.Errorf("cannot parse '%s' as int: %s", name, err)
}
default:
err = fmt.Errorf(
"'%s' expected type '%s', got unconvertible type '%s'",
name, val.Type(), dataVal.Type(),
)
}
return err
}
func (d *Decoder) decodeUint(name string, data any, val reflect.Value) (err error) {
dataVal := reflect.ValueOf(data)
kind := dataVal.Kind()
switch {
2022-12-13 21:13:31 +08:00
case isUint(kind):
val.SetUint(dataVal.Uint())
2022-12-13 21:13:31 +08:00
case isInt(kind) && d.option.WeaklyTypedInput:
val.SetUint(uint64(dataVal.Int()))
2022-12-13 21:13:31 +08:00
case isFloat(kind) && d.option.WeaklyTypedInput:
val.SetUint(uint64(dataVal.Float()))
case kind == reflect.String && d.option.WeaklyTypedInput:
var i uint64
i, err = strconv.ParseUint(dataVal.String(), 0, val.Type().Bits())
if err == nil {
val.SetUint(i)
} else {
err = fmt.Errorf("cannot parse '%s' as int: %s", name, err)
}
default:
err = fmt.Errorf(
"'%s' expected type '%s', got unconvertible type '%s'",
name, val.Type(), dataVal.Type(),
)
}
return err
}
2022-12-13 21:13:31 +08:00
func (d *Decoder) decodeFloat(name string, data any, val reflect.Value) (err error) {
dataVal := reflect.ValueOf(data)
kind := dataVal.Kind()
switch {
case isFloat(kind):
val.SetFloat(dataVal.Float())
case isUint(kind):
val.SetFloat(float64(dataVal.Uint()))
case isInt(kind) && d.option.WeaklyTypedInput:
val.SetFloat(float64(dataVal.Int()))
case kind == reflect.String && d.option.WeaklyTypedInput:
var i float64
i, err = strconv.ParseFloat(dataVal.String(), val.Type().Bits())
if err == nil {
val.SetFloat(i)
} else {
err = fmt.Errorf("cannot parse '%s' as int: %s", name, err)
}
default:
err = fmt.Errorf(
"'%s' expected type '%s', got unconvertible type '%s'",
name, val.Type(), dataVal.Type(),
)
}
return err
}
2022-03-16 12:10:13 +08:00
func (d *Decoder) decodeString(name string, data any, val reflect.Value) (err error) {
2018-09-30 16:30:11 +08:00
dataVal := reflect.ValueOf(data)
kind := dataVal.Kind()
switch {
case kind == reflect.String:
val.SetString(dataVal.String())
2022-12-13 21:13:31 +08:00
case isInt(kind) && d.option.WeaklyTypedInput:
2018-09-30 16:30:11 +08:00
val.SetString(strconv.FormatInt(dataVal.Int(), 10))
2022-12-13 21:13:31 +08:00
case isUint(kind) && d.option.WeaklyTypedInput:
val.SetString(strconv.FormatUint(dataVal.Uint(), 10))
case isFloat(kind) && d.option.WeaklyTypedInput:
val.SetString(strconv.FormatFloat(dataVal.Float(), 'E', -1, dataVal.Type().Bits()))
2018-09-30 16:30:11 +08:00
default:
err = fmt.Errorf(
2019-03-25 20:42:20 +08:00
"'%s' expected type '%s', got unconvertible type '%s'",
2018-09-30 16:30:11 +08:00
name, val.Type(), dataVal.Type(),
)
}
return err
}
2022-03-16 12:10:13 +08:00
func (d *Decoder) decodeBool(name string, data any, val reflect.Value) (err error) {
2018-09-30 16:30:11 +08:00
dataVal := reflect.ValueOf(data)
kind := dataVal.Kind()
switch {
case kind == reflect.Bool:
val.SetBool(dataVal.Bool())
2022-12-13 21:13:31 +08:00
case isInt(kind) && d.option.WeaklyTypedInput:
2018-09-30 16:30:11 +08:00
val.SetBool(dataVal.Int() != 0)
2022-12-13 21:13:31 +08:00
case isUint(kind) && d.option.WeaklyTypedInput:
val.SetString(strconv.FormatUint(dataVal.Uint(), 10))
2018-09-30 16:30:11 +08:00
default:
err = fmt.Errorf(
2019-03-25 20:42:20 +08:00
"'%s' expected type '%s', got unconvertible type '%s'",
2018-09-30 16:30:11 +08:00
name, val.Type(), dataVal.Type(),
)
}
return err
}
2022-03-16 12:10:13 +08:00
func (d *Decoder) decodeSlice(name string, data any, val reflect.Value) error {
2018-09-30 16:30:11 +08:00
dataVal := reflect.Indirect(reflect.ValueOf(data))
valType := val.Type()
valElemType := valType.Elem()
2022-12-13 21:13:31 +08:00
if dataVal.Kind() == reflect.String && valElemType.Kind() == reflect.Uint8 { // from encoding/json
s := []byte(dataVal.String())
b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
n, err := base64.StdEncoding.Decode(b, s)
if err != nil {
return fmt.Errorf("try decode '%s' by base64 error: %w", name, err)
}
val.SetBytes(b[:n])
return nil
}
2018-09-30 16:30:11 +08:00
if dataVal.Kind() != reflect.Slice {
return fmt.Errorf("'%s' is not a slice", name)
}
valSlice := val
// make a new slice with cap(val)==cap(dataVal)
// the caller can determine whether the original configuration contains this item by judging whether the value is nil.
valSlice = reflect.MakeSlice(valType, 0, dataVal.Len())
2018-09-30 16:30:11 +08:00
for i := 0; i < dataVal.Len(); i++ {
currentData := dataVal.Index(i).Interface()
for valSlice.Len() <= i {
valSlice = reflect.Append(valSlice, reflect.Zero(valElemType))
}
fieldName := fmt.Sprintf("%s[%d]", name, i)
2022-05-06 11:43:53 +08:00
if currentData == nil {
// in weakly type mode, null will convert to zero value
if d.option.WeaklyTypedInput {
continue
}
// in non-weakly type mode, null will convert to nil if element's zero value is nil, otherwise return an error
if elemKind := valElemType.Kind(); elemKind == reflect.Map || elemKind == reflect.Slice {
continue
}
return fmt.Errorf("'%s' can not be null", fieldName)
}
currentField := valSlice.Index(i)
2018-09-30 16:30:11 +08:00
if err := d.decode(fieldName, currentData, currentField); err != nil {
return err
}
}
val.Set(valSlice)
return nil
}
2022-03-16 12:10:13 +08:00
func (d *Decoder) decodeMap(name string, data any, val reflect.Value) error {
valType := val.Type()
valKeyType := valType.Key()
valElemType := valType.Elem()
valMap := val
if valMap.IsNil() {
mapType := reflect.MapOf(valKeyType, valElemType)
valMap = reflect.MakeMap(mapType)
}
dataVal := reflect.Indirect(reflect.ValueOf(data))
if dataVal.Kind() != reflect.Map {
return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
}
return d.decodeMapFromMap(name, dataVal, val, valMap)
}
func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
valType := val.Type()
valKeyType := valType.Key()
valElemType := valType.Elem()
errors := make([]string, 0)
if dataVal.Len() == 0 {
if dataVal.IsNil() {
if !val.IsNil() {
val.Set(dataVal)
}
} else {
val.Set(valMap)
}
return nil
}
for _, k := range dataVal.MapKeys() {
fieldName := fmt.Sprintf("%s[%s]", name, k)
currentKey := reflect.Indirect(reflect.New(valKeyType))
if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
errors = append(errors, err.Error())
continue
}
v := dataVal.MapIndex(k).Interface()
2020-03-19 11:04:56 +08:00
if v == nil {
errors = append(errors, fmt.Sprintf("filed %s invalid", fieldName))
continue
}
currentVal := reflect.Indirect(reflect.New(valElemType))
if err := d.decode(fieldName, v, currentVal); err != nil {
errors = append(errors, err.Error())
continue
}
valMap.SetMapIndex(currentKey, currentVal)
}
val.Set(valMap)
if len(errors) > 0 {
return fmt.Errorf(strings.Join(errors, ","))
}
return nil
}
2019-02-11 15:25:10 +08:00
2022-03-16 12:10:13 +08:00
func (d *Decoder) decodeStruct(name string, data any, val reflect.Value) error {
2019-12-08 12:17:24 +08:00
dataVal := reflect.Indirect(reflect.ValueOf(data))
// If the type of the value to write to and the data match directly,
// then we just set it directly instead of recursing into the structure.
if dataVal.Type() == val.Type() {
val.Set(dataVal)
return nil
}
dataValKind := dataVal.Kind()
switch dataValKind {
case reflect.Map:
return d.decodeStructFromMap(name, dataVal, val)
default:
return fmt.Errorf("'%s' expected a map, got '%s'", name, dataVal.Kind())
}
}
func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) error {
dataValType := dataVal.Type()
if kind := dataValType.Key().Kind(); kind != reflect.String && kind != reflect.Interface {
return fmt.Errorf(
"'%s' needs a map with string keys, has '%s' keys",
name, dataValType.Key().Kind())
}
dataValKeys := make(map[reflect.Value]struct{})
2022-03-16 12:10:13 +08:00
dataValKeysUnused := make(map[any]struct{})
2019-12-08 12:17:24 +08:00
for _, dataValKey := range dataVal.MapKeys() {
dataValKeys[dataValKey] = struct{}{}
dataValKeysUnused[dataValKey.Interface()] = struct{}{}
}
errors := make([]string, 0)
// This slice will keep track of all the structs we'll be decoding.
// There can be more than one struct if there are embedded structs
// that are squashed.
structs := make([]reflect.Value, 1, 5)
structs[0] = val
// Compile the list of all the fields that we're going to be decoding
// from all the structs.
type field struct {
field reflect.StructField
val reflect.Value
}
var fields []field
2019-12-08 12:17:24 +08:00
for len(structs) > 0 {
structVal := structs[0]
structs = structs[1:]
structType := structVal.Type()
for i := 0; i < structType.NumField(); i++ {
fieldType := structType.Field(i)
fieldKind := fieldType.Type.Kind()
// If "squash" is specified in the tag, we squash the field down.
squash := false
tagParts := strings.Split(fieldType.Tag.Get(d.option.TagName), ",")
for _, tag := range tagParts[1:] {
if tag == "squash" {
squash = true
break
}
}
if squash {
if fieldKind != reflect.Struct {
errors = append(errors,
fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldKind).Error())
} else {
structs = append(structs, structVal.FieldByName(fieldType.Name))
}
continue
}
// Normal struct field, store it away
fields = append(fields, field{fieldType, structVal.Field(i)})
}
}
// for fieldType, field := range fields {
for _, f := range fields {
field, fieldValue := f.field, f.val
fieldName := field.Name
tagValue := field.Tag.Get(d.option.TagName)
tagValue = strings.SplitN(tagValue, ",", 2)[0]
if tagValue != "" {
fieldName = tagValue
}
rawMapKey := reflect.ValueOf(fieldName)
rawMapVal := dataVal.MapIndex(rawMapKey)
if !rawMapVal.IsValid() {
// Do a slower search by iterating over each key and
// doing case-insensitive search.
if d.option.KeyReplacer != nil {
fieldName = d.option.KeyReplacer.Replace(fieldName)
}
2019-12-08 12:17:24 +08:00
for dataValKey := range dataValKeys {
mK, ok := dataValKey.Interface().(string)
if !ok {
// Not a string key
continue
}
if d.option.KeyReplacer != nil {
mK = d.option.KeyReplacer.Replace(mK)
}
2019-12-08 12:17:24 +08:00
if strings.EqualFold(mK, fieldName) {
rawMapKey = dataValKey
rawMapVal = dataVal.MapIndex(dataValKey)
break
}
}
if !rawMapVal.IsValid() {
// There was no matching key in the map for the value in
// the struct. Just ignore.
continue
}
}
// Delete the key we're using from the unused map so we stop tracking
delete(dataValKeysUnused, rawMapKey.Interface())
if !fieldValue.IsValid() {
// This should never happen
panic("field is not valid")
}
// If we can't set the field, then it is unexported or something,
// and we just continue onwards.
if !fieldValue.CanSet() {
continue
}
// If the name is empty string, then we're at the root, and we
// don't dot-join the fields.
if name != "" {
fieldName = fmt.Sprintf("%s.%s", name, fieldName)
}
if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
errors = append(errors, err.Error())
}
}
if len(errors) > 0 {
return fmt.Errorf(strings.Join(errors, ","))
}
return nil
}
2022-03-16 12:10:13 +08:00
func (d *Decoder) setInterface(name string, data any, val reflect.Value) (err error) {
2019-02-11 15:25:10 +08:00
dataVal := reflect.ValueOf(data)
val.Set(dataVal)
return nil
}
func (d *Decoder) decodeTextUnmarshaller(name string, data any, val reflect.Value) (bool, error) {
if !val.CanAddr() {
return false, nil
}
valAddr := val.Addr()
if !valAddr.CanInterface() {
return false, nil
}
unmarshaller, ok := valAddr.Interface().(encoding.TextUnmarshaler)
if !ok {
return false, nil
}
var str string
if err := d.decodeString(name, data, reflect.Indirect(reflect.ValueOf(&str))); err != nil {
return false, err
}
if err := unmarshaller.UnmarshalText([]byte(str)); err != nil {
return true, fmt.Errorf("cannot parse '%s' as %s: %s", name, val.Type(), err)
}
return true, nil
}