Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.50.0] - 2026-03-06

### Added

- `tilebox-datasets`: Added dataset-level `find` and `query` methods on both sync and async `DatasetClient` to query
across multiple collections.


## [0.49.0] - 2026-02-19

### Added
Expand Down Expand Up @@ -333,7 +341,8 @@ the first client that does not cache data (since it's already on the local file
- Released under the [MIT](https://opensource.org/license/mit) license.
- Released packages: `tilebox-datasets`, `tilebox-workflows`, `tilebox-storage`, `tilebox-grpc`

[Unreleased]: https://github.com/tilebox/tilebox-python/compare/v0.49.0...HEAD
[Unreleased]: https://github.com/tilebox/tilebox-python/compare/v0.50.0...HEAD
[0.50.0]: https://github.com/tilebox/tilebox-python/compare/v0.49.0...v0.50.0
[0.49.0]: https://github.com/tilebox/tilebox-python/compare/v0.48.0...v0.49.0
[0.48.0]: https://github.com/tilebox/tilebox-python/compare/v0.47.0...v0.48.0
[0.47.0]: https://github.com/tilebox/tilebox-python/compare/v0.46.0...v0.47.0
Expand Down
62 changes: 62 additions & 0 deletions tilebox-datasets/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,68 @@ def test_timeseries_dataset_collection_find_not_found() -> None:
mocked.collection.find("14eb91a2-a42f-421f-9397-1dab577f05a9")


@settings(max_examples=1)
@given(example_datapoints(generated_fields=True, missing_fields=True))
def test_timeseries_dataset_find_multiple_collections(expected_datapoint: ExampleDatapoint) -> None:
"""Test that DatasetClient.find() supports querying by mixed collection reference types."""
dataset, service = _mocked_dataset()

named_collection = CollectionInfo(Collection(uuid4(), "named-collection"), None, None)
other_collection = CollectionInfo(Collection(uuid4(), "other-collection"), None, None)

service.get_collections.return_value = Promise.resolve([named_collection, other_collection])
message = AnyMessage(example_dataset_type_url(), expected_datapoint.SerializeToString())
service.query_by_id.return_value = Promise.resolve(message)

datapoint_id = uuid_message_to_uuid(expected_datapoint.id)
datapoint = dataset.find(
datapoint_id,
[
named_collection.collection.name,
],
)

assert isinstance(datapoint, xr.Dataset)
service.get_collections.assert_called_once_with(dataset._dataset.id, True, True)
service.query_by_id.assert_called_once_with(
dataset._dataset.id,
[
named_collection.collection.id,
],
datapoint_id,
False,
)


@settings(max_examples=1)
@given(pages=paginated_query_results())
def test_timeseries_dataset_query_multiple_collections(pages: list[QueryResultPage]) -> None:
"""Test that DatasetClient.query() forwards all selected collection ids to the backend query endpoint."""
dataset, service = _mocked_dataset()

named_collection = CollectionInfo(Collection(uuid4(), "named-collection"), None, None)
other_collection = CollectionInfo(Collection(uuid4(), "other-collection"), None, None)

service.get_collections.return_value = Promise.resolve([named_collection, other_collection])
service.query.side_effect = [Promise.resolve(page) for page in pages]

interval = TimeInterval(datetime.now(), datetime.now() + timedelta(days=1))
queried = dataset.query(
collections=[
named_collection.collection.name,
],
temporal_extent=interval,
)

_assert_datapoints_match(queried, pages)
service.get_collections.assert_called_once_with(dataset._dataset.id, True, True)
first_call_args = service.query.call_args_list[0][0]
assert first_call_args[0] == dataset._dataset.id
assert first_call_args[1] == [
named_collection.collection.id,
]


@patch("tilebox.datasets.sync.pagination.tqdm")
@patch("tilebox.datasets.progress.tqdm")
@settings(deadline=1000, max_examples=3) # increase deadline to 1s to not timeout because of the progress bar
Expand Down
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
214 changes: 188 additions & 26 deletions tilebox-datasets/tilebox/datasets/aio/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from _tilebox.grpc.aio.producer_consumer import async_producer_consumer
from _tilebox.grpc.error import ArgumentError, NotFoundError
from tilebox.datasets.aio.pagination import with_progressbar, with_time_progress_callback, with_time_progressbar
from tilebox.datasets.data.collection import CollectionInfo
from tilebox.datasets.data.collection import Collection, CollectionInfo
from tilebox.datasets.data.data_access import QueryFilters, SpatialFilter, SpatialFilterLike
from tilebox.datasets.data.datapoint import QueryResultPage
from tilebox.datasets.data.datasets import Dataset
Expand Down Expand Up @@ -139,6 +139,122 @@ async def delete_collection(self, collection: "str | UUID | CollectionClient") -

await self._service.delete_collection(self._dataset.id, collection_id)

async def find(
self,
datapoint_id: str | UUID,
collections: "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | None" = None,
skip_data: bool = False,
) -> xr.Dataset:
"""
Find a specific datapoint in one of the specified collections by its id.

Args:
datapoint_id: The id of the datapoint to find.
collections: The collections to search in. Supports collection names, ids or collection objects.
If not specified, all collections in the dataset are searched.
skip_data: Whether to skip the actual data of the datapoint. If True, only
datapoint metadata is returned.

Returns:
The datapoint as an xarray dataset.
"""
collection_ids = await self._collection_ids(collections)
try:
datapoint = await self._service.query_by_id(
self._dataset.id,
collection_ids,
as_uuid(datapoint_id),
skip_data,
)
except ArgumentError:
raise ValueError(f"Invalid datapoint id: {datapoint_id} is not a valid UUID") from None
except NotFoundError:
raise NotFoundError(f"No such datapoint {datapoint_id}") from None

message_type = get_message_type(datapoint.type_url)
data = message_type.FromString(datapoint.value)

converter = MessageToXarrayConverter(initial_capacity=1)
converter.convert(data)
return converter.finalize("time", skip_empty_fields=skip_data).isel(time=0)

async def query(
self,
*,
collections: "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | dict[str, CollectionClient] | None",
temporal_extent: TimeIntervalLike,
spatial_extent: SpatialFilterLike | None = None,
skip_data: bool = False,
show_progress: bool | ProgressCallback = False,
) -> xr.Dataset:
"""
Query datapoints in the specified collections and temporal extent.

Args:
collections: The collections to query in. Supports collection names, ids or collection objects.
If not specified, all collections in the dataset are queried.
temporal_extent: The temporal extent to query data for. (Required)
spatial_extent: The spatial extent to query data in. (Optional)
skip_data: Whether to skip the actual data of the datapoint. If True, only
datapoint metadata is returned.
show_progress: Whether to show a progress bar while loading the data.
If a callable is specified it is used as callback to report progress percentages.

Returns:
Matching datapoints in the given temporal and spatial extent as an xarray dataset.
"""
if temporal_extent is None:
raise ValueError("A temporal_extent for your query must be specified")

collection_ids = await self._collection_ids(collections)
pages = _iter_query_pages(
self._service,
self._dataset.id,
collection_ids,
temporal_extent,
spatial_extent,
skip_data,
dataset_name=self.name,
show_progress=show_progress,
)
return await _convert_to_dataset(pages, skip_empty_fields=skip_data)

async def _collection_id(self, collection: "UUID | Collection | CollectionInfo | CollectionClient") -> UUID:
if isinstance(collection, CollectionClient):
return collection._collection.id
if isinstance(collection, CollectionInfo):
return collection.collection.id
if isinstance(collection, Collection):
return collection.id
return collection

async def _collection_ids(
self,
collections: "list[str] | list[UUID] | list[Collection] | list[CollectionInfo] | list[CollectionClient] | dict[str, CollectionClient] | None",
) -> list[UUID]:
if collections is None:
return []

all_collections: list[CollectionInfo] = await self._service.get_collections(self._dataset.id, True, True)
# find all valid collection names and ids
collections_by_name = {c.collection.name: c.collection.id for c in all_collections}
valid_collection_ids = {c.collection.id for c in all_collections}

collection_ids: list[UUID] = []
for collection in collections:
if isinstance(collection, str):
try:
collection_ids.append(collections_by_name[collection])
except KeyError:
raise ValueError(f"Collection {collection} not found in dataset {self.name}") from None
else:
collection_id = await self._collection_id(collection)
if collection_id not in valid_collection_ids:
raise ValueError(f"Collection {collection_id} is not part of the dataset {self.name}")
collection_ids.append(collection_id)

return collection_ids

def __repr__(self) -> str:
return f"{self.name} [Timeseries Dataset]: {self._dataset.summary}"

Expand Down Expand Up @@ -221,7 +337,7 @@ async def find(self, datapoint_id: str | UUID, skip_data: bool = False) -> xr.Da
"""
try:
datapoint = await self._dataset._service.query_by_id(
[self._collection.id], as_uuid(datapoint_id), skip_data
self._dataset._dataset.id, [self._collection.id], as_uuid(datapoint_id), skip_data
)
except ArgumentError:
raise ValueError(f"Invalid datapoint id: {datapoint_id} is not a valid UUID") from None
Expand Down Expand Up @@ -259,8 +375,14 @@ async def _find_interval(
filters = QueryFilters(temporal_extent=IDInterval.parse(datapoint_id_interval, end_inclusive=end_inclusive))

async def request(page: PaginationProtocol) -> QueryResultPage:
query_page = Pagination(page.limit, page.starting_after)
return await self._dataset._service.query([self._collection.id], filters, skip_data, query_page)
return await _query_page(
self._dataset._service,
self._dataset._dataset.id,
[self._collection.id],
filters,
skip_data,
page,
)

initial_page = Pagination()
pages = paginated_request(request, initial_page)
Expand Down Expand Up @@ -350,7 +472,16 @@ async def query(
if temporal_extent is None:
raise ValueError("A temporal_extent for your query must be specified")

pages = self._iter_pages(temporal_extent, spatial_extent, skip_data, show_progress=show_progress)
pages = _iter_query_pages(
self._dataset._service,
self._dataset._dataset.id,
[self._collection.id],
temporal_extent,
spatial_extent,
skip_data,
dataset_name=self._dataset.name,
show_progress=show_progress,
)
return await _convert_to_dataset(pages, skip_empty_fields=skip_data)

async def _iter_pages(
Expand All @@ -361,29 +492,19 @@ async def _iter_pages(
show_progress: bool | ProgressCallback = False,
page_size: int | None = None,
) -> AsyncIterator[QueryResultPage]:
time_interval = TimeInterval.parse(temporal_extent)
filters = QueryFilters(time_interval, SpatialFilter.parse(spatial_extent) if spatial_extent else None)

request = partial(self._query_page, filters, skip_data)

initial_page = Pagination(limit=page_size)
pages = paginated_request(request, initial_page)

if callable(show_progress):
pages = with_time_progress_callback(pages, time_interval, show_progress)
elif show_progress:
message = f"Fetching {self._dataset.name}"
pages = with_time_progressbar(pages, time_interval, message)

async for page in pages:
async for page in _iter_query_pages(
self._dataset._service,
self._dataset._dataset.id,
[self._collection.id],
temporal_extent,
spatial_extent,
skip_data,
dataset_name=self._dataset.name,
show_progress=show_progress,
page_size=page_size,
):
yield page

async def _query_page(
self, filters: QueryFilters, skip_data: bool, page: PaginationProtocol | None = None
) -> QueryResultPage:
query_page = Pagination(page.limit, page.starting_after) if page else Pagination()
return await self._dataset._service.query([self._collection.id], filters, skip_data, query_page)

async def ingest(
self,
data: IngestionData,
Expand Down Expand Up @@ -477,6 +598,47 @@ async def delete(self, datapoints: DatapointIDs, *, show_progress: bool | Progre
return num_deleted


async def _query_page( # noqa: PLR0913
service: TileboxDatasetService,
dataset_id: UUID,
collection_ids: list[UUID] | None,
filters: QueryFilters,
skip_data: bool,
page: PaginationProtocol | None = None,
) -> QueryResultPage:
query_page = Pagination(page.limit, page.starting_after) if page else Pagination()
return await service.query(dataset_id, collection_ids or [], filters, skip_data, query_page)


async def _iter_query_pages( # noqa: PLR0913
service: TileboxDatasetService,
dataset_id: UUID,
collection_ids: list[UUID] | None,
temporal_extent: TimeIntervalLike,
spatial_extent: SpatialFilterLike | None = None,
skip_data: bool = False,
*,
dataset_name: str,
show_progress: bool | ProgressCallback = False,
page_size: int | None = None,
) -> AsyncIterator[QueryResultPage]:
time_interval = TimeInterval.parse(temporal_extent)
filters = QueryFilters(time_interval, SpatialFilter.parse(spatial_extent) if spatial_extent else None)

request = partial(_query_page, service, dataset_id, collection_ids, filters, skip_data)

initial_page = Pagination(limit=page_size)
pages = paginated_request(request, initial_page)

if callable(show_progress):
pages = with_time_progress_callback(pages, time_interval, show_progress)
elif show_progress:
pages = with_time_progressbar(pages, time_interval, f"Fetching {dataset_name}")

async for page in pages:
yield page


async def _convert_to_dataset(pages: AsyncIterator[QueryResultPage], skip_empty_fields: bool = False) -> xr.Dataset:
"""
Convert an async iterator of QueryResultPages into a single xarray Dataset
Expand Down
Loading
Loading