Skip to content

Commit

Permalink
Merge pull request #88 from vimeo/anonymous_field_flatten_mangler
Browse files Browse the repository at this point in the history
anonymous field flatten mangler
  • Loading branch information
dfinkel authored Mar 20, 2024
2 parents 4b8cbbd + 83263c7 commit 6d7d1f9
Show file tree
Hide file tree
Showing 6 changed files with 405 additions and 6 deletions.
10 changes: 8 additions & 2 deletions decoders/yaml/yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ const YAMLTagName = "yaml"

// Decoder is a decoder that knows how to work with text encoded in YAML.
type Decoder struct {
// Flatten any anonymous struct fields into the parent
FlattenAnonymous bool
}

// Decode reads from `r` and decodes what is read as YAML depositing the
Expand All @@ -29,9 +31,13 @@ func (d *Decoder) Decode(r io.Reader, t *dials.Type) (reflect.Value, error) {
return reflect.Value{}, fmt.Errorf("error reading YAML: %s", err)
}

manglers := []transform.Mangler{&tagformat.TagCopyingMangler{
SrcTag: common.DialsTagName, NewTag: YAMLTagName}}
if d.FlattenAnonymous {
manglers = append(manglers, transform.AnonymousFlattenMangler{})
}
tfmr := transform.NewTransformer(t.Type(),
&tagformat.TagCopyingMangler{
SrcTag: common.DialsTagName, NewTag: YAMLTagName},
manglers...,
)
val, tfmErr := tfmr.Translate()
if tfmErr != nil {
Expand Down
49 changes: 49 additions & 0 deletions decoders/yaml/yaml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/vimeo/dials"
"github.com/vimeo/dials/sources/static"
)
Expand Down Expand Up @@ -145,6 +146,54 @@ func TestMoreDeeplyNestedYAML(t *testing.T) {
assert.Equal(t, c.DatabaseUser.SliceThing[1].Zizzle, "fizzlebat")
}

func TestAnonymousNestedYAML(t *testing.T) {
type OtherStuff struct {
Something string `dials:"something"`
IPAddress net.IP `dials:"ip_address"`
Timeout time.Duration `dials:"timeout"`
}
type testConfig struct {
DatabaseName string `dials:"database_name"`
DatabaseAddress string `dials:"database_address"`
DatabaseUser struct {
Username string `dials:"username"`
Password string `dials:"password"`
OtherStuff
} `dials:"database_user"`
}

yamlData := `{
"database_name": "something",
"database_address": "127.0.0.1",
"database_user": {
"username": "test",
"password": "password",
"something": "asdf",
"ip_address": "123.10.11.121",
"timeout": "10s",
}
}`

myConfig := &testConfig{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
d, err := dials.Config(
ctx,
myConfig,
&static.StringSource{Data: yamlData, Decoder: &Decoder{FlattenAnonymous: true}},
)

require.NoError(t, err)
c := d.View()

assert.Equal(t, "something", c.DatabaseName)
assert.Equal(t, "test", c.DatabaseUser.Username)
assert.Equal(t, "password", c.DatabaseUser.Password)
assert.Equal(t, "asdf", c.DatabaseUser.Something)
assert.Equal(t, net.IPv4(123, 10, 11, 121), c.DatabaseUser.IPAddress)
assert.Equal(t, time.Duration(10*time.Second), c.DatabaseUser.Timeout)
}

func TestDecoderBadMarkup(t *testing.T) {
type testConfig struct {
Val1 string
Expand Down
45 changes: 41 additions & 4 deletions ez/ez.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,23 @@ type Params[T any] struct {
// Note that this does not affect the flags or environment variable
// naming. To manipulate flag naming, see [Params.FlagConfig].
FileFieldNameEncoder caseconversion.EncodeCasingFunc

// FlattenAnonymousFields inserts the AnonymousFlattenMangler into the
// chain so decoders that do not handle anonymous fields never see such
// things.
// (Currently only affects the yaml decoder)
FlattenAnonymousFields bool
}

// DecoderFactory should return the appropriate decoder based on the config file
// path that is passed as the string argument to DecoderFactory
type DecoderFactory func(string) dials.Decoder

// DecoderFactoryWithParams should return the appropriate decoder based on the config file
// path that is passed as the string argument to DecoderFactory
// Params may provide useful context/arguments
type DecoderFactoryWithParams[T any] func(string, Params[T]) dials.Decoder

// ConfigWithConfigPath is an interface config struct that supplies a
// ConfigPath() method to indicate which file to read as the config file once
// populated.
Expand Down Expand Up @@ -125,6 +136,25 @@ func fileSource(cfgPath string, decoder dials.Decoder, watch bool) (dials.Source
// The contents of cfg for the defaults
// cfg.ConfigPath() is evaluated on the stacked config with the file-contents omitted (using a "blank" source)
func ConfigFileEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, cfg TP, df DecoderFactory, params Params[T]) (*dials.Dials[T], error) {
dfp := func(path string, _ Params[T]) dials.Decoder {
return df(path)
}
return ConfigFileEnvFlagDecoderFactoryParams(ctx, cfg, dfp, params)

}

// ConfigFileEnvFlagDecoderFactoryParams takes advantage of the ConfigWithConfigPath cfg to indicate
// what file to read and uses the passed decoder.
// Configuration values provided by the returned Dials are the result of
// stacking the sources in the following order:
// - configuration file
// - environment variables
// - flags it registers with the standard library flags package
//
// The contents of cfg for the defaults
// cfg.ConfigPath() is evaluated on the stacked config with the file-contents omitted (using a "blank" source)
// It differs from ConfigFileEnvFlag by the signature of the decoder factory, (which requires a params struct in this function)
func ConfigFileEnvFlagDecoderFactoryParams[T any, TP ConfigWithConfigPath[T]](ctx context.Context, cfg TP, df DecoderFactoryWithParams[T], params Params[T]) (*dials.Dials[T], error) {
blank := sourcewrap.Blank{}

flagSrc := params.FlagSource
Expand Down Expand Up @@ -184,7 +214,7 @@ func ConfigFileEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, c
return d, nil
}

decoder := df(cfgPath)
decoder := df(cfgPath, params)
if decoder == nil {
return nil, fmt.Errorf("decoderFactory provided a nil decoder for path: %s", cfgPath)
}
Expand Down Expand Up @@ -243,7 +273,7 @@ func ConfigFileEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, c
// YAMLConfigEnvFlag takes advantage of the ConfigWithConfigPath cfg, thinly
// wraping ConfigFileEnvFlag with the decoder statically set to YAML.
func YAMLConfigEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, cfg TP, params Params[T]) (*dials.Dials[T], error) {
return ConfigFileEnvFlag(ctx, cfg, func(string) dials.Decoder { return &yaml.Decoder{} }, params)
return ConfigFileEnvFlag(ctx, cfg, func(string) dials.Decoder { return &yaml.Decoder{FlattenAnonymous: params.FlattenAnonymousFields} }, params)
}

// JSONConfigEnvFlag takes advantage of the ConfigWithConfigPath cfg, thinly
Expand All @@ -268,10 +298,17 @@ func TOMLConfigEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, c
// based on the extension of the filename or nil if there is not an appropriate
// mapping.
func DecoderFromExtension(path string) dials.Decoder {
return DecoderFromExtensionWithParams(path, Params[struct{}]{})
}

// DecoderFromExtension is a DecoderFactory that returns an appropriate decoder
// based on the extension of the filename or nil if there is not an appropriate
// mapping.
func DecoderFromExtensionWithParams[T any](path string, p Params[T]) dials.Decoder {
ext := filepath.Ext(path)
switch strings.ToLower(ext) {
case ".yaml", ".yml":
return &yaml.Decoder{}
return &yaml.Decoder{FlattenAnonymous: p.FlattenAnonymousFields}
case ".json":
return &json.Decoder{}
case ".toml":
Expand All @@ -289,5 +326,5 @@ func DecoderFromExtension(path string) dials.Decoder {
// file contents based on the file extension (from the limited set of JSON,
// Cue, YAML and TOML).
func FileExtensionDecoderConfigEnvFlag[T any, TP ConfigWithConfigPath[T]](ctx context.Context, cfg TP, params Params[T]) (*dials.Dials[T], error) {
return ConfigFileEnvFlag(ctx, cfg, DecoderFromExtension, params)
return ConfigFileEnvFlagDecoderFactoryParams(ctx, cfg, DecoderFromExtensionWithParams[T], params)
}
117 changes: 117 additions & 0 deletions transform/anonymous_flatten_mangler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package transform

import (
"reflect"
)

// AnonymousFlattenMangler hoists the fields from the types of anonymous
// struct-fields into the parent type. (working around decoders/sources that
// are unaware of anonymous fields)
// Note: this mangler is unaware of TextUnmarshaler implementations (it's tricky to do right when flattening).
// It should be combined with the TextUnmarshalerMangler if the prefered
// handling is to mask the other fields in that struct with the TextUnmarshaler
// implementation.
type AnonymousFlattenMangler struct{}

// Mangle is called for every field in a struct, and maps that to one or more output fields.
// Implementations that desire to leave fields unchanged should return
// the argument unchanged. (particularly useful if taking advantage of
// recursive evaluation)
func (a AnonymousFlattenMangler) Mangle(sf reflect.StructField) ([]reflect.StructField, error) {
// If it's not an anonymous field, return it as-is
if !sf.Anonymous {
return []reflect.StructField{sf}, nil
}
// Note: TranslateType already skips unexported fields

// anonymous/embedded fields can only be interfaces, pointers and structs
switch sf.Type.Kind() {
case reflect.Pointer:
// recurse with the pointer stripped off
sfInner := sf
sfInner.Type = sf.Type.Elem()
return a.Mangle(sfInner)
case reflect.Struct:
out := make([]reflect.StructField, 0, sf.Type.NumField())
for i := 0; i < sf.Type.NumField(); i++ {
innerField := sf.Type.Field(i)
if !innerField.IsExported() {
// skip unexported fields
continue
}
out = append(out, innerField)
}

return out, nil
default:
// leave everything else alone (there's nothing to promote)
// this includes interfaces and all other non-struct and
// non-pointer-to-struct types.
return []reflect.StructField{sf}, nil
}
}

// bool return value indicates whether all fields are nil (and as such, a nil value should be returned for pointer-types)
func (a AnonymousFlattenMangler) unmangleStruct(sf reflect.StructField, fvs []FieldValueTuple) (reflect.Value, bool) {
out := reflect.New(sf.Type).Elem()
if len(fvs) == 0 {
// no fields made it, just return out.
return out, true
}
fvsIdx := 0
allNil := true
for i := 0; i < sf.Type.NumField(); i++ {
oft := sf.Type.Field(i)
if oft.Name == fvs[fvsIdx].Field.Name {
out.Field(i).Set(fvs[fvsIdx].Value)
switch fvs[fvsIdx].Value.Kind() {
// check for nil-able types
case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Interface, reflect.Chan:
if !fvs[fvsIdx].Value.IsZero() {
allNil = false
}
default:
// non-nilable field, just assume it's non-nil
// pointerification shold have made this nilable, though.
allNil = false
}
fvsIdx++
}
}
return out, allNil
}

// Unmangle is called for every source-field->mangled-field
// mapping-set, with the mangled-field and its populated value set. The
// implementation of Unmangle should return a reflect.Value that will
// be used for the next mangler or final struct value)
// Returned reflect.Value should be convertible to the field's type.
func (a AnonymousFlattenMangler) Unmangle(sf reflect.StructField, fvs []FieldValueTuple) (reflect.Value, error) {
if !sf.Anonymous {
// not anonymous, just forward the single value
return fvs[0].Value, nil
}
switch sf.Type.Kind() {
case reflect.Pointer:
// It's a pointer. check for nil; strip off the pointer and recurse
msf := sf
msf.Type = sf.Type.Elem()
v, allNil := a.unmangleStruct(msf, fvs)
if allNil {
return reflect.Zero(sf.Type), nil
}
return v.Addr(), nil
case reflect.Struct:
out, _ := a.unmangleStruct(sf, fvs)
return out, nil
default:
// not a struct-typed anonymous field, just forward up the chain
return fvs[0].Value, nil
}
}

// ShouldRecurse is called after Mangle for each field so nested struct
// fields get iterated over after any transformation done by Mangle().
func (a AnonymousFlattenMangler) ShouldRecurse(_ reflect.StructField) bool {
return true
}
Loading

0 comments on commit 6d7d1f9

Please sign in to comment.