From fb49b498aaba2f954833098ffda6d538800177f7 Mon Sep 17 00:00:00 2001 From: psrsingh Date: Sat, 16 May 2026 19:06:03 +0530 Subject: [PATCH 1/7] feat: add SSS client with Auth0 token caching Signed-off-by: psrsingh --- cla-backend-go/sss/auth.go | 20 ++ cla-backend-go/sss/client.go | 226 ++++++++++++++++++++ cla-backend-go/sss/client_test.go | 328 ++++++++++++++++++++++++++++++ cla-backend-go/sss/errors.go | 46 +++++ cla-backend-go/sss/types.go | 24 +++ 5 files changed, 644 insertions(+) create mode 100644 cla-backend-go/sss/auth.go create mode 100644 cla-backend-go/sss/client.go create mode 100644 cla-backend-go/sss/client_test.go create mode 100644 cla-backend-go/sss/errors.go create mode 100644 cla-backend-go/sss/types.go diff --git a/cla-backend-go/sss/auth.go b/cla-backend-go/sss/auth.go new file mode 100644 index 000000000..2b12d2019 --- /dev/null +++ b/cla-backend-go/sss/auth.go @@ -0,0 +1,20 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +// authRequest is the payload used for the Auth0 client credentials request. +type authRequest struct { + GrantType string `json:"grant_type"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Audience string `json:"audience"` +} + +// authResponse is the Auth0 token response payload. +type authResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} diff --git a/cla-backend-go/sss/client.go b/cla-backend-go/sss/client.go new file mode 100644 index 000000000..801d0dc7c --- /dev/null +++ b/cla-backend-go/sss/client.go @@ -0,0 +1,226 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" +) + +const defaultTimeout = 10 * time.Second + +// Client is a reusable HTTP client for the Sanctions Screening Service. +type Client struct { + cfg SSSConfig + httpClient *http.Client + token string + expiry time.Time + tokenMutex sync.RWMutex +} + +// NewClient creates a new SSS client configured for Auth0 client credentials. +func NewClient(cfg SSSConfig) (*Client, error) { + if strings.TrimSpace(cfg.BaseURL) == "" { + return nil, fmt.Errorf("base URL is required") + } + if strings.TrimSpace(cfg.Auth0Domain) == "" { + return nil, fmt.Errorf("Auth0 domain is required") + } + if strings.TrimSpace(cfg.Auth0ClientID) == "" { + return nil, fmt.Errorf("Auth0 client ID is required") + } + if strings.TrimSpace(cfg.Auth0ClientSecret) == "" { + return nil, fmt.Errorf("Auth0 client secret is required") + } + if strings.TrimSpace(cfg.Auth0Audience) == "" { + return nil, fmt.Errorf("Auth0 audience is required") + } + if cfg.Timeout <= 0 { + cfg.Timeout = defaultTimeout + } + + return &Client{ + cfg: cfg, + httpClient: &http.Client{Timeout: cfg.Timeout}, + }, nil +} + +// GetOrganizationStatus retrieves the sanctions screening result for an organization. +func (c *Client) GetOrganizationStatus(ctx context.Context, organizationID string) (*ScreeningResult, error) { + if strings.TrimSpace(organizationID) == "" { + return nil, &BadRequestError{Message: "organization id is required"} + } + + token, err := c.getToken(ctx) + if err != nil { + return nil, err + } + + endpoint := strings.TrimRight(c.cfg.BaseURL, "/") + "/api/v1/organizations/status" + reqURL, err := url.Parse(endpoint) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + query := reqURL.Query() + query.Set("organization_id", strings.TrimSpace(organizationID)) + reqURL.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, toClientError(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + switch resp.StatusCode { + case http.StatusOK: + var result ScreeningResult + if err := json.NewDecoder(bytes.NewReader(body)).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode screening result: %w", err) + } + return &result, nil + case http.StatusBadRequest: + return nil, &BadRequestError{Message: strings.TrimSpace(string(body))} + case http.StatusUnauthorized, http.StatusForbidden: + return nil, &AuthError{Message: strings.TrimSpace(string(body))} + case http.StatusServiceUnavailable: + return nil, &RetryableError{Message: strings.TrimSpace(string(body)), RetryAfter: parseRetryAfter(resp.Header.Get("Retry-After"))} + default: + return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } +} + +func (c *Client) getToken(ctx context.Context) (string, error) { + c.tokenMutex.RLock() + currentToken := c.token + expiry := c.expiry + c.tokenMutex.RUnlock() + + if currentToken == "" || time.Until(expiry) <= time.Minute { + return c.fetchToken(ctx) + } + + return currentToken, nil +} + +func (c *Client) fetchToken(ctx context.Context) (string, error) { + c.tokenMutex.Lock() + defer c.tokenMutex.Unlock() + + if c.token != "" && time.Until(c.expiry) > time.Minute { + return c.token, nil + } + + requestPayload := authRequest{ + GrantType: "client_credentials", + ClientID: c.cfg.Auth0ClientID, + ClientSecret: c.cfg.Auth0ClientSecret, + Audience: c.cfg.Auth0Audience, + } + payload, err := json.Marshal(requestPayload) + if err != nil { + return "", fmt.Errorf("failed to marshal auth request: %w", err) + } + + authURL := c.authTokenURL() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authURL, bytes.NewReader(payload)) + if err != nil { + return "", fmt.Errorf("failed to create auth request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", toClientError(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read auth response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", &AuthError{Message: fmt.Sprintf("authentication failed: %s", strings.TrimSpace(string(body)))} + } + + var authResponse authResponse + if err := json.Unmarshal(body, &authResponse); err != nil { + return "", fmt.Errorf("failed to decode auth response: %w", err) + } + + if authResponse.AccessToken == "" { + return "", &AuthError{Message: "empty access token from auth server"} + } + + expiresIn := time.Duration(authResponse.ExpiresIn) * time.Second + if expiresIn <= 0 { + expiresIn = defaultTimeout + } + c.token = authResponse.AccessToken + c.expiry = time.Now().Add(expiresIn) + + return c.token, nil +} + +func (c *Client) authTokenURL() string { + domain := strings.TrimSpace(c.cfg.Auth0Domain) + if strings.HasPrefix(domain, "http://") || strings.HasPrefix(domain, "https://") { + return strings.TrimRight(domain, "/") + "/oauth/token" + } + return "https://" + strings.TrimRight(domain, "/") + "/oauth/token" +} + +func parseRetryAfter(value string) time.Duration { + if strings.TrimSpace(value) == "" { + return 0 + } + + if seconds, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { + return time.Duration(seconds) * time.Second + } + + if parsedTime, err := http.ParseTime(value); err == nil { + return time.Until(parsedTime) + } + + return 0 +} + +func toClientError(err error) error { + if errors.Is(err, context.DeadlineExceeded) { + return &TimeoutError{Message: err.Error()} + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return &TimeoutError{Message: err.Error()} + } + + return err +} diff --git a/cla-backend-go/sss/client_test.go b/cla-backend-go/sss/client_test.go new file mode 100644 index 000000000..3f1324e59 --- /dev/null +++ b/cla-backend-go/sss/client_test.go @@ -0,0 +1,328 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func TestGetOrganizationStatus_Success(t *testing.T) { + authCalls := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + atomic.AddInt32(&authCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-abc","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/api/v1/organizations/status" { + t.Fatalf("unexpected service request %s %s", r.Method, r.URL.Path) + } + if got := r.URL.Query().Get("organization_id"); got != "org-123" { + t.Fatalf("unexpected organization_id: %s", got) + } + if got := r.Header.Get("Authorization"); got != "Bearer token-abc" { + t.Fatalf("unexpected auth header: %s", got) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clear","entity_id":"org-123","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + result, err := client.GetOrganizationStatus(context.Background(), "org-123") + if err != nil { + t.Fatal(err) + } + if result.Status != "clear" || result.EntityID != "org-123" || result.Source != "ofac" { + t.Fatalf("unexpected result: %+v", result) + } + if !result.ScreenedAt.Equal(time.Date(2025, 5, 16, 12, 34, 56, 0, time.UTC)) { + t.Fatalf("unexpected screened_at: %v", result.ScreenedAt) + } + if got := atomic.LoadInt32(&authCalls); got != 1 { + t.Fatalf("expected 1 auth call, got %d", got) + } +} + +func TestGetOrganizationStatus_FlaggedResponse(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-flagged","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"flagged","entity_id":"org-flagged","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + result, err := client.GetOrganizationStatus(context.Background(), "org-flagged") + if err != nil { + t.Fatal(err) + } + if result.Status != "flagged" || result.EntityID != "org-flagged" { + t.Fatalf("unexpected result: %+v", result) + } +} + +func TestGetOrganizationStatus_400ReturnsBadRequestError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-400","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `bad request from server`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), "org-400") + var badReq *BadRequestError + if !errors.As(err, &badReq) { + t.Fatalf("expected BadRequestError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_401ReturnsAuthError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-401","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `unauthorized`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), "org-401") + var authErr *AuthError + if !errors.As(err, &authErr) { + t.Fatalf("expected AuthError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_503ReturnsRetryableError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-503","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "3") + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprint(w, `service unavailable`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), "org-503") + var retryErr *RetryableError + if !errors.As(err, &retryErr) { + t.Fatalf("expected RetryableError, got %T: %v", err, err) + } + if retryErr.RetryAfter != 3*time.Second { + t.Fatalf("expected retry after 3s, got %v", retryErr.RetryAfter) + } +} + +func TestGetOrganizationStatus_TimeoutReturnsTimeoutError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-timeout","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"clear","entity_id":"org-timeout","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 50 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), "org-timeout") + var timeoutErr *TimeoutError + if !errors.As(err, &timeoutErr) { + t.Fatalf("expected TimeoutError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_UsesCachedToken(t *testing.T) { + authCalls := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&authCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"cached-token","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clear","entity_id":"org-cache","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + if _, err := client.GetOrganizationStatus(context.Background(), "org-cache"); err != nil { + t.Fatal(err) + } + } + if got := atomic.LoadInt32(&authCalls); got != 1 { + t.Fatalf("expected 1 auth call, got %d", got) + } +} + +func TestGetOrganizationStatus_RefreshesExpiredToken(t *testing.T) { + tokenIndex := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + index := atomic.AddInt32(&tokenIndex, 1) + accessToken := fmt.Sprintf("token-%d", index) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"%s","expires_in":1,"token_type":"Bearer"}`, + accessToken) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got == "" { + t.Fatalf("missing authorization header") + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clear","entity_id":"org-expire","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + if _, err := client.GetOrganizationStatus(context.Background(), "org-expire"); err != nil { + t.Fatal(err) + } + if _, err := client.GetOrganizationStatus(context.Background(), "org-expire"); err != nil { + t.Fatal(err) + } + if got := atomic.LoadInt32(&tokenIndex); got < 2 { + t.Fatalf("expected token refresh after expiry, got %d auth calls", got) + } +} diff --git a/cla-backend-go/sss/errors.go b/cla-backend-go/sss/errors.go new file mode 100644 index 000000000..d61fb235a --- /dev/null +++ b/cla-backend-go/sss/errors.go @@ -0,0 +1,46 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +import ( + "fmt" + "time" +) + +// BadRequestError indicates a 400 response from the SSS API. +type BadRequestError struct { + Message string +} + +func (e *BadRequestError) Error() string { + return fmt.Sprintf("bad request: %s", e.Message) +} + +// AuthError indicates a 401 or 403 response from the SSS API. +type AuthError struct { + Message string +} + +func (e *AuthError) Error() string { + return fmt.Sprintf("authentication error: %s", e.Message) +} + +// RetryableError indicates a 503 response from the SSS API. +type RetryableError struct { + Message string + RetryAfter time.Duration +} + +func (e *RetryableError) Error() string { + return fmt.Sprintf("retryable error: %s", e.Message) +} + +// TimeoutError indicates the request timed out. +type TimeoutError struct { + Message string +} + +func (e *TimeoutError) Error() string { + return fmt.Sprintf("timeout: %s", e.Message) +} diff --git a/cla-backend-go/sss/types.go b/cla-backend-go/sss/types.go new file mode 100644 index 000000000..1dae50cfb --- /dev/null +++ b/cla-backend-go/sss/types.go @@ -0,0 +1,24 @@ +// Copyright The Linux Foundation and each contributor to CommunityBridge. +// SPDX-License-Identifier: MIT + +package sss + +import "time" + +// SSSConfig holds configuration values for the SSS client. +type SSSConfig struct { + BaseURL string + Auth0Domain string + Auth0ClientID string + Auth0ClientSecret string + Auth0Audience string + Timeout time.Duration +} + +// ScreeningResult is returned by the SSS organization status endpoint. +type ScreeningResult struct { + Status string `json:"status"` + EntityID string `json:"entity_id"` + Source string `json:"source"` + ScreenedAt time.Time `json:"screened_at"` +} From 6bd932c24176c483c8ba2c05d74909af20c66de5 Mon Sep 17 00:00:00 2001 From: prashant singh Date: Tue, 26 May 2026 13:42:27 +0530 Subject: [PATCH 2/7] Return zero for negative durations from parsed time Handle negative duration case when parsing time. Signed-off-by: prashant singh --- cla-backend-go/sss/client.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cla-backend-go/sss/client.go b/cla-backend-go/sss/client.go index 801d0dc7c..c51339a4e 100644 --- a/cla-backend-go/sss/client.go +++ b/cla-backend-go/sss/client.go @@ -206,9 +206,12 @@ func parseRetryAfter(value string) time.Duration { } if parsedTime, err := http.ParseTime(value); err == nil { - return time.Until(parsedTime) - } - + d := time.Until(parsedTime) + if d < 0 { + return 0 + } + return d +} return 0 } From 0bdb94c1e5a769127a69167d2a099636a355447b Mon Sep 17 00:00:00 2001 From: psrsingh Date: Tue, 26 May 2026 19:08:26 +0530 Subject: [PATCH 3/7] fix: address API contract and active signature review feedback Signed-off-by: psrsingh --- cla-backend-go/sss/client.go | 64 ++++++++++--- cla-backend-go/sss/client_test.go | 146 ++++++++++++++++++++++++++---- cla-backend-go/sss/types.go | 31 ++++++- cla-backend-go/v2/sign/service.go | 5 +- 4 files changed, 208 insertions(+), 38 deletions(-) diff --git a/cla-backend-go/sss/client.go b/cla-backend-go/sss/client.go index c51339a4e..97c9a1960 100644 --- a/cla-backend-go/sss/client.go +++ b/cla-backend-go/sss/client.go @@ -58,9 +58,12 @@ func NewClient(cfg SSSConfig) (*Client, error) { } // GetOrganizationStatus retrieves the sanctions screening result for an organization. -func (c *Client) GetOrganizationStatus(ctx context.Context, organizationID string) (*ScreeningResult, error) { - if strings.TrimSpace(organizationID) == "" { - return nil, &BadRequestError{Message: "organization id is required"} +func (c *Client) GetOrganizationStatus(ctx context.Context, statusReq OrganizationStatusRequest) (*ScreeningResult, error) { + if strings.TrimSpace(statusReq.Domain) == "" { + return nil, &BadRequestError{Message: "domain is required"} + } + if strings.TrimSpace(statusReq.OrgName) == "" { + return nil, &BadRequestError{Message: "org_name is required"} } token, err := c.getToken(ctx) @@ -75,7 +78,26 @@ func (c *Client) GetOrganizationStatus(ctx context.Context, organizationID strin } query := reqURL.Query() - query.Set("organization_id", strings.TrimSpace(organizationID)) + query.Set("domain", strings.TrimSpace(statusReq.Domain)) + query.Set("org_name", strings.TrimSpace(statusReq.OrgName)) + if v := strings.TrimSpace(statusReq.Country); v != "" { + query.Set("country", v) + } + if v := strings.TrimSpace(statusReq.City); v != "" { + query.Set("city", v) + } + if v := strings.TrimSpace(statusReq.State); v != "" { + query.Set("state", v) + } + if v := strings.TrimSpace(statusReq.PostalCode); v != "" { + query.Set("postal_code", v) + } + if v := strings.TrimSpace(statusReq.SfdcID); v != "" { + query.Set("sfdc_id", v) + } + if v := strings.TrimSpace(statusReq.ClearbitID); v != "" { + query.Set("clearbit_id", v) + } reqURL.RawQuery = query.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) @@ -105,13 +127,13 @@ func (c *Client) GetOrganizationStatus(ctx context.Context, organizationID strin } return &result, nil case http.StatusBadRequest: - return nil, &BadRequestError{Message: strings.TrimSpace(string(body))} + return nil, &BadRequestError{Message: responseMessage(body)} case http.StatusUnauthorized, http.StatusForbidden: - return nil, &AuthError{Message: strings.TrimSpace(string(body))} - case http.StatusServiceUnavailable: - return nil, &RetryableError{Message: strings.TrimSpace(string(body)), RetryAfter: parseRetryAfter(resp.Header.Get("Retry-After"))} + return nil, &AuthError{Message: responseMessage(body)} + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + return nil, &RetryableError{Message: responseMessage(body), RetryAfter: parseRetryAfter(resp.Header.Get("Retry-After"))} default: - return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, responseMessage(body)) } } @@ -206,15 +228,27 @@ func parseRetryAfter(value string) time.Duration { } if parsedTime, err := http.ParseTime(value); err == nil { - d := time.Until(parsedTime) - if d < 0 { - return 0 - } - return d -} + d := time.Until(parsedTime) + if d < 0 { + return 0 + } + return d + } return 0 } +func responseMessage(body []byte) string { + var errPayload struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(body, &errPayload); err == nil && strings.TrimSpace(errPayload.Error.Message) != "" { + return strings.TrimSpace(errPayload.Error.Message) + } + return strings.TrimSpace(string(body)) +} + func toClientError(err error) error { if errors.Is(err, context.DeadlineExceeded) { return &TimeoutError{Message: err.Error()} diff --git a/cla-backend-go/sss/client_test.go b/cla-backend-go/sss/client_test.go index 3f1324e59..982cfee88 100644 --- a/cla-backend-go/sss/client_test.go +++ b/cla-backend-go/sss/client_test.go @@ -30,14 +30,35 @@ func TestGetOrganizationStatus_Success(t *testing.T) { if r.Method != http.MethodGet || r.URL.Path != "/api/v1/organizations/status" { t.Fatalf("unexpected service request %s %s", r.Method, r.URL.Path) } - if got := r.URL.Query().Get("organization_id"); got != "org-123" { - t.Fatalf("unexpected organization_id: %s", got) + if got := r.URL.Query().Get("domain"); got != "example.com" { + t.Fatalf("unexpected domain: %s", got) + } + if got := r.URL.Query().Get("org_name"); got != "Example Corp" { + t.Fatalf("unexpected org_name: %s", got) + } + if got := r.URL.Query().Get("country"); got != "US" { + t.Fatalf("unexpected country: %s", got) + } + if got := r.URL.Query().Get("city"); got != "San Francisco" { + t.Fatalf("unexpected city: %s", got) + } + if got := r.URL.Query().Get("state"); got != "CA" { + t.Fatalf("unexpected state: %s", got) + } + if got := r.URL.Query().Get("postal_code"); got != "94105" { + t.Fatalf("unexpected postal_code: %s", got) + } + if got := r.URL.Query().Get("sfdc_id"); got != "SFDC-123" { + t.Fatalf("unexpected sfdc_id: %s", got) + } + if got := r.URL.Query().Get("clearbit_id"); got != "CLEARBIT-123" { + t.Fatalf("unexpected clearbit_id: %s", got) } if got := r.Header.Get("Authorization"); got != "Bearer token-abc" { t.Fatalf("unexpected auth header: %s", got) } w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"status":"clear","entity_id":"org-123","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-123","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) })) defer serviceServer.Close() @@ -53,11 +74,20 @@ func TestGetOrganizationStatus_Success(t *testing.T) { t.Fatal(err) } - result, err := client.GetOrganizationStatus(context.Background(), "org-123") + result, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{ + Domain: "example.com", + OrgName: "Example Corp", + Country: "US", + City: "San Francisco", + State: "CA", + PostalCode: "94105", + SfdcID: "SFDC-123", + ClearbitID: "CLEARBIT-123", + }) if err != nil { t.Fatal(err) } - if result.Status != "clear" || result.EntityID != "org-123" || result.Source != "ofac" { + if result.Status != StatusClean || result.EntityID != "org-123" || result.Source != "ofac" { t.Fatalf("unexpected result: %+v", result) } if !result.ScreenedAt.Equal(time.Date(2025, 5, 16, 12, 34, 56, 0, time.UTC)) { @@ -68,6 +98,85 @@ func TestGetOrganizationStatus_Success(t *testing.T) { } } +func TestGetOrganizationStatus_MissingDomain(t *testing.T) { + client, err := NewClient(SSSConfig{ + BaseURL: "https://example.com", + Auth0Domain: "https://auth.example.com", + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{OrgName: "Example Org"}) + var badReq *BadRequestError + if !errors.As(err, &badReq) { + t.Fatalf("expected BadRequestError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_MissingOrgName(t *testing.T) { + client, err := NewClient(SSSConfig{ + BaseURL: "https://example.com", + Auth0Domain: "https://auth.example.com", + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com"}) + var badReq *BadRequestError + if !errors.As(err, &badReq) { + t.Fatalf("expected BadRequestError, got %T: %v", err, err) + } +} + +func TestGetOrganizationStatus_TooManyRequestsReturnsRetryableError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-429","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "5") + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprint(w, `{"error":{"message":"rate limit exceeded"}}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RateOrg"}) + var retryErr *RetryableError + if !errors.As(err, &retryErr) { + t.Fatalf("expected RetryableError, got %T: %v", err, err) + } + if retryErr.RetryAfter != 5*time.Second { + t.Fatalf("expected retry after 5s, got %v", retryErr.RetryAfter) + } + if retryErr.Message != "rate limit exceeded" { + t.Fatalf("unexpected retry message: %s", retryErr.Message) + } +} + func TestGetOrganizationStatus_FlaggedResponse(t *testing.T) { authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -93,11 +202,14 @@ func TestGetOrganizationStatus_FlaggedResponse(t *testing.T) { t.Fatal(err) } - result, err := client.GetOrganizationStatus(context.Background(), "org-flagged") + result, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{ + Domain: "example.org", + OrgName: "Flagged Org", + }) if err != nil { t.Fatal(err) } - if result.Status != "flagged" || result.EntityID != "org-flagged" { + if result.Status != StatusFlagged || result.EntityID != "org-flagged" { t.Fatalf("unexpected result: %+v", result) } } @@ -130,7 +242,7 @@ func TestGetOrganizationStatus_400ReturnsBadRequestError(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), "org-400") + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "BadOrg"}) var badReq *BadRequestError if !errors.As(err, &badReq) { t.Fatalf("expected BadRequestError, got %T: %v", err, err) @@ -165,7 +277,7 @@ func TestGetOrganizationStatus_401ReturnsAuthError(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), "org-401") + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "AuthOrg"}) var authErr *AuthError if !errors.As(err, &authErr) { t.Fatalf("expected AuthError, got %T: %v", err, err) @@ -201,7 +313,7 @@ func TestGetOrganizationStatus_503ReturnsRetryableError(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), "org-503") + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RetryOrg"}) var retryErr *RetryableError if !errors.As(err, &retryErr) { t.Fatalf("expected RetryableError, got %T: %v", err, err) @@ -224,7 +336,7 @@ func TestGetOrganizationStatus_TimeoutReturnsTimeoutError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) w.WriteHeader(http.StatusOK) - fmt.Fprint(w, `{"status":"clear","entity_id":"org-timeout","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-timeout","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) })) defer server.Close() @@ -240,7 +352,7 @@ func TestGetOrganizationStatus_TimeoutReturnsTimeoutError(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), "org-timeout") + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "TimeoutOrg"}) var timeoutErr *TimeoutError if !errors.As(err, &timeoutErr) { t.Fatalf("expected TimeoutError, got %T: %v", err, err) @@ -258,7 +370,7 @@ func TestGetOrganizationStatus_UsesCachedToken(t *testing.T) { serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"status":"clear","entity_id":"org-cache","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-cache","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) })) defer serviceServer.Close() @@ -275,7 +387,7 @@ func TestGetOrganizationStatus_UsesCachedToken(t *testing.T) { } for i := 0; i < 2; i++ { - if _, err := client.GetOrganizationStatus(context.Background(), "org-cache"); err != nil { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "CacheOrg"}); err != nil { t.Fatal(err) } } @@ -300,7 +412,7 @@ func TestGetOrganizationStatus_RefreshesExpiredToken(t *testing.T) { t.Fatalf("missing authorization header") } w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"status":"clear","entity_id":"org-expire","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-expire","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) })) defer serviceServer.Close() @@ -316,10 +428,10 @@ func TestGetOrganizationStatus_RefreshesExpiredToken(t *testing.T) { t.Fatal(err) } - if _, err := client.GetOrganizationStatus(context.Background(), "org-expire"); err != nil { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "ExpireOrg"}); err != nil { t.Fatal(err) } - if _, err := client.GetOrganizationStatus(context.Background(), "org-expire"); err != nil { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "ExpireOrg"}); err != nil { t.Fatal(err) } if got := atomic.LoadInt32(&tokenIndex); got < 2 { diff --git a/cla-backend-go/sss/types.go b/cla-backend-go/sss/types.go index 1dae50cfb..e20d810ad 100644 --- a/cla-backend-go/sss/types.go +++ b/cla-backend-go/sss/types.go @@ -15,10 +15,33 @@ type SSSConfig struct { Timeout time.Duration } +// OrganizationStatusRequest holds parameters for querying organization screening status. +type OrganizationStatusRequest struct { + Domain string `json:"domain"` + OrgName string `json:"org_name"` + Country string `json:"country,omitempty"` + City string `json:"city,omitempty"` + State string `json:"state,omitempty"` + PostalCode string `json:"postal_code,omitempty"` + SfdcID string `json:"sfdc_id,omitempty"` + ClearbitID string `json:"clearbit_id,omitempty"` +} + +const ( + StatusClean = "clean" + StatusFlagged = "flagged" +) + // ScreeningResult is returned by the SSS organization status endpoint. type ScreeningResult struct { - Status string `json:"status"` - EntityID string `json:"entity_id"` - Source string `json:"source"` - ScreenedAt time.Time `json:"screened_at"` + Status string `json:"status"` + EntityID string `json:"entity_id"` + Source string `json:"source"` + ScreenedAt time.Time `json:"screened_at"` + ClearbitID string `json:"clearbit_id,omitempty"` + SfdcID string `json:"sfdc_id,omitempty"` + OrgName string `json:"org_name,omitempty"` + Domain string `json:"domain,omitempty"` + Vendor string `json:"vendor,omitempty"` + ClearbitEnriched bool `json:"clearbit_enriched,omitempty"` } diff --git a/cla-backend-go/v2/sign/service.go b/cla-backend-go/v2/sign/service.go index e8699aad3..a1cefe2b6 100644 --- a/cla-backend-go/v2/sign/service.go +++ b/cla-backend-go/v2/sign/service.go @@ -1460,7 +1460,7 @@ func (s *service) RequestIndividualSignature(ctx context.Context, input *models. var returnURL string if input.ReturnURL.String() == "" { log.WithFields(f).Warnf("signature return url is empty") - returnURL, err = getActiveSignatureReturnURL(*input.UserID, activeSignatureMetadata) + returnURL, err = s.getActiveSignatureReturnURL(ctx, *input.UserID, activeSignatureMetadata) if err != nil { log.WithFields(f).WithError(err).Warnf("unable to get active signature return url for user: %s", *input.UserID) return nil, err @@ -1682,7 +1682,8 @@ func (s *service) getInstallationIDFromRepositoryID(ctx context.Context, reposit installationId = githubOrg.OrganizationInstallationID if installationId == 0 { - log.WithFields(f).Warnf("unable to get installation ID for repository ID: %s", repositoryID) + err = fmt.Errorf("installation ID missing for repository ID: %s", repositoryID) + log.WithFields(f).WithError(err).Warnf("unable to get installation ID for repository ID: %s", repositoryID) return 0, err } From b9f9e01208b5dad875aa3f7b2d4e32d4c7b9ba34 Mon Sep 17 00:00:00 2001 From: psrsingh Date: Tue, 26 May 2026 21:08:19 +0530 Subject: [PATCH 4/7] fix: clamp negative Retry-After values Signed-off-by: psrsingh --- cla-backend-go/sss/client.go | 3 +++ cla-backend-go/sss/client_test.go | 39 +++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/cla-backend-go/sss/client.go b/cla-backend-go/sss/client.go index 97c9a1960..78242c676 100644 --- a/cla-backend-go/sss/client.go +++ b/cla-backend-go/sss/client.go @@ -224,6 +224,9 @@ func parseRetryAfter(value string) time.Duration { } if seconds, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { + if seconds < 0 { + return 0 + } return time.Duration(seconds) * time.Second } diff --git a/cla-backend-go/sss/client_test.go b/cla-backend-go/sss/client_test.go index 982cfee88..98a7841af 100644 --- a/cla-backend-go/sss/client_test.go +++ b/cla-backend-go/sss/client_test.go @@ -177,6 +177,45 @@ func TestGetOrganizationStatus_TooManyRequestsReturnsRetryableError(t *testing.T } } +func TestGetOrganizationStatus_TooManyRequestsClampsNegativeRetryAfter(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-negative-retry","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "-5") + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprint(w, `{"error":{"message":"rate limit exceeded"}}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RateOrg"}) + var retryErr *RetryableError + if !errors.As(err, &retryErr) { + t.Fatalf("expected RetryableError, got %T: %v", err, err) + } + if retryErr.RetryAfter != 0 { + t.Fatalf("expected retry after 0s, got %v", retryErr.RetryAfter) + } + if retryErr.Message != "rate limit exceeded" { + t.Fatalf("unexpected retry message: %s", retryErr.Message) + } +} + func TestGetOrganizationStatus_FlaggedResponse(t *testing.T) { authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") From 1d783badde4adff2a84576afd66f3eb78b0b453b Mon Sep 17 00:00:00 2001 From: psrsingh Date: Wed, 27 May 2026 14:18:49 +0530 Subject: [PATCH 5/7] fix: align SSS client contract with API schema Signed-off-by: psrsingh --- cla-backend-go/sss/client.go | 2 +- cla-backend-go/sss/client_test.go | 26 ++++++++++++++----- cla-backend-go/sss/types.go | 8 ++++-- cla-backend-go/v2/sign/service.go | 43 ------------------------------- 4 files changed, 26 insertions(+), 53 deletions(-) diff --git a/cla-backend-go/sss/client.go b/cla-backend-go/sss/client.go index 78242c676..900d13a52 100644 --- a/cla-backend-go/sss/client.go +++ b/cla-backend-go/sss/client.go @@ -92,7 +92,7 @@ func (c *Client) GetOrganizationStatus(ctx context.Context, statusReq Organizati if v := strings.TrimSpace(statusReq.PostalCode); v != "" { query.Set("postal_code", v) } - if v := strings.TrimSpace(statusReq.SfdcID); v != "" { + if v := strings.TrimSpace(statusReq.SFDCID); v != "" { query.Set("sfdc_id", v) } if v := strings.TrimSpace(statusReq.ClearbitID); v != "" { diff --git a/cla-backend-go/sss/client_test.go b/cla-backend-go/sss/client_test.go index 98a7841af..850975576 100644 --- a/cla-backend-go/sss/client_test.go +++ b/cla-backend-go/sss/client_test.go @@ -58,7 +58,7 @@ func TestGetOrganizationStatus_Success(t *testing.T) { t.Fatalf("unexpected auth header: %s", got) } w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"status":"clean","entity_id":"org-123","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-123","source":"screening_db","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"Example Corp"}`) })) defer serviceServer.Close() @@ -81,15 +81,21 @@ func TestGetOrganizationStatus_Success(t *testing.T) { City: "San Francisco", State: "CA", PostalCode: "94105", - SfdcID: "SFDC-123", + SFDCID: "SFDC-123", ClearbitID: "CLEARBIT-123", }) if err != nil { t.Fatal(err) } - if result.Status != StatusClean || result.EntityID != "org-123" || result.Source != "ofac" { + if result.Status != StatusClean || result.EntityID != "org-123" || result.Source != SourceScreeningDB { t.Fatalf("unexpected result: %+v", result) } + if result.SFDCID != nil { + t.Fatalf("expected nullable sfdc_id to decode as nil, got %q", *result.SFDCID) + } + if result.Vendor != "descartes" || !result.ClearbitEnriched || result.Domain != "example.com" || result.OrgName != "Example Corp" { + t.Fatalf("unexpected enriched fields: %+v", result) + } if !result.ScreenedAt.Equal(time.Date(2025, 5, 16, 12, 34, 56, 0, time.UTC)) { t.Fatalf("unexpected screened_at: %v", result.ScreenedAt) } @@ -225,7 +231,7 @@ func TestGetOrganizationStatus_FlaggedResponse(t *testing.T) { serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"status":"flagged","entity_id":"org-flagged","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"flagged","entity_id":"org-flagged","source":"descartes_api","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":false,"sfdc_id":"SFDC-456","domain":"example.org","org_name":"Flagged Org"}`) })) defer serviceServer.Close() @@ -251,6 +257,12 @@ func TestGetOrganizationStatus_FlaggedResponse(t *testing.T) { if result.Status != StatusFlagged || result.EntityID != "org-flagged" { t.Fatalf("unexpected result: %+v", result) } + if result.Source != SourceDescartesAPI { + t.Fatalf("unexpected source: %s", result.Source) + } + if result.SFDCID == nil || *result.SFDCID != "SFDC-456" { + t.Fatalf("unexpected sfdc_id: %+v", result.SFDCID) + } } func TestGetOrganizationStatus_400ReturnsBadRequestError(t *testing.T) { @@ -375,7 +387,7 @@ func TestGetOrganizationStatus_TimeoutReturnsTimeoutError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) w.WriteHeader(http.StatusOK) - fmt.Fprint(w, `{"status":"clean","entity_id":"org-timeout","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-timeout","source":"sfdc","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"TimeoutOrg"}`) })) defer server.Close() @@ -409,7 +421,7 @@ func TestGetOrganizationStatus_UsesCachedToken(t *testing.T) { serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"status":"clean","entity_id":"org-cache","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-cache","source":"sfdc","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"CacheOrg"}`) })) defer serviceServer.Close() @@ -451,7 +463,7 @@ func TestGetOrganizationStatus_RefreshesExpiredToken(t *testing.T) { t.Fatalf("missing authorization header") } w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"status":"clean","entity_id":"org-expire","source":"ofac","screened_at":"2025-05-16T12:34:56Z"}`) + fmt.Fprint(w, `{"status":"clean","entity_id":"org-expire","source":"screening_db","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"ExpireOrg"}`) })) defer serviceServer.Close() diff --git a/cla-backend-go/sss/types.go b/cla-backend-go/sss/types.go index e20d810ad..c3e0bb985 100644 --- a/cla-backend-go/sss/types.go +++ b/cla-backend-go/sss/types.go @@ -23,13 +23,17 @@ type OrganizationStatusRequest struct { City string `json:"city,omitempty"` State string `json:"state,omitempty"` PostalCode string `json:"postal_code,omitempty"` - SfdcID string `json:"sfdc_id,omitempty"` + SFDCID string `json:"sfdc_id,omitempty"` ClearbitID string `json:"clearbit_id,omitempty"` } const ( StatusClean = "clean" StatusFlagged = "flagged" + + SourceScreeningDB = "screening_db" + SourceSFDC = "sfdc" + SourceDescartesAPI = "descartes_api" ) // ScreeningResult is returned by the SSS organization status endpoint. @@ -39,7 +43,7 @@ type ScreeningResult struct { Source string `json:"source"` ScreenedAt time.Time `json:"screened_at"` ClearbitID string `json:"clearbit_id,omitempty"` - SfdcID string `json:"sfdc_id,omitempty"` + SFDCID *string `json:"sfdc_id"` OrgName string `json:"org_name,omitempty"` Domain string `json:"domain,omitempty"` Vendor string `json:"vendor,omitempty"` diff --git a/cla-backend-go/v2/sign/service.go b/cla-backend-go/v2/sign/service.go index a1cefe2b6..238f26b6f 100644 --- a/cla-backend-go/v2/sign/service.go +++ b/cla-backend-go/v2/sign/service.go @@ -2223,49 +2223,6 @@ func getUserEmail(user *v1Models.User, preferredEmail string, providerType strin return "" } -func getActiveSignatureReturnURL(userID string, metadata map[string]interface{}) (string, error) { - - f := logrus.Fields{ - "functionName": "sign.getActiveSignatureReturnURL", - } - - var returnURL string - var err error - var pullRequestID int - var installationID int64 - var repositoryID int64 - - if found, ok := metadata["pull_request_id"].(int); ok { - pullRequestID = found - } else { - log.WithFields(f).WithError(err).Warnf("unable to get pull request ID for user: %s", userID) - return "", err - } - - if found, ok := metadata["installation_id"].(int64); ok { - installationID = found - } else { - log.WithFields(f).WithError(err).Warnf("unable to get installation ID for user: %s", userID) - return "", err - } - - if found, ok := metadata["repository_id"].(int64); ok { - repositoryID = found - } else { - log.WithFields(f).WithError(err).Warnf("unable to get repository ID for user: %s", userID) - return "", err - } - - returnURL, err = github.GetReturnURL(context.Background(), installationID, repositoryID, pullRequestID) - - if err != nil { - return "", err - } - - return returnURL, nil - -} - func (s *service) createDefaultIndividualValues(user *v1Models.User, preferredEmail string, allowLFEmail bool) map[string]interface{} { f := logrus.Fields{ "functionName": "sign.createDefaultIndiviualValues", From e9ee6ac85aec5f411da6899fdec473493737c1cc Mon Sep 17 00:00:00 2001 From: psrsingh Date: Thu, 28 May 2026 14:36:30 +0530 Subject: [PATCH 6/7] fix(sss): address PR review feedback Signed-off-by: psrsingh --- cla-backend-go/sss/client.go | 67 +++++++++-- cla-backend-go/sss/client_test.go | 193 +++++++++++++++++++++++++++++- cla-backend-go/sss/errors.go | 21 +++- cla-backend-go/sss/types.go | 10 +- 4 files changed, 272 insertions(+), 19 deletions(-) diff --git a/cla-backend-go/sss/client.go b/cla-backend-go/sss/client.go index 900d13a52..9c8f023ae 100644 --- a/cla-backend-go/sss/client.go +++ b/cla-backend-go/sss/client.go @@ -19,7 +19,11 @@ import ( "time" ) -const defaultTimeout = 10 * time.Second +const ( + defaultTimeout = 30 * time.Second + defaultTokenTTL = time.Hour + userAgent = "easycla-cla-backend-go/sss-client" +) // Client is a reusable HTTP client for the Sanctions Screening Service. type Client struct { @@ -107,6 +111,7 @@ func (c *Client) GetOrganizationStatus(ctx context.Context, statusReq Organizati req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", userAgent) resp, err := c.httpClient.Do(req) if err != nil { @@ -127,11 +132,18 @@ func (c *Client) GetOrganizationStatus(ctx context.Context, statusReq Organizati } return &result, nil case http.StatusBadRequest: - return nil, &BadRequestError{Message: responseMessage(body)} + details := responseErrorDetails(body) + return nil, &BadRequestError{Message: details.Message, Code: details.Code, RequestID: details.RequestID} + case http.StatusNotFound: + details := responseErrorDetails(body) + return nil, &NotFoundError{Message: details.Message, Code: details.Code, RequestID: details.RequestID} case http.StatusUnauthorized, http.StatusForbidden: - return nil, &AuthError{Message: responseMessage(body)} + c.invalidateToken(token) + details := responseErrorDetails(body) + return nil, &AuthError{Message: details.Message, Code: details.Code, RequestID: details.RequestID} case http.StatusTooManyRequests, http.StatusServiceUnavailable: - return nil, &RetryableError{Message: responseMessage(body), RetryAfter: parseRetryAfter(resp.Header.Get("Retry-After"))} + details := responseErrorDetails(body) + return nil, &RetryableError{Message: details.Message, Code: details.Code, RequestID: details.RequestID, RetryAfter: parseRetryAfter(resp.Header.Get("Retry-After"))} default: return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, responseMessage(body)) } @@ -175,6 +187,7 @@ func (c *Client) fetchToken(ctx context.Context) (string, error) { return "", fmt.Errorf("failed to create auth request: %w", err) } req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", userAgent) resp, err := c.httpClient.Do(req) if err != nil { @@ -188,7 +201,11 @@ func (c *Client) fetchToken(ctx context.Context) (string, error) { } if resp.StatusCode != http.StatusOK { - return "", &AuthError{Message: fmt.Sprintf("authentication failed: %s", strings.TrimSpace(string(body)))} + details := responseErrorDetails(body) + if details.Message == "" { + details.Message = strings.TrimSpace(string(body)) + } + return "", &AuthError{Message: fmt.Sprintf("authentication failed: %s", details.Message), Code: details.Code, RequestID: details.RequestID} } var authResponse authResponse @@ -202,7 +219,7 @@ func (c *Client) fetchToken(ctx context.Context) (string, error) { expiresIn := time.Duration(authResponse.ExpiresIn) * time.Second if expiresIn <= 0 { - expiresIn = defaultTimeout + expiresIn = defaultTokenTTL } c.token = authResponse.AccessToken c.expiry = time.Now().Add(expiresIn) @@ -210,6 +227,16 @@ func (c *Client) fetchToken(ctx context.Context) (string, error) { return c.token, nil } +func (c *Client) invalidateToken(token string) { + c.tokenMutex.Lock() + defer c.tokenMutex.Unlock() + + if c.token == token { + c.token = "" + c.expiry = time.Time{} + } +} + func (c *Client) authTokenURL() string { domain := strings.TrimSpace(c.cfg.Auth0Domain) if strings.HasPrefix(domain, "http://") || strings.HasPrefix(domain, "https://") { @@ -240,16 +267,36 @@ func parseRetryAfter(value string) time.Duration { return 0 } -func responseMessage(body []byte) string { +type upstreamErrorDetails struct { + Message string + Code string + RequestID string +} + +func responseErrorDetails(body []byte) upstreamErrorDetails { var errPayload struct { Error struct { + Code string `json:"code"` Message string `json:"message"` } `json:"error"` + RequestID string `json:"request_id"` } - if err := json.Unmarshal(body, &errPayload); err == nil && strings.TrimSpace(errPayload.Error.Message) != "" { - return strings.TrimSpace(errPayload.Error.Message) + if err := json.Unmarshal(body, &errPayload); err == nil { + details := upstreamErrorDetails{ + Message: strings.TrimSpace(errPayload.Error.Message), + Code: strings.TrimSpace(errPayload.Error.Code), + RequestID: strings.TrimSpace(errPayload.RequestID), + } + if details.Message != "" || details.Code != "" || details.RequestID != "" { + return details + } } - return strings.TrimSpace(string(body)) + + return upstreamErrorDetails{Message: strings.TrimSpace(string(body))} +} + +func responseMessage(body []byte) string { + return responseErrorDetails(body).Message } func toClientError(err error) error { diff --git a/cla-backend-go/sss/client_test.go b/cla-backend-go/sss/client_test.go index 850975576..ce65f4f21 100644 --- a/cla-backend-go/sss/client_test.go +++ b/cla-backend-go/sss/client_test.go @@ -20,6 +20,9 @@ func TestGetOrganizationStatus_Success(t *testing.T) { if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) } + if got := r.Header.Get("User-Agent"); got != userAgent { + t.Fatalf("unexpected auth user-agent: %s", got) + } atomic.AddInt32(&authCalls, 1) w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, `{"access_token":"token-abc","expires_in":3600,"token_type":"Bearer"}`) @@ -57,6 +60,9 @@ func TestGetOrganizationStatus_Success(t *testing.T) { if got := r.Header.Get("Authorization"); got != "Bearer token-abc" { t.Fatalf("unexpected auth header: %s", got) } + if got := r.Header.Get("User-Agent"); got != userAgent { + t.Fatalf("unexpected service user-agent: %s", got) + } w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, `{"status":"clean","entity_id":"org-123","source":"screening_db","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"Example Corp"}`) })) @@ -154,7 +160,7 @@ func TestGetOrganizationStatus_TooManyRequestsReturnsRetryableError(t *testing.T server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Retry-After", "5") w.WriteHeader(http.StatusTooManyRequests) - fmt.Fprint(w, `{"error":{"message":"rate limit exceeded"}}`) + fmt.Fprint(w, `{"error":{"code":"RATE_LIMITED","message":"rate limit exceeded"},"request_id":"req-429"}`) })) defer server.Close() @@ -181,6 +187,9 @@ func TestGetOrganizationStatus_TooManyRequestsReturnsRetryableError(t *testing.T if retryErr.Message != "rate limit exceeded" { t.Fatalf("unexpected retry message: %s", retryErr.Message) } + if retryErr.Code != "RATE_LIMITED" || retryErr.RequestID != "req-429" { + t.Fatalf("unexpected retry details: %+v", retryErr) + } } func TestGetOrganizationStatus_TooManyRequestsClampsNegativeRetryAfter(t *testing.T) { @@ -300,6 +309,76 @@ func TestGetOrganizationStatus_400ReturnsBadRequestError(t *testing.T) { } } +func TestGetOrganizationStatus_400PreservesStructuredErrorDetails(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-400-structured","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `{"error":{"code":"INVALID_DOMAIN","message":"invalid domain"},"request_id":"req-400"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "bad domain", OrgName: "BadOrg"}) + var badReq *BadRequestError + if !errors.As(err, &badReq) { + t.Fatalf("expected BadRequestError, got %T: %v", err, err) + } + if badReq.Message != "invalid domain" || badReq.Code != "INVALID_DOMAIN" || badReq.RequestID != "req-400" { + t.Fatalf("unexpected bad request details: %+v", badReq) + } +} + +func TestGetOrganizationStatus_404ReturnsNotFoundError(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"token-404","expires_in":3600,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"error":{"code":"ORG_NOT_FOUND","message":"Organization not found in any tier"},"request_id":"req-404"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "missing.example", OrgName: "MissingOrg"}) + var notFound *NotFoundError + if !errors.As(err, ¬Found) { + t.Fatalf("expected NotFoundError, got %T: %v", err, err) + } + if notFound.Message != "Organization not found in any tier" || notFound.Code != "ORG_NOT_FOUND" || notFound.RequestID != "req-404" { + t.Fatalf("unexpected not found details: %+v", notFound) + } +} + func TestGetOrganizationStatus_401ReturnsAuthError(t *testing.T) { authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { @@ -312,7 +391,7 @@ func TestGetOrganizationStatus_401ReturnsAuthError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) - fmt.Fprint(w, `unauthorized`) + fmt.Fprint(w, `{"error":{"code":"TOKEN_EXPIRED","message":"unauthorized"},"request_id":"req-401"}`) })) defer server.Close() @@ -333,6 +412,63 @@ func TestGetOrganizationStatus_401ReturnsAuthError(t *testing.T) { if !errors.As(err, &authErr) { t.Fatalf("expected AuthError, got %T: %v", err, err) } + if authErr.Message != "unauthorized" || authErr.Code != "TOKEN_EXPIRED" || authErr.RequestID != "req-401" { + t.Fatalf("unexpected auth details: %+v", authErr) + } +} + +func TestGetOrganizationStatus_401InvalidatesCachedToken(t *testing.T) { + authCalls := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + index := atomic.AddInt32(&authCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"token-%d","expires_in":3600,"token_type":"Bearer"}`, index) + })) + defer authServer.Close() + + serviceCalls := int32(0) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + index := atomic.AddInt32(&serviceCalls, 1) + if index == 1 { + if got := r.Header.Get("Authorization"); got != "Bearer token-1" { + t.Fatalf("unexpected first auth header: %s", got) + } + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error":{"code":"TOKEN_EXPIRED","message":"token expired"},"request_id":"req-expired"}`) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer token-2" { + t.Fatalf("expected refreshed token on second request, got %s", got) + } + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clean","entity_id":"org-refresh","source":"sfdc","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"RefreshOrg"}`) + })) + defer server.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: server.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RefreshOrg"}) + var authErr *AuthError + if !errors.As(err, &authErr) { + t.Fatalf("expected AuthError, got %T: %v", err, err) + } + + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RefreshOrg"}); err != nil { + t.Fatal(err) + } + if got := atomic.LoadInt32(&authCalls); got != 2 { + t.Fatalf("expected token refetch after auth failure, got %d auth calls", got) + } } func TestGetOrganizationStatus_503ReturnsRetryableError(t *testing.T) { @@ -447,6 +583,43 @@ func TestGetOrganizationStatus_UsesCachedToken(t *testing.T) { } } +func TestGetOrganizationStatus_CachesTokenWhenExpiresInMissing(t *testing.T) { + authCalls := int32(0) + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&authCalls, 1) + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"fallback-token","expires_in":0,"token_type":"Bearer"}`) + })) + defer authServer.Close() + + serviceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"status":"clean","entity_id":"org-fallback","source":"sfdc","screened_at":"2025-05-16T12:34:56Z","vendor":"descartes","clearbit_enriched":true,"sfdc_id":null,"domain":"example.com","org_name":"FallbackOrg"}`) + })) + defer serviceServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: serviceServer.URL, + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "FallbackOrg"}); err != nil { + t.Fatal(err) + } + } + if got := atomic.LoadInt32(&authCalls); got != 1 { + t.Fatalf("expected fallback token ttl to cache token, got %d auth calls", got) + } +} + func TestGetOrganizationStatus_RefreshesExpiredToken(t *testing.T) { tokenIndex := int32(0) authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -489,3 +662,19 @@ func TestGetOrganizationStatus_RefreshesExpiredToken(t *testing.T) { t.Fatalf("expected token refresh after expiry, got %d auth calls", got) } } + +func TestNewClient_DefaultTimeout(t *testing.T) { + client, err := NewClient(SSSConfig{ + BaseURL: "https://sss.example.com", + Auth0Domain: "https://auth.example.com", + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + }) + if err != nil { + t.Fatal(err) + } + if client.httpClient.Timeout != defaultTimeout { + t.Fatalf("expected default timeout %v, got %v", defaultTimeout, client.httpClient.Timeout) + } +} diff --git a/cla-backend-go/sss/errors.go b/cla-backend-go/sss/errors.go index d61fb235a..60c130747 100644 --- a/cla-backend-go/sss/errors.go +++ b/cla-backend-go/sss/errors.go @@ -10,7 +10,9 @@ import ( // BadRequestError indicates a 400 response from the SSS API. type BadRequestError struct { - Message string + Message string + Code string + RequestID string } func (e *BadRequestError) Error() string { @@ -19,7 +21,9 @@ func (e *BadRequestError) Error() string { // AuthError indicates a 401 or 403 response from the SSS API. type AuthError struct { - Message string + Message string + Code string + RequestID string } func (e *AuthError) Error() string { @@ -29,6 +33,8 @@ func (e *AuthError) Error() string { // RetryableError indicates a 503 response from the SSS API. type RetryableError struct { Message string + Code string + RequestID string RetryAfter time.Duration } @@ -36,6 +42,17 @@ func (e *RetryableError) Error() string { return fmt.Sprintf("retryable error: %s", e.Message) } +// NotFoundError indicates a 404 response from the SSS API. +type NotFoundError struct { + Message string + Code string + RequestID string +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("not found: %s", e.Message) +} + // TimeoutError indicates the request timed out. type TimeoutError struct { Message string diff --git a/cla-backend-go/sss/types.go b/cla-backend-go/sss/types.go index c3e0bb985..8f9617508 100644 --- a/cla-backend-go/sss/types.go +++ b/cla-backend-go/sss/types.go @@ -42,10 +42,10 @@ type ScreeningResult struct { EntityID string `json:"entity_id"` Source string `json:"source"` ScreenedAt time.Time `json:"screened_at"` - ClearbitID string `json:"clearbit_id,omitempty"` + ClearbitID string `json:"clearbit_id"` SFDCID *string `json:"sfdc_id"` - OrgName string `json:"org_name,omitempty"` - Domain string `json:"domain,omitempty"` - Vendor string `json:"vendor,omitempty"` - ClearbitEnriched bool `json:"clearbit_enriched,omitempty"` + OrgName string `json:"org_name"` + Domain string `json:"domain"` + Vendor string `json:"vendor"` + ClearbitEnriched bool `json:"clearbit_enriched"` } From d520dd3400a80c7dc0a23e66d545d54b7e9af07b Mon Sep 17 00:00:00 2001 From: psrsingh Date: Thu, 28 May 2026 15:56:50 +0530 Subject: [PATCH 7/7] fix(sss): address review feedback and improve auth error handling Signed-off-by: psrsingh --- cla-backend-go/sss/client.go | 30 +++++++-- cla-backend-go/sss/client_test.go | 103 +++++++++++++++++++++--------- cla-backend-go/sss/errors.go | 21 ++++-- cla-backend-go/sss/types.go | 8 ++- 4 files changed, 118 insertions(+), 44 deletions(-) diff --git a/cla-backend-go/sss/client.go b/cla-backend-go/sss/client.go index 9c8f023ae..04273ed66 100644 --- a/cla-backend-go/sss/client.go +++ b/cla-backend-go/sss/client.go @@ -201,27 +201,43 @@ func (c *Client) fetchToken(ctx context.Context) (string, error) { } if resp.StatusCode != http.StatusOK { - details := responseErrorDetails(body) - if details.Message == "" { + var auth0Err struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + details := upstreamErrorDetails{} + if err := json.Unmarshal(body, &auth0Err); err == nil { + details.Code = strings.TrimSpace(auth0Err.Error) + details.Message = strings.TrimSpace(auth0Err.ErrorDescription) + if details.Message == "" && details.Code != "" { + details.Message = details.Code + } + if details.Message == "" { + details.Message = resp.Status + } + } else { details.Message = strings.TrimSpace(string(body)) + if details.Message == "" { + details.Message = resp.Status + } } return "", &AuthError{Message: fmt.Sprintf("authentication failed: %s", details.Message), Code: details.Code, RequestID: details.RequestID} } - var authResponse authResponse - if err := json.Unmarshal(body, &authResponse); err != nil { + var tokenResp authResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { return "", fmt.Errorf("failed to decode auth response: %w", err) } - if authResponse.AccessToken == "" { + if tokenResp.AccessToken == "" { return "", &AuthError{Message: "empty access token from auth server"} } - expiresIn := time.Duration(authResponse.ExpiresIn) * time.Second + expiresIn := time.Duration(tokenResp.ExpiresIn) * time.Second if expiresIn <= 0 { expiresIn = defaultTokenTTL } - c.token = authResponse.AccessToken + c.token = tokenResp.AccessToken c.expiry = time.Now().Add(expiresIn) return c.token, nil diff --git a/cla-backend-go/sss/client_test.go b/cla-backend-go/sss/client_test.go index ce65f4f21..d277cdbe1 100644 --- a/cla-backend-go/sss/client_test.go +++ b/cla-backend-go/sss/client_test.go @@ -14,10 +14,17 @@ import ( "time" ) +const ( + testAuthTokenPath = "/oauth/token" + testOrgDomain = "example.com" + testOrgName = "Example Corp" + testRateLimitExceeded = "rate limit exceeded" +) + func TestGetOrganizationStatus_Success(t *testing.T) { authCalls := int32(0) authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) } if got := r.Header.Get("User-Agent"); got != userAgent { @@ -33,10 +40,10 @@ func TestGetOrganizationStatus_Success(t *testing.T) { if r.Method != http.MethodGet || r.URL.Path != "/api/v1/organizations/status" { t.Fatalf("unexpected service request %s %s", r.Method, r.URL.Path) } - if got := r.URL.Query().Get("domain"); got != "example.com" { + if got := r.URL.Query().Get("domain"); got != testOrgDomain { t.Fatalf("unexpected domain: %s", got) } - if got := r.URL.Query().Get("org_name"); got != "Example Corp" { + if got := r.URL.Query().Get("org_name"); got != testOrgName { t.Fatalf("unexpected org_name: %s", got) } if got := r.URL.Query().Get("country"); got != "US" { @@ -81,8 +88,8 @@ func TestGetOrganizationStatus_Success(t *testing.T) { } result, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{ - Domain: "example.com", - OrgName: "Example Corp", + Domain: testOrgDomain, + OrgName: testOrgName, Country: "US", City: "San Francisco", State: "CA", @@ -99,7 +106,7 @@ func TestGetOrganizationStatus_Success(t *testing.T) { if result.SFDCID != nil { t.Fatalf("expected nullable sfdc_id to decode as nil, got %q", *result.SFDCID) } - if result.Vendor != "descartes" || !result.ClearbitEnriched || result.Domain != "example.com" || result.OrgName != "Example Corp" { + if result.Vendor != "descartes" || !result.ClearbitEnriched || result.Domain != testOrgDomain || result.OrgName != testOrgName { t.Fatalf("unexpected enriched fields: %+v", result) } if !result.ScreenedAt.Equal(time.Date(2025, 5, 16, 12, 34, 56, 0, time.UTC)) { @@ -112,7 +119,7 @@ func TestGetOrganizationStatus_Success(t *testing.T) { func TestGetOrganizationStatus_MissingDomain(t *testing.T) { client, err := NewClient(SSSConfig{ - BaseURL: "https://example.com", + BaseURL: "https://" + testOrgDomain, Auth0Domain: "https://auth.example.com", Auth0ClientID: "id", Auth0ClientSecret: "secret", @@ -132,7 +139,7 @@ func TestGetOrganizationStatus_MissingDomain(t *testing.T) { func TestGetOrganizationStatus_MissingOrgName(t *testing.T) { client, err := NewClient(SSSConfig{ - BaseURL: "https://example.com", + BaseURL: "https://" + testOrgDomain, Auth0Domain: "https://auth.example.com", Auth0ClientID: "id", Auth0ClientSecret: "secret", @@ -143,7 +150,7 @@ func TestGetOrganizationStatus_MissingOrgName(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com"}) + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain}) var badReq *BadRequestError if !errors.As(err, &badReq) { t.Fatalf("expected BadRequestError, got %T: %v", err, err) @@ -160,7 +167,7 @@ func TestGetOrganizationStatus_TooManyRequestsReturnsRetryableError(t *testing.T server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Retry-After", "5") w.WriteHeader(http.StatusTooManyRequests) - fmt.Fprint(w, `{"error":{"code":"RATE_LIMITED","message":"rate limit exceeded"},"request_id":"req-429"}`) + fmt.Fprintf(w, `{"error":{"code":"RATE_LIMITED","message":"%s"},"request_id":"req-429"}`, testRateLimitExceeded) })) defer server.Close() @@ -176,7 +183,7 @@ func TestGetOrganizationStatus_TooManyRequestsReturnsRetryableError(t *testing.T t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RateOrg"}) + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RateOrg"}) var retryErr *RetryableError if !errors.As(err, &retryErr) { t.Fatalf("expected RetryableError, got %T: %v", err, err) @@ -184,7 +191,7 @@ func TestGetOrganizationStatus_TooManyRequestsReturnsRetryableError(t *testing.T if retryErr.RetryAfter != 5*time.Second { t.Fatalf("expected retry after 5s, got %v", retryErr.RetryAfter) } - if retryErr.Message != "rate limit exceeded" { + if retryErr.Message != testRateLimitExceeded { t.Fatalf("unexpected retry message: %s", retryErr.Message) } if retryErr.Code != "RATE_LIMITED" || retryErr.RequestID != "req-429" { @@ -202,7 +209,7 @@ func TestGetOrganizationStatus_TooManyRequestsClampsNegativeRetryAfter(t *testin server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Retry-After", "-5") w.WriteHeader(http.StatusTooManyRequests) - fmt.Fprint(w, `{"error":{"message":"rate limit exceeded"}}`) + fmt.Fprintf(w, `{"error":{"message":"%s"}}`, testRateLimitExceeded) })) defer server.Close() @@ -218,7 +225,7 @@ func TestGetOrganizationStatus_TooManyRequestsClampsNegativeRetryAfter(t *testin t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RateOrg"}) + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RateOrg"}) var retryErr *RetryableError if !errors.As(err, &retryErr) { t.Fatalf("expected RetryableError, got %T: %v", err, err) @@ -226,7 +233,7 @@ func TestGetOrganizationStatus_TooManyRequestsClampsNegativeRetryAfter(t *testin if retryErr.RetryAfter != 0 { t.Fatalf("expected retry after 0s, got %v", retryErr.RetryAfter) } - if retryErr.Message != "rate limit exceeded" { + if retryErr.Message != testRateLimitExceeded { t.Fatalf("unexpected retry message: %s", retryErr.Message) } } @@ -276,7 +283,7 @@ func TestGetOrganizationStatus_FlaggedResponse(t *testing.T) { func TestGetOrganizationStatus_400ReturnsBadRequestError(t *testing.T) { authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) } w.Header().Set("Content-Type", "application/json") @@ -302,7 +309,7 @@ func TestGetOrganizationStatus_400ReturnsBadRequestError(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "BadOrg"}) + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "BadOrg"}) var badReq *BadRequestError if !errors.As(err, &badReq) { t.Fatalf("expected BadRequestError, got %T: %v", err, err) @@ -381,7 +388,7 @@ func TestGetOrganizationStatus_404ReturnsNotFoundError(t *testing.T) { func TestGetOrganizationStatus_401ReturnsAuthError(t *testing.T) { authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) } w.Header().Set("Content-Type", "application/json") @@ -407,7 +414,7 @@ func TestGetOrganizationStatus_401ReturnsAuthError(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "AuthOrg"}) + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "AuthOrg"}) var authErr *AuthError if !errors.As(err, &authErr) { t.Fatalf("expected AuthError, got %T: %v", err, err) @@ -457,13 +464,13 @@ func TestGetOrganizationStatus_401InvalidatesCachedToken(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RefreshOrg"}) + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RefreshOrg"}) var authErr *AuthError if !errors.As(err, &authErr) { t.Fatalf("expected AuthError, got %T: %v", err, err) } - if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RefreshOrg"}); err != nil { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RefreshOrg"}); err != nil { t.Fatal(err) } if got := atomic.LoadInt32(&authCalls); got != 2 { @@ -471,9 +478,47 @@ func TestGetOrganizationStatus_401InvalidatesCachedToken(t *testing.T) { } } +func TestGetOrganizationStatus_Auth0ErrorUsesAuth0Payload(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { + t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) + } + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error":"invalid_client","error_description":"Client authentication failed"}`) + })) + defer authServer.Close() + + client, err := NewClient(SSSConfig{ + BaseURL: "https://sss.example.com", + Auth0Domain: authServer.URL, + Auth0ClientID: "id", + Auth0ClientSecret: "secret", + Auth0Audience: "audience", + Timeout: 5 * time.Second, + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "Auth0Org"}) + var authErr *AuthError + if !errors.As(err, &authErr) { + t.Fatalf("expected AuthError, got %T: %v", err, err) + } + if authErr.Code != "invalid_client" { + t.Fatalf("unexpected auth code: %s", authErr.Code) + } + if authErr.Message != "authentication failed: Client authentication failed" { + t.Fatalf("unexpected auth message: %s", authErr.Message) + } + if got := authErr.Error(); got != "authentication error: authentication failed: Client authentication failed (code=invalid_client request_id=)" { + t.Fatalf("unexpected auth error string: %s", got) + } +} + func TestGetOrganizationStatus_503ReturnsRetryableError(t *testing.T) { authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) } w.Header().Set("Content-Type", "application/json") @@ -500,7 +545,7 @@ func TestGetOrganizationStatus_503ReturnsRetryableError(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "RetryOrg"}) + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "RetryOrg"}) var retryErr *RetryableError if !errors.As(err, &retryErr) { t.Fatalf("expected RetryableError, got %T: %v", err, err) @@ -512,7 +557,7 @@ func TestGetOrganizationStatus_503ReturnsRetryableError(t *testing.T) { func TestGetOrganizationStatus_TimeoutReturnsTimeoutError(t *testing.T) { authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost || r.URL.Path != "/oauth/token" { + if r.Method != http.MethodPost || r.URL.Path != testAuthTokenPath { t.Fatalf("unexpected auth request %s %s", r.Method, r.URL.Path) } w.Header().Set("Content-Type", "application/json") @@ -539,7 +584,7 @@ func TestGetOrganizationStatus_TimeoutReturnsTimeoutError(t *testing.T) { t.Fatal(err) } - _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "TimeoutOrg"}) + _, err = client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "TimeoutOrg"}) var timeoutErr *TimeoutError if !errors.As(err, &timeoutErr) { t.Fatalf("expected TimeoutError, got %T: %v", err, err) @@ -574,7 +619,7 @@ func TestGetOrganizationStatus_UsesCachedToken(t *testing.T) { } for i := 0; i < 2; i++ { - if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "CacheOrg"}); err != nil { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "CacheOrg"}); err != nil { t.Fatal(err) } } @@ -611,7 +656,7 @@ func TestGetOrganizationStatus_CachesTokenWhenExpiresInMissing(t *testing.T) { } for i := 0; i < 2; i++ { - if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "FallbackOrg"}); err != nil { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "FallbackOrg"}); err != nil { t.Fatal(err) } } @@ -652,10 +697,10 @@ func TestGetOrganizationStatus_RefreshesExpiredToken(t *testing.T) { t.Fatal(err) } - if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "ExpireOrg"}); err != nil { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "ExpireOrg"}); err != nil { t.Fatal(err) } - if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: "example.com", OrgName: "ExpireOrg"}); err != nil { + if _, err := client.GetOrganizationStatus(context.Background(), OrganizationStatusRequest{Domain: testOrgDomain, OrgName: "ExpireOrg"}); err != nil { t.Fatal(err) } if got := atomic.LoadInt32(&tokenIndex); got < 2 { diff --git a/cla-backend-go/sss/errors.go b/cla-backend-go/sss/errors.go index 60c130747..bece4b45d 100644 --- a/cla-backend-go/sss/errors.go +++ b/cla-backend-go/sss/errors.go @@ -16,7 +16,7 @@ type BadRequestError struct { } func (e *BadRequestError) Error() string { - return fmt.Sprintf("bad request: %s", e.Message) + return formatError("bad request", e.Message, e.Code, e.RequestID) } // AuthError indicates a 401 or 403 response from the SSS API. @@ -27,7 +27,7 @@ type AuthError struct { } func (e *AuthError) Error() string { - return fmt.Sprintf("authentication error: %s", e.Message) + return formatError("authentication error", e.Message, e.Code, e.RequestID) } // RetryableError indicates a 503 response from the SSS API. @@ -39,7 +39,7 @@ type RetryableError struct { } func (e *RetryableError) Error() string { - return fmt.Sprintf("retryable error: %s", e.Message) + return formatError("retryable error", e.Message, e.Code, e.RequestID) } // NotFoundError indicates a 404 response from the SSS API. @@ -50,14 +50,23 @@ type NotFoundError struct { } func (e *NotFoundError) Error() string { - return fmt.Sprintf("not found: %s", e.Message) + return formatError("not found", e.Message, e.Code, e.RequestID) } // TimeoutError indicates the request timed out. type TimeoutError struct { - Message string + Message string + Code string + RequestID string } func (e *TimeoutError) Error() string { - return fmt.Sprintf("timeout: %s", e.Message) + return formatError("timeout", e.Message, e.Code, e.RequestID) +} + +func formatError(prefix, message, code, requestID string) string { + if code != "" || requestID != "" { + return fmt.Sprintf("%s: %s (code=%s request_id=%s)", prefix, message, code, requestID) + } + return fmt.Sprintf("%s: %s", prefix, message) } diff --git a/cla-backend-go/sss/types.go b/cla-backend-go/sss/types.go index 8f9617508..351b847de 100644 --- a/cla-backend-go/sss/types.go +++ b/cla-backend-go/sss/types.go @@ -11,8 +11,12 @@ type SSSConfig struct { Auth0Domain string Auth0ClientID string Auth0ClientSecret string - Auth0Audience string - Timeout time.Duration + // Auth0Audience is the Auth0 API audience/resource server identifier. + // Production values may require the exact identifier configured in Auth0, + // including a trailing slash when the resource server uses one. + Auth0Audience string + // Timeout is shared by SSS API requests and Auth0 token acquisition requests. + Timeout time.Duration } // OrganizationStatusRequest holds parameters for querying organization screening status.