diff --git a/callbacks.go b/callbacks.go index 2d6507a2..c84f215c 100644 --- a/callbacks.go +++ b/callbacks.go @@ -6,7 +6,10 @@ import ( "strings" ) -type bindings map[reflect.Type]func() (reflect.Value, error) +// A map of type to function that returns a value of that type. +// +// The function should have the signature func(...) (T, error). Arguments are recursively resolved. +type bindings map[reflect.Type]any func (b bindings) String() string { out := []string{} @@ -19,32 +22,23 @@ func (b bindings) String() string { func (b bindings) add(values ...interface{}) bindings { for _, v := range values { v := v - b[reflect.TypeOf(v)] = func() (reflect.Value, error) { return reflect.ValueOf(v), nil } + b[reflect.TypeOf(v)] = func() (any, error) { return v, nil } } return b } func (b bindings) addTo(impl, iface interface{}) { - valueOf := reflect.ValueOf(impl) - b[reflect.TypeOf(iface).Elem()] = func() (reflect.Value, error) { return valueOf, nil } + b[reflect.TypeOf(iface).Elem()] = func() (any, error) { return impl, nil } } func (b bindings) addProvider(provider interface{}) error { pv := reflect.ValueOf(provider) t := pv.Type() - if t.Kind() != reflect.Func || t.NumIn() != 0 || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() { - return fmt.Errorf("%T must be a function with the signature func()(T, error)", provider) + if t.Kind() != reflect.Func || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() { + return fmt.Errorf("%T must be a function with the signature func(...)(T, error)", provider) } rt := pv.Type().Out(0) - b[rt] = func() (reflect.Value, error) { - out := pv.Call(nil) - errv := out[1] - var err error - if !errv.IsNil() { - err = errv.Interface().(error) //nolint - } - return out[0], err - } + b[rt] = provider return nil } @@ -101,15 +95,19 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error) t := f.Type() for i := 0; i < t.NumIn(); i++ { pt := t.In(i) - if argf, ok := bindings[pt]; ok { - argv, err := argf() - if err != nil { - return nil, err - } - in = append(in, argv) - } else { + argf, ok := bindings[pt] + if !ok { return nil, fmt.Errorf("couldn't find binding of type %s for parameter %d of %s(), use kong.Bind(%s)", pt, i, t, pt) } + // Recursively resolve binding functions. + argv, err := callAnyFunction(reflect.ValueOf(argf), bindings) + if err != nil { + return nil, fmt.Errorf("%s: %w", pt, err) + } + if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && !ferrv.IsNil() { + return nil, ferrv.Interface().(error) //nolint:forcetypeassert + } + in = append(in, reflect.ValueOf(argv[0])) } outv := f.Call(in) out = make([]any, len(outv)) diff --git a/options.go b/options.go index 3ab383b5..3bc991b9 100644 --- a/options.go +++ b/options.go @@ -208,7 +208,11 @@ func BindTo(impl, iface interface{}) Option { }) } -// BindToProvider allows binding of provider functions. +// BindToProvider binds an injected value to a provider function. +// +// The provider function must have the signature: +// +// func() (interface{}, error) // // This is useful when the Run() function of different commands require different values that may // not all be initialisable from the main() function. diff --git a/options_test.go b/options_test.go index ca41720a..f6971f8e 100644 --- a/options_test.go +++ b/options_test.go @@ -89,11 +89,13 @@ func TestCallbackCustomError(t *testing.T) { } type bindToProviderCLI struct { + Filled bool `default:"true"` Called bool Cmd bindToProviderCmd `cmd:""` } type boundThing struct { + Filled bool } type bindToProviderCmd struct{} @@ -105,7 +107,10 @@ func (*bindToProviderCmd) Run(cli *bindToProviderCLI, b *boundThing) error { func TestBindToProvider(t *testing.T) { var cli bindToProviderCLI - app, err := New(&cli, BindToProvider(func() (*boundThing, error) { return &boundThing{}, nil })) + app, err := New(&cli, BindToProvider(func(cli *bindToProviderCLI) (*boundThing, error) { + assert.True(t, cli.Filled, "CLI struct should have already been populated by Kong") + return &boundThing{Filled: cli.Filled}, nil + })) assert.NoError(t, err) ctx, err := app.Parse([]string{"cmd"}) assert.NoError(t, err)