diff --git a/elt-common/src/elt_common/iceberg/catalog.py b/elt-common/src/elt_common/iceberg/catalog.py new file mode 100644 index 00000000..344a4493 --- /dev/null +++ b/elt-common/src/elt_common/iceberg/catalog.py @@ -0,0 +1,21 @@ +"""Iceberg catalog configuration. + +Reads connection properties from environment variables and provides +a ``connect_catalog()`` helper that returns a connected pyiceberg ``Catalog``. +""" + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.typedef import Identifier +from pyiceberg.utils.config import Config as IcebergCatalogConfig + + +def connect_catalog() -> Catalog: + """The default load_catalog only allows environment variables set before the first import or pyiceberg.catalog""" + config = IcebergCatalogConfig() + name = config.get_default_catalog_name() + return load_catalog(name, **config.get_catalog_config(name)) # type: ignore + + +def table_identifier(namespace: str, table_name: str) -> Identifier: + """Construct a standard fully-qualified table name.""" + return namespace, table_name diff --git a/elt-common/src/elt_common/iceberg/io.py b/elt-common/src/elt_common/iceberg/io.py new file mode 100644 index 00000000..3c0481c1 --- /dev/null +++ b/elt-common/src/elt_common/iceberg/io.py @@ -0,0 +1,143 @@ +"""Iceberg table io using pyiceberg. + +Provides :class:`IcebergIO` which can be used to write pyarrow tables to an +iceberg catalog, and read properties from iceberg tables, whilst handling table, +namespace, and schema creation/evolution +""" + +import logging + +import pyarrow as pa +from pyiceberg.typedef import Identifier + +from elt_common.iceberg.catalog import Catalog +from elt_common.iceberg.schema import create_schema, evolve_schema +from elt_common.iceberg.partition import create_partition_spec +from elt_common.iceberg.sortorder import create_sort_order +from elt_common.typing import ( + BaseIO, + PartitionConfig, + SortOrderConfig, + WriteMode, +) +from pyiceberg.table import ALWAYS_TRUE, Table as IcebergTable + + +LOGGER = logging.getLogger(__name__) + + +class IcebergIO(BaseIO): + """Read/write arrow tables to/from Iceberg, handling table creation and schema evolution.""" + + def __init__(self, catalog: Catalog) -> None: + self.catalog = catalog + + def ensure_namespace(self, namespace: str) -> None: + """Create the namespace if it doesn't already exist.""" + if not self.catalog.namespace_exists(namespace): + self.catalog.create_namespace(namespace) + LOGGER.info(f"Created namespace '{namespace}'") + + def read_property(self, table_id: Identifier, key: str) -> str: + """Read a table property. + + :param table_id: namespaced identifier of the table to read from + :param key: the key to read the value of + :raises: KeyError if property does not exist + """ + table = self.catalog.load_table(table_id) + return table.properties[key] + + def write_table( + self, + table_id: Identifier, + data: pa.Table, + write_mode: WriteMode, + *, + merge_on: list[str] | None = None, + partition: PartitionConfig | None = None, + sort_order: SortOrderConfig | None = None, + properties: dict[str, str] | None = None, + ) -> None: + """Write an Arrow table to an Iceberg table. + + :param table_id: namespaced identifier of the table to write to + :param data: the new data to write to the table + :param write_mode: 'append' adds the data to the table, + 'merge' adds new data and modifies existing rows, + 'replace' completely overwrites the table with the new data + :param merge_on: field(s) to determine if rows should be merged. Required if write_mode is 'merge' + :param partition: mapping of table names to the column(s) they should be partitioned by + :param sort_order: mapping of table names to the sort direction of their column(s) + :param properties: additional properties to set on the table upon completion. Useful for watermarking + """ + if data.num_rows == 0: + LOGGER.info(f"No data to write to {table_id}, skipping.") + return + + iceberg_table = self._ensure_table(table_id, data.schema, partition, sort_order) + + with iceberg_table.transaction() as txn: + if write_mode == "append": + txn.append(data) + elif write_mode == "merge": + if merge_on is None: + raise ValueError( + f"Table '{table_id}': write mode 'merge' requires 'merge_on' property." + ) + txn.upsert( + df=data, + join_cols=merge_on, + when_matched_update_all=True, + when_not_matched_insert_all=True, + case_sensitive=True, + ) + elif write_mode == "replace": + txn.overwrite(data, overwrite_filter=ALWAYS_TRUE, case_sensitive=True) + else: + raise ValueError(f"Unsupported write mode: {write_mode!r}") + + if properties is not None: + txn.set_properties(properties) + + LOGGER.debug(f"Wrote {data.num_rows} rows to {table_id} (mode={write_mode})") + + # private + def _ensure_table( + self, + table_id: Identifier, + arrow_schema: pa.Schema, + partition: PartitionConfig | None, + sort_order: SortOrderConfig | None, + ) -> IcebergTable: + """Load an existing table or create a new one. + + For existing tables ensure the schema matches the incoming data.""" + if self.catalog.table_exists(table_id): + return _ensure_table_schema(self.catalog.load_table(table_id), arrow_schema) + + iceberg_schema = create_schema(arrow_schema) + LOGGER.debug(f"Created iceberg schema: {iceberg_schema}") + partition_spec = create_partition_spec(partition, iceberg_schema) + LOGGER.debug(f"Created partition spec: {partition_spec}") + sort_order_spec = create_sort_order(sort_order, iceberg_schema) + LOGGER.debug(f"Created sort order spec: {sort_order_spec}") + + LOGGER.info(f"Creating table {table_id}") + return self.catalog.create_table( + table_id, + schema=iceberg_schema, + partition_spec=partition_spec, + sort_order=sort_order_spec, + ) + + +def _ensure_table_schema(iceberg_table: IcebergTable, new_schema: pa.Schema) -> IcebergTable: + """Ensure the existing table schema matches the new schema.""" + new_schema = evolve_schema(iceberg_table.schema(), new_schema) # type:ignore + if new_schema is not None: + LOGGER.debug(f"Evolving schema. New schema: {new_schema}") + with iceberg_table.update_schema() as update: + update.union_by_name(new_schema) + + return iceberg_table diff --git a/elt-common/src/elt_common/iceberg/partition.py b/elt-common/src/elt_common/iceberg/partition.py new file mode 100644 index 00000000..40da0ecb --- /dev/null +++ b/elt-common/src/elt_common/iceberg/partition.py @@ -0,0 +1,33 @@ +from elt_common.typing import PartitionConfig +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionField, + PartitionSpec, +) +from pyiceberg.schema import Schema +import pyiceberg.transforms as transforms + + +def create_partition_spec( + partition_config: PartitionConfig | None, iceberg_schema: Schema +) -> PartitionSpec: + """Create an Iceberg partition spec from the partition hints""" + + def field_name(column_name: str, transform: str): + bracket_index = transform.find("[") + return f"{column_name}_{transform[:bracket_index] if bracket_index > 0 else transform}" + + if not partition_config: + return UNPARTITIONED_PARTITION_SPEC + + return PartitionSpec( + *( + PartitionField( + source_id=iceberg_schema.find_field(column_name).field_id, + field_id=1000 + index, # the documentation does this... + transform=transforms.parse_transform(transform), + name=field_name(column_name, transform), + ) + for index, (column_name, transform) in enumerate(partition_config.items()) + ) + ) diff --git a/elt-common/src/elt_common/iceberg/schema.py b/elt-common/src/elt_common/iceberg/schema.py new file mode 100644 index 00000000..9167d3f0 --- /dev/null +++ b/elt-common/src/elt_common/iceberg/schema.py @@ -0,0 +1,115 @@ +import itertools +from typing import Sequence + +import pyarrow as pa +from pyiceberg.schema import Schema +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + IntegerType, + LongType, + NestedField, + PrimitiveType, + StringType, + TimeType, + TimestampType, + TimestamptzType, +) + + +def arrow_type_to_iceberg(arrow_type: pa.DataType) -> PrimitiveType: + """Returns the Iceberg type for the given pyarrow data type. + + :raises TypeError: If the type is unknown or is not supported + """ + if pa.types.is_boolean(arrow_type): + return BooleanType() + elif pa.types.is_int32(arrow_type): + return IntegerType() + elif pa.types.is_int64(arrow_type): + return LongType() + elif pa.types.is_float64(arrow_type): + return DoubleType() + elif pa.types.is_decimal(arrow_type): + return DecimalType(arrow_type.precision, arrow_type.scale) + elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): + return StringType() + elif pa.types.is_date(arrow_type): + return DateType() + elif pa.types.is_time(arrow_type): + if arrow_type.unit != "us": + raise TypeError( + f"Iceberg time type only supports 'us' precision. Requested precision={arrow_type.unit}'." + ) + return TimeType() + elif pa.types.is_timestamp(arrow_type): + if arrow_type.unit == "ns": + raise TypeError("Iceberg v1 & v2 does not support timestamps in 'ns' precision.") + if arrow_type.tz is not None: + return TimestamptzType() + else: + return TimestampType() + elif ( + pa.types.is_binary(arrow_type) + or pa.types.is_large_binary(arrow_type) + or pa.types.is_fixed_size_binary(arrow_type) + ): + return BinaryType() + else: + raise TypeError(f"Pyarrow type '{arrow_type}' unknown to type mapper.") + + +def arrow_field_to_iceberg(column_id: int, arrow_field: pa.Field) -> NestedField: + return NestedField( + field_id=column_id, + name=arrow_field.name, + field_type=arrow_type_to_iceberg(arrow_field.type), + required=not arrow_field.nullable, + ) + + +def create_schema(arrow_schema: pa.Schema, identifier_fields: Sequence[str] = ()) -> Schema: + """Convert a pyarrow schema into an iceberg schema + + :param arrow_schema: A pyarrow schema. + :param identifier_fields: An optional list of fields to mark as identifiers + """ + iceberg_fields, identifier_field_ids = [], [] + for index, arrow_field in enumerate(arrow_schema): + col_id = index + 1 + iceberg_fields.append(arrow_field_to_iceberg(col_id, arrow_field)) + if arrow_field.name in identifier_fields: + identifier_field_ids.append(col_id) + + return Schema(*iceberg_fields, identifier_field_ids=identifier_field_ids) + + +def evolve_schema(iceberg_schema: Schema, new_arrow_schema: pa.Schema) -> Schema | None: + """Attempt to evolve the schema to match the data. + + Returns the new schema if updates were applied, else None + """ + existing_columns = set(iceberg_schema.column_names) + new_columns = set(new_arrow_schema.names) - existing_columns + if new_columns: + num_existing_fields = len(iceberg_schema.fields) + + return Schema( + *( + itertools.chain( + iceberg_schema.fields, + [ + arrow_field_to_iceberg( + num_existing_fields + index + 1, new_arrow_schema.field(name) + ) + for index, name in enumerate(new_arrow_schema.names) + if name in new_columns + ], + ) + ) + ) + else: + return None diff --git a/elt-common/src/elt_common/iceberg/sortorder.py b/elt-common/src/elt_common/iceberg/sortorder.py new file mode 100644 index 00000000..9d1f235f --- /dev/null +++ b/elt-common/src/elt_common/iceberg/sortorder.py @@ -0,0 +1,28 @@ +from elt_common.typing import SortOrderConfig +import pyiceberg.transforms as transforms +from pyiceberg.schema import Schema +from pyiceberg.table.sorting import ( + UNSORTED_SORT_ORDER, + SortOrder, + SortField, + SortDirection, +) + + +def create_sort_order( + sort_order_config: SortOrderConfig | None, iceberg_schema: Schema +) -> SortOrder: + """If a sort order hint is provider, create the appropriate SortOrder instance.""" + if not sort_order_config: + return UNSORTED_SORT_ORDER + + return SortOrder( + *( + SortField( + source_id=iceberg_schema.find_field(column_name).field_id, + direction=SortDirection(direction), + transform=transforms.parse_transform("identity"), + ) + for column_name, direction in sort_order_config.items() + ) + ) diff --git a/elt-common/src/elt_common/pipeline.py b/elt-common/src/elt_common/pipeline.py index b40fae31..00cffa69 100644 --- a/elt-common/src/elt_common/pipeline.py +++ b/elt-common/src/elt_common/pipeline.py @@ -1,27 +1,10 @@ """Utilities for capturing and describing information about ELT jobs in a set of elt pipelines.""" -import dataclasses as dc from pathlib import Path -INGEST = "ingest" - - -@dc.dataclass(frozen=True) -class ELTJobManifest: - """Parsed representation of an ELT job""" - - name: str - domain: str - ingest_job_dir: Path +from .typing import ELTJobManifest - @property - def full_name(self) -> str: - return f"{self.domain}.{self.name}" - - @property - def destination_namespace(self) -> str: - """The destination namespace for this job: ``{domain}_{name}``.""" - return f"{self.domain}_{self.name}" +INGEST = "ingest" class PipelinesProject: diff --git a/elt-common/src/elt_common/typing.py b/elt-common/src/elt_common/typing.py new file mode 100644 index 00000000..0382f92d --- /dev/null +++ b/elt-common/src/elt_common/typing.py @@ -0,0 +1,74 @@ +from abc import abstractmethod, ABC +import dataclasses as dc +from pathlib import Path +from typing import TYPE_CHECKING, Iterator, Literal + +from pyiceberg.typedef import Identifier + +if TYPE_CHECKING: + import pyarrow as pa + + +class BaseIO(ABC): + @abstractmethod + def ensure_namespace(self, namespace: str) -> None: + raise NotImplementedError( + "Subclass should implement `ensure_namespace` to ensure the namespace exists." + ) + + @abstractmethod + def write_table( + self, + table_id: Identifier, + data: "pa.Table", + write_mode: "WriteMode", + *, + merge_on: list[str] | None = None, + partition: "PartitionConfig | None" = None, + sort_order: "SortOrderConfig | None" = None, + properties: dict[str, str] | None = None, + ) -> None: + raise NotImplementedError( + "Subclass should implement `write_table` to write a table to the destination." + ) + + +DataChunks = Iterator["pa.Table"] +"""An iterator to a collection of DataChunk objects.""" + + +@dc.dataclass(frozen=True) +class ELTJobManifest: + """Parsed representation of an ELT job""" + + name: str + domain: str + ingest_job_dir: Path + + @property + def full_name(self) -> str: + return f"{self.domain}.{self.name}" + + @property + def destination_namespace(self) -> str: + """The destination namespace for this job: ``{domain}_{name}``.""" + return f"{self.domain}_{self.name}" + + +PartitionConfig = dict[str, str] +"""Define the configuration of a Table partition where a key represents a column and the mapped +value defines an Iceberg transformation. +""" + +SortOrderConfig = dict[str, str] +"""Define the sort order on the columns in the Iceberg table. +""" + + +WriteMode = Literal["append", "merge", "replace"] +"""Catalog write modes. + +- Append: Append data to existing records +- Merge: Upsert data into existing records, updating values of any records that exist +- Replace: Before loading, drop the data in the destination then append the new records. +""" diff --git a/elt-common/tests/unit_tests/iceberg/conftest.py b/elt-common/tests/unit_tests/iceberg/conftest.py new file mode 100644 index 00000000..fd926791 --- /dev/null +++ b/elt-common/tests/unit_tests/iceberg/conftest.py @@ -0,0 +1,16 @@ +"""Common fixtures and pytest config""" + +from pyiceberg.schema import Schema +from pyiceberg.types import NestedField, LongType, StringType, DateType, TimestamptzType +import pytest + + +@pytest.fixture(scope="session") +def sample_schema(): + """Create a sample Iceberg schema for testing.""" + return Schema( + NestedField(field_id=1, name="id", type=LongType(), required=True), + NestedField(field_id=2, name="name", type=StringType(), required=True), + NestedField(field_id=3, name="date_col", type=DateType(), required=False), + NestedField(field_id=4, name="ts_col", type=TimestamptzType(), required=False), + ) diff --git a/elt-common/tests/unit_tests/iceberg/test_catalog.py b/elt-common/tests/unit_tests/iceberg/test_catalog.py new file mode 100644 index 00000000..b8b9a546 --- /dev/null +++ b/elt-common/tests/unit_tests/iceberg/test_catalog.py @@ -0,0 +1,59 @@ +"""Tests for elt_common.iceberg.catalog""" + +import pytest +from pytest_mock import MockerFixture +from unittest.mock import MagicMock + +from elt_common.iceberg.catalog import ( + connect_catalog, + table_identifier, +) + + +@pytest.fixture +def mock_config(mocker: MockerFixture): + mock_config_cls = mocker.patch("elt_common.iceberg.catalog.IcebergCatalogConfig") + mock_config = MagicMock() + mock_config.get_default_catalog_name.return_value = "default" + mock_config.get_catalog_config.return_value = {"warehouse": "/tmp/warehouse"} + mock_config_cls.return_value = mock_config + + return mock_config + + +@pytest.fixture +def mock_load_catalog(mocker: MockerFixture): + return mocker.patch("elt_common.iceberg.catalog.load_catalog") + + +def test_connect_catalog_loads_default_catalog(mock_config, mock_load_catalog): + # Execute + connect_catalog() + + # Assert + mock_config.get_default_catalog_name.assert_called_once() + mock_config.get_catalog_config.assert_called_once_with("default") + mock_load_catalog.assert_called_once_with("default", warehouse="/tmp/warehouse") + + +def test_connect_catalog_forwards_all_options_from_pyiceberg_catalog_config( + mock_config, mock_load_catalog +): + catalog_config = { + "warehouse": "/data/warehouse", + "uri": "http://localhost:8181", + "auth": "oauth2", + } + mock_config.get_catalog_config.return_value = catalog_config + + # Execute + connect_catalog() + + # Assert + mock_load_catalog.assert_called_once_with("default", **catalog_config) + + +def test_table_id_returns_tuple_identifier(): + """Test that table_id returns a tuple with namespace and table name.""" + result = table_identifier("my_namespace", "my_table") + assert result == ("my_namespace", "my_table") diff --git a/elt-common/tests/unit_tests/iceberg/test_io.py b/elt-common/tests/unit_tests/iceberg/test_io.py new file mode 100644 index 00000000..fa88d935 --- /dev/null +++ b/elt-common/tests/unit_tests/iceberg/test_io.py @@ -0,0 +1,186 @@ +"""Tests for elt_common.iceberg.io""" + +from collections import namedtuple +import datetime as dt + +from elt_common.iceberg.io import IcebergIO +import pyarrow as pa +from pyiceberg.catalog import Catalog +from pyiceberg.table import ALWAYS_TRUE, Table +import pytest +from unittest.mock import MagicMock + + +MockedDependencies = namedtuple("MockedDependencies", ["mock_catalog", "mock_transaction"]) + + +@pytest.fixture(scope="session") +def sample_arrow_table(): + return pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "name": pa.array(["a", "b", "c"], type=pa.string()), + "ts": pa.array( + [ + dt.datetime.fromisoformat("2024-01-15T10:00:00"), + dt.datetime.fromisoformat("2024-02-20T11:00:00"), + dt.datetime.fromisoformat("2024-03-25T12:00:00"), + ], + type=pa.timestamp("us", tz="UTC"), + ), + } + ) + + +@pytest.fixture +def mock_dependencies() -> MockedDependencies: + mock_catalog = MagicMock(spec=Catalog) + mock_table = MagicMock(spec=Table) + mock_transaction = MagicMock() + mock_table.transaction.return_value.__enter__.return_value = mock_transaction + mock_catalog.create_table.return_value = mock_table + mock_catalog.load_table.return_value = mock_table + + return MockedDependencies( + mock_catalog=mock_catalog, + mock_transaction=mock_transaction, + ) + + +def test_ensure_namespace_creates_when_missing(mock_dependencies: MockedDependencies): + """Tests for IcebergIO.ensure_namespace creates namespace when missing""" + mock_dependencies.mock_catalog.namespace_exists.return_value = False + io = IcebergIO(mock_dependencies.mock_catalog) + + io.ensure_namespace("test_ns") + mock_dependencies.mock_catalog.create_namespace.assert_called_once_with("test_ns") + + +def test_ensure_namespace_noop_when_namespace_exists(mock_dependencies: MockedDependencies): + """Tests for IcebergIO.ensure_namespace when namespace exists""" + mock_dependencies.mock_catalog.namespace_exists.return_value = True + io = IcebergIO(mock_dependencies.mock_catalog) + + io.ensure_namespace("test_ns") + mock_dependencies.mock_catalog.create_namespace.assert_not_called() + + +def test_read_property_loads_table_and_returns_property_if_exists( + mock_dependencies: MockedDependencies, +): + mock_dependencies.mock_catalog.load_table.return_value.properties = {"test.key": "value"} + io = IcebergIO(mock_dependencies.mock_catalog) + + result = io.read_property(("ns", "t"), "test.key") + + assert result == "value" + mock_dependencies.mock_catalog.load_table.assert_called_once_with(("ns", "t")) + + +def test_read_property_raises_KeyError_property_if_missing( + mock_dependencies: MockedDependencies, +): + mock_dependencies.mock_catalog.load_table.return_value.properties = {} + io = IcebergIO(mock_dependencies.mock_catalog) + + with pytest.raises(KeyError): + io.read_property(("ns", "t"), "test.key") + + mock_dependencies.mock_catalog.load_table.assert_called_once_with(("ns", "t")) + + +def test_write_table_skips_empty_data( + mock_dependencies: MockedDependencies, sample_arrow_table: pa.Table +): + """Tests for IcebergIO.write_table skips empty data""" + io = IcebergIO(mock_dependencies.mock_catalog) + + empty = sample_arrow_table.slice(0, 0) + io.write_table(("ns", "t"), empty, "append") + + mock_dependencies.mock_catalog.load_table.assert_not_called() + mock_dependencies.mock_catalog.create_table.assert_not_called() + + +def test_write_table_append_creates_and_appends( + mock_dependencies: MockedDependencies, sample_arrow_table: pa.Table +): + """Tests for IcebergIO.write_table append mode""" + mock_catalog = mock_dependencies.mock_catalog + mock_catalog.table_exists.return_value = False + + io = IcebergIO(mock_catalog) + io.write_table(("ns", "t"), sample_arrow_table, "append") + + mock_dependencies.mock_catalog.load_table.assert_not_called() + mock_dependencies.mock_catalog.create_table.assert_called_once() + mock_dependencies.mock_transaction.append.assert_called_once_with(sample_arrow_table) + + +def test_write_table_merge_requires_merge_on( + mock_dependencies: MockedDependencies, sample_arrow_table +): + """Tests for IcebergIO.write_table merge mode requires merge_on""" + mock_dependencies.mock_catalog.table_exists.return_value = True + io = IcebergIO(mock_dependencies.mock_catalog) + + with pytest.raises(ValueError, match=r".*write mode 'merge' requires 'merge_on' property\."): + io.write_table(("ns", "t"), sample_arrow_table, "merge") + + +def test_write_table_merge_calls_upsert(mock_dependencies: MockedDependencies, sample_arrow_table): + """Tests for IcebergIO.write_table merge mode calls upsert""" + mock_dependencies.mock_catalog.table_exists.return_value = True + + io = IcebergIO(mock_dependencies.mock_catalog) + io.write_table( + ("ns", "t"), + sample_arrow_table, + "merge", + merge_on=["id"], + ) + + mock_dependencies.mock_transaction.upsert.assert_called_once_with( + df=sample_arrow_table, + join_cols=["id"], + when_matched_update_all=True, + when_not_matched_insert_all=True, + case_sensitive=True, + ) + + +def test_write_table_sets_properties_if_supplied( + mock_dependencies: MockedDependencies, sample_arrow_table: pa.Table +): + mock_dependencies.mock_catalog.table_exists.return_value = True + properties = {"timestamp": "2024-01-15T10:00:00Z"} + + io = IcebergIO(mock_dependencies.mock_catalog) + io.write_table( + ("ns", "t"), + sample_arrow_table, + "append", + properties=properties, + ) + + mock_dependencies.mock_transaction.append.assert_called_once_with(sample_arrow_table) + mock_dependencies.mock_transaction.set_properties.assert_called_once_with(properties) + + +def test_write_table_replace_calls_overwrite( + mock_dependencies: MockedDependencies, sample_arrow_table: pa.Table +): + """Tests for IcebergIO.write_table replace mode""" + mock_catalog = mock_dependencies.mock_catalog + mock_catalog.table_exists.return_value = True + + io = IcebergIO(mock_catalog) + io.write_table( + ("ns", "t"), + sample_arrow_table, + "replace", + ) + + mock_dependencies.mock_transaction.overwrite.assert_called_once_with( + sample_arrow_table, overwrite_filter=ALWAYS_TRUE, case_sensitive=True + ) diff --git a/elt-common/tests/unit_tests/iceberg/test_partition.py b/elt-common/tests/unit_tests/iceberg/test_partition.py new file mode 100644 index 00000000..d6e68e22 --- /dev/null +++ b/elt-common/tests/unit_tests/iceberg/test_partition.py @@ -0,0 +1,137 @@ +"""Tests for elt_common.iceberg.partition""" + +import pytest +from elt_common.iceberg.partition import create_partition_spec +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC +from pyiceberg import transforms + + +def test_create_partition_spec_when_hint_none_or_empty(sample_schema): + """Test that None partition hint returns unpartitioned spec.""" + assert create_partition_spec(None, sample_schema) == UNPARTITIONED_PARTITION_SPEC + assert create_partition_spec({}, sample_schema) == UNPARTITIONED_PARTITION_SPEC + + +def test_create_partition_spec_single_column_identity(sample_schema): + """Test single column with identity transform.""" + partition_hint = {"name": "identity"} + result = create_partition_spec(partition_hint, sample_schema) + + assert len(result.fields) == 1 + field = result.fields[0] + assert field.name == "name_identity" + assert field.source_id == 2 # id of 'name' field + assert field.field_id == 1000 + assert isinstance(field.transform, transforms.IdentityTransform) + + +@pytest.mark.parametrize( + "transform,expected_name,expected_transform_type", + [ + ("year", "date_col_year", transforms.YearTransform), + ("month", "date_col_month", transforms.MonthTransform), + ("day", "date_col_day", transforms.DayTransform), + ], +) +def test_create_partition_spec_single_column_date_transforms( + sample_schema, transform, expected_name, expected_transform_type +): + """Test single column with date transforms (year, month, day).""" + partition_hint = {"date_col": transform} + result = create_partition_spec(partition_hint, sample_schema) + + assert len(result.fields) == 1 + field = result.fields[0] + assert field.name == expected_name + assert field.source_id == 3 # id of 'date_col' field + assert field.field_id == 1000 + assert isinstance(field.transform, expected_transform_type) + + +def test_create_partition_spec_single_column_hour(sample_schema): + """Test single column with hour transform.""" + partition_hint = {"ts_col": "hour"} + result = create_partition_spec(partition_hint, sample_schema) + + assert len(result.fields) == 1 + field = result.fields[0] + assert field.name == "ts_col_hour" + assert field.source_id == 4 # id of 'ts_col' field + assert field.field_id == 1000 + assert isinstance(field.transform, transforms.HourTransform) + + +def test_create_partition_spec_truncate_transform(sample_schema): + """Test transform with parameter like truncate[10].""" + partition_hint = {"name": "truncate[10]"} + result = create_partition_spec(partition_hint, sample_schema) + + assert len(result.fields) == 1 + field = result.fields[0] + # Field name should use only the transform part before bracket + assert field.name == "name_truncate" + assert field.source_id == 2 + assert isinstance(field.transform, transforms.TruncateTransform) + + +def test_create_partition_spec_bucket_transform(sample_schema): + """Test bucket transform with parameter.""" + partition_hint = {"id": "bucket[100]"} + result = create_partition_spec(partition_hint, sample_schema) + + assert len(result.fields) == 1 + field = result.fields[0] + assert field.name == "id_bucket" + assert field.source_id == 1 + assert isinstance(field.transform, transforms.BucketTransform) + + +def test_create_partition_spec_multiple_columns(sample_schema): + """Test multiple columns with different transforms.""" + partition_hint = { + "date_col": "year", + "name": "identity", + "ts_col": "month", + } + result = create_partition_spec(partition_hint, sample_schema) + + assert len(result.fields) == 3 + + # Check field IDs are sequential starting from 1000 + field_ids = sorted([field.field_id for field in result.fields]) + assert field_ids == [1000, 1001, 1002] + + # Check field names + field_names = sorted([field.name for field in result.fields]) + assert "date_col_year" in field_names + assert "name_identity" in field_names + assert "ts_col_month" in field_names + + # Check transform types + transform_map = {field.name: field.transform for field in result.fields} + assert isinstance(transform_map["date_col_year"], transforms.YearTransform) + assert isinstance(transform_map["name_identity"], transforms.IdentityTransform) + assert isinstance(transform_map["ts_col_month"], transforms.MonthTransform) + + +def test_create_partition_spec_multiple_columns_field_id_sequence(sample_schema): + """Test that field IDs are assigned sequentially.""" + partition_hint = { + "id": "identity", + "name": "identity", + "date_col": "year", + } + result = create_partition_spec(partition_hint, sample_schema) + + # Collect field IDs and check they're sequential + field_ids = [field.field_id for field in result.fields] + assert len(set(field_ids)) == 3 # All unique + assert all(field_id >= 1000 for field_id in field_ids) + + +def test_create_partition_spec_nonexistent_column_raises_error(sample_schema): + """Test that referencing a nonexistent column raises an error.""" + partition_hint = {"nonexistent_col": "year"} + + with pytest.raises(ValueError): + create_partition_spec(partition_hint, sample_schema) diff --git a/elt-common/tests/unit_tests/iceberg/test_schema.py b/elt-common/tests/unit_tests/iceberg/test_schema.py new file mode 100644 index 00000000..557a4876 --- /dev/null +++ b/elt-common/tests/unit_tests/iceberg/test_schema.py @@ -0,0 +1,129 @@ +"""Tests for elt_common.iceberg.schema""" + +import pyarrow as pa +from pyiceberg.schema import Schema, NestedField +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + IntegerType, + LongType, + StringType, + TimeType, + TimestampType, + TimestamptzType, +) +import pytest + +from elt_common.iceberg.schema import arrow_type_to_iceberg, create_schema, evolve_schema + + +@pytest.fixture() +def arrow_schema() -> pa.Schema: + return pa.schema( + [ + pa.field("row_id", pa.int64(), nullable=False), + pa.field("entry_name", pa.string(), nullable=False), + pa.field("entry_timestamp", pa.timestamp(unit="us")), + pa.field("entry_weight", pa.float64()), + ] + ) + + +def test_unsupported_arrow_type_raises(): + with pytest.raises(TypeError, match="unknown to type mapper"): + arrow_type_to_iceberg(pa.string_view()) + + +@pytest.mark.parametrize( + ("arrow_type", "expected_type"), + [ + (pa.bool_(), BooleanType), + (pa.int32(), IntegerType), + (pa.int64(), LongType), + (pa.float64(), DoubleType), + (pa.decimal128(20, 5), DecimalType), + (pa.string(), StringType), + (pa.large_string(), StringType), + (pa.date32(), DateType), + (pa.time64("us"), TimeType), + (pa.timestamp("us"), TimestampType), + (pa.timestamp("ms", tz="UTC"), TimestamptzType), + (pa.binary(), BinaryType), + (pa.large_binary(), BinaryType), + (pa.binary(8), BinaryType), + ], +) +def test_returns_expected_iceberg_type(arrow_type, expected_type): + result = arrow_type_to_iceberg(arrow_type) + assert isinstance(result, expected_type) + + +def test_maps_decimal_precision_and_scale(): + result = arrow_type_to_iceberg(pa.decimal128(12, 3)) + + assert isinstance(result, DecimalType) + assert result.precision == 12 + assert result.scale == 3 + + +@pytest.mark.parametrize("time_type", [pa.time32("s"), pa.time32("ms"), pa.time64("ns")]) +def test_time_precision_other_than_microseconds_raises(time_type): + with pytest.raises(TypeError, match="only supports 'us' precision"): + arrow_type_to_iceberg(time_type) + + +def test_timestamp_nanoseconds_raises(): + with pytest.raises(TypeError, match="does not support timestamps"): + arrow_type_to_iceberg(pa.timestamp("ns")) + + +def test_create_empty_schema(): + empty_schema = pa.schema([]) + iceberg_schema = create_schema(empty_schema) + + assert len(iceberg_schema.fields) == 0 + + +@pytest.mark.parametrize("identifier_fields", [(), ["row_id", "entry_name"]]) +def test_create_iceberg_schema(arrow_schema: pa.Schema, identifier_fields): + iceberg_schema = create_schema(arrow_schema, identifier_fields) + + assert len(iceberg_schema.fields) == len(arrow_schema.names) + assert [f.name for f in iceberg_schema.fields] == arrow_schema.names + assert [not f.required for f in iceberg_schema.fields] == [f.nullable for f in arrow_schema] + + # assume the types are correct as the type mapping is tested above + if identifier_fields: + assert iceberg_schema.identifier_field_ids == [1, 2] + + +@pytest.mark.parametrize( + ["iceberg_field_names", "expected_new_field_names"], + [ + ([], {"row_id", "entry_name", "entry_timestamp", "entry_weight"}), + ( + ["row_id", "entry_name", "entry_timestamp"], + {"row_id", "entry_name", "entry_timestamp", "entry_weight"}, + ), + (["row_id", "entry_name", "entry_timestamp", "entry_weight"], {}), + ], +) +def test_evolve_schema( + arrow_schema: pa.Schema, iceberg_field_names: list[str], expected_new_field_names +): + existing_fields = [ + NestedField(field_id=i + 1, name=name, field_type=StringType(), required=False) + for i, name in enumerate(iceberg_field_names) + ] + existing_schema = Schema(*existing_fields) + + schema_with_new_fields = evolve_schema(existing_schema, arrow_schema) + + if expected_new_field_names: + assert schema_with_new_fields is not None + assert {f.name for f in schema_with_new_fields.fields} == expected_new_field_names + else: + assert schema_with_new_fields is None diff --git a/elt-common/tests/unit_tests/iceberg/test_sortorder.py b/elt-common/tests/unit_tests/iceberg/test_sortorder.py new file mode 100644 index 00000000..2e4a8550 --- /dev/null +++ b/elt-common/tests/unit_tests/iceberg/test_sortorder.py @@ -0,0 +1,49 @@ +import pytest +from elt_common.iceberg.sortorder import create_sort_order +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortDirection + + +def test_create_sort_order_when_config_empty_or_none(sample_schema): + assert create_sort_order(None, sample_schema) == UNSORTED_SORT_ORDER + assert create_sort_order({}, sample_schema) == UNSORTED_SORT_ORDER + + +@pytest.mark.parametrize( + "sort_config,expected_source_id,expected_direction", + [ + ({"id": "asc"}, 1, SortDirection.ASC), + ({"name": "desc"}, 2, SortDirection.DESC), + ], +) +def test_create_sort_order_single_column( + sample_schema, sort_config, expected_source_id, expected_direction +): + result = create_sort_order(sort_config, sample_schema) + + assert len(result.fields) == 1 + field = result.fields[0] + assert field.source_id == expected_source_id + assert field.direction == expected_direction + + +def test_create_sort_order_multiple_columns(sample_schema): + sort_config = {"id": "asc", "name": "desc", "date_col": "asc"} + result = create_sort_order(sort_config, sample_schema) + + assert len(result.fields) == 3 + # Check that all columns are present + source_ids = {field.source_id for field in result.fields} + assert source_ids == {1, 2, 3} # id, name, date_col + + +@pytest.mark.parametrize( + "sort_config", + [ + {"nonexistent_col": "asc"}, # invalid column name + {"id": "invalid"}, # invalid direction value + ], +) +def test_create_sort_order_invalid_config(sample_schema, sort_config): + """Test that invalid column names and directions raise errors.""" + with pytest.raises(Exception): + create_sort_order(sort_config, sample_schema)