diff --git a/internal/handlers/azdo_api.go b/internal/handlers/azdo_api.go index bef305f..801e0a2 100644 --- a/internal/handlers/azdo_api.go +++ b/internal/handlers/azdo_api.go @@ -71,7 +71,7 @@ func (h *AzureDevOpsAPIHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr } logging.RequestLogf(ctx, "* authenticating azure devops api request with token for %s", host) - req.SetBasicAuth(creds[0].username, creds[0].password) + helpers.SetBasicAuthorization(req, creds[0].username, creds[0].password) // Azure DevOps requires an api-version to be set for requests. Add it if it is not present. var queryParams = req.URL.Query() diff --git a/internal/handlers/cargo_registry.go b/internal/handlers/cargo_registry.go index b875d84..2795464 100644 --- a/internal/handlers/cargo_registry.go +++ b/internal/handlers/cargo_registry.go @@ -101,7 +101,7 @@ func (h *CargoRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro } logging.RequestLogf(ctx, "* authenticating cargo registry request (url: %s)", cred.url) - req.Header.Set("Authorization", cred.authorization) + helpers.SetRawAuthorization(req, cred.authorization) return req, nil } diff --git a/internal/handlers/composer.go b/internal/handlers/composer.go index 2865fad..d2f5bd9 100644 --- a/internal/handlers/composer.go +++ b/internal/handlers/composer.go @@ -81,10 +81,10 @@ func (h *ComposerHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCtx if cred.token != "" { logging.RequestLogf(ctx, "* authenticating composer registry request (host: %s, token auth)", req.URL.Hostname()) - req.Header.Set("Authorization", "Bearer "+cred.token) + helpers.SetBearerAuthorization(req, cred.token) } else { logging.RequestLogf(ctx, "* authenticating composer registry request (host: %s, basic auth)", req.URL.Hostname()) - req.SetBasicAuth(cred.username, cred.password) + helpers.SetBasicAuthorization(req, cred.username, cred.password) } return req, nil diff --git a/internal/handlers/docker_registry.go b/internal/handlers/docker_registry.go index 4d6ce1a..aafe5f7 100644 --- a/internal/handlers/docker_registry.go +++ b/internal/handlers/docker_registry.go @@ -120,7 +120,7 @@ func (h *DockerRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr if cred.getECRCredentials(ctx) { logging.RequestLogf(ctx, "* authenticating docker ecr request (host: %s)", req.URL.Hostname()) - req.SetBasicAuth(cred.ecrUsername, cred.ecrPassword) + helpers.SetBasicAuthorization(req, cred.ecrUsername, cred.ecrPassword) } else { logging.RequestLogf(ctx, "* authenticating docker registry request (host: %s)", req.URL.Hostname()) transport := ®istry.BasicTransport{ diff --git a/internal/handlers/git_server.go b/internal/handlers/git_server.go index 784aeb9..cb3b50b 100644 --- a/internal/handlers/git_server.go +++ b/internal/handlers/git_server.go @@ -282,7 +282,7 @@ func (h *GitServerHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt logging.RequestLogf(ctx, "* authenticating git server request (host: %s)", helpers.GetHost(req)) credsToUse := creds[0] - req.SetBasicAuth(credsToUse.username, credsToUse.password) + helpers.SetBasicAuthorization(req, credsToUse.username, credsToUse.password) if ctx != nil { ctxdata.SetValue(ctx, addedAuthCtxKey, credsToUse) } @@ -472,7 +472,7 @@ func (h *GitServerHandler) requestWithAlternativeAuth(ctx *goproxy.ProxyCtx, bod newReq.Body = io.NopCloser(bytes.NewReader(body)) } - newReq.SetBasicAuth(creds.username, creds.password) + helpers.SetBasicAuthorization(newReq, creds.username, creds.password) newRsp, err := ctx.RoundTrip(newReq) if err != nil { return nil diff --git a/internal/handlers/goproxy_server_handler.go b/internal/handlers/goproxy_server_handler.go index 64e6eff..b083b37 100644 --- a/internal/handlers/goproxy_server_handler.go +++ b/internal/handlers/goproxy_server_handler.go @@ -77,7 +77,7 @@ func (h *GoProxyServerHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro } logging.RequestLogf(ctx, "* authenticating goproxy request (host: %s)", req.URL.Hostname()) - req.SetBasicAuth(cred.username, cred.password) + helpers.SetBasicAuthorization(req, cred.username, cred.password) return req, nil } diff --git a/internal/handlers/helm_registry.go b/internal/handlers/helm_registry.go index 981c6a6..9e36b73 100644 --- a/internal/handlers/helm_registry.go +++ b/internal/handlers/helm_registry.go @@ -74,7 +74,7 @@ func (h *HelmRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Prox } logging.RequestLogf(ctx, "* authenticating helm registry request (host: %s)", req.URL.Hostname()) - req.SetBasicAuth(cred.username, cred.password) + helpers.SetBasicAuthorization(req, cred.username, cred.password) return req, nil } diff --git a/internal/handlers/hex_organization.go b/internal/handlers/hex_organization.go index 74a8d18..495509e 100644 --- a/internal/handlers/hex_organization.go +++ b/internal/handlers/hex_organization.go @@ -69,7 +69,7 @@ func (h *HexOrganizationHandler) HandleRequest(req *http.Request, ctx *goproxy.P for _, cred := range h.credentials { if cred.organization == reqOrg { logging.RequestLogf(ctx, "* authenticating hex request (org: %s)", reqOrg) - req.Header.Set("authorization", cred.key) + helpers.SetRawAuthorization(req, cred.key) return req, nil } } diff --git a/internal/handlers/hex_repository.go b/internal/handlers/hex_repository.go index 94c0f80..0c3c79f 100644 --- a/internal/handlers/hex_repository.go +++ b/internal/handlers/hex_repository.go @@ -85,7 +85,7 @@ func (h *HexRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro } logging.RequestLogf(ctx, "* authenticating hex repository request (host: %s)", req.URL.Hostname()) - req.Header.Set("authorization", cred.authKey) + helpers.SetRawAuthorization(req, cred.authKey) return req, nil } diff --git a/internal/handlers/maven_repository.go b/internal/handlers/maven_repository.go index 40f59fe..fc12ebb 100644 --- a/internal/handlers/maven_repository.go +++ b/internal/handlers/maven_repository.go @@ -79,7 +79,7 @@ func (h *MavenRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.P } logging.RequestLogf(ctx, "* authenticating maven repository request (host: %s)", req.URL.Hostname()) - req.SetBasicAuth(cred.username, cred.password) + helpers.SetBasicAuthorization(req, cred.username, cred.password) return req, nil } diff --git a/internal/handlers/npm_registry.go b/internal/handlers/npm_registry.go index 04b04ff..02c7ad8 100644 --- a/internal/handlers/npm_registry.go +++ b/internal/handlers/npm_registry.go @@ -108,10 +108,10 @@ func (h *NPMRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy username, password, found := strings.Cut(cred.token, ":") if found { logging.RequestLogf(ctx, "* authenticating npm registry request (host: %s, basic auth)", reqHost) - req.SetBasicAuth(username, password) + helpers.SetBasicAuthorization(req, username, password) } else { logging.RequestLogf(ctx, "* authenticating npm registry request (host: %s, token auth)", reqHost) - req.Header.Set("authorization", "Bearer "+cred.token) + helpers.SetBearerAuthorization(req, cred.token) } return req, nil } diff --git a/internal/handlers/nuget_feed.go b/internal/handlers/nuget_feed.go index 4191dc0..e7ded8e 100644 --- a/internal/handlers/nuget_feed.go +++ b/internal/handlers/nuget_feed.go @@ -300,14 +300,14 @@ func authenticateNugetRequest(req *http.Request, cred nugetFeedCredentials, ctx username, password, found := strings.Cut(token, ":") if found { logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, basic auth)", req.URL.Hostname()) - req.SetBasicAuth(username, password) + helpers.SetBasicAuthorization(req, username, password) } else if token != "" { if shouldTreatTokenAsPassword(req.URL) { logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, basic auth for Azure DevOps)", req.URL.Hostname()) - req.SetBasicAuth("", token) + helpers.SetBasicAuthorization(req, "", token) } else { logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, bearer auth)", req.URL.Hostname()) - req.Header.Set("authorization", "Bearer "+token) + helpers.SetBearerAuthorization(req, token) } } } diff --git a/internal/handlers/pub_repository.go b/internal/handlers/pub_repository.go index a3bae5a..d51f02b 100644 --- a/internal/handlers/pub_repository.go +++ b/internal/handlers/pub_repository.go @@ -83,7 +83,7 @@ func (h *PubRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro } logging.RequestLogf(ctx, "* authenticating pub repository request (url: %s)", cred.url) - req.Header.Set("Authorization", "Bearer "+cred.token) + helpers.SetBearerAuthorization(req, cred.token) return req, nil } diff --git a/internal/handlers/python_index.go b/internal/handlers/python_index.go index ef69f9a..d7c0098 100644 --- a/internal/handlers/python_index.go +++ b/internal/handlers/python_index.go @@ -107,7 +107,7 @@ func (h *PythonIndexHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy } // ignore `found` because it's okay for the password to be an empty string username, password, _ := strings.Cut(token, ":") - req.SetBasicAuth(username, password) + helpers.SetBasicAuthorization(req, username, password) return req, nil } diff --git a/internal/handlers/rubygems_server.go b/internal/handlers/rubygems_server.go index cb9829a..8e4aa0d 100644 --- a/internal/handlers/rubygems_server.go +++ b/internal/handlers/rubygems_server.go @@ -80,7 +80,7 @@ func (h *RubyGemsServerHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr // ignore `found` because it's okay for the password to be an empty string username, password, _ := strings.Cut(cred.token, ":") - req.SetBasicAuth(username, password) + helpers.SetBasicAuthorization(req, username, password) return req, nil } diff --git a/internal/handlers/terraform_registry.go b/internal/handlers/terraform_registry.go index df6b09c..22b3c8d 100644 --- a/internal/handlers/terraform_registry.go +++ b/internal/handlers/terraform_registry.go @@ -88,7 +88,7 @@ func (h *TerraformRegistryHandler) HandleRequest(request *http.Request, context } logging.RequestLogf(context, "* authenticating terraform registry request (host: %s)", request.URL.Hostname()) - request.Header.Set("Authorization", "Bearer "+cred.token) + helpers.SetBearerAuthorization(request, cred.token) return request, nil } diff --git a/internal/helpers/helpers.go b/internal/helpers/helpers.go index cf529ae..0573af3 100644 --- a/internal/helpers/helpers.go +++ b/internal/helpers/helpers.go @@ -1,6 +1,7 @@ package helpers import ( + "encoding/base64" "io" "net/http" "net/url" @@ -10,6 +11,33 @@ import ( "golang.org/x/net/idna" ) +// ReplaceAuthorization replaces the authorization configured on req with the given key and value. +// +// Note: "Authorization"-header is always cleared to avoid multiple auth headers being set on the request. +func ReplaceAuthorization(req *http.Request, key string, value string) { + req.Header.Del("Authorization") + req.Header.Set(key, value) +} + +// SetRawAuthorization sets the authorization header on req to the given value +func SetRawAuthorization(req *http.Request, authorization string) { + ReplaceAuthorization(req, "Authorization", authorization) +} + +func SetBasicAuthorization(req *http.Request, username, password string) { + credentials := username + ":" + password + encoded := base64.StdEncoding.EncodeToString([]byte(credentials)) + SetRawAuthorization(req, "Basic "+encoded) +} + +func SetBearerAuthorization(req *http.Request, token string) { + SetRawAuthorization(req, "Bearer "+token) +} + +func SetGitHubAPITokenAuthorization(req *http.Request, token string) { + SetRawAuthorization(req, "token "+token) +} + func CheckGitHubAPIHost(r *http.Request) bool { hostname := GetHost(r) // Check if the hostname is a GitHub API hostname and will return true diff --git a/internal/helpers/helpers_test.go b/internal/helpers/helpers_test.go index 74d59e8..41b8178 100644 --- a/internal/helpers/helpers_test.go +++ b/internal/helpers/helpers_test.go @@ -1,11 +1,159 @@ package helpers import ( + "encoding/base64" "net/http" + "net/http/httptest" "net/url" "testing" ) +// newRequest builds a GET request to the given raw URL for use in tests. +func newRequest(t *testing.T, rawURL string) *http.Request { + t.Helper() + return httptest.NewRequest(http.MethodGet, rawURL, nil) +} + +// newRequestWithAuth builds a request that already carries an Authorization header, +// simulating a client that sent credentials which should be replaced. +func newRequestWithAuth(t *testing.T, rawURL, existing string) *http.Request { + t.Helper() + req := newRequest(t, rawURL) + req.Header.Set("Authorization", existing) + return req +} + +func TestSetBasicAuthorization(t *testing.T) { + t.Run("sets correct Basic header", func(t *testing.T) { + req := newRequest(t, "https://example.com") + SetBasicAuthorization(req, "user", "pass") + + want := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")) + if got := req.Header.Get("Authorization"); got != want { + t.Errorf("Authorization = %q, want %q", got, want) + } + }) + + t.Run("clears pre-existing Authorization header", func(t *testing.T) { + req := newRequestWithAuth(t, "https://example.com", "Bearer old-token") + SetBasicAuthorization(req, "user", "pass") + + want := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")) + if got := req.Header.Get("Authorization"); got != want { + t.Errorf("Authorization = %q, want %q", got, want) + } + if vals := req.Header["Authorization"]; len(vals) != 1 { + t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals) + } + }) + + t.Run("encodes empty username correctly", func(t *testing.T) { + req := newRequest(t, "https://example.com") + SetBasicAuthorization(req, "", "token") + + want := "Basic " + base64.StdEncoding.EncodeToString([]byte(":token")) + if got := req.Header.Get("Authorization"); got != want { + t.Errorf("Authorization = %q, want %q", got, want) + } + }) +} + +func TestSetBearerAuthorization(t *testing.T) { + t.Run("sets correct Bearer header", func(t *testing.T) { + req := newRequest(t, "https://example.com") + SetBearerAuthorization(req, "my-token") + + if got := req.Header.Get("Authorization"); got != "Bearer my-token" { + t.Errorf("Authorization = %q, want %q", got, "Bearer my-token") + } + }) + + t.Run("clears pre-existing Authorization header", func(t *testing.T) { + req := newRequestWithAuth(t, "https://example.com", "Basic dXNlcjpwYXNz") + SetBearerAuthorization(req, "new-token") + + if got := req.Header.Get("Authorization"); got != "Bearer new-token" { + t.Errorf("Authorization = %q, want %q", got, "Bearer new-token") + } + if vals := req.Header["Authorization"]; len(vals) != 1 { + t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals) + } + }) +} + +func TestSetGitHubAPITokenAuthorization(t *testing.T) { + t.Run("sets correct token header", func(t *testing.T) { + req := newRequest(t, "https://api.github.com") + SetGitHubAPITokenAuthorization(req, "ghp_abc123") + + if got := req.Header.Get("Authorization"); got != "token ghp_abc123" { + t.Errorf("Authorization = %q, want %q", got, "token ghp_abc123") + } + }) + + t.Run("clears pre-existing Authorization header", func(t *testing.T) { + req := newRequestWithAuth(t, "https://api.github.com", "token old-token") + SetGitHubAPITokenAuthorization(req, "new-token") + + if got := req.Header.Get("Authorization"); got != "token new-token" { + t.Errorf("Authorization = %q, want %q", got, "token new-token") + } + if vals := req.Header["Authorization"]; len(vals) != 1 { + t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals) + } + }) +} + +func TestSetRawAuthorization(t *testing.T) { + t.Run("sets pre-formatted value as-is", func(t *testing.T) { + req := newRequest(t, "https://example.com") + SetRawAuthorization(req, "Bearer already-formatted") + + if got := req.Header.Get("Authorization"); got != "Bearer already-formatted" { + t.Errorf("Authorization = %q, want %q", got, "Bearer already-formatted") + } + }) + + t.Run("clears pre-existing Authorization header", func(t *testing.T) { + req := newRequestWithAuth(t, "https://example.com", "Bearer stale") + SetRawAuthorization(req, "token new-raw") + + if got := req.Header.Get("Authorization"); got != "token new-raw" { + t.Errorf("Authorization = %q, want %q", got, "token new-raw") + } + if vals := req.Header["Authorization"]; len(vals) != 1 { + t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals) + } + }) +} + +func TestReplaceAuthorization_CustomKey(t *testing.T) { + t.Run("sets value on custom header key", func(t *testing.T) { + req := newRequest(t, "https://cloudsmith.example.com") + ReplaceAuthorization(req, "X-Api-Key", "my-api-key") + + if got := req.Header.Get("X-Api-Key"); got != "my-api-key" { + t.Errorf("X-Api-Key = %q, want %q", got, "my-api-key") + } + if got := req.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization should be empty, got %q", got) + } + }) + + t.Run("clears pre-existing Authorization header before setting custom key", func(t *testing.T) { + req := newRequest(t, "https://cloudsmith.example.com") + req.Header.Set("X-Api-Key", "old-key") + ReplaceAuthorization(req, "X-Api-Key", "new-key") + + if got := req.Header.Get("X-Api-Key"); got != "new-key" { + t.Errorf("X-Api-Key = %q, want %q", got, "new-key") + } + if vals := req.Header["X-Api-Key"]; len(vals) != 1 { + t.Errorf("expected exactly 1 X-Api-Key value, got %d: %v", len(vals), vals) + } + }) +} + func TestUrlMatchesRequest(t *testing.T) { tests := []struct { name string diff --git a/internal/oidc/oidc_registry.go b/internal/oidc/oidc_registry.go index ce8f702..6d72f96 100644 --- a/internal/oidc/oidc_registry.go +++ b/internal/oidc/oidc_registry.go @@ -1,7 +1,6 @@ package oidc import ( - "fmt" "net/http" "strings" "sync" @@ -143,18 +142,18 @@ func (r *OIDCRegistry) TryAuth(req *http.Request, ctx *goproxy.ProxyCtx) bool { switch matched.parameters.(type) { case *CloudsmithOIDCParameters: logging.RequestLogf(ctx, "* authenticating request with OIDC API key (host: %s)", host) - req.Header.Set("X-Api-Key", token) + helpers.ReplaceAuthorization(req, "X-Api-Key", token) case *GCPOIDCParameters: if strings.HasSuffix(host, "-docker.pkg.dev") { logging.RequestLogf(ctx, "* authenticating request with OIDC oauth2accesstoken (host: %s)", host) - req.SetBasicAuth("oauth2accesstoken", token) + helpers.SetBasicAuthorization(req, "oauth2accesstoken", token) } else { logging.RequestLogf(ctx, "* authenticating request with OIDC token (host: %s)", host) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + helpers.SetBearerAuthorization(req, token) } default: logging.RequestLogf(ctx, "* authenticating request with OIDC token (host: %s)", host) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + helpers.SetBearerAuthorization(req, token) } return true