diff --git a/internal/pkg/cli/command/auth/login.go b/internal/pkg/cli/command/auth/login.go index 1b264f3..0916590 100644 --- a/internal/pkg/cli/command/auth/login.go +++ b/internal/pkg/cli/command/auth/login.go @@ -49,6 +49,7 @@ var ( func NewLoginCmd() *cobra.Command { var jsonOutput bool + var orgId string cmd := &cobra.Command{ Use: "login", @@ -58,6 +59,9 @@ func NewLoginCmd() *cobra.Command { # Interactive login (opens a browser) pc auth login + # Login scoped to a specific organization (enables SSO routing) + pc auth login --org "ORG_ID" + # Agentic login — first call returns a pending URL pc auth login --json @@ -66,11 +70,16 @@ func NewLoginCmd() *cobra.Command { `), GroupID: help.GROUP_AUTH.ID, Run: func(cmd *cobra.Command, args []string) { - login.Run(cmd.Context(), login.Options{Json: jsonOutput}) + opts := login.Options{Json: jsonOutput} + if cmd.Flags().Changed("org") { + opts.OrgId = &orgId + } + login.Run(cmd.Context(), opts) }, } cmd.Flags().BoolVarP(&jsonOutput, "json", "j", false, "emit JSON output") + cmd.Flags().StringVar(&orgId, "org", "", "Organization ID to authenticate into (enables SSO routing for organizations with SSO enforced)") return cmd } diff --git a/internal/pkg/cli/command/login/login.go b/internal/pkg/cli/command/login/login.go index 40a6802..efcc567 100644 --- a/internal/pkg/cli/command/login/login.go +++ b/internal/pkg/cli/command/login/login.go @@ -41,6 +41,7 @@ var ( func NewLoginCmd() *cobra.Command { var jsonOutput bool + var orgId string cmd := &cobra.Command{ Use: "login", @@ -50,6 +51,9 @@ func NewLoginCmd() *cobra.Command { # Interactive login (opens a browser) pc login + # Login scoped to a specific organization (enables SSO routing) + pc login --org "ORG_ID" + # Agentic login — first call returns a pending URL pc login --json @@ -58,11 +62,16 @@ func NewLoginCmd() *cobra.Command { `), GroupID: help.GROUP_AUTH.ID, Run: func(cmd *cobra.Command, args []string) { - login.Run(cmd.Context(), login.Options{Json: jsonOutput}) + opts := login.Options{Json: jsonOutput} + if cmd.Flags().Changed("org") { + opts.OrgId = &orgId + } + login.Run(cmd.Context(), opts) }, } cmd.Flags().BoolVarP(&jsonOutput, "json", "j", false, "emit JSON output") + cmd.Flags().StringVar(&orgId, "org", "", "Organization ID to authenticate into (enables SSO routing for organizations with SSO enforced)") return cmd } diff --git a/internal/pkg/cli/command/target/target.go b/internal/pkg/cli/command/target/target.go index 6f91402..bdb0204 100644 --- a/internal/pkg/cli/command/target/target.go +++ b/internal/pkg/cli/command/target/target.go @@ -199,8 +199,11 @@ func NewTargetCmd() *cobra.Command { // If the org chosen differs from the current orgId in the token, we need to login again if currentTokenOrgId != "" && currentTokenOrgId != targetOrg.Id { + // Fetch SSO connection while the current token is still valid, + // before logout clears it. + ssoConn := login.ResolveSSOConnection(ctx, targetOrg.Id) oauth.Logout() - err = login.GetAndSetAccessToken(ctx, &targetOrg.Id, login.Options{Json: options.json, Wait: true}) + err = login.GetAndSetAccessToken(ctx, &targetOrg.Id, login.Options{Json: options.json, Wait: true, SSOConnection: ssoConn}) if err != nil { msg.FailJSON(options.json, "Failed to get access token: %s", err) exit.Error(err, "Error getting access token") @@ -245,8 +248,11 @@ func NewTargetCmd() *cobra.Command { // If the org chosen differs from the current orgId in the token, we need to login again if currentTokenOrgId != org.Id { + // Fetch SSO connection while the current token is still valid, + // before logout clears it. + ssoConn := login.ResolveSSOConnection(ctx, org.Id) oauth.Logout() - err = login.GetAndSetAccessToken(ctx, &org.Id, login.Options{Json: options.json, Wait: true}) + err = login.GetAndSetAccessToken(ctx, &org.Id, login.Options{Json: options.json, Wait: true, SSOConnection: ssoConn}) if err != nil { msg.FailJSON(options.json, "Failed to get access token: %s", err) exit.Error(err, "Error getting access token") diff --git a/internal/pkg/utils/login/login.go b/internal/pkg/utils/login/login.go index 4237428..7fb9878 100644 --- a/internal/pkg/utils/login/login.go +++ b/internal/pkg/utils/login/login.go @@ -46,6 +46,13 @@ type Options struct { // RunPostAuthSetup is not called in Wait mode; the caller is responsible // for any post-auth state setup and output. Wait bool + // OrgId pins the login flow to a specific organization. + OrgId *string + // SSOConnection is the Auth0 connection name to pass as `connection=` in the + // authorization URL, routing the browser directly to the org's IdP. + // Callers that hold a valid token before clearing credentials (e.g. pc target) + // should resolve this with FetchSSOConnection before logout, then pass it here. + SSOConnection *string } func Run(ctx context.Context, opts Options) { @@ -62,7 +69,7 @@ func Run(ctx context.Context, opts Options) { exit.Error(err, "Error checking for existing auth session") } if sess != nil { - if err := getAndSetAccessTokenJSON(ctx, nil, false, sess, result); err != nil { + if err := getAndSetAccessTokenJSON(ctx, opts.OrgId, false, opts.SSOConnection, sess, result); err != nil { msg.FailMsg("Error acquiring access token while logging in: %s", err) exit.Error(err, "Error acquiring access token while logging in") } @@ -80,26 +87,43 @@ func Run(ctx context.Context, opts Options) { } if !expired && token != nil && token.AccessToken != "" { - if opts.Json { - claims, err := oauth.ParseClaimsUnverified(token) - if err == nil { - fmt.Fprintln(os.Stdout, text.IndentJSON(struct { - Status string `json:"status"` - Email string `json:"email"` - OrgId string `json:"org_id"` - }{Status: "already_authenticated", Email: claims.Email, OrgId: claims.OrgId})) - } else { - fmt.Fprintln(os.Stdout, text.IndentJSON(struct { - Status string `json:"status"` - }{Status: "already_authenticated"})) + // If --org targets a different organization, re-authenticate now while + // the token is still valid so we can look up the SSO connection before + // clearing credentials. + differentOrg := false + if opts.OrgId != nil && *opts.OrgId != "" { + if claims, claimsErr := oauth.ParseClaimsUnverified(token); claimsErr == nil { + differentOrg = claims.OrgId != *opts.OrgId } + } + + if differentOrg { + opts.SSOConnection = ResolveSSOConnection(ctx, *opts.OrgId) + oauth.Logout() + // Fall through to GetAndSetAccessToken. } else { - msg.WarnMsg("You are already logged in. Please log out first using %s.", style.Code("pc auth logout")) + // Same org (or no --org flag) — show "already logged in". + if opts.Json { + claims, err := oauth.ParseClaimsUnverified(token) + if err == nil { + fmt.Fprintln(os.Stdout, text.IndentJSON(struct { + Status string `json:"status"` + Email string `json:"email"` + OrgId string `json:"org_id"` + }{Status: "already_authenticated", Email: claims.Email, OrgId: claims.OrgId})) + } else { + fmt.Fprintln(os.Stdout, text.IndentJSON(struct { + Status string `json:"status"` + }{Status: "already_authenticated"})) + } + } else { + msg.WarnMsg("You are already logged in. Please log out first using %s.", style.Code("pc auth logout")) + } + return } - return } - err = GetAndSetAccessToken(ctx, nil, opts) + err = GetAndSetAccessToken(ctx, opts.OrgId, opts) if err != nil { msg.FailMsg("Error acquiring access token while logging in: %s", err) exit.Error(err, "Error acquiring access token while logging in") @@ -188,9 +212,9 @@ func GetAndSetAccessToken(ctx context.Context, orgId *string, opts Options) erro // a terminal (agentic context), always use the JSON/daemon path. opts.Json = opts.Json || !term.IsTerminal(int(os.Stdout.Fd())) if opts.Json { - return getAndSetAccessTokenJSON(ctx, orgId, opts.Wait, nil, nil) + return getAndSetAccessTokenJSON(ctx, orgId, opts.Wait, opts.SSOConnection, nil, nil) } - return getAndSetAccessTokenInteractive(ctx, orgId) + return getAndSetAccessTokenInteractive(ctx, orgId, opts.SSOConnection) } // getAndSetAccessTokenJSON is the agentic path: daemon-backed, non-blocking on stdin. @@ -203,7 +227,7 @@ func GetAndSetAccessToken(ctx context.Context, orgId *string, opts Options) erro // When wait is true (for callers like pc target that need a token on return): spawns // daemon, blocks until auth completes, and returns with the token stored. RunPostAuthSetup // is not called; the caller owns post-auth state and output. -func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, sess *SessionState, result *SessionResult) error { +func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, ssoConnection *string, sess *SessionState, result *SessionResult) error { if sess == nil { // No pre-fetched session — look one up now. var err error @@ -238,7 +262,7 @@ func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, ses return fmt.Errorf("error creating new auth verifier and challenge: %w", err) } - authURL, err := a.GetAuthURL(ctx, csrfState, challenge, orgId) + authURL, err := a.GetAuthURL(ctx, csrfState, challenge, orgId, ssoConnection) if err != nil { return fmt.Errorf("error getting auth URL: %w", err) } @@ -370,7 +394,7 @@ func printPendingJSON(authURL, sessionId string) { // getAndSetAccessTokenInteractive is the original interactive path: inline callback server, // optional [Enter]-to-open-browser prompt when stdin is a TTY. -func getAndSetAccessTokenInteractive(ctx context.Context, orgId *string) error { +func getAndSetAccessTokenInteractive(ctx context.Context, orgId *string, ssoConnection *string) error { // If a daemon from a prior JSON-mode login exists, check whether it has // already finished before deciding whether to block interactive login. sess, result, err := findResumableSession() @@ -398,7 +422,7 @@ func getAndSetAccessTokenInteractive(ctx context.Context, orgId *string) error { return fmt.Errorf("error creating new auth verifier and challenge: %w", err) } - authURL, err := a.GetAuthURL(ctx, csrfState, challenge, orgId) + authURL, err := a.GetAuthURL(ctx, csrfState, challenge, orgId, ssoConnection) if err != nil { return fmt.Errorf("error getting auth URL: %w", err) } diff --git a/internal/pkg/utils/login/sso.go b/internal/pkg/utils/login/sso.go new file mode 100644 index 0000000..80ce16c --- /dev/null +++ b/internal/pkg/utils/login/sso.go @@ -0,0 +1,105 @@ +package login + +import ( + "context" + "encoding/json" + "io" + "net/http" + + "github.com/pinecone-io/cli/internal/pkg/utils/configuration/config" + "github.com/pinecone-io/cli/internal/pkg/utils/environment" + "github.com/pinecone-io/cli/internal/pkg/utils/log" + "github.com/pinecone-io/cli/internal/pkg/utils/oauth" +) + +// dashboardOrg is the subset of the dashboard API org response needed for SSO lookup. +type dashboardOrg struct { + Id string `json:"id"` + SSOConnectionName string `json:"sso_connection_name"` + EnforceSSO bool `json:"enforce_sso_authentication"` +} + +type dashboardOrgsResponse struct { + NewOrgs []dashboardOrg `json:"newOrgs"` +} + +// ResolveSSOConnection is a convenience wrapper around FetchSSOConnection that +// returns a pointer to the connection name when SSO is enforced for the org, or +// nil otherwise. Errors are logged at debug level and treated as "no SSO". +func ResolveSSOConnection(ctx context.Context, orgId string) *string { + conn, err := FetchSSOConnection(ctx, orgId) + if err != nil { + log.Debug().Err(err).Str("orgId", orgId).Msg("SSO connection lookup failed, proceeding without connection param") + } + if conn == "" { + return nil + } + return &conn +} + +// FetchSSOConnection calls the private dashboard API to retrieve the Auth0 +// connection name for the given orgId. It returns ("", nil) when the org has +// no SSO configured, enforce_sso_authentication is false, or any error occurs. +// Errors are non-fatal: the caller should proceed with a normal login URL. +func FetchSSOConnection(ctx context.Context, orgId string) (string, error) { + token, err := oauth.Token(ctx) + if err != nil || token == nil || token.AccessToken == "" { + log.Debug().Str("orgId", orgId).Msg("SSO lookup skipped: no valid token available") + return "", nil + } + + envConfig, err := environment.GetEnvConfig(config.Environment.Get()) + if err != nil { + return "", nil + } + + return fetchSSOConnectionFromURL(ctx, orgId, token.AccessToken, http.DefaultClient, envConfig.DashboardUrl) +} + +// fetchSSOConnectionFromURL is the testable core: it takes an explicit HTTP +// client and dashboard base URL so tests can inject a local httptest.Server. +func fetchSSOConnectionFromURL(ctx context.Context, orgId string, accessToken string, client *http.Client, dashboardURL string) (string, error) { + url := dashboardURL + "/v2/dashboard/organizations" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", nil + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := client.Do(req) + if err != nil { + log.Debug().Err(err).Str("orgId", orgId).Msg("SSO lookup: dashboard API request failed") + return "", nil + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + log.Debug().Int("status", resp.StatusCode).Str("orgId", orgId).Msg("SSO lookup: dashboard API returned non-2xx") + return "", nil + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + log.Debug().Err(err).Str("orgId", orgId).Msg("SSO lookup: failed to read dashboard API response") + return "", nil + } + + var orgsResp dashboardOrgsResponse + if err := json.Unmarshal(body, &orgsResp); err != nil { + log.Debug().Err(err).Str("orgId", orgId).Msg("SSO lookup: failed to decode dashboard API response") + return "", nil + } + + for _, org := range orgsResp.NewOrgs { + if org.Id == orgId { + if org.EnforceSSO && org.SSOConnectionName != "" { + log.Debug().Str("orgId", orgId).Str("connection", org.SSOConnectionName).Msg("SSO lookup: found connection") + return org.SSOConnectionName, nil + } + return "", nil + } + } + + log.Debug().Str("orgId", orgId).Msg("SSO lookup: org not found in dashboard response") + return "", nil +} diff --git a/internal/pkg/utils/login/sso_test.go b/internal/pkg/utils/login/sso_test.go new file mode 100644 index 0000000..77dbe89 --- /dev/null +++ b/internal/pkg/utils/login/sso_test.go @@ -0,0 +1,115 @@ +package login + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +// newDashboardServer starts an httptest.Server that returns the given org list. +// Pass a non-zero statusCode to simulate an error response. +func newDashboardServer(t *testing.T, orgs []dashboardOrg, statusCode int) *httptest.Server { + t.Helper() + if statusCode == 0 { + statusCode = http.StatusOK + } + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if statusCode != http.StatusOK { + http.Error(w, "error", statusCode) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(dashboardOrgsResponse{NewOrgs: orgs}) + })) +} + +func TestFetchSSOConnection_EnforcedWithConnection(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-1", SSOConnectionName: "alby-saml", EnforceSSO: true}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "alby-saml" { + t.Errorf("expected %q, got %q", "alby-saml", conn) + } +} + +func TestFetchSSOConnection_NotEnforced(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-1", SSOConnectionName: "alby-saml", EnforceSSO: false}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "" { + t.Errorf("expected empty connection when SSO not enforced, got %q", conn) + } +} + +func TestFetchSSOConnection_OrgNotFound(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-other", SSOConnectionName: "other-saml", EnforceSSO: true}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "" { + t.Errorf("expected empty connection when org not found, got %q", conn) + } +} + +func TestFetchSSOConnection_NonOKStatus(t *testing.T) { + server := newDashboardServer(t, nil, http.StatusUnauthorized) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "" { + t.Errorf("expected empty connection on non-2xx response, got %q", conn) + } +} + +func TestFetchSSOConnection_EmptyConnectionName(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-1", SSOConnectionName: "", EnforceSSO: true}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "" { + t.Errorf("expected empty connection when name is empty, got %q", conn) + } +} + +func TestFetchSSOConnection_MultipleOrgs(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-1", SSOConnectionName: "org1-saml", EnforceSSO: true}, + {Id: "org-2", SSOConnectionName: "org2-saml", EnforceSSO: true}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-2", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "org2-saml" { + t.Errorf("expected %q, got %q", "org2-saml", conn) + } +} diff --git a/internal/pkg/utils/oauth/auth.go b/internal/pkg/utils/oauth/auth.go index 0bb0ebb..fea55a6 100644 --- a/internal/pkg/utils/oauth/auth.go +++ b/internal/pkg/utils/oauth/auth.go @@ -15,7 +15,7 @@ const ( SourceTag = "pinecone_cli" ) -func (a *Auth) GetAuthURL(ctx context.Context, csrfState string, codeChallenge string, orgId *string) (string, error) { +func (a *Auth) GetAuthURL(ctx context.Context, csrfState string, codeChallenge string, orgId *string, ssoConnection *string) (string, error) { conf, err := newOauth2Config() if err != nil { return "", err @@ -34,6 +34,9 @@ func (a *Auth) GetAuthURL(ctx context.Context, csrfState string, codeChallenge s if orgId != nil && *orgId != "" { opts = append(opts, oauth2.SetAuthURLParam("orgId", *orgId)) } + if ssoConnection != nil && *ssoConnection != "" { + opts = append(opts, oauth2.SetAuthURLParam("connection", *ssoConnection)) + } return conf.AuthCodeURL(csrfState, opts...), nil } diff --git a/internal/pkg/utils/oauth/auth_test.go b/internal/pkg/utils/oauth/auth_test.go index e275cff..d5a40a8 100644 --- a/internal/pkg/utils/oauth/auth_test.go +++ b/internal/pkg/utils/oauth/auth_test.go @@ -21,7 +21,7 @@ func TestGetAuthURL_ContainsSourceTag(t *testing.T) { t.Fatalf("failed to create verifier/challenge: %v", err) } - rawURL, err := a.GetAuthURL(ctx, "test-csrf-state", challenge, nil) + rawURL, err := a.GetAuthURL(ctx, "test-csrf-state", challenge, nil, nil) if err != nil { t.Fatalf("GetAuthURL returned error: %v", err) } @@ -48,7 +48,7 @@ func TestGetAuthURL_RequiredParams(t *testing.T) { } csrfState := "test-state-123" - rawURL, err := a.GetAuthURL(ctx, csrfState, challenge, nil) + rawURL, err := a.GetAuthURL(ctx, csrfState, challenge, nil, nil) if err != nil { t.Fatalf("GetAuthURL returned error: %v", err) } @@ -84,7 +84,7 @@ func TestGetAuthURL_WithOrgId(t *testing.T) { } orgId := "test-org-456" - rawURL, err := a.GetAuthURL(ctx, "state", challenge, &orgId) + rawURL, err := a.GetAuthURL(ctx, "state", challenge, &orgId, nil) if err != nil { t.Fatalf("GetAuthURL returned error: %v", err) } @@ -109,7 +109,7 @@ func TestGetAuthURL_WithEmptyOrgId(t *testing.T) { } emptyOrgId := "" - rawURL, err := a.GetAuthURL(ctx, "state", challenge, &emptyOrgId) + rawURL, err := a.GetAuthURL(ctx, "state", challenge, &emptyOrgId, nil) if err != nil { t.Fatalf("GetAuthURL returned error: %v", err) } @@ -123,3 +123,77 @@ func TestGetAuthURL_WithEmptyOrgId(t *testing.T) { t.Errorf("expected orgId to be absent for empty string, got %q", got) } } + +func TestGetAuthURL_WithSSOConnection(t *testing.T) { + a := &Auth{} + ctx := context.Background() + + _, challenge, err := a.CreateNewVerifierAndChallenge() + if err != nil { + t.Fatalf("failed to create verifier/challenge: %v", err) + } + + connection := "alby-saml" + rawURL, err := a.GetAuthURL(ctx, "state", challenge, nil, &connection) + if err != nil { + t.Fatalf("GetAuthURL returned error: %v", err) + } + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("failed to parse auth URL: %v", err) + } + + if got := parsed.Query().Get("connection"); got != connection { + t.Errorf("expected connection=%q, got %q", connection, got) + } +} + +func TestGetAuthURL_WithNilSSOConnection(t *testing.T) { + a := &Auth{} + ctx := context.Background() + + _, challenge, err := a.CreateNewVerifierAndChallenge() + if err != nil { + t.Fatalf("failed to create verifier/challenge: %v", err) + } + + rawURL, err := a.GetAuthURL(ctx, "state", challenge, nil, nil) + if err != nil { + t.Fatalf("GetAuthURL returned error: %v", err) + } + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("failed to parse auth URL: %v", err) + } + + if got := parsed.Query().Get("connection"); got != "" { + t.Errorf("expected connection param to be absent, got %q", got) + } +} + +func TestGetAuthURL_WithEmptySSOConnection(t *testing.T) { + a := &Auth{} + ctx := context.Background() + + _, challenge, err := a.CreateNewVerifierAndChallenge() + if err != nil { + t.Fatalf("failed to create verifier/challenge: %v", err) + } + + emptyConnection := "" + rawURL, err := a.GetAuthURL(ctx, "state", challenge, nil, &emptyConnection) + if err != nil { + t.Fatalf("GetAuthURL returned error: %v", err) + } + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("failed to parse auth URL: %v", err) + } + + if got := parsed.Query().Get("connection"); got != "" { + t.Errorf("expected connection param to be absent for empty string, got %q", got) + } +}