Skip to content

Commit

Permalink
Move readEnrollSecret into knapsack
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany committed Feb 15, 2024
1 parent 404cf34 commit 8faa706
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 30 deletions.
20 changes: 20 additions & 0 deletions ee/agent/knapsack/knapsack.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package knapsack

import (
"bytes"
"context"
"errors"
"fmt"
"os"

"log/slog"

Expand Down Expand Up @@ -144,3 +148,19 @@ func (k *knapsack) LatestOsquerydPath(ctx context.Context) string {

return latestBin.Path
}

func (k *knapsack) ReadEnrollSecret() (string, error) {
if k.EnrollSecret() != "" {
return k.EnrollSecret(), nil
}

if k.EnrollSecretPath() != "" {
content, err := os.ReadFile(k.EnrollSecretPath())
if err != nil {
return "", fmt.Errorf("could not read enroll secret path %s: %w", k.EnrollSecretPath(), err)
}
return string(bytes.TrimSpace(content)), nil
}

return "", errors.New("enroll secret not set")
}
2 changes: 2 additions & 0 deletions ee/agent/types/knapsack.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ type Knapsack interface {
Slogger
// LatestOsquerydPath finds the path to the latest osqueryd binary, after accounting for updates.
LatestOsquerydPath(ctx context.Context) string
// ReadEnrollSecret returns the enroll secret value, checking in various locations.
ReadEnrollSecret() (string, error)
}
24 changes: 24 additions & 0 deletions ee/agent/types/mocks/knapsack.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 1 addition & 21 deletions pkg/osquery/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ func (e *Extension) Enroll(ctx context.Context) (string, bool, error) {
)
span.AddEvent("starting_enrollment")

enrollSecret, err := e.readEnrollSecret(ctx)
enrollSecret, err := e.knapsack.ReadEnrollSecret()
if err != nil {
return "", true, fmt.Errorf("could not read enroll secret: %w", err)
}
Expand Down Expand Up @@ -483,26 +483,6 @@ func (e *Extension) enrolled() bool {
return e.NodeKey != ""
}

// readEnrollSecret checks knapsack's flags to find the correct enroll secret location.
func (e *Extension) readEnrollSecret(ctx context.Context) (string, error) {
_, span := traces.StartSpan(ctx)
defer span.End()

if e.knapsack.EnrollSecret() != "" {
return e.knapsack.EnrollSecret(), nil
}

if e.knapsack.EnrollSecretPath() != "" {
content, err := os.ReadFile(e.knapsack.EnrollSecretPath())
if err != nil {
return "", fmt.Errorf("could not read enroll secret path %s: %w", e.knapsack.EnrollSecretPath(), err)
}
return string(bytes.TrimSpace(content)), nil
}

return "", errors.New("enroll secret not set")
}

// RequireReenroll clears the existing node key information, ensuring that the
// next call to Enroll will cause the enrollment process to take place.
func (e *Extension) RequireReenroll(ctx context.Context) {
Expand Down
17 changes: 8 additions & 9 deletions pkg/osquery/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func makeKnapsack(t *testing.T, db *bbolt.DB) types.Knapsack {
m.On("LatestOsquerydPath", testifymock.Anything).Maybe().Return("")
m.On("ConfigStore").Return(storageci.NewStore(t, log.NewNopLogger(), storage.ConfigStore.String()))
m.On("Slogger").Return(multislogger.New().Logger)
m.On("EnrollSecret").Maybe().Return("enroll_secret")
m.On("ReadEnrollSecret").Maybe().Return("enroll_secret", nil)
return m
}

Expand All @@ -68,7 +68,7 @@ func TestNewExtensionEmptyEnrollSecret(t *testing.T) {
m.On("LatestOsquerydPath", testifymock.Anything).Maybe().Return("")
m.On("ConfigStore").Return(storageci.NewStore(t, log.NewNopLogger(), storage.ConfigStore.String()))
m.On("Slogger").Return(multislogger.New().Logger)
m.On("EnrollSecret").Maybe().Return("")
m.On("ReadEnrollSecret").Maybe().Return("", errors.New("test"))

// We should be able to make an extension despite an empty enroll secret
e, err := NewExtension(context.TODO(), &mock.KolideService{}, m, ExtensionOpts{})
Expand Down Expand Up @@ -217,7 +217,7 @@ func TestExtensionEnroll(t *testing.T) {
k.On("ConfigStore").Return(storageci.NewStore(t, log.NewNopLogger(), storage.ConfigStore.String()))
k.On("Slogger").Return(multislogger.New().Logger)
expectedEnrollSecret := "foo_secret"
k.On("EnrollSecret").Maybe().Return(expectedEnrollSecret)
k.On("ReadEnrollSecret").Maybe().Return(expectedEnrollSecret, nil)

e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{})
require.Nil(t, err)
Expand Down Expand Up @@ -355,8 +355,7 @@ func TestGenerateConfigs_CannotEnrollYet(t *testing.T) {
k.On("LatestOsquerydPath", testifymock.Anything).Maybe().Return("")
k.On("ConfigStore").Return(storageci.NewStore(t, log.NewNopLogger(), storage.ConfigStore.String()))
k.On("Slogger").Return(multislogger.New().Logger)
k.On("EnrollSecret").Return("")
k.On("EnrollSecretPath").Return("")
k.On("ReadEnrollSecret").Maybe().Return("", errors.New("test"))

e, err := NewExtension(context.TODO(), s, k, ExtensionOpts{})
require.Nil(t, err)
Expand Down Expand Up @@ -533,7 +532,7 @@ func TestExtensionWriteBufferedLogsEmpty(t *testing.T) {
k.On("ConfigStore").Return(storageci.NewStore(t, log.NewNopLogger(), storage.ConfigStore.String()))
k.On("BboltDB").Return(db)
k.On("Slogger").Return(multislogger.New().Logger).Maybe()
k.On("EnrollSecret").Maybe().Return("enroll_secret")
k.On("ReadEnrollSecret").Maybe().Return("enroll_secret", nil)

e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{})
require.Nil(t, err)
Expand Down Expand Up @@ -572,7 +571,7 @@ func TestExtensionWriteBufferedLogs(t *testing.T) {
k.On("ConfigStore").Return(storageci.NewStore(t, log.NewNopLogger(), storage.ConfigStore.String()))
k.On("BboltDB").Return(db)
k.On("Slogger").Return(multislogger.New().Logger).Maybe()
k.On("EnrollSecret").Maybe().Return("enroll_secret")
k.On("ReadEnrollSecret").Maybe().Return("enroll_secret", nil)

e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{})
require.Nil(t, err)
Expand Down Expand Up @@ -642,7 +641,7 @@ func TestExtensionWriteBufferedLogsEnrollmentInvalid(t *testing.T) {
k.On("OsquerydPath").Maybe().Return("")
k.On("LatestOsquerydPath", testifymock.Anything).Maybe().Return("")
k.On("Slogger").Return(multislogger.New().Logger)
k.On("EnrollSecret").Maybe().Return("enroll_secret")
k.On("ReadEnrollSecret").Maybe().Return("enroll_secret", nil)

e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{})
require.Nil(t, err)
Expand Down Expand Up @@ -1039,7 +1038,7 @@ func TestExtensionGetQueriesEnrollmentInvalid(t *testing.T) {
k.On("OsquerydPath").Maybe().Return("")
k.On("LatestOsquerydPath", testifymock.Anything).Maybe().Return("")
k.On("Slogger").Return(multislogger.New().Logger)
k.On("EnrollSecret").Return("enroll_secret")
k.On("ReadEnrollSecret").Return("enroll_secret", nil)

e, err := NewExtension(context.TODO(), m, k, ExtensionOpts{})
require.Nil(t, err)
Expand Down

0 comments on commit 8faa706

Please sign in to comment.