Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/handlers/azdo_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (h *AzureDevOpsAPIHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr
}

logging.RequestLogf(ctx, "* authenticating azure devops api request with token for %s", host)
req.SetBasicAuth(creds[0].username, creds[0].password)
helpers.SetBasicAuthorization(req, creds[0].username, creds[0].password)

// Azure DevOps requires an api-version to be set for requests. Add it if it is not present.
var queryParams = req.URL.Query()
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/cargo_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (h *CargoRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro
}

logging.RequestLogf(ctx, "* authenticating cargo registry request (url: %s)", cred.url)
req.Header.Set("Authorization", cred.authorization)
helpers.SetRawAuthorization(req, cred.authorization)

return req, nil
}
Expand Down
4 changes: 2 additions & 2 deletions internal/handlers/composer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ func (h *ComposerHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCtx

if cred.token != "" {
logging.RequestLogf(ctx, "* authenticating composer registry request (host: %s, token auth)", req.URL.Hostname())
req.Header.Set("Authorization", "Bearer "+cred.token)
helpers.SetBearerAuthorization(req, cred.token)
} else {
logging.RequestLogf(ctx, "* authenticating composer registry request (host: %s, basic auth)", req.URL.Hostname())
req.SetBasicAuth(cred.username, cred.password)
helpers.SetBasicAuthorization(req, cred.username, cred.password)
}

return req, nil
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/docker_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (h *DockerRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr

if cred.getECRCredentials(ctx) {
logging.RequestLogf(ctx, "* authenticating docker ecr request (host: %s)", req.URL.Hostname())
req.SetBasicAuth(cred.ecrUsername, cred.ecrPassword)
helpers.SetBasicAuthorization(req, cred.ecrUsername, cred.ecrPassword)
} else {
logging.RequestLogf(ctx, "* authenticating docker registry request (host: %s)", req.URL.Hostname())
transport := &registry.BasicTransport{
Expand Down
4 changes: 2 additions & 2 deletions internal/handlers/git_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (h *GitServerHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt

logging.RequestLogf(ctx, "* authenticating git server request (host: %s)", helpers.GetHost(req))
credsToUse := creds[0]
req.SetBasicAuth(credsToUse.username, credsToUse.password)
helpers.SetBasicAuthorization(req, credsToUse.username, credsToUse.password)
if ctx != nil {
ctxdata.SetValue(ctx, addedAuthCtxKey, credsToUse)
}
Expand Down Expand Up @@ -472,7 +472,7 @@ func (h *GitServerHandler) requestWithAlternativeAuth(ctx *goproxy.ProxyCtx, bod
newReq.Body = io.NopCloser(bytes.NewReader(body))
}

newReq.SetBasicAuth(creds.username, creds.password)
helpers.SetBasicAuthorization(newReq, creds.username, creds.password)
newRsp, err := ctx.RoundTrip(newReq)
if err != nil {
return nil
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/goproxy_server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (h *GoProxyServerHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro
}

logging.RequestLogf(ctx, "* authenticating goproxy request (host: %s)", req.URL.Hostname())
req.SetBasicAuth(cred.username, cred.password)
helpers.SetBasicAuthorization(req, cred.username, cred.password)

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/helm_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (h *HelmRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Prox
}

logging.RequestLogf(ctx, "* authenticating helm registry request (host: %s)", req.URL.Hostname())
req.SetBasicAuth(cred.username, cred.password)
helpers.SetBasicAuthorization(req, cred.username, cred.password)

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/hex_organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (h *HexOrganizationHandler) HandleRequest(req *http.Request, ctx *goproxy.P
for _, cred := range h.credentials {
if cred.organization == reqOrg {
logging.RequestLogf(ctx, "* authenticating hex request (org: %s)", reqOrg)
req.Header.Set("authorization", cred.key)
helpers.SetRawAuthorization(req, cred.key)
return req, nil
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/hex_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (h *HexRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro
}

logging.RequestLogf(ctx, "* authenticating hex repository request (host: %s)", req.URL.Hostname())
req.Header.Set("authorization", cred.authKey)
helpers.SetRawAuthorization(req, cred.authKey)

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/maven_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (h *MavenRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.P
}

logging.RequestLogf(ctx, "* authenticating maven repository request (host: %s)", req.URL.Hostname())
req.SetBasicAuth(cred.username, cred.password)
helpers.SetBasicAuthorization(req, cred.username, cred.password)

return req, nil
}
Expand Down
4 changes: 2 additions & 2 deletions internal/handlers/npm_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ func (h *NPMRegistryHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy
username, password, found := strings.Cut(cred.token, ":")
if found {
logging.RequestLogf(ctx, "* authenticating npm registry request (host: %s, basic auth)", reqHost)
req.SetBasicAuth(username, password)
helpers.SetBasicAuthorization(req, username, password)
} else {
logging.RequestLogf(ctx, "* authenticating npm registry request (host: %s, token auth)", reqHost)
req.Header.Set("authorization", "Bearer "+cred.token)
helpers.SetBearerAuthorization(req, cred.token)
}
return req, nil
}
Expand Down
6 changes: 3 additions & 3 deletions internal/handlers/nuget_feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,14 @@ func authenticateNugetRequest(req *http.Request, cred nugetFeedCredentials, ctx
username, password, found := strings.Cut(token, ":")
if found {
logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, basic auth)", req.URL.Hostname())
req.SetBasicAuth(username, password)
helpers.SetBasicAuthorization(req, username, password)
} else if token != "" {
if shouldTreatTokenAsPassword(req.URL) {
logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, basic auth for Azure DevOps)", req.URL.Hostname())
req.SetBasicAuth("", token)
helpers.SetBasicAuthorization(req, "", token)
} else {
logging.RequestLogf(ctx, "* authenticating nuget feed request (host: %s, bearer auth)", req.URL.Hostname())
req.Header.Set("authorization", "Bearer "+token)
helpers.SetBearerAuthorization(req, token)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/pub_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (h *PubRepositoryHandler) HandleRequest(req *http.Request, ctx *goproxy.Pro
}

logging.RequestLogf(ctx, "* authenticating pub repository request (url: %s)", cred.url)
req.Header.Set("Authorization", "Bearer "+cred.token)
helpers.SetBearerAuthorization(req, cred.token)

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/python_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (h *PythonIndexHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy
}
// ignore `found` because it's okay for the password to be an empty string
username, password, _ := strings.Cut(token, ":")
req.SetBasicAuth(username, password)
helpers.SetBasicAuthorization(req, username, password)

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/rubygems_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (h *RubyGemsServerHandler) HandleRequest(req *http.Request, ctx *goproxy.Pr

// ignore `found` because it's okay for the password to be an empty string
username, password, _ := strings.Cut(cred.token, ":")
req.SetBasicAuth(username, password)
helpers.SetBasicAuthorization(req, username, password)

return req, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/terraform_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (h *TerraformRegistryHandler) HandleRequest(request *http.Request, context
}

logging.RequestLogf(context, "* authenticating terraform registry request (host: %s)", request.URL.Hostname())
request.Header.Set("Authorization", "Bearer "+cred.token)
helpers.SetBearerAuthorization(request, cred.token)
return request, nil
}

Expand Down
33 changes: 33 additions & 0 deletions internal/helpers/helpers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package helpers

import (
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
Expand All @@ -10,6 +12,37 @@ import (
"golang.org/x/net/idna"
)

// ReplaceAuthorization replaces the authorization configured on req with the given key and value.
//
// Note: "Authorization"-header is always cleared to avoid multiple auth headers being set on the request.
func ReplaceAuthorization(req *http.Request, key string, value string) {
req.Header.Del("Authorization")
req.Header.Set(key, value)
}
Comment on lines +15 to +21

// SetRawAuthorization sets the authorization header on req to the given value
func SetRawAuthorization(req *http.Request, authorization string) {
ReplaceAuthorization(req, "Authorization", authorization)
}

func SetBasicAuthorization(req *http.Request, username, password string) {
SetRawAuthorization(
req,
fmt.Sprintf(
"Basic %s",
base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password))),
),
)
Comment on lines +29 to +35
}

func SetBearerAuthorization(req *http.Request, token string) {
SetRawAuthorization(req, fmt.Sprintf("Bearer %s", token))
}

func SetGithubAPITokenAuthorization(req *http.Request, token string) {
SetRawAuthorization(req, fmt.Sprintf("token %s", token))
}

func CheckGitHubAPIHost(r *http.Request) bool {
hostname := GetHost(r)
// Check if the hostname is a GitHub API hostname and will return true
Expand Down
148 changes: 148 additions & 0 deletions internal/helpers/helpers_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,159 @@
package helpers

import (
"encoding/base64"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)

// newRequest builds a GET request to the given raw URL for use in tests.
func newRequest(t *testing.T, rawURL string) *http.Request {
t.Helper()
return httptest.NewRequest(http.MethodGet, rawURL, nil)
}

// newRequestWithAuth builds a request that already carries an Authorization header,
// simulating a client that sent credentials which should be replaced.
func newRequestWithAuth(t *testing.T, rawURL, existing string) *http.Request {
t.Helper()
req := newRequest(t, rawURL)
req.Header.Set("Authorization", existing)
return req
}

func TestSetBasicAuthorization(t *testing.T) {
t.Run("sets correct Basic header", func(t *testing.T) {
req := newRequest(t, "https://example.com")
SetBasicAuthorization(req, "user", "pass")

want := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
if got := req.Header.Get("Authorization"); got != want {
t.Errorf("Authorization = %q, want %q", got, want)
}
})

t.Run("clears pre-existing Authorization header", func(t *testing.T) {
req := newRequestWithAuth(t, "https://example.com", "Bearer old-token")
SetBasicAuthorization(req, "user", "pass")

want := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
if got := req.Header.Get("Authorization"); got != want {
t.Errorf("Authorization = %q, want %q", got, want)
}
if vals := req.Header["Authorization"]; len(vals) != 1 {
t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals)
}
})

t.Run("encodes empty username correctly", func(t *testing.T) {
req := newRequest(t, "https://example.com")
SetBasicAuthorization(req, "", "token")

want := "Basic " + base64.StdEncoding.EncodeToString([]byte(":token"))
if got := req.Header.Get("Authorization"); got != want {
t.Errorf("Authorization = %q, want %q", got, want)
}
})
}

func TestSetBearerAuthorization(t *testing.T) {
t.Run("sets correct Bearer header", func(t *testing.T) {
req := newRequest(t, "https://example.com")
SetBearerAuthorization(req, "my-token")

if got := req.Header.Get("Authorization"); got != "Bearer my-token" {
t.Errorf("Authorization = %q, want %q", got, "Bearer my-token")
}
})

t.Run("clears pre-existing Authorization header", func(t *testing.T) {
req := newRequestWithAuth(t, "https://example.com", "Basic dXNlcjpwYXNz")
SetBearerAuthorization(req, "new-token")

if got := req.Header.Get("Authorization"); got != "Bearer new-token" {
t.Errorf("Authorization = %q, want %q", got, "Bearer new-token")
}
if vals := req.Header["Authorization"]; len(vals) != 1 {
t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals)
}
})
}

func TestSetGithubAPITokenAuthorization(t *testing.T) {
t.Run("sets correct token header", func(t *testing.T) {
req := newRequest(t, "https://api.github.com")
SetGithubAPITokenAuthorization(req, "ghp_abc123")

if got := req.Header.Get("Authorization"); got != "token ghp_abc123" {
t.Errorf("Authorization = %q, want %q", got, "token ghp_abc123")
}
})

t.Run("clears pre-existing Authorization header", func(t *testing.T) {
req := newRequestWithAuth(t, "https://api.github.com", "token old-token")
SetGithubAPITokenAuthorization(req, "new-token")

if got := req.Header.Get("Authorization"); got != "token new-token" {
t.Errorf("Authorization = %q, want %q", got, "token new-token")
}
if vals := req.Header["Authorization"]; len(vals) != 1 {
t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals)
}
})
}

func TestSetRawAuthorization(t *testing.T) {
t.Run("sets pre-formatted value as-is", func(t *testing.T) {
req := newRequest(t, "https://example.com")
SetRawAuthorization(req, "Bearer already-formatted")

if got := req.Header.Get("Authorization"); got != "Bearer already-formatted" {
t.Errorf("Authorization = %q, want %q", got, "Bearer already-formatted")
}
})

t.Run("clears pre-existing Authorization header", func(t *testing.T) {
req := newRequestWithAuth(t, "https://example.com", "Bearer stale")
SetRawAuthorization(req, "token new-raw")

if got := req.Header.Get("Authorization"); got != "token new-raw" {
t.Errorf("Authorization = %q, want %q", got, "token new-raw")
}
if vals := req.Header["Authorization"]; len(vals) != 1 {
t.Errorf("expected exactly 1 Authorization value, got %d: %v", len(vals), vals)
}
})
}

func TestReplaceAuthorization_CustomKey(t *testing.T) {
t.Run("sets value on custom header key", func(t *testing.T) {
req := newRequest(t, "https://cloudsmith.example.com")
ReplaceAuthorization(req, "X-Api-Key", "my-api-key")

if got := req.Header.Get("X-Api-Key"); got != "my-api-key" {
t.Errorf("X-Api-Key = %q, want %q", got, "my-api-key")
}
if got := req.Header.Get("Authorization"); got != "" {
t.Errorf("Authorization should be empty, got %q", got)
}
})

t.Run("clears pre-existing Authorization header before setting custom key", func(t *testing.T) {
req := newRequest(t, "https://cloudsmith.example.com")
req.Header.Set("X-Api-Key", "old-key")
ReplaceAuthorization(req, "X-Api-Key", "new-key")

if got := req.Header.Get("X-Api-Key"); got != "new-key" {
t.Errorf("X-Api-Key = %q, want %q", got, "new-key")
}
if vals := req.Header["X-Api-Key"]; len(vals) != 1 {
t.Errorf("expected exactly 1 X-Api-Key value, got %d: %v", len(vals), vals)
}
})
}

func TestUrlMatchesRequest(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading
Loading