Skip to content

Commit 259b8f8

Browse files
authored
Merge pull request #844 from fluxcd/sc-hostname-overwrite
loader: allow overwrite of URL hostname again
2 parents 8a6e68b + 1e66201 commit 259b8f8

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

internal/loader/artifact_url.go

+29
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import (
2424
"fmt"
2525
"io"
2626
"net/http"
27+
"net/url"
28+
"os"
2729

2830
"github.com/hashicorp/go-retryablehttp"
2931
digestlib "github.com/opencontainers/go-digest"
@@ -32,6 +34,13 @@ import (
3234
"helm.sh/helm/v3/pkg/chart/loader"
3335
)
3436

37+
const (
38+
// envSourceControllerLocalhost is the name of the environment variable
39+
// used to override the hostname of the source-controller from which
40+
// the chart is usually downloaded.
41+
envSourceControllerLocalhost = "SOURCE_CONTROLLER_LOCALHOST"
42+
)
43+
3544
var (
3645
// ErrFileNotFound is an error type used to signal 404 HTTP status code responses.
3746
ErrFileNotFound = errors.New("file not found")
@@ -45,6 +54,11 @@ var (
4554
// digest before loading the chart. It returns the loaded chart.Chart, or an
4655
// error. The error may be of type ErrIntegrity if the integrity check fails.
4756
func SecureLoadChartFromURL(client *retryablehttp.Client, URL, digest string) (*chart.Chart, error) {
57+
URL, err := overwriteHostname(URL, os.Getenv(envSourceControllerLocalhost))
58+
if err != nil {
59+
return nil, err
60+
}
61+
4862
req, err := retryablehttp.NewRequest(http.MethodGet, URL, nil)
4963
if err != nil {
5064
return nil, err
@@ -94,3 +108,18 @@ func copyAndVerify(digest string, reader io.Reader, writer io.Writer) error {
94108
}
95109
return nil
96110
}
111+
112+
// overwriteHostname overwrites the hostname of the given URL with the given
113+
// hostname. If the hostname is empty, the URL is returned unmodified.
114+
func overwriteHostname(URL, hostname string) (string, error) {
115+
if hostname == "" {
116+
return URL, nil
117+
}
118+
119+
u, err := url.Parse(URL)
120+
if err != nil {
121+
return "", fmt.Errorf("failed to parse URL to overwrite hostname: %w", err)
122+
}
123+
u.Host = hostname
124+
return u.String(), nil
125+
}

internal/loader/artifact_url_test.go

+55
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"net/http"
2424
"net/http/httptest"
2525
"os"
26+
"strings"
2627
"testing"
2728

2829
"github.com/hashicorp/go-retryablehttp"
@@ -72,6 +73,19 @@ func TestSecureLoadChartFromURL(t *testing.T) {
7273
g.Expect(got.Metadata.Version).To(Equal("0.1.0"))
7374
})
7475

76+
t.Run("overwrites hostname", func(t *testing.T) {
77+
g := NewWithT(t)
78+
79+
t.Setenv(envSourceControllerLocalhost, strings.TrimPrefix(server.URL, "http://"))
80+
wrongHostnameURL := "http://invalid.com" + chartPath
81+
82+
got, err := SecureLoadChartFromURL(client, wrongHostnameURL, digest.String())
83+
g.Expect(err).ToNot(HaveOccurred())
84+
g.Expect(got).ToNot(BeNil())
85+
g.Expect(got.Name()).To(Equal("chart"))
86+
g.Expect(got.Metadata.Version).To(Equal("0.1.0"))
87+
})
88+
7589
t.Run("error on chart data digest mismatch", func(t *testing.T) {
7690
g := NewWithT(t)
7791

@@ -162,3 +176,44 @@ func Test_copyAndVerify(t *testing.T) {
162176
})
163177
}
164178
}
179+
180+
func Test_overwriteHostname(t *testing.T) {
181+
tests := []struct {
182+
name string
183+
URL string
184+
hostname string
185+
want string
186+
wantErr bool
187+
}{
188+
{
189+
name: "overwrite hostname",
190+
URL: "http://example.com",
191+
hostname: "localhost",
192+
want: "http://localhost",
193+
},
194+
{
195+
name: "overwrite hostname with port",
196+
URL: "http://example.com",
197+
hostname: "localhost:9090",
198+
want: "http://localhost:9090",
199+
},
200+
{
201+
name: "no hostname",
202+
URL: "http://example.com",
203+
hostname: "",
204+
want: "http://example.com",
205+
},
206+
}
207+
for _, tt := range tests {
208+
t.Run(tt.name, func(t *testing.T) {
209+
got, err := overwriteHostname(tt.URL, tt.hostname)
210+
if (err != nil) != tt.wantErr {
211+
t.Errorf("overwriteHostname() error = %v, wantErr %v", err, tt.wantErr)
212+
return
213+
}
214+
if got != tt.want {
215+
t.Errorf("overwriteHostname() got = %v, want %v", got, tt.want)
216+
}
217+
})
218+
}
219+
}

0 commit comments

Comments
 (0)