From 3cb5e32e49c71d94d4e2812cfc6f156dca7036d4 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Tue, 13 Jan 2026 03:58:22 -0800 Subject: [PATCH] feat: Secrets --- cloud_pipelines_backend/api_router.py | 24 +++ cloud_pipelines_backend/api_server_sql.py | 151 ++++++++++++++++-- cloud_pipelines_backend/backend_types_sql.py | 11 ++ .../component_structures.py | 18 ++- .../launchers/interfaces.py | 1 + cloud_pipelines_backend/orchestrator_sql.py | 29 ++++ 6 files changed, 223 insertions(+), 11 deletions(-) diff --git a/cloud_pipelines_backend/api_router.py b/cloud_pipelines_backend/api_router.py index 6652637..f2c0dac 100644 --- a/cloud_pipelines_backend/api_router.py +++ b/cloud_pipelines_backend/api_router.py @@ -390,6 +390,30 @@ def get_current_user( permissions=permissions, ) + ### Secrets routes + secrets_service = api_server_sql.SecretsApiService() + + router.get("/api/secrets/", tags=["secrets"], **default_config)( + inject_session_dependency( + inject_user_name(secrets_service.list_secrets, parameter_name="user_id") + ) + ) + router.post("/api/secrets/", tags=["secrets"], **default_config)( + inject_session_dependency( + inject_user_name(secrets_service.create_secret, parameter_name="user_id") + ) + ) + router.put("/api/secrets/{secret_id}", tags=["secrets"], **default_config)( + inject_session_dependency( + inject_user_name(secrets_service.update_secret, parameter_name="user_id") + ) + ) + router.delete("/api/secrets/{secret_id}", tags=["secrets"], **default_config)( + inject_session_dependency( + inject_user_name(secrets_service.delete_secret, parameter_name="user_id") + ) + ) + ### Component library routes component_service = components_api.ComponentService() diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index e8e0624..9fb8307 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -994,6 +994,116 @@ def get_signed_artifact_url( return GetArtifactSignedUrlResponse(signed_url=signed_url) +# === Secrets Service +@dataclasses.dataclass(kw_only=True) +class SecretInfoResponse: + secret_id: str + + +@dataclasses.dataclass(kw_only=True) +class ListSecretsResponse: + secrets: list[SecretInfoResponse] + + +class SecretsApiService: + + def create_secret( + self, + *, + session: orm.Session, + user_id: str, + secret_id: str, + secret_value: str, + ): + return self._set_secret_value( + session=session, + user_id=user_id, + secret_id=secret_id, + secret_value=secret_value, + raise_if_exists=True, + ) + + def update_secret( + self, + *, + session: orm.Session, + user_id: str, + secret_id: str, + secret_value: str, + ): + return self._set_secret_value( + session=session, + user_id=user_id, + secret_id=secret_id, + secret_value=secret_value, + raise_if_not_exists=True, + ) + + def _set_secret_value( + self, + *, + session: orm.Session, + user_id: str, + secret_id: str, + secret_value: str, + raise_if_not_exists: bool = False, + raise_if_exists: bool = False, + ): + current_time = _get_current_time() + secret = session.get(bts.Secret, (user_id, secret_id)) + if secret: + if not raise_if_exists: + raise errors.ItemAlreadyExistsError( + f"Secret with id '{secret_id}' already exists." + ) + secret.secret_value = secret_value + secret.updated_at = current_time + else: + if raise_if_not_exists: + raise errors.ItemNotFoundError( + f"Secret with id '{secret_id}' does not exist." + ) + secret = bts.Secret( + user_id=user_id, + secret_id=secret_id, + secret_value=secret_value, + created_at=current_time, + updated_at=current_time, + ) + session.add(secret) + session.commit() + + def delete_secret( + self, + *, + session: orm.Session, + user_id: str, + secret_id: str, + ): + secret = session.get(bts.Secret, (user_id, secret_id)) + if not secret: + raise errors.ItemNotFoundError( + f"Secret with id '{secret_id}' does not exist." + ) + session.delete(secret) + session.commit() + + def list_secrets( + self, + *, + session: orm.Session, + user_id: str, + ) -> ListSecretsResponse: + secrets = session.scalars( + sql.select(bts.Secret).where(bts.Secret.user_id == user_id) + ).all() + return ListSecretsResponse( + secrets=[ + SecretInfoResponse(secret_id=secret.secret_id) for secret in secrets + ] + ) + + # ============ # Idea for how to add deep nested graph: @@ -1005,11 +1115,14 @@ def get_signed_artifact_url( # No. Decided to first do topological sort and then 1-stage generation. +_ArtifactNodeOrSecretType = typing.Union[bts.ArtifactNode, structures.SecretReference] + + def _recursively_create_all_executions_and_artifacts_root( session: orm.Session, root_task_spec: structures.TaskSpec, ) -> bts.ExecutionNode: - input_artifact_nodes: dict[str, bts.ArtifactNode] = {} + input_artifact_nodes: dict[str, _ArtifactNodeOrSecretType] = {} root_component_spec = root_task_spec.component_ref.spec if not root_component_spec: @@ -1035,12 +1148,8 @@ def _recursively_create_all_executions_and_artifacts_root( raise ApiServiceError( f"root task arguments can only be constants, but got {input_name}={input_argument}. {root_task_spec=}" ) - elif not isinstance(input_argument, str): - raise ApiServiceError( - f"root task constant argument must be a string, but got {input_name}={input_argument}. {root_task_spec=}" - ) # TODO: Support constant input artifacts (artifact IDs) - if input_argument is not None: + elif isinstance(input_argument, str): input_artifact_nodes[input_name] = ( # _construct_constant_artifact_node_and_add_to_session( # session=session, value=input_argument, artifact_type=input_spec.type @@ -1052,6 +1161,12 @@ def _recursively_create_all_executions_and_artifacts_root( # This constant artifact won't be added to the DB # TODO: Actually, they will be added... # We don't need to link this input artifact here. It will be handled downstream. + elif isinstance(input_argument, structures.SecretArgument): + input_artifact_nodes[input_name] = input_argument.secret + else: + raise ApiServiceError( + f"root task constant argument must be a string, but got {input_name}={input_argument}. {root_task_spec=}" + ) root_execution_node = _recursively_create_all_executions_and_artifacts( session=session, @@ -1065,7 +1180,7 @@ def _recursively_create_all_executions_and_artifacts_root( def _recursively_create_all_executions_and_artifacts( session: orm.Session, root_task_spec: structures.TaskSpec, - input_artifact_nodes: dict[str, bts.ArtifactNode], + input_artifact_nodes: dict[str, _ArtifactNodeOrSecretType], ancestors: list[bts.ExecutionNode], ) -> bts.ExecutionNode: root_component_spec = root_task_spec.component_ref.spec @@ -1098,6 +1213,18 @@ def _recursively_create_all_executions_and_artifacts( input_artifact_nodes = dict(input_artifact_nodes) for input_spec in root_component_spec.inputs or []: input_artifact_node = input_artifact_nodes.get(input_spec.name) + if isinstance(input_artifact_node, structures.SecretReference): + # We don't use these secret arguments, but adding them just in case. + extra_data = root_execution_node.extra_data or {} + secret_reference_arguments = extra_data.setdefault( + bts.EXECUTION_NODE_EXTRA_DATA_SECRET_REFERENCE_ARGUMENTS_KEY, {} + ) + secret_reference_arguments[input_spec.name] = ( + input_artifact_node.to_json_dict() + ) + root_execution_node.extra_data = extra_data + # Not adding any artifact link for secret inputs + continue if input_artifact_node is None and not input_spec.optional: if input_spec.default: input_artifact_node = ( @@ -1163,7 +1290,8 @@ def _recursively_create_all_executions_and_artifacts( root_execution_node.container_execution_status = ( bts.ContainerExecutionStatus.QUEUED if all( - artifact_node.artifact_data + not isinstance(artifact_node, bts.ArtifactNode) + or artifact_node.artifact_data for artifact_node in input_artifact_nodes.values() ) else bts.ContainerExecutionStatus.WAITING_FOR_UPSTREAM @@ -1190,10 +1318,10 @@ def _recursively_create_all_executions_and_artifacts( raise ApiServiceError( f"child_task_spec.component_ref.spec is empty. {child_task_spec=}" ) - child_task_input_artifact_nodes: dict[str, bts.ArtifactNode] = {} + child_task_input_artifact_nodes: dict[str, _ArtifactNodeOrSecretType] = {} for input_spec in child_component_spec.inputs or []: input_argument = (child_task_spec.arguments or {}).get(input_spec.name) - input_artifact_node: bts.ArtifactNode | None = None + input_artifact_node: _ArtifactNodeOrSecretType | None = None if input_argument is None and not input_spec.optional: # Not failing on unconnected required input if there is a default value if input_spec.default is None: @@ -1233,6 +1361,9 @@ def _recursively_create_all_executions_and_artifacts( # artifact_type=input_spec.type, # ) # ) + elif isinstance(input_argument, structures.SecretArgument): + # We'll deal with secrets when launching the container. + input_artifact_node = input_argument.secret else: raise ApiServiceError( f"Unexpected task argument: {input_spec.name}={input_argument}. {child_task_spec=}" diff --git a/cloud_pipelines_backend/backend_types_sql.py b/cloud_pipelines_backend/backend_types_sql.py index af16b3c..65f4a53 100644 --- a/cloud_pipelines_backend/backend_types_sql.py +++ b/cloud_pipelines_backend/backend_types_sql.py @@ -406,6 +406,7 @@ class ExecutionNode(_TableBase): EXECUTION_NODE_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY = ( "orchestration_error_message" ) +EXECUTION_NODE_EXTRA_DATA_SECRET_REFERENCE_ARGUMENTS_KEY = "secret_reference_arguments" CONTAINER_EXECUTION_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY = ( "orchestration_error_message" ) @@ -476,3 +477,13 @@ class PipelineRunAnnotation(_TableBase): pipeline_run: orm.Mapped[PipelineRun] = orm.relationship(repr=False, init=False) key: orm.Mapped[str] = orm.mapped_column(default=None, primary_key=True) value: orm.Mapped[str | None] = orm.mapped_column(default=None) + + +class Secret(_TableBase): + __tablename__ = "secret" + user_id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True) + secret_id: orm.Mapped[str] = orm.mapped_column(primary_key=True) + secret_value: orm.Mapped[str] + created_at: orm.Mapped[datetime.datetime | None] = orm.mapped_column(default=None) + updated_at: orm.Mapped[datetime.datetime | None] = orm.mapped_column(default=None) + extra_data: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None) diff --git a/cloud_pipelines_backend/component_structures.py b/cloud_pipelines_backend/component_structures.py index 12a26e0..d7c64ed 100644 --- a/cloud_pipelines_backend/component_structures.py +++ b/cloud_pipelines_backend/component_structures.py @@ -317,7 +317,23 @@ class TaskOutputArgument(_BaseModel): # Has additional constructor for convenie task_output: TaskOutputReference -ArgumentType = Union[PrimitiveTypes, GraphInputArgument, TaskOutputArgument] +@dataclasses.dataclass +class SecretReference(_BaseModel): + """References a secret""" + + id: str + + +@dataclasses.dataclass +class SecretArgument(_BaseModel): + """Argument that references a secret""" + + secret: SecretReference + + +ArgumentType = Union[ + PrimitiveTypes, GraphInputArgument, TaskOutputArgument, SecretArgument +] @dataclasses.dataclass diff --git a/cloud_pipelines_backend/launchers/interfaces.py b/cloud_pipelines_backend/launchers/interfaces.py index 1272e60..f5807f9 100644 --- a/cloud_pipelines_backend/launchers/interfaces.py +++ b/cloud_pipelines_backend/launchers/interfaces.py @@ -36,6 +36,7 @@ class InputArgument: value: str | None = None uri: str | None = None staging_uri: str + is_secret: bool = False class ContainerTaskLauncher(typing.Generic[_TLaunchedContainer], abc.ABC): diff --git a/cloud_pipelines_backend/orchestrator_sql.py b/cloud_pipelines_backend/orchestrator_sql.py index e81f09a..e958e8d 100644 --- a/cloud_pipelines_backend/orchestrator_sql.py +++ b/cloud_pipelines_backend/orchestrator_sql.py @@ -436,6 +436,34 @@ def generate_execution_log_uri( for output_spec in component_spec.outputs or [] } + # Handling secrets. + # We read secrets from execution_node.extra_data rather than from task_spec.arguments, + # because some secrets might have been passed from upstream graph inputs. + secret_reference_arguments = (execution.extra_data or {}).get( + bts.EXECUTION_NODE_EXTRA_DATA_SECRET_REFERENCE_ARGUMENTS_KEY, {} + ) + secret_hash = "" + for input_name, secret_reference_dict in secret_reference_arguments.items(): + user_id = pipeline_run.created_by + secret_id = secret_reference_dict["id"] + secret = session.get(bts.Secret, (user_id, secret_id)) + if not secret: + raise OrchestratorError( + f"{execution.id=}: User error: Error resolving a secret argument for {input_name=}: User {user_id} does not have secret {secret_id}." + ) + secret_value = secret.secret_value + input_artifact_data[input_name] = bts.ArtifactData( + total_size=len(secret_value.encode("utf-8")), + is_dir=False, + value=secret_value, + uri=None, + # This hash is not used, so we're using a dummy value here that makes it possible to identify the secret arguments in the following code. + hash=secret_hash, + ) + session.rollback() + + # Preparing the launcher input arguments + input_arguments = { input_name: launcher_interfaces.InputArgument( total_size=artifact_data.total_size, @@ -447,6 +475,7 @@ def generate_execution_log_uri( execution_id=container_execution_uuid, input_name=input_name, ), + is_secret=(artifact_data.hash == secret_hash), ) for input_name, artifact_data in input_artifact_data.items() }