diff --git a/decode.go b/decode.go index 79117cae..b6e22a5e 100644 --- a/decode.go +++ b/decode.go @@ -15,11 +15,12 @@ import ( "strconv" "time" + "golang.org/x/xerrors" + "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/internal/errors" "github.com/goccy/go-yaml/parser" "github.com/goccy/go-yaml/token" - "golang.org/x/xerrors" ) // Decoder reads and decodes YAML values from an input stream. @@ -1500,10 +1501,19 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node } continue } - k := reflect.ValueOf(d.nodeToValue(key)) - if k.IsValid() && k.Type().ConvertibleTo(keyType) { - k = k.Convert(keyType) + + k := d.createDecodableValue(keyType) + if d.canDecodeByUnmarshaler(k) { + if err := d.decodeByUnmarshaler(ctx, k, key); err != nil { + return errors.Wrapf(err, "failed to decode by unmarshaler") + } + } else { + k = reflect.ValueOf(d.nodeToValue(key)) + if k.IsValid() && k.Type().ConvertibleTo(keyType) { + k = k.Convert(keyType) + } } + if k.IsValid() { if err := d.validateDuplicateKey(keyMap, k.Interface(), key); err != nil { return errors.Wrapf(err, "invalid map key") diff --git a/decode_test.go b/decode_test.go index cabfd33c..7e9b6635 100644 --- a/decode_test.go +++ b/decode_test.go @@ -14,11 +14,12 @@ import ( "testing" "time" + "golang.org/x/xerrors" + "github.com/goccy/go-yaml" "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/internal/errors" "github.com/goccy/go-yaml/parser" - "golang.org/x/xerrors" ) type Child struct { @@ -2906,3 +2907,29 @@ func TestSameNameInineStruct(t *testing.T) { t.Fatalf("failed to decode") } } + +type unmarshableMapKey struct { + Key string +} + +func (mk *unmarshableMapKey) UnmarshalYAML(b []byte) error { + mk.Key = string(b) + return nil +} + +func TestMapKeyCustomUnmarshaler(t *testing.T) { + var m map[unmarshableMapKey]string + if err := yaml.Unmarshal([]byte(`key: value`), &m); err != nil { + t.Fatalf("failed to unmarshal %v", err) + } + if len(m) != 1 { + t.Fatalf("expected 1 element in map, but got %d", len(m)) + } + val, ok := m[unmarshableMapKey{Key: "key"}] + if !ok { + t.Fatal("expected to have element 'key' in map") + } + if val != "value" { + t.Fatalf("expected to have value \"value\", but got %q", val) + } +} \ No newline at end of file