diff --git a/cla-backend-go/sss/auth.go b/cla-backend-go/sss/auth.go new file mode 100644 index 000000000..2b12d2019 --- /dev/null +++ b/cla-backend-go/sss/auth.go @@ -0,0 +1,20 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +// authRequest is the payload used for the Auth0 client credentials request. +type authRequest struct { + GrantType string `json:"grant_type"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Audience string `json:"audience"` +} + +// authResponse is the Auth0 token response payload. +type authResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} diff --git a/cla-backend-go/sss/client.go b/cla-backend-go/sss/client.go new file mode 100644 index 000000000..04273ed66 --- /dev/null +++ b/cla-backend-go/sss/client.go @@ -0,0 +1,329 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" +) + +const ( + defaultTimeout = 30 * time.Second + defaultTokenTTL = time.Hour + userAgent = "easycla-cla-backend-go/sss-client" +) + +// Client is a reusable HTTP client for the Sanctions Screening Service. +type Client struct { + cfg SSSConfig + httpClient *http.Client + token string + expiry time.Time + tokenMutex sync.RWMutex +} + +// NewClient creates a new SSS client configured for Auth0 client credentials. +func NewClient(cfg SSSConfig) (*Client, error) { + if strings.TrimSpace(cfg.BaseURL) == "" { + return nil, fmt.Errorf("base URL is required") + } + if strings.TrimSpace(cfg.Auth0Domain) == "" { + return nil, fmt.Errorf("Auth0 domain is required") + } + if strings.TrimSpace(cfg.Auth0ClientID) == "" { + return nil, fmt.Errorf("Auth0 client ID is required") + } + if strings.TrimSpace(cfg.Auth0ClientSecret) == "" { + return nil, fmt.Errorf("Auth0 client secret is required") + } + if strings.TrimSpace(cfg.Auth0Audience) == "" { + return nil, fmt.Errorf("Auth0 audience is required") + } + if cfg.Timeout <= 0 { + cfg.Timeout = defaultTimeout + } + + return &Client{ + cfg: cfg, + httpClient: &http.Client{Timeout: cfg.Timeout}, + }, nil +} + +// GetOrganizationStatus retrieves the sanctions screening result for an organization. +func (c *Client) GetOrganizationStatus(ctx context.Context, statusReq OrganizationStatusRequest) (*ScreeningResult, error) { + if strings.TrimSpace(statusReq.Domain) == "" { + return nil, &BadRequestError{Message: "domain is required"} + } + if strings.TrimSpace(statusReq.OrgName) == "" { + return nil, &BadRequestError{Message: "org_name is required"} + } + + token, err := c.getToken(ctx) + if err != nil { + return nil, err + } + + endpoint := strings.TrimRight(c.cfg.BaseURL, "/") + "/api/v1/organizations/status" + reqURL, err := url.Parse(endpoint) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + query := reqURL.Query() + query.Set("domain", strings.TrimSpace(statusReq.Domain)) + query.Set("org_name", strings.TrimSpace(statusReq.OrgName)) + if v := strings.TrimSpace(statusReq.Country); v != "" { + query.Set("country", v) + } + if v := strings.TrimSpace(statusReq.City); v != "" { + query.Set("city", v) + } + if v := strings.TrimSpace(statusReq.State); v != "" { + query.Set("state", v) + } + if v := strings.TrimSpace(statusReq.PostalCode); v != "" { + query.Set("postal_code", v) + } + if v := strings.TrimSpace(statusReq.SFDCID); v != "" { + query.Set("sfdc_id", v) + } + if v := strings.TrimSpace(statusReq.ClearbitID); v != "" { + query.Set("clearbit_id", v) + } + reqURL.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", userAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, toClientError(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + switch resp.StatusCode { + case http.StatusOK: + var result ScreeningResult + if err := json.NewDecoder(bytes.NewReader(body)).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode screening result: %w", err) + } + return &result, nil + case http.StatusBadRequest: + details := responseErrorDetails(body) + return nil, &BadRequestError{Message: details.Message, Code: details.Code, RequestID: details.RequestID} + case http.StatusNotFound: + details := responseErrorDetails(body) + return nil, &NotFoundError{Message: details.Message, Code: details.Code, RequestID: details.RequestID} + case http.StatusUnauthorized, http.StatusForbidden: + c.invalidateToken(token) + details := responseErrorDetails(body) + return nil, &AuthError{Message: details.Message, Code: details.Code, RequestID: details.RequestID} + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + details := responseErrorDetails(body) + return nil, &RetryableError{Message: details.Message, Code: details.Code, RequestID: details.RequestID, RetryAfter: parseRetryAfter(resp.Header.Get("Retry-After"))} + default: + return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, responseMessage(body)) + } +} + +func (c *Client) getToken(ctx context.Context) (string, error) { + c.tokenMutex.RLock() + currentToken := c.token + expiry := c.expiry + c.tokenMutex.RUnlock() + + if currentToken == "" || time.Until(expiry) <= time.Minute { + return c.fetchToken(ctx) + } + + return currentToken, nil +} + +func (c *Client) fetchToken(ctx context.Context) (string, error) { + c.tokenMutex.Lock() + defer c.tokenMutex.Unlock() + + if c.token != "" && time.Until(c.expiry) > time.Minute { + return c.token, nil + } + + requestPayload := authRequest{ + GrantType: "client_credentials", + ClientID: c.cfg.Auth0ClientID, + ClientSecret: c.cfg.Auth0ClientSecret, + Audience: c.cfg.Auth0Audience, + } + payload, err := json.Marshal(requestPayload) + if err != nil { + return "", fmt.Errorf("failed to marshal auth request: %w", err) + } + + authURL := c.authTokenURL() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authURL, bytes.NewReader(payload)) + if err != nil { + return "", fmt.Errorf("failed to create auth request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", userAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", toClientError(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read auth response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var auth0Err struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + details := upstreamErrorDetails{} + if err := json.Unmarshal(body, &auth0Err); err == nil { + details.Code = strings.TrimSpace(auth0Err.Error) + details.Message = strings.TrimSpace(auth0Err.ErrorDescription) + if details.Message == "" && details.Code != "" { + details.Message = details.Code + } + if details.Message == "" { + details.Message = resp.Status + } + } else { + details.Message = strings.TrimSpace(string(body)) + if details.Message == "" { + details.Message = resp.Status + } + } + return "", &AuthError{Message: fmt.Sprintf("authentication failed: %s", details.Message), Code: details.Code, RequestID: details.RequestID} + } + + var tokenResp authResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", fmt.Errorf("failed to decode auth response: %w", err) + } + + if tokenResp.AccessToken == "" { + return "", &AuthError{Message: "empty access token from auth server"} + } + + expiresIn := time.Duration(tokenResp.ExpiresIn) * time.Second + if expiresIn <= 0 { + expiresIn = defaultTokenTTL + } + c.token = tokenResp.AccessToken + c.expiry = time.Now().Add(expiresIn) + + return c.token, nil +} + +func (c *Client) invalidateToken(token string) { + c.tokenMutex.Lock() + defer c.tokenMutex.Unlock() + + if c.token == token { + c.token = "" + c.expiry = time.Time{} + } +} + +func (c *Client) authTokenURL() string { + domain := strings.TrimSpace(c.cfg.Auth0Domain) + if strings.HasPrefix(domain, "http://") || strings.HasPrefix(domain, "https://") { + return strings.TrimRight(domain, "/") + "/oauth/token" + } + return "https://" + strings.TrimRight(domain, "/") + "/oauth/token" +} + +func parseRetryAfter(value string) time.Duration { + if strings.TrimSpace(value) == "" { + return 0 + } + + if seconds, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { + if seconds < 0 { + return 0 + } + return time.Duration(seconds) * time.Second + } + + if parsedTime, err := http.ParseTime(value); err == nil { + d := time.Until(parsedTime) + if d < 0 { + return 0 + } + return d + } + return 0 +} + +type upstreamErrorDetails struct { + Message string + Code string + RequestID string +} + +func responseErrorDetails(body []byte) upstreamErrorDetails { + var errPayload struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + RequestID string `json:"request_id"` + } + if err := json.Unmarshal(body, &errPayload); err == nil { + details := upstreamErrorDetails{ + Message: strings.TrimSpace(errPayload.Error.Message), + Code: strings.TrimSpace(errPayload.Error.Code), + RequestID: strings.TrimSpace(errPayload.RequestID), + } + if details.Message != "" || details.Code != "" || details.RequestID != "" { + return details + } + } + + return upstreamErrorDetails{Message: strings.TrimSpace(string(body))} +} + +func responseMessage(body []byte) string { + return responseErrorDetails(body).Message +} + +func toClientError(err error) error { + if errors.Is(err, context.DeadlineExceeded) { + return &TimeoutError{Message: err.Error()} + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return &TimeoutError{Message: err.Error()} + } + + return err +} diff --git a/cla-backend-go/sss/client_test.go b/cla-backend-go/sss/client_test.go new file mode 100644 index 000000000..d277cdbe1 --- /dev/null +++ b/cla-backend-go/sss/client_test.go @@ -0,0 +1,725 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +const ( + testAuthTokenPath = "/oauth/token" + testOrgDomain = "example.com" + testOrgName = "Example Corp" + testRateLimitExceeded = "rate limit exceeded" +) + +func TestGetOrganizationStatus_Success(t *testing.T) { + authCalls := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + if got := r.Header.Get("User-Agent"); got != userAgent { + t.Fatalf("unexpected auth user-agent: %s", got) + } + atomic.AddInt32(&authCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-abc","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/api/v1/organizations/status" { + t.Fatalf("unexpected service request %s %s", r.Method, r.URL.Path) + } + if got := r.URL.Query().Get("domain"); got != testOrgDomain { + t.Fatalf("unexpected domain: %s", got) + } + if got := r.URL.Query().Get("org_name"); got != testOrgName { + t.Fatalf("unexpected org_name: %s", got) + } + if got := r.URL.Query().Get("country"); got != "US" { + t.Fatalf("unexpected country: %s", got) + } + if got := r.URL.Query().Get("city"); got != "San Francisco" { + t.Fatalf("unexpected city: %s", got) + } + if got := r.URL.Query().Get("state"); got != "CA" { + t.Fatalf("unexpected state: %s", got) + } + if got := r.URL.Query().Get("postal_code"); got != "94105" { + t.Fatalf("unexpected postal_code: %s", got) + } + if got := r.URL.Query().Get("sfdc_id"); got != "SFDC-123" { + t.Fatalf("unexpected sfdc_id: %s", got) + } + if got := r.URL.Query().Get("clearbit_id"); got != "CLEARBIT-123" { + t.Fatalf("unexpected clearbit_id: %s", got) + } + if got := r.Header.Get("Authorization"); got != "Bearer token-abc" { + t.Fatalf("unexpected auth header: %s", got) + } + if got := r.Header.Get("User-Agent"); got != userAgent { + t.Fatalf("unexpected service user-agent: %s", got) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clean","entity_id":"org-123","source":"screening_db","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"Example Corp"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + result, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{ + Domain: testOrgDomain, + OrgName: testOrgName, + Country: "US", + City: "San Francisco", + State: "CA", + PostalCode: "94105", + SFDCID: "SFDC-123", + ClearbitID: "CLEARBIT-123", + }) + if err != nil { + t.Fatal(err) + } + if result.Status != StatusClean || result.EntityID != "org-123" || result.Source != SourceScreeningDB { + t.Fatalf("unexpected result: %+v", result) + } + if result.SFDCID != nil { + t.Fatalf("expected nullable sfdc_id to decode as nil, got %q", *result.SFDCID) + } + if result.Vendor != "descartes" || !result.ClearbitEnriched || result.Domain != testOrgDomain || result.OrgName != testOrgName { + t.Fatalf("unexpected enriched fields: %+v", result) + } + if !result.ScreenedAt.Equal(time.Date(2025, 5, 16, 12, 34, 56, 0, time.UTC)) { + t.Fatalf("unexpected screened_at: %v", result.ScreenedAt) + } + if got := atomic.LoadInt32(&authCalls); got != 1 { + t.Fatalf("expected 1 auth call, got %d", got) + } +} + +func TestGetOrganizationStatus_MissingDomain(t *testing.T) { + client, err := NewClient(SSSConfig{ + BaseURL: "https://" + testOrgDomain, + Auth0Domain: "https://auth.example.com", + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{OrgName: "Example Org"}) + var badReq *BadRequestError + if !errors.As(err, &badReq) { + t.Fatalf("expected BadRequestError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_MissingOrgName(t *testing.T) { + client, err := NewClient(SSSConfig{ + BaseURL: "https://" + testOrgDomain, + Auth0Domain: "https://auth.example.com", + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain}) + var badReq *BadRequestError + if !errors.As(err, &badReq) { + t.Fatalf("expected BadRequestError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_TooManyRequestsReturnsRetryableError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-429","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "5") + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprintf(w, `{"error":{"code":"RATE_LIMITED","message":"%s"},"request_id":"req-429"}`, testRateLimitExceeded) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RateOrg"}) + var retryErr *RetryableError + if !errors.As(err, &retryErr) { + t.Fatalf("expected RetryableError, got %T: %v", err, err) + } + if retryErr.RetryAfter != 5*time.Second { + t.Fatalf("expected retry after 5s, got %v", retryErr.RetryAfter) + } + if retryErr.Message != testRateLimitExceeded { + t.Fatalf("unexpected retry message: %s", retryErr.Message) + } + if retryErr.Code != "RATE_LIMITED" || retryErr.RequestID != "req-429" { + t.Fatalf("unexpected retry details: %+v", retryErr) + } +} + +func TestGetOrganizationStatus_TooManyRequestsClampsNegativeRetryAfter(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-negative-retry","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "-5") + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprintf(w, `{"error":{"message":"%s"}}`, testRateLimitExceeded) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RateOrg"}) + var retryErr *RetryableError + if !errors.As(err, &retryErr) { + t.Fatalf("expected RetryableError, got %T: %v", err, err) + } + if retryErr.RetryAfter != 0 { + t.Fatalf("expected retry after 0s, got %v", retryErr.RetryAfter) + } + if retryErr.Message != testRateLimitExceeded { + t.Fatalf("unexpected retry message: %s", retryErr.Message) + } +} + +func TestGetOrganizationStatus_FlaggedResponse(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-flagged","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"flagged","entity_id":"org-flagged","source":"descartes_api","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":false,"sfdc_id":"SFDC-456","domain":"example.org","org_name":"Flagged Org"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + result, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{ + Domain: "example.org", + OrgName: "Flagged Org", + }) + if err != nil { + t.Fatal(err) + } + if result.Status != StatusFlagged || result.EntityID != "org-flagged" { + t.Fatalf("unexpected result: %+v", result) + } + if result.Source != SourceDescartesAPI { + t.Fatalf("unexpected source: %s", result.Source) + } + if result.SFDCID == nil || *result.SFDCID != "SFDC-456" { + t.Fatalf("unexpected sfdc_id: %+v", result.SFDCID) + } +} + +func TestGetOrganizationStatus_400ReturnsBadRequestError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-400","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `bad request from server`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "BadOrg"}) + var badReq *BadRequestError + if !errors.As(err, &badReq) { + t.Fatalf("expected BadRequestError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_400PreservesStructuredErrorDetails(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-400-structured","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `{"error":{"code":"INVALID_DOMAIN","message":"invalid domain"},"request_id":"req-400"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "bad domain", OrgName: "BadOrg"}) + var badReq *BadRequestError + if !errors.As(err, &badReq) { + t.Fatalf("expected BadRequestError, got %T: %v", err, err) + } + if badReq.Message != "invalid domain" || badReq.Code != "INVALID_DOMAIN" || badReq.RequestID != "req-400" { + t.Fatalf("unexpected bad request details: %+v", badReq) + } +} + +func TestGetOrganizationStatus_404ReturnsNotFoundError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-404","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":{"code":"ORG_NOT_FOUND","message":"Organization not found in any tier"},"request_id":"req-404"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "missing.example", OrgName: "MissingOrg"}) + var notFound *NotFoundError + if !errors.As(err, ¬Found) { + t.Fatalf("expected NotFoundError, got %T: %v", err, err) + } + if notFound.Message != "Organization not found in any tier" || notFound.Code != "ORG_NOT_FOUND" || notFound.RequestID != "req-404" { + t.Fatalf("unexpected not found details: %+v", notFound) + } +} + +func TestGetOrganizationStatus_401ReturnsAuthError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-401","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error":{"code":"TOKEN_EXPIRED","message":"unauthorized"},"request_id":"req-401"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "AuthOrg"}) + var authErr *AuthError + if !errors.As(err, &authErr) { + t.Fatalf("expected AuthError, got %T: %v", err, err) + } + if authErr.Message != "unauthorized" || authErr.Code != "TOKEN_EXPIRED" || authErr.RequestID != "req-401" { + t.Fatalf("unexpected auth details: %+v", authErr) + } +} + +func TestGetOrganizationStatus_401InvalidatesCachedToken(t *testing.T) { + authCalls := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + index := atomic.AddInt32(&authCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"token-%d","expires_in":3600,"token_type":"Bearer"}`, index) + })) + defer authServer.Close() + + serviceCalls := int32(0) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + index := atomic.AddInt32(&serviceCalls, 1) + if index == 1 { + if got := r.Header.Get("Authorization"); got != "Bearer token-1" { + t.Fatalf("unexpected first auth header: %s", got) + } + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error":{"code":"TOKEN_EXPIRED","message":"token expired"},"request_id":"req-expired"}`) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer token-2" { + t.Fatalf("expected refreshed token on second request, got %s", got) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clean","entity_id":"org-refresh","source":"sfdc","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"RefreshOrg"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RefreshOrg"}) + var authErr *AuthError + if !errors.As(err, &authErr) { + t.Fatalf("expected AuthError, got %T: %v", err, err) + } + + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RefreshOrg"}); err != nil { + t.Fatal(err) + } + if got := atomic.LoadInt32(&authCalls); got != 2 { + t.Fatalf("expected token refetch after auth failure, got %d auth calls", got) + } +} + +func TestGetOrganizationStatus_Auth0ErrorUsesAuth0Payload(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error":"invalid_client","error_description":"Client authentication failed"}`) + })) + defer authServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: "https://sss.example.com", + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "Auth0Org"}) + var authErr *AuthError + if !errors.As(err, &authErr) { + t.Fatalf("expected AuthError, got %T: %v", err, err) + } + if authErr.Code != "invalid_client" { + t.Fatalf("unexpected auth code: %s", authErr.Code) + } + if authErr.Message != "authentication failed: Client authentication failed" { + t.Fatalf("unexpected auth message: %s", authErr.Message) + } + if got := authErr.Error(); got != "authentication error: authentication failed: Client authentication failed (code=invalid_client request_id=)" { + t.Fatalf("unexpected auth error string: %s", got) + } +} + +func TestGetOrganizationStatus_503ReturnsRetryableError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-503","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "3") + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprint(w, `service unavailable`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RetryOrg"}) + var retryErr *RetryableError + if !errors.As(err, &retryErr) { + t.Fatalf("expected RetryableError, got %T: %v", err, err) + } + if retryErr.RetryAfter != 3*time.Second { + t.Fatalf("expected retry after 3s, got %v", retryErr.RetryAfter) + } +} + +func TestGetOrganizationStatus_TimeoutReturnsTimeoutError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-timeout","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-timeout","source":"sfdc","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"TimeoutOrg"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 50 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "TimeoutOrg"}) + var timeoutErr *TimeoutError + if !errors.As(err, &timeoutErr) { + t.Fatalf("expected TimeoutError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_UsesCachedToken(t *testing.T) { + authCalls := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&authCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"cached-token","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clean","entity_id":"org-cache","source":"sfdc","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"CacheOrg"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "CacheOrg"}); err != nil { + t.Fatal(err) + } + } + if got := atomic.LoadInt32(&authCalls); got != 1 { + t.Fatalf("expected 1 auth call, got %d", got) + } +} + +func TestGetOrganizationStatus_CachesTokenWhenExpiresInMissing(t *testing.T) { + authCalls := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&authCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"fallback-token","expires_in":0,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clean","entity_id":"org-fallback","source":"sfdc","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"FallbackOrg"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "FallbackOrg"}); err != nil { + t.Fatal(err) + } + } + if got := atomic.LoadInt32(&authCalls); got != 1 { + t.Fatalf("expected fallback token ttl to cache token, got %d auth calls", got) + } +} + +func TestGetOrganizationStatus_RefreshesExpiredToken(t *testing.T) { + tokenIndex := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + index := atomic.AddInt32(&tokenIndex, 1) + accessToken := fmt.Sprintf("token-%d", index) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"%s","expires_in":1,"token_type":"Bearer"}`, + accessToken) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got == "" { + t.Fatalf("missing authorization header") + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clean","entity_id":"org-expire","source":"screening_db","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"ExpireOrg"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "ExpireOrg"}); err != nil { + t.Fatal(err) + } + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "ExpireOrg"}); err != nil { + t.Fatal(err) + } + if got := atomic.LoadInt32(&tokenIndex); got < 2 { + t.Fatalf("expected token refresh after expiry, got %d auth calls", got) + } +} + +func TestNewClient_DefaultTimeout(t *testing.T) { + client, err := NewClient(SSSConfig{ + BaseURL: "https://sss.example.com", + Auth0Domain: "https://auth.example.com", + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + }) + if err != nil { + t.Fatal(err) + } + if client.httpClient.Timeout != defaultTimeout { + t.Fatalf("expected default timeout %v, got %v", defaultTimeout, client.httpClient.Timeout) + } +} diff --git a/cla-backend-go/sss/errors.go b/cla-backend-go/sss/errors.go new file mode 100644 index 000000000..bece4b45d --- /dev/null +++ b/cla-backend-go/sss/errors.go @@ -0,0 +1,72 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +import ( + "fmt" + "time" +) + +// BadRequestError indicates a 400 response from the SSS API. +type BadRequestError struct { + Message string + Code string + RequestID string +} + +func (e *BadRequestError) Error() string { + return formatError("bad request", e.Message, e.Code, e.RequestID) +} + +// AuthError indicates a 401 or 403 response from the SSS API. +type AuthError struct { + Message string + Code string + RequestID string +} + +func (e *AuthError) Error() string { + return formatError("authentication error", e.Message, e.Code, e.RequestID) +} + +// RetryableError indicates a 503 response from the SSS API. +type RetryableError struct { + Message string + Code string + RequestID string + RetryAfter time.Duration +} + +func (e *RetryableError) Error() string { + return formatError("retryable error", e.Message, e.Code, e.RequestID) +} + +// NotFoundError indicates a 404 response from the SSS API. +type NotFoundError struct { + Message string + Code string + RequestID string +} + +func (e *NotFoundError) Error() string { + return formatError("not found", e.Message, e.Code, e.RequestID) +} + +// TimeoutError indicates the request timed out. +type TimeoutError struct { + Message string + Code string + RequestID string +} + +func (e *TimeoutError) Error() string { + return formatError("timeout", e.Message, e.Code, e.RequestID) +} + +func formatError(prefix, message, code, requestID string) string { + if code != "" || requestID != "" { + return fmt.Sprintf("%s: %s (code=%s request_id=%s)", prefix, message, code, requestID) + } + return fmt.Sprintf("%s: %s", prefix, message) +} diff --git a/cla-backend-go/sss/types.go b/cla-backend-go/sss/types.go new file mode 100644 index 000000000..351b847de --- /dev/null +++ b/cla-backend-go/sss/types.go @@ -0,0 +1,55 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +import "time" + +// SSSConfig holds configuration values for the SSS client. +type SSSConfig struct { + BaseURL string + Auth0Domain string + Auth0ClientID string + Auth0ClientSecret string + // Auth0Audience is the Auth0 API audience/resource server identifier. + // Production values may require the exact identifier configured in Auth0, + // including a trailing slash when the resource server uses one. + Auth0Audience string + // Timeout is shared by SSS API requests and Auth0 token acquisition requests. + Timeout time.Duration +} + +// OrganizationStatusRequest holds parameters for querying organization screening status. +type OrganizationStatusRequest struct { + Domain string `json:"domain"` + OrgName string `json:"org_name"` + Country string `json:"country,omitempty"` + City string `json:"city,omitempty"` + State string `json:"state,omitempty"` + PostalCode string `json:"postal_code,omitempty"` + SFDCID string `json:"sfdc_id,omitempty"` + ClearbitID string `json:"clearbit_id,omitempty"` +} + +const ( + StatusClean = "clean" + StatusFlagged = "flagged" + + SourceScreeningDB = "screening_db" + SourceSFDC = "sfdc" + SourceDescartesAPI = "descartes_api" +) + +// ScreeningResult is returned by the SSS organization status endpoint. +type ScreeningResult struct { + Status string `json:"status"` + EntityID string `json:"entity_id"` + Source string `json:"source"` + ScreenedAt time.Time `json:"screened_at"` + ClearbitID string `json:"clearbit_id"` + SFDCID *string `json:"sfdc_id"` + OrgName string `json:"org_name"` + Domain string `json:"domain"` + Vendor string `json:"vendor"` + ClearbitEnriched bool `json:"clearbit_enriched"` +} diff --git a/cla-backend-go/v2/sign/service.go b/cla-backend-go/v2/sign/service.go index e8699aad3..238f26b6f 100644 --- a/cla-backend-go/v2/sign/service.go +++ b/cla-backend-go/v2/sign/service.go @@ -1460,7 +1460,7 @@ func (s *service) RequestIndividualSignature(ctx context.Context, input *models. var returnURL string if input.ReturnURL.String() == "" { log.WithFields(f).Warnf("signature return url is empty") - returnURL, err = getActiveSignatureReturnURL(*input.UserID, activeSignatureMetadata) + returnURL, err = s.getActiveSignatureReturnURL(ctx, *input.UserID, activeSignatureMetadata) if err != nil { log.WithFields(f).WithError(err).Warnf("unable to get active signature return url for user: %s", *input.UserID) return nil, err @@ -1682,7 +1682,8 @@ func (s *service) getInstallationIDFromRepositoryID(ctx context.Context, reposit installationId = githubOrg.OrganizationInstallationID if installationId == 0 { - log.WithFields(f).Warnf("unable to get installation ID for repository ID: %s", repositoryID) + err = fmt.Errorf("installation ID missing for repository ID: %s", repositoryID) + log.WithFields(f).WithError(err).Warnf("unable to get installation ID for repository ID: %s", repositoryID) return 0, err } @@ -2222,49 +2223,6 @@ func getUserEmail(user *v1Models.User, preferredEmail string, providerType strin return "" } -func getActiveSignatureReturnURL(userID string, metadata map[string]interface{}) (string, error) { - - f := logrus.Fields{ - "functionName": "sign.getActiveSignatureReturnURL", - } - - var returnURL string - var err error - var pullRequestID int - var installationID int64 - var repositoryID int64 - - if found, ok := metadata["pull_request_id"].(int); ok { - pullRequestID = found - } else { - log.WithFields(f).WithError(err).Warnf("unable to get pull request ID for user: %s", userID) - return "", err - } - - if found, ok := metadata["installation_id"].(int64); ok { - installationID = found - } else { - log.WithFields(f).WithError(err).Warnf("unable to get installation ID for user: %s", userID) - return "", err - } - - if found, ok := metadata["repository_id"].(int64); ok { - repositoryID = found - } else { - log.WithFields(f).WithError(err).Warnf("unable to get repository ID for user: %s", userID) - return "", err - } - - returnURL, err = github.GetReturnURL(context.Background(), installationID, repositoryID, pullRequestID) - - if err != nil { - return "", err - } - - return returnURL, nil - -} - func (s *service) createDefaultIndividualValues(user *v1Models.User, preferredEmail string, allowLFEmail bool) map[string]interface{} { f := logrus.Fields{ "functionName": "sign.createDefaultIndiviualValues",