Skip to content

Commit

Permalink
fix sequence value for merge key (#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
goccy authored Dec 22, 2024
1 parent aeed806 commit e2e4400
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 64 deletions.
27 changes: 27 additions & 0 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,11 @@ func (m *MapNodeIter) Value() Node {
return m.values[m.idx].Value
}

// KeyValue returns the MappingValueNode of the iterator's current map node entry.
func (m *MapNodeIter) KeyValue() *MappingValueNode {
return m.values[m.idx]
}

// MappingNode type of mapping node
type MappingNode struct {
*BaseNode
Expand Down Expand Up @@ -1653,6 +1658,28 @@ func (n *SequenceNode) MarshalYAML() ([]byte, error) {
return []byte(n.String()), nil
}

// SequenceMergeValue creates SequenceMergeValueNode instance.
func SequenceMergeValue(values ...MapNode) *SequenceMergeValueNode {
return &SequenceMergeValueNode{
values: values,
}
}

// SequenceMergeValueNode is used to convert the Sequence node specified for the merge key into a MapNode format.
type SequenceMergeValueNode struct {
values []MapNode
}

// MapRange returns MapNodeIter instance.
func (n *SequenceMergeValueNode) MapRange() *MapNodeIter {
ret := &MapNodeIter{idx: startRangeIndex}
for _, value := range n.values {
iter := value.MapRange()
ret.values = append(ret.values, iter.values...)
}
return ret
}

// AnchorNode type of anchor node
type AnchorNode struct {
*BaseNode
Expand Down
17 changes: 17 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package yaml

import "context"

type ctxMergeKey struct{}

func withMerge(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxMergeKey{}, true)
}

func isMerge(ctx context.Context) bool {
v, ok := ctx.Value(ctxMergeKey{}).(bool)
if !ok {
return false
}
return v
}
124 changes: 62 additions & 62 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,6 @@ func (d *Decoder) castToFloat(v interface{}) interface{} {
return 0
}

func (d *Decoder) mergeValueNode(value ast.Node) ast.Node {
if value.Type() == ast.AliasType {
aliasNode, _ := value.(*ast.AliasNode)
aliasName := aliasNode.Value.GetToken().Value
return d.anchorNodeMap[aliasName]
}
return value
}

func (d *Decoder) mapKeyNodeToString(node ast.MapKeyNode) (string, error) {
key, err := d.nodeToValue(node)
if err != nil {
Expand All @@ -148,9 +139,16 @@ func (d *Decoder) setToMapValue(node ast.Node, m map[string]interface{}) error {
switch n := node.(type) {
case *ast.MappingValueNode:
if n.Key.IsMergeKey() {
if err := d.setToMapValue(d.mergeValueNode(n.Value), m); err != nil {
value, err := d.getMapNode(n.Value, true)
if err != nil {
return err
}
iter := value.MapRange()
for iter.Next() {
if err := d.setToMapValue(iter.KeyValue(), m); err != nil {
return err
}
}
} else {
key, err := d.mapKeyNodeToString(n.Key)
if err != nil {
Expand Down Expand Up @@ -186,9 +184,16 @@ func (d *Decoder) setToOrderedMapValue(node ast.Node, m *MapSlice) error {
switch n := node.(type) {
case *ast.MappingValueNode:
if n.Key.IsMergeKey() {
if err := d.setToOrderedMapValue(d.mergeValueNode(n.Value), m); err != nil {
value, err := d.getMapNode(n.Value, true)
if err != nil {
return err
}
iter := value.MapRange()
for iter.Next() {
if err := d.setToOrderedMapValue(iter.KeyValue(), m); err != nil {
return err
}
}
} else {
key, err := d.mapKeyNodeToString(n.Key)
if err != nil {
Expand Down Expand Up @@ -468,17 +473,25 @@ func (d *Decoder) nodeToValue(node ast.Node) (any, error) {
return d.nodeToValue(n.Value)
case *ast.MappingValueNode:
if n.Key.IsMergeKey() {
value := d.mergeValueNode(n.Value)
value, err := d.getMapNode(n.Value, true)
if err != nil {
return nil, err
}
iter := value.MapRange()
if d.useOrderedMap {
m := MapSlice{}
if err := d.setToOrderedMapValue(value, &m); err != nil {
return nil, err
for iter.Next() {
if err := d.setToOrderedMapValue(iter.KeyValue(), &m); err != nil {
return nil, err
}
}
return m, nil
}
m := map[string]interface{}{}
if err := d.setToMapValue(value, m); err != nil {
return nil, err
m := make(map[string]any)
for iter.Next() {
if err := d.setToMapValue(iter.KeyValue(), m); err != nil {
return nil, err
}
}
return m, nil
}
Expand Down Expand Up @@ -598,40 +611,42 @@ func (d *Decoder) resolveAlias(node ast.Node) (ast.Node, error) {
return node, nil
}

func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) {
func (d *Decoder) getMapNode(node ast.Node, isMerge bool) (ast.MapNode, error) {
d.stepIn()
defer d.stepOut()
if d.isExceededMaxDepth() {
return nil, ErrExceededMaxDepth
}

if _, ok := node.(*ast.NullNode); ok {
return nil, nil
}
if anchor, ok := node.(*ast.AnchorNode); ok {
mapNode, ok := anchor.Value.(ast.MapNode)
if ok {
return mapNode, nil
}
return nil, errors.ErrUnexpectedNodeType(anchor.Value.Type(), ast.MappingType, node.GetToken())
}
if alias, ok := node.(*ast.AliasNode); ok {
aliasName := alias.Value.GetToken().Value
switch n := node.(type) {
case ast.MapNode:
return n, nil
case *ast.AnchorNode:
anchorName := n.Name.GetToken().Value
d.anchorNodeMap[anchorName] = n.Value
return d.getMapNode(n.Value, isMerge)
case *ast.AliasNode:
aliasName := n.Value.GetToken().Value
node := d.anchorNodeMap[aliasName]
if node == nil {
return nil, fmt.Errorf("cannot find anchor by alias name %s", aliasName)
}
mapNode, ok := node.(ast.MapNode)
if ok {
return mapNode, nil
return d.getMapNode(node, isMerge)
case *ast.SequenceNode:
if !isMerge {
return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
}
return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
}
mapNode, ok := node.(ast.MapNode)
if !ok {
return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
var mapNodes []ast.MapNode
for _, value := range n.Values {
mapNode, err := d.getMapNode(value, false)
if err != nil {
return nil, err
}
mapNodes = append(mapNodes, mapNode)
}
return ast.SequenceMergeValue(mapNodes...), nil
}
return mapNode, nil
return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken())
}

func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) {
Expand Down Expand Up @@ -1191,15 +1206,12 @@ func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValue
return nil, ErrExceededMaxDepth
}

mapNode, err := d.getMapNode(node)
mapNode, err := d.getMapNode(node, false)
if err != nil {
return nil, err
}
keyMap := map[string]struct{}{}
keyToNodeMap := map[string]ast.Node{}
if mapNode == nil {
return keyToNodeMap, nil
}
mapIter := mapNode.MapRange()
for mapIter.Next() {
keyNode := mapIter.Key()
Expand Down Expand Up @@ -1358,13 +1370,10 @@ func (d *Decoder) decodeDuration(ctx context.Context, dst reflect.Value, src ast

// getMergeAliasName support single alias only
func (d *Decoder) getMergeAliasName(src ast.Node) string {
mapNode, err := d.getMapNode(src)
mapNode, err := d.getMapNode(src, true)
if err != nil {
return ""
}
if mapNode == nil {
return ""
}
mapIter := mapNode.MapRange()
for mapIter.Next() {
key := mapIter.Key()
Expand Down Expand Up @@ -1649,21 +1658,18 @@ func (d *Decoder) decodeMapItem(ctx context.Context, dst *MapItem, src ast.Node)
return ErrExceededMaxDepth
}

mapNode, err := d.getMapNode(src)
mapNode, err := d.getMapNode(src, isMerge(ctx))
if err != nil {
return err
}
if mapNode == nil {
return nil
}
mapIter := mapNode.MapRange()
if !mapIter.Next() {
return nil
}
key := mapIter.Key()
value := mapIter.Value()
if key.IsMergeKey() {
if err := d.decodeMapItem(ctx, dst, value); err != nil {
if err := d.decodeMapItem(withMerge(ctx), dst, value); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -1701,13 +1707,10 @@ func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Nod
return ErrExceededMaxDepth
}

mapNode, err := d.getMapNode(src)
mapNode, err := d.getMapNode(src, isMerge(ctx))
if err != nil {
return err
}
if mapNode == nil {
return nil
}
mapSlice := MapSlice{}
mapIter := mapNode.MapRange()
keyMap := map[string]struct{}{}
Expand All @@ -1716,7 +1719,7 @@ func (d *Decoder) decodeMapSlice(ctx context.Context, dst *MapSlice, src ast.Nod
value := mapIter.Value()
if key.IsMergeKey() {
var m MapSlice
if err := d.decodeMapSlice(ctx, &m, value); err != nil {
if err := d.decodeMapSlice(withMerge(ctx), &m, value); err != nil {
return err
}
for _, v := range m {
Expand Down Expand Up @@ -1751,13 +1754,10 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node
return ErrExceededMaxDepth
}

mapNode, err := d.getMapNode(src)
mapNode, err := d.getMapNode(src, isMerge(ctx))
if err != nil {
return err
}
if mapNode == nil {
return nil
}
mapType := dst.Type()
mapValue := reflect.MakeMap(mapType)
keyType := mapValue.Type().Key()
Expand All @@ -1769,7 +1769,7 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node
key := mapIter.Key()
value := mapIter.Value()
if key.IsMergeKey() {
if err := d.decodeMap(ctx, dst, value); err != nil {
if err := d.decodeMap(withMerge(ctx), dst, value); err != nil {
return err
}
iter := dst.MapRange()
Expand Down
32 changes: 32 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,38 @@ c: 3
map[string]any{"a": 1, "b": 2, "c": 3},
},

// merge
{
`
a: &a
foo: 1
b: &b
bar: 2
merge:
<<: [*a, *b]
`,
map[string]map[string]any{
"a": {"foo": 1},
"b": {"bar": 2},
"merge": {"foo": 1, "bar": 2},
},
},
{
`
a: &a
foo: 1
b: &b
bar: 2
merge:
<<: [*a, *b]
`,
map[string]yaml.MapSlice{
"a": {{Key: "foo", Value: 1}},
"b": {{Key: "bar", Value: 2}},
"merge": {{Key: "foo", Value: 1}, {Key: "bar", Value: 2}},
},
},

// Flow sequence
{
"v: [A,B]",
Expand Down
27 changes: 27 additions & 0 deletions parser/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ func (c *context) insertNullToken(tk *Token) *Token {
return nullToken
}

func (c *context) addNullValueToken(tk *Token) *Token {
nullToken := c.createNullToken(tk)
rawTk := nullToken.RawToken()

// add space for map or sequence value.
rawTk.Origin = " null"
rawTk.Position.Column++

c.addToken(nullToken)
c.goNext()

return nullToken
}

func (c *context) createNullToken(base *Token) *Token {
pos := *(base.RawToken().Position)
pos.Column++
Expand Down Expand Up @@ -157,3 +171,16 @@ func (c *context) insertToken(tk *Token) {
ref.tokens[idx] = tk
ref.size = len(ref.tokens)
}

func (c *context) addToken(tk *Token) {
ref := c.tokenRef
lastTk := ref.tokens[ref.size-1]
if lastTk.Group != nil {
lastTk = lastTk.Group.Last()
}
lastTk.RawToken().Next = tk.RawToken()
tk.RawToken().Prev = lastTk.RawToken()

ref.tokens = append(ref.tokens, tk)
ref.size = len(ref.tokens)
}
Loading

0 comments on commit e2e4400

Please sign in to comment.