Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions cloud_pipelines_backend/api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,30 @@ def get_current_user(
permissions=permissions,
)

### Secrets routes
secrets_service = api_server_sql.SecretsApiService()

router.get("/api/secrets/", tags=["secrets"], **default_config)(
inject_session_dependency(
inject_user_name(secrets_service.list_secrets, parameter_name="user_id")
)
)
router.post("/api/secrets/", tags=["secrets"], **default_config)(
inject_session_dependency(
inject_user_name(secrets_service.create_secret, parameter_name="user_id")
)
)
router.put("/api/secrets/{secret_id}", tags=["secrets"], **default_config)(
inject_session_dependency(
inject_user_name(secrets_service.update_secret, parameter_name="user_id")
)
)
router.delete("/api/secrets/{secret_id}", tags=["secrets"], **default_config)(
inject_session_dependency(
inject_user_name(secrets_service.delete_secret, parameter_name="user_id")
)
)

### Component library routes

component_service = components_api.ComponentService()
Expand Down
151 changes: 141 additions & 10 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,116 @@ def get_signed_artifact_url(
return GetArtifactSignedUrlResponse(signed_url=signed_url)


# === Secrets Service
@dataclasses.dataclass(kw_only=True)
class SecretInfoResponse:
secret_id: str


@dataclasses.dataclass(kw_only=True)
class ListSecretsResponse:
secrets: list[SecretInfoResponse]


class SecretsApiService:

def create_secret(
self,
*,
session: orm.Session,
user_id: str,
secret_id: str,
secret_value: str,
):
return self._set_secret_value(
Comment on lines +1017 to +1018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
):
return self._set_secret_value(
) -> None:
self._set_secret_value(

_set_secret_value doesn't return anything, so a return is not necessary here?

session=session,
user_id=user_id,
secret_id=secret_id,
secret_value=secret_value,
raise_if_exists=True,
)

def update_secret(
self,
*,
session: orm.Session,
user_id: str,
secret_id: str,
secret_value: str,
):
return self._set_secret_value(
Comment on lines +1033 to +1034

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
):
return self._set_secret_value(
) -> None:
self._set_secret_value(

session=session,
user_id=user_id,
secret_id=secret_id,
secret_value=secret_value,
raise_if_not_exists=True,
)

def _set_secret_value(
self,
*,
session: orm.Session,
user_id: str,
secret_id: str,
secret_value: str,
raise_if_not_exists: bool = False,
raise_if_exists: bool = False,
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
):
) -> None:

current_time = _get_current_time()
secret = session.get(bts.Secret, (user_id, secret_id))
if secret:
if not raise_if_exists:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not raise_if_exists:
if raise_if_exists:

I believe you don't want the not here?

raise errors.ItemAlreadyExistsError(
f"Secret with id '{secret_id}' already exists."
)
secret.secret_value = secret_value
secret.updated_at = current_time
else:
if raise_if_not_exists:
raise errors.ItemNotFoundError(
f"Secret with id '{secret_id}' does not exist."
)
secret = bts.Secret(
user_id=user_id,
secret_id=secret_id,
secret_value=secret_value,
Comment on lines +1068 to +1069

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious is there any validation needed for secret id/value? For example, what if they were empty strings?

created_at=current_time,
updated_at=current_time,
)
session.add(secret)
session.commit()

def delete_secret(
self,
*,
session: orm.Session,
user_id: str,
secret_id: str,
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
):
) -> None:

secret = session.get(bts.Secret, (user_id, secret_id))
if not secret:
raise errors.ItemNotFoundError(
f"Secret with id '{secret_id}' does not exist."
)
session.delete(secret)
session.commit()

def list_secrets(
self,
*,
session: orm.Session,
user_id: str,
) -> ListSecretsResponse:
secrets = session.scalars(
sql.select(bts.Secret).where(bts.Secret.user_id == user_id)
).all()
return ListSecretsResponse(
secrets=[
SecretInfoResponse(secret_id=secret.secret_id) for secret in secrets
]
)


# ============

# Idea for how to add deep nested graph:
Expand All @@ -1005,11 +1115,14 @@ def get_signed_artifact_url(
# No. Decided to first do topological sort and then 1-stage generation.


_ArtifactNodeOrSecretType = typing.Union[bts.ArtifactNode, structures.SecretReference]


def _recursively_create_all_executions_and_artifacts_root(
session: orm.Session,
root_task_spec: structures.TaskSpec,
) -> bts.ExecutionNode:
input_artifact_nodes: dict[str, bts.ArtifactNode] = {}
input_artifact_nodes: dict[str, _ArtifactNodeOrSecretType] = {}

root_component_spec = root_task_spec.component_ref.spec
if not root_component_spec:
Expand All @@ -1035,12 +1148,8 @@ def _recursively_create_all_executions_and_artifacts_root(
raise ApiServiceError(
f"root task arguments can only be constants, but got {input_name}={input_argument}. {root_task_spec=}"
)
elif not isinstance(input_argument, str):
raise ApiServiceError(
f"root task constant argument must be a string, but got {input_name}={input_argument}. {root_task_spec=}"
)
# TODO: Support constant input artifacts (artifact IDs)
if input_argument is not None:
elif isinstance(input_argument, str):
input_artifact_nodes[input_name] = (
# _construct_constant_artifact_node_and_add_to_session(
# session=session, value=input_argument, artifact_type=input_spec.type
Expand All @@ -1052,6 +1161,12 @@ def _recursively_create_all_executions_and_artifacts_root(
# This constant artifact won't be added to the DB
# TODO: Actually, they will be added...
# We don't need to link this input artifact here. It will be handled downstream.
elif isinstance(input_argument, structures.SecretArgument):
input_artifact_nodes[input_name] = input_argument.secret
else:
raise ApiServiceError(
f"root task constant argument must be a string, but got {input_name}={input_argument}. {root_task_spec=}"
)

root_execution_node = _recursively_create_all_executions_and_artifacts(
session=session,
Expand All @@ -1065,7 +1180,7 @@ def _recursively_create_all_executions_and_artifacts_root(
def _recursively_create_all_executions_and_artifacts(
session: orm.Session,
root_task_spec: structures.TaskSpec,
input_artifact_nodes: dict[str, bts.ArtifactNode],
input_artifact_nodes: dict[str, _ArtifactNodeOrSecretType],
ancestors: list[bts.ExecutionNode],
) -> bts.ExecutionNode:
root_component_spec = root_task_spec.component_ref.spec
Expand Down Expand Up @@ -1098,6 +1213,18 @@ def _recursively_create_all_executions_and_artifacts(
input_artifact_nodes = dict(input_artifact_nodes)
for input_spec in root_component_spec.inputs or []:
input_artifact_node = input_artifact_nodes.get(input_spec.name)
if isinstance(input_artifact_node, structures.SecretReference):
# We don't use these secret arguments, but adding them just in case.
extra_data = root_execution_node.extra_data or {}
secret_reference_arguments = extra_data.setdefault(
bts.EXECUTION_NODE_EXTRA_DATA_SECRET_REFERENCE_ARGUMENTS_KEY, {}
)
secret_reference_arguments[input_spec.name] = (
input_artifact_node.to_json_dict()
)
root_execution_node.extra_data = extra_data
# Not adding any artifact link for secret inputs
continue
if input_artifact_node is None and not input_spec.optional:
if input_spec.default:
input_artifact_node = (
Expand Down Expand Up @@ -1163,7 +1290,8 @@ def _recursively_create_all_executions_and_artifacts(
root_execution_node.container_execution_status = (
bts.ContainerExecutionStatus.QUEUED
if all(
artifact_node.artifact_data
not isinstance(artifact_node, bts.ArtifactNode)
or artifact_node.artifact_data
for artifact_node in input_artifact_nodes.values()
)
else bts.ContainerExecutionStatus.WAITING_FOR_UPSTREAM
Expand All @@ -1190,10 +1318,10 @@ def _recursively_create_all_executions_and_artifacts(
raise ApiServiceError(
f"child_task_spec.component_ref.spec is empty. {child_task_spec=}"
)
child_task_input_artifact_nodes: dict[str, bts.ArtifactNode] = {}
child_task_input_artifact_nodes: dict[str, _ArtifactNodeOrSecretType] = {}
for input_spec in child_component_spec.inputs or []:
input_argument = (child_task_spec.arguments or {}).get(input_spec.name)
input_artifact_node: bts.ArtifactNode | None = None
input_artifact_node: _ArtifactNodeOrSecretType | None = None
if input_argument is None and not input_spec.optional:
# Not failing on unconnected required input if there is a default value
if input_spec.default is None:
Expand Down Expand Up @@ -1233,6 +1361,9 @@ def _recursively_create_all_executions_and_artifacts(
# artifact_type=input_spec.type,
# )
# )
elif isinstance(input_argument, structures.SecretArgument):
# We'll deal with secrets when launching the container.
input_artifact_node = input_argument.secret
else:
raise ApiServiceError(
f"Unexpected task argument: {input_spec.name}={input_argument}. {child_task_spec=}"
Expand Down
11 changes: 11 additions & 0 deletions cloud_pipelines_backend/backend_types_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ class ExecutionNode(_TableBase):
EXECUTION_NODE_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY = (
"orchestration_error_message"
)
EXECUTION_NODE_EXTRA_DATA_SECRET_REFERENCE_ARGUMENTS_KEY = "secret_reference_arguments"
CONTAINER_EXECUTION_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY = (
"orchestration_error_message"
)
Expand Down Expand Up @@ -476,3 +477,13 @@ class PipelineRunAnnotation(_TableBase):
pipeline_run: orm.Mapped[PipelineRun] = orm.relationship(repr=False, init=False)
key: orm.Mapped[str] = orm.mapped_column(default=None, primary_key=True)
value: orm.Mapped[str | None] = orm.mapped_column(default=None)


class Secret(_TableBase):
__tablename__ = "secret"
user_id: orm.Mapped[str] = orm.mapped_column(primary_key=True, index=True)
secret_id: orm.Mapped[str] = orm.mapped_column(primary_key=True)
secret_value: orm.Mapped[str]
created_at: orm.Mapped[datetime.datetime | None] = orm.mapped_column(default=None)
updated_at: orm.Mapped[datetime.datetime | None] = orm.mapped_column(default=None)
extra_data: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None)
18 changes: 17 additions & 1 deletion cloud_pipelines_backend/component_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,23 @@ class TaskOutputArgument(_BaseModel): # Has additional constructor for convenie
task_output: TaskOutputReference


ArgumentType = Union[PrimitiveTypes, GraphInputArgument, TaskOutputArgument]
@dataclasses.dataclass
class SecretReference(_BaseModel):
"""References a secret"""

id: str


@dataclasses.dataclass
class SecretArgument(_BaseModel):
"""Argument that references a secret"""

secret: SecretReference


ArgumentType = Union[
PrimitiveTypes, GraphInputArgument, TaskOutputArgument, SecretArgument
]


@dataclasses.dataclass
Expand Down
1 change: 1 addition & 0 deletions cloud_pipelines_backend/launchers/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class InputArgument:
value: str | None = None
uri: str | None = None
staging_uri: str
is_secret: bool = False


class ContainerTaskLauncher(typing.Generic[_TLaunchedContainer], abc.ABC):
Expand Down
29 changes: 29 additions & 0 deletions cloud_pipelines_backend/orchestrator_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,34 @@ def generate_execution_log_uri(
for output_spec in component_spec.outputs or []
}

# Handling secrets.
# We read secrets from execution_node.extra_data rather than from task_spec.arguments,
# because some secrets might have been passed from upstream graph inputs.
secret_reference_arguments = (execution.extra_data or {}).get(
bts.EXECUTION_NODE_EXTRA_DATA_SECRET_REFERENCE_ARGUMENTS_KEY, {}
)
secret_hash = "<DUMMY_HASH_FOR_SECRET>"
for input_name, secret_reference_dict in secret_reference_arguments.items():
user_id = pipeline_run.created_by
secret_id = secret_reference_dict["id"]
secret = session.get(bts.Secret, (user_id, secret_id))
if not secret:
raise OrchestratorError(
f"{execution.id=}: User error: Error resolving a secret argument for {input_name=}: User {user_id} does not have secret {secret_id}."
)
secret_value = secret.secret_value
input_artifact_data[input_name] = bts.ArtifactData(
total_size=len(secret_value.encode("utf-8")),
is_dir=False,
value=secret_value,
uri=None,
# This hash is not used, so we're using a dummy value here that makes it possible to identify the secret arguments in the following code.
hash=secret_hash,
)
session.rollback()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why is there a rollback here? I didn't notice any session changes.


# Preparing the launcher input arguments

input_arguments = {
input_name: launcher_interfaces.InputArgument(
total_size=artifact_data.total_size,
Expand All @@ -447,6 +475,7 @@ def generate_execution_log_uri(
execution_id=container_execution_uuid,
input_name=input_name,
),
is_secret=(artifact_data.hash == secret_hash),
)
for input_name, artifact_data in input_artifact_data.items()
}
Expand Down