From 83fd92759c7b2f2ab770d01f36f7ad97509ee083 Mon Sep 17 00:00:00 2001 From: Liam Galvin Date: Wed, 24 Jul 2024 17:13:14 +0100 Subject: [PATCH 1/2] Fix panic when s3 URL is invalid Gracefully handle when S3 URLs have an unexpected number of path segments. Currently we expect `s3.amazonaws.com/bucket/path`, but something like `s3.amazonaws.com/bucket` will cause a panic, e.g. ``` panic: runtime error: index out of range [2] with length 2 github.com/hashicorp/go-getter.(*S3Getter).parseUrl(,) /go/pkg/mod/github.com/hashicorp/go-getter@v1.7.5/get_s3.go:272 github.com/hashicorp/go-getter.(*S3Getter).Get(, {,},) /go/pkg/mod/git... ``` --- get_s3.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/get_s3.go b/get_s3.go index 94291947c..346b98a0b 100644 --- a/get_s3.go +++ b/get_s3.go @@ -268,6 +268,10 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c region = "us-east-1" } pathParts := strings.SplitN(u.Path, "/", 3) + if len(pathParts) < 3 { + err = fmt.Errorf("URL is not a valid S3 URL") + return + } bucket = pathParts[1] path = pathParts[2] // vhost-style, dash region indication @@ -279,12 +283,20 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c return } pathParts := strings.SplitN(u.Path, "/", 2) + if len(pathParts) < 2 { + err = fmt.Errorf("URL is not a valid S3 URL") + return + } bucket = hostParts[0] path = pathParts[1] //vhost-style, dot region indication case 5: region = hostParts[2] pathParts := strings.SplitN(u.Path, "/", 2) + if len(pathParts) < 3 { + err = fmt.Errorf("URL is not a valid S3 URL") + return + } bucket = hostParts[0] path = pathParts[1] From 8339301726cc9144b0a1d8a54d32de993c6af03b Mon Sep 17 00:00:00 2001 From: Liam Galvin Date: Thu, 1 Aug 2024 20:14:17 +0100 Subject: [PATCH 2/2] add tests --- get_s3.go | 4 ++-- get_s3_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/get_s3.go b/get_s3.go index 346b98a0b..b478bde4e 100644 --- a/get_s3.go +++ b/get_s3.go @@ -276,7 +276,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c path = pathParts[2] // vhost-style, dash region indication case 4: - // Parse the region out of the first part of the host + // Parse the region out of the second part of the host region = strings.TrimPrefix(strings.TrimPrefix(hostParts[1], "s3-"), "s3") if region == "" { err = fmt.Errorf("URL is not a valid S3 URL") @@ -293,7 +293,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c case 5: region = hostParts[2] pathParts := strings.SplitN(u.Path, "/", 2) - if len(pathParts) < 3 { + if len(pathParts) < 2 { err = fmt.Errorf("URL is not a valid S3 URL") return } diff --git a/get_s3_test.go b/get_s3_test.go index 7b2425404..18187f149 100644 --- a/get_s3_test.go +++ b/get_s3_test.go @@ -165,12 +165,13 @@ func TestS3Getter_ClientMode_collision(t *testing.T) { func TestS3Getter_Url(t *testing.T) { var s3tests = []struct { - name string - url string - region string - bucket string - path string - version string + name string + url string + region string + bucket string + path string + version string + expectedErr string }{ { name: "AWSv1234", @@ -220,6 +221,11 @@ func TestS3Getter_Url(t *testing.T) { path: "hello.txt", version: "", }, + { + name: "malformed s3 url", + url: "s3::https://s3.amazonaws.com/bucket", + expectedErr: "URL is not a valid S3 URL", + }, } for i, pt := range s3tests { @@ -238,7 +244,15 @@ func TestS3Getter_Url(t *testing.T) { region, bucket, path, version, creds, err := g.parseUrl(u) if err != nil { - t.Fatalf("err: %s", err) + if pt.expectedErr == "" { + t.Fatalf("err: %s", err) + } + if err.Error() != pt.expectedErr { + t.Fatalf("expected %s, got %s", pt.expectedErr, err.Error()) + } + return + } else if pt.expectedErr != "" { + t.Fatalf("expected error, got none") } if region != pt.region { t.Fatalf("expected %s, got %s", pt.region, region) @@ -258,3 +272,40 @@ func TestS3Getter_Url(t *testing.T) { }) } } + +func Test_S3Getter_ParseUrl_Malformed(t *testing.T) { + tests := []struct { + name string + url string + }{ + { + name: "path style", + url: "https://s3.amazonaws.com/bucket", + }, + { + name: "vhost-style, dash region indication", + url: "https://bucket.s3-us-east-1.amazonaws.com", + }, + { + name: "vhost-style, dot region indication", + url: "https://bucket.s3.us-east-1.amazonaws.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := new(S3Getter) + u, err := url.Parse(tt.url) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + _, _, _, _, _, err = g.parseUrl(u) + if err == nil { + t.Fatalf("expected error, got none") + } + if err.Error() != "URL is not a valid S3 URL" { + t.Fatalf("expected error 'URL is not a valid S3 URL', got %s", err.Error()) + } + }) + } + +}