Skip to content

Commit

Permalink
Instrument GitHub source with a UnitReporter (#3284)
Browse files Browse the repository at this point in the history
* Fix GitHub integration test

* Instrument GitHub source with a UnitReporter

The reporter is currently unused, but is the first step to support
scanning while enumerating.

* Update GitHub unit tests
  • Loading branch information
mcastorina authored Sep 12, 2024
1 parent 0cb8723 commit e89190f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 47 deletions.
81 changes: 59 additions & 22 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,32 @@ type Source struct {
sources.CommonSourceUnitUnmarshaller
}

// --------------------------------------------------------------------------------
// RepoUnit and GistUnit are implementations of SourceUnit used during
// enumeration. The different types aren't strictly necessary, but are a bit
// more explicit and allow type checking/safety.

var _ sources.SourceUnit = (*RepoUnit)(nil)
var _ sources.SourceUnit = (*GistUnit)(nil)

type RepoUnit struct {
name string
url string
}

func (r RepoUnit) SourceUnitID() (string, sources.SourceUnitKind) { return r.url, "repo" }
func (r RepoUnit) Display() string { return r.name }

type GistUnit struct {
name string
url string
}

func (g GistUnit) SourceUnitID() (string, sources.SourceUnitKind) { return g.url, "gist" }
func (g GistUnit) Display() string { return g.name }

// --------------------------------------------------------------------------------

// WithCustomContentWriter sets the useCustomContentWriter flag on the source.
func (s *Source) WithCustomContentWriter() { s.useCustomContentWriter = true }

Expand Down Expand Up @@ -313,25 +339,30 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, tar
}

func (s *Source) enumerate(ctx context.Context) error {
// Create a reporter that does nothing for now.
noopReporter := sources.VisitorReporter{
VisitUnit: func(ctx context.Context, su sources.SourceUnit) error {
return nil
},
}
// I'm not wild about switching on the connector type here (as opposed to dispatching to the connector itself) but
// this felt like a compromise that allowed me to isolate connection logic without rewriting the entire source.
switch c := s.connector.(type) {
case *appConnector:
if err := s.enumerateWithApp(ctx, c.InstallationClient()); err != nil {
if err := s.enumerateWithApp(ctx, c.InstallationClient(), noopReporter); err != nil {
return err
}
case *basicAuthConnector:
if err := s.enumerateBasicAuth(ctx); err != nil {
if err := s.enumerateBasicAuth(ctx, noopReporter); err != nil {
return err
}
case *tokenConnector:
if err := s.enumerateWithToken(ctx, c.IsGithubEnterprise()); err != nil {
if err := s.enumerateWithToken(ctx, c.IsGithubEnterprise(), noopReporter); err != nil {
return err
}
case *unauthenticatedConnector:
s.enumerateUnauthenticated(ctx)
s.enumerateUnauthenticated(ctx, noopReporter)
}

s.repos = make([]string, 0, s.filteredRepoCache.Count())

RepoLoop:
Expand Down Expand Up @@ -393,15 +424,17 @@ RepoLoop:
return nil
}

func (s *Source) enumerateBasicAuth(ctx context.Context) error {
func (s *Source) enumerateBasicAuth(ctx context.Context, reporter sources.UnitReporter) error {
for _, org := range s.orgsCache.Keys() {
orgCtx := context.WithValue(ctx, "account", org)
userType, err := s.getReposByOrgOrUser(ctx, org)
userType, err := s.getReposByOrgOrUser(ctx, org, reporter)
if err != nil {
orgCtx.Logger().Error(err, "error fetching repos for org or user")
continue
}

// TODO: This modifies s.memberCache but it doesn't look like
// we do anything with it.
if userType == organization && s.conn.ScanUsers {
if err := s.addMembersByOrg(ctx, org); err != nil {
orgCtx.Logger().Error(err, "Unable to add members by org")
Expand All @@ -412,14 +445,14 @@ func (s *Source) enumerateBasicAuth(ctx context.Context) error {
return nil
}

func (s *Source) enumerateUnauthenticated(ctx context.Context) {
func (s *Source) enumerateUnauthenticated(ctx context.Context, reporter sources.UnitReporter) {
if s.orgsCache.Count() > unauthGithubOrgRateLimt {
ctx.Logger().Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
}

for _, org := range s.orgsCache.Keys() {
orgCtx := context.WithValue(ctx, "account", org)
userType, err := s.getReposByOrgOrUser(ctx, org)
userType, err := s.getReposByOrgOrUser(ctx, org, reporter)
if err != nil {
orgCtx.Logger().Error(err, "error fetching repos for org or user")
continue
Expand All @@ -431,7 +464,7 @@ func (s *Source) enumerateUnauthenticated(ctx context.Context) {
}
}

func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool) error {
func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool, reporter sources.UnitReporter) error {
ctx.Logger().V(1).Info("Enumerating with token")

var ghUser *github.User
Expand All @@ -450,10 +483,10 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
specificScope := len(s.repos) > 0 || s.orgsCache.Count() > 0
if !specificScope {
// Enumerate the user's orgs and repos if none were specified.
if err := s.getReposByUser(ctx, ghUser.GetLogin()); err != nil {
if err := s.getReposByUser(ctx, ghUser.GetLogin(), reporter); err != nil {
ctx.Logger().Error(err, "Unable to fetch repos for the current user", "user", ghUser.GetLogin())
}
if err := s.addUserGistsToCache(ctx, ghUser.GetLogin()); err != nil {
if err := s.addUserGistsToCache(ctx, ghUser.GetLogin(), reporter); err != nil {
ctx.Logger().Error(err, "Unable to fetch gists for the current user", "user", ghUser.GetLogin())
}

Expand All @@ -469,7 +502,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool
if len(s.orgsCache.Keys()) > 0 {
for _, org := range s.orgsCache.Keys() {
orgCtx := context.WithValue(ctx, "account", org)
userType, err := s.getReposByOrgOrUser(ctx, org)
userType, err := s.getReposByOrgOrUser(ctx, org, reporter)
if err != nil {
orgCtx.Logger().Error(err, "Unable to fetch repos for org or user")
continue
Expand All @@ -484,17 +517,17 @@ func (s *Source) enumerateWithToken(ctx context.Context, isGithubEnterprise bool

if s.conn.ScanUsers && len(s.memberCache) > 0 {
ctx.Logger().Info("Fetching repos for org members", "org_count", s.orgsCache.Count(), "member_count", len(s.memberCache))
s.addReposForMembers(ctx)
s.addReposForMembers(ctx, reporter)
}
}

return nil
}

func (s *Source) enumerateWithApp(ctx context.Context, installationClient *github.Client) error {
func (s *Source) enumerateWithApp(ctx context.Context, installationClient *github.Client, reporter sources.UnitReporter) error {
// If no repos were provided, enumerate them.
if len(s.repos) == 0 {
if err := s.getReposByApp(ctx); err != nil {
if err := s.getReposByApp(ctx, reporter); err != nil {
return err
}

Expand All @@ -505,12 +538,13 @@ func (s *Source) enumerateWithApp(ctx context.Context, installationClient *githu
return err
}
ctx.Logger().Info("Scanning repos", "org_members", len(s.memberCache))
// TODO: Replace loop below with a call to s.addReposForMembers(ctx, reporter)
for member := range s.memberCache {
logger := ctx.Logger().WithValues("member", member)
if err := s.addUserGistsToCache(ctx, member); err != nil {
if err := s.addUserGistsToCache(ctx, member, reporter); err != nil {
logger.Error(err, "error fetching gists by user")
}
if err := s.getReposByUser(ctx, member); err != nil {
if err := s.getReposByUser(ctx, member, reporter); err != nil {
logger.Error(err, "error fetching repos by user")
}
}
Expand Down Expand Up @@ -721,21 +755,21 @@ func (s *Source) handleRateLimit(ctx context.Context, errIn error) bool {
return true
}

func (s *Source) addReposForMembers(ctx context.Context) {
func (s *Source) addReposForMembers(ctx context.Context, reporter sources.UnitReporter) {
ctx.Logger().Info("Fetching repos from members", "members", len(s.memberCache))
for member := range s.memberCache {
if err := s.addUserGistsToCache(ctx, member); err != nil {
if err := s.addUserGistsToCache(ctx, member, reporter); err != nil {
ctx.Logger().Info("Unable to fetch gists by user", "user", member, "error", err)
}
if err := s.getReposByUser(ctx, member); err != nil {
if err := s.getReposByUser(ctx, member, reporter); err != nil {
ctx.Logger().Info("Unable to fetch repos by user", "user", member, "error", err)
}
}
}

// addUserGistsToCache collects all the gist urls for a given user,
// and adds them to the filteredRepoCache.
func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
func (s *Source) addUserGistsToCache(ctx context.Context, user string, reporter sources.UnitReporter) error {
gistOpts := &github.GistListOptions{}
logger := ctx.Logger().WithValues("user", user)

Expand All @@ -751,6 +785,9 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {
for _, gist := range gists {
s.filteredRepoCache.Set(gist.GetID(), gist.GetGitPullURL())
s.cacheGistInfo(gist)
if err := reporter.UnitOk(ctx, GistUnit{name: gist.GetID(), url: gist.GetGitPullURL()}); err != nil {
return err
}
}

if res == nil || res.NextPage == 0 {
Expand Down
4 changes: 1 addition & 3 deletions pkg/sources/github/github_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"testing"
"time"

"github.com/go-logr/logr"
"github.com/kylelemons/godebug/pretty"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
Expand Down Expand Up @@ -58,12 +57,11 @@ func TestSource_Token(t *testing.T) {

s := Source{
conn: src,
log: logr.Discard(),
memberCache: map[string]struct{}{},
repoInfoCache: newRepoInfoCache(),
}
s.Init(ctx, "github integration test source", 0, 0, false, conn, 1)
s.filteredRepoCache = s.newFilteredRepoCache(memory.New[string](), nil, nil)
s.filteredRepoCache = s.newFilteredRepoCache(ctx, memory.New[string](), nil, nil)

err = s.enumerateWithApp(ctx, s.connector.(*appConnector).InstallationClient())
assert.NoError(t, err)
Expand Down
28 changes: 18 additions & 10 deletions pkg/sources/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestAddReposByOrg(t *testing.T) {
Repositories: nil,
IgnoreRepos: []string{"secret/super-*-repo2"},
})
err := s.getReposByOrg(context.Background(), "super-secret-org")
err := s.getReposByOrg(context.Background(), "super-secret-org", noopReporter())
assert.Nil(t, err)
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-repo")
Expand Down Expand Up @@ -127,7 +127,7 @@ func TestAddReposByOrg_IncludeRepos(t *testing.T) {
IncludeRepos: []string{"super-secret-org/super*"},
Organizations: []string{"super-secret-org"},
})
err := s.getReposByOrg(context.Background(), "super-secret-org")
err := s.getReposByOrg(context.Background(), "super-secret-org", noopReporter())
assert.Nil(t, err)
assert.Equal(t, 2, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-org/super-secret-repo")
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestAddReposByUser(t *testing.T) {
},
IgnoreRepos: []string{"super-secret-user/super-secret-repo2"},
})
err := s.getReposByUser(context.Background(), "super-secret-user")
err := s.getReposByUser(context.Background(), "super-secret-user", noopReporter())
assert.Nil(t, err)
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-user/super-secret-repo")
Expand All @@ -173,7 +173,7 @@ func TestAddGistsByUser(t *testing.T) {
JSON([]map[string]string{{"id": "aa5a315d61ae9438b18d", "git_pull_url": "https://gist.github.com/aa5a315d61ae9438b18d.git"}})

s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
err := s.addUserGistsToCache(context.Background(), "super-secret-user")
err := s.addUserGistsToCache(context.Background(), "super-secret-user", noopReporter())
assert.Nil(t, err)
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("aa5a315d61ae9438b18d")
Expand Down Expand Up @@ -265,7 +265,7 @@ func TestAddReposByApp(t *testing.T) {
})

s := initTestSource(&sourcespb.GitHub{Credential: &sourcespb.GitHub_Unauthenticated{}})
err := s.getReposByApp(context.Background())
err := s.getReposByApp(context.Background(), noopReporter())
assert.Nil(t, err)
assert.Equal(t, 2, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("ssr1")
Expand Down Expand Up @@ -419,7 +419,7 @@ func TestEnumerateUnauthenticated(t *testing.T) {
s.orgsCache = memory.New[string]()
s.orgsCache.Set("super-secret-org", "super-secret-org")
//s.enumerateUnauthenticated(context.Background(), apiEndpoint)
s.enumerateUnauthenticated(context.Background())
s.enumerateUnauthenticated(context.Background(), noopReporter())
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-org/super-secret-repo")
assert.True(t, ok)
Expand Down Expand Up @@ -458,7 +458,7 @@ func TestEnumerateWithToken(t *testing.T) {
Token: "token",
},
})
err := s.enumerateWithToken(context.Background(), false)
err := s.enumerateWithToken(context.Background(), false, noopReporter())
assert.Nil(t, err)
assert.Equal(t, 2, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-user/super-secret-repo")
Expand Down Expand Up @@ -502,7 +502,7 @@ func BenchmarkEnumerateWithToken(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = s.enumerateWithToken(context.Background(), false)
_ = s.enumerateWithToken(context.Background(), false, noopReporter())
}
}

Expand Down Expand Up @@ -660,7 +660,7 @@ func TestEnumerateWithToken_IncludeRepos(t *testing.T) {
})
s.repos = []string{"some-special-repo"}

err := s.enumerateWithToken(context.Background(), false)
err := s.enumerateWithToken(context.Background(), false, noopReporter())
assert.Nil(t, err)
assert.Equal(t, 1, len(s.repos))
assert.Equal(t, []string{"some-special-repo"}, s.repos)
Expand Down Expand Up @@ -693,7 +693,7 @@ func TestEnumerateWithApp(t *testing.T) {
},
},
})
err := s.enumerateWithApp(context.Background(), s.connector.(*appConnector).InstallationClient())
err := s.enumerateWithApp(context.Background(), s.connector.(*appConnector).InstallationClient(), noopReporter())
assert.Nil(t, err)
assert.Equal(t, 0, len(s.repos))
assert.False(t, gock.HasUnmatchedRequest())
Expand Down Expand Up @@ -908,3 +908,11 @@ func Test_ScanMultipleTargets_MultipleErrors(t *testing.T) {
assert.ElementsMatch(t, got, want)
}
}

func noopReporter() sources.UnitReporter {
return sources.VisitorReporter{
VisitUnit: func(context.Context, sources.SourceUnit) error {
return nil
},
}
}
Loading

0 comments on commit e89190f

Please sign in to comment.