Skip to content

Commit f5626a8

Browse files
committed
fixes context override
A customer reported flags being ignored even when they were being provided via CLI arguments. The behavior of overriding current context via CLI arguments was inconsistent across the codebase. This centralizes the logic so all the various commands use the same logic.
1 parent c03fcf6 commit f5626a8

File tree

9 files changed

+286
-78
lines changed

9 files changed

+286
-78
lines changed

internal/client/client.go

+100-16
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
"github.com/mitchellh/go-homedir"
1313
"github.com/rs/zerolog/log"
1414
"github.com/spf13/cobra"
15-
grpc "google.golang.org/grpc"
15+
"google.golang.org/grpc"
1616
"google.golang.org/grpc/credentials/insecure"
1717

1818
zgrpcutil "github.com/authzed/zed/internal/grpcutil"
@@ -28,20 +28,17 @@ type Client interface {
2828
}
2929

3030
// NewClient defines an (overridable) means of creating a new client.
31-
var NewClient = newGRPCClient
31+
var (
32+
NewClient = newClientForCurrentContext
33+
NewClientForContext = newClientForContext
34+
)
3235

33-
func newGRPCClient(cmd *cobra.Command) (Client, error) {
36+
func newClientForCurrentContext(cmd *cobra.Command) (Client, error) {
3437
configStore, secretStore := DefaultStorage()
35-
token, err := storage.DefaultToken(
36-
cobrautil.MustGetString(cmd, "endpoint"),
37-
cobrautil.MustGetString(cmd, "token"),
38-
configStore,
39-
secretStore,
40-
)
38+
token, err := GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
4139
if err != nil {
4240
return nil, err
4341
}
44-
log.Trace().Interface("token", token).Send()
4542

4643
dialOpts, err := DialOptsFromFlags(cmd, token)
4744
if err != nil {
@@ -56,28 +53,115 @@ func newGRPCClient(cmd *cobra.Command) (Client, error) {
5653
return client, err
5754
}
5855

56+
func newClientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
57+
currentToken, err := storage.GetToken(contextName, secretStore)
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
token, err := GetTokenWithCLIOverride(cmd, currentToken)
63+
if err != nil {
64+
return nil, err
65+
}
66+
67+
dialOpts, err := DialOptsFromFlags(cmd, token)
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
return authzed.NewClient(token.Endpoint, dialOpts...)
73+
}
74+
75+
// GetCurrentTokenWithCLIOverride returns the current token, but overridden by any parameter specified via CLI args
76+
func GetCurrentTokenWithCLIOverride(cmd *cobra.Command, configStore storage.ConfigStore, secretStore storage.SecretStore) (storage.Token, error) {
77+
token, err := storage.CurrentToken(
78+
configStore,
79+
secretStore,
80+
)
81+
if err != nil {
82+
return storage.Token{}, err
83+
}
84+
85+
return GetTokenWithCLIOverride(cmd, token)
86+
}
87+
88+
// GetTokenWithCLIOverride returns the provided token, but overridden by any parameter specified explicitly via command
89+
// flags
90+
func GetTokenWithCLIOverride(cmd *cobra.Command, token storage.Token) (storage.Token, error) {
91+
overrideToken, err := tokenFromCli(cmd)
92+
if err != nil {
93+
return storage.Token{}, err
94+
}
95+
96+
result, err := storage.TokenWithOverride(
97+
overrideToken,
98+
token,
99+
)
100+
if err != nil {
101+
return storage.Token{}, err
102+
}
103+
104+
log.Trace().Bool("context-override-via-cli", overrideToken.AnyValue()).Interface("context", result).Send()
105+
return result, nil
106+
}
107+
108+
func tokenFromCli(cmd *cobra.Command) (storage.Token, error) {
109+
certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path")
110+
var certBytes []byte
111+
var err error
112+
if certPath != "" {
113+
certBytes, err = os.ReadFile(certPath)
114+
if err != nil {
115+
return storage.Token{}, fmt.Errorf("failed to read ceritficate: %w", err)
116+
}
117+
}
118+
119+
explicitInsecure := cmd.Flags().Changed("insecure")
120+
var notSecure *bool
121+
if explicitInsecure {
122+
i := cobrautil.MustGetBool(cmd, "insecure")
123+
notSecure = &i
124+
}
125+
126+
explicitNoVerifyCA := cmd.Flags().Changed("no-verify-ca")
127+
var notVerifyCA *bool
128+
if explicitNoVerifyCA {
129+
nvc := cobrautil.MustGetBool(cmd, "no-verify-ca")
130+
notVerifyCA = &nvc
131+
}
132+
overrideToken := storage.Token{
133+
APIToken: cobrautil.MustGetString(cmd, "token"),
134+
Endpoint: cobrautil.MustGetString(cmd, "endpoint"),
135+
Insecure: notSecure,
136+
NoVerifyCA: notVerifyCA,
137+
CACert: certBytes,
138+
}
139+
return overrideToken, nil
140+
}
141+
59142
// DefaultStorage returns the default configured config store and secret store.
60143
func DefaultStorage() (storage.ConfigStore, storage.SecretStore) {
61144
var home string
62145
if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" {
63146
home = filepath.Join(xdg, "zed")
64147
} else {
65-
homedir, _ := homedir.Dir()
66-
home = filepath.Join(homedir, ".zed")
148+
hmdir, _ := homedir.Dir()
149+
home = filepath.Join(hmdir, ".zed")
67150
}
68151
return &storage.JSONConfigStore{ConfigPath: home},
69152
&storage.KeychainSecretStore{ConfigPath: home}
70153
}
71154

72-
func certOption(cmd *cobra.Command, token storage.Token) (opt grpc.DialOption, err error) {
155+
func certOption(token storage.Token) (opt grpc.DialOption, err error) {
73156
verification := grpcutil.VerifyCA
74-
if cobrautil.MustGetBool(cmd, "no-verify-ca") || token.HasNoVerifyCA() {
157+
if token.HasNoVerifyCA() {
75158
verification = grpcutil.SkipVerifyCA
76159
}
77160

78161
if certBytes, ok := token.Certificate(); ok {
79162
return grpcutil.WithCustomCertBytes(verification, certBytes)
80163
}
164+
81165
return grpcutil.WithSystemCerts(verification)
82166
}
83167

@@ -96,12 +180,12 @@ func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOpti
96180
grpc.WithChainStreamInterceptor(zgrpcutil.StreamLogDispatchTrailers),
97181
}
98182

99-
if cobrautil.MustGetBool(cmd, "insecure") || (token.IsInsecure()) {
183+
if token.IsInsecure() {
100184
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
101185
opts = append(opts, grpcutil.WithInsecureBearerToken(token.APIToken))
102186
} else {
103187
opts = append(opts, grpcutil.WithBearerToken(token.APIToken))
104-
certOpt, err := certOption(cmd, token)
188+
certOpt, err := certOption(token)
105189
if err != nil {
106190
return nil, fmt.Errorf("failed to configure TLS cert: %w", err)
107191
}

internal/client/client_test.go

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package client_test
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/authzed/zed/internal/client"
8+
"github.com/authzed/zed/internal/storage"
9+
zedtesting "github.com/authzed/zed/internal/testing"
10+
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestGetTokenWithCLIOverride(t *testing.T) {
15+
testCert, err := os.CreateTemp("", "")
16+
require.NoError(t, err)
17+
_, err = testCert.Write([]byte("hi"))
18+
require.NoError(t, err)
19+
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
20+
zedtesting.StringFlag{FlagName: "token", FlagValue: "t1", Changed: true},
21+
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: testCert.Name(), Changed: true},
22+
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "e1", Changed: true},
23+
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: true},
24+
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: true},
25+
)
26+
27+
bTrue := true
28+
bFalse := false
29+
30+
// cli args take precedence when defined
31+
to, err := client.GetTokenWithCLIOverride(cmd, storage.Token{})
32+
require.NoError(t, err)
33+
require.True(t, to.AnyValue())
34+
require.Equal(t, "t1", to.APIToken)
35+
require.Equal(t, "e1", to.Endpoint)
36+
require.Equal(t, []byte("hi"), to.CACert)
37+
require.Equal(t, &bTrue, to.Insecure)
38+
require.Equal(t, &bTrue, to.NoVerifyCA)
39+
40+
// storage token takes precedence when defined
41+
cmd = zedtesting.CreateTestCobraCommandWithFlagValue(t,
42+
zedtesting.StringFlag{FlagName: "token", FlagValue: "", Changed: false},
43+
zedtesting.StringFlag{FlagName: "certificate-path", FlagValue: "", Changed: false},
44+
zedtesting.StringFlag{FlagName: "endpoint", FlagValue: "", Changed: false},
45+
zedtesting.BoolFlag{FlagName: "insecure", FlagValue: true, Changed: false},
46+
zedtesting.BoolFlag{FlagName: "no-verify-ca", FlagValue: true, Changed: false},
47+
)
48+
to, err = client.GetTokenWithCLIOverride(cmd, storage.Token{
49+
APIToken: "t2",
50+
Endpoint: "e2",
51+
CACert: []byte("bye"),
52+
Insecure: &bFalse,
53+
NoVerifyCA: &bFalse,
54+
})
55+
require.NoError(t, err)
56+
require.True(t, to.AnyValue())
57+
require.Equal(t, "t2", to.APIToken)
58+
require.Equal(t, "e2", to.Endpoint)
59+
require.Equal(t, []byte("bye"), to.CACert)
60+
require.Equal(t, &bFalse, to.Insecure)
61+
require.Equal(t, &bFalse, to.NoVerifyCA)
62+
}

internal/cmd/schema.go

+3-19
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"strings"
1010

1111
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
12-
"github.com/authzed/authzed-go/v1"
1312
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
1413
"github.com/authzed/spicedb/pkg/schemadsl/generator"
1514
"github.com/authzed/spicedb/pkg/schemadsl/input"
@@ -23,7 +22,6 @@ import (
2322
"github.com/authzed/zed/internal/client"
2423
"github.com/authzed/zed/internal/commands"
2524
"github.com/authzed/zed/internal/console"
26-
"github.com/authzed/zed/internal/storage"
2725
)
2826

2927
func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
@@ -52,28 +50,14 @@ var schemaCopyCmd = &cobra.Command{
5250
RunE: schemaCopyCmdFunc,
5351
}
5452

55-
// TODO(jschorr): support this in the client package
56-
func clientForContext(cmd *cobra.Command, contextName string, secretStore storage.SecretStore) (*authzed.Client, error) {
57-
token, err := storage.GetToken(contextName, secretStore)
58-
if err != nil {
59-
return nil, err
60-
}
61-
log.Trace().Interface("token", token).Send()
62-
63-
dialOpts, err := client.DialOptsFromFlags(cmd, token)
64-
if err != nil {
65-
return nil, err
66-
}
67-
return authzed.NewClient(token.Endpoint, dialOpts...)
68-
}
69-
7053
func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
7154
_, secretStore := client.DefaultStorage()
72-
srcClient, err := clientForContext(cmd, args[0], secretStore)
55+
srcClient, err := client.NewClientForContext(cmd, args[0], secretStore)
7356
if err != nil {
7457
return err
7558
}
76-
destClient, err := clientForContext(cmd, args[1], secretStore)
59+
60+
destClient, err := client.NewClientForContext(cmd, args[1], secretStore)
7761
if err != nil {
7862
return err
7963
}

internal/cmd/version.go

+2-20
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@ import (
99
"github.com/gookit/color"
1010
"github.com/jzelinskie/cobrautil/v2"
1111
"github.com/mattn/go-isatty"
12-
"github.com/rs/zerolog/log"
1312
"github.com/spf13/cobra"
1413
"google.golang.org/grpc"
1514
"google.golang.org/grpc/metadata"
1615

1716
"github.com/authzed/zed/internal/client"
1817
"github.com/authzed/zed/internal/console"
19-
"github.com/authzed/zed/internal/storage"
2018
)
2119

2220
func versionCmdFunc(cmd *cobra.Command, _ []string) error {
@@ -26,14 +24,9 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {
2624

2725
includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version")
2826
hasContext := false
29-
configStore, secretStore := client.DefaultStorage()
3027
if includeRemoteVersion {
31-
_, err := storage.DefaultToken(
32-
cobrautil.MustGetString(cmd, "endpoint"),
33-
cobrautil.MustGetString(cmd, "token"),
34-
configStore,
35-
secretStore,
36-
)
28+
configStore, secretStore := client.DefaultStorage()
29+
_, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
3730
hasContext = err == nil
3831
}
3932

@@ -45,17 +38,6 @@ func versionCmdFunc(cmd *cobra.Command, _ []string) error {
4538
console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps")))
4639

4740
if hasContext && includeRemoteVersion {
48-
token, err := storage.DefaultToken(
49-
cobrautil.MustGetString(cmd, "endpoint"),
50-
cobrautil.MustGetString(cmd, "token"),
51-
configStore,
52-
secretStore,
53-
)
54-
if err != nil {
55-
return err
56-
}
57-
log.Trace().Interface("token", token).Send()
58-
5941
client, err := client.NewClient(cmd)
6042
if err != nil {
6143
return err

0 commit comments

Comments
 (0)