diff --git a/src/imas_streams/cli.py b/src/imas_streams/cli.py index f763373..608066a 100644 --- a/src/imas_streams/cli.py +++ b/src/imas_streams/cli.py @@ -3,8 +3,11 @@ import click import imas +from imas.ids_defs import CLOSEST_INTERP, IDS_TIME_MODE_HOMOGENEOUS -from imas_streams import BatchedIDSConsumer +from imas_streams import BatchedIDSConsumer, StreamingIDSProducer + +_PROGRESS_BAR_UPDATE_MINSTEP = 1001 @click.group(invoke_without_command=True, no_args_is_help=True) @@ -22,6 +25,116 @@ def main() -> None: ) +@main.command() +@click.argument("imas_uri") +@click.argument("kafka_host") +@click.argument("kafka_topic") +@click.option( + "--get", + is_flag=True, + help="Get full IDS instead of iteratively requesting a time slice with get_slice.", +) +@click.option("-n", default=0, help="Maximum number of time slices to stream") +def imasentry_to_kafka( + imas_uri: str, kafka_host: str, kafka_topic: str, get: bool, n: int +) -> None: + """Stream data from an existing IMAS data entry to a Kafka topic. + + The input data must be limited to dynamic floating point data, and array shapes must + remain constant for all time slices. An error will be displayed when this is not + adhered to. + + \b + Arguments: + IMAS_URI IMAS URI (including IDS and optionally occurrence) with the data to + be streamed. For example: "imas:hdf5?path=./testdata#magnetics". + KAFKA_HOST Kafka host and port (aka bootstrap.servers). E.g. 'localhost:9092'. + KAFKA_TOPIC Name of the kafka topic to stream the data to. + """ + # Local import: kafka is an optional dependency + from imas_streams.kafka import KafkaProducer, KafkaSettings + + # Extract IDS/occurrence + base_uri, _, ids_and_occurrence = imas_uri.partition("#") + idsname, _, occurrence = ids_and_occurrence.partition(":") + if not idsname: + raise click.UsageError( + f"Invalid IMAS URI '{imas_uri}': no IDS name given. Hint: " + "add '#' to your URI." + ) + if occurrence: + try: + occurrence = int(occurrence) + except ValueError: + raise click.UsageError( + f"Invalid IMAS URI '{imas_uri}': " + f"occurrence '{occurrence}' is not an integer." + ) from None + else: + occurrence = 0 + + logging.info("Opening data entry...") + with imas.DBEntry(base_uri, "r") as entry: + logging.info("Reading IDS...") + # Ensure IDS uses homogeneous time, extract all time points + lazy_ids = entry.get(idsname, occurrence, lazy=True, autoconvert=False) + if lazy_ids.ids_properties.homogeneous_time != IDS_TIME_MODE_HOMOGENEOUS: + raise click.ClickException("The loaded IDS is not using homogeneous time.") + times = lazy_ids.time[:] + del lazy_ids + logging.info("Found %d time slices to stream", len(times)) + if n and n < len(times): + logging.info("Streaming first %d time slices", n) + times = times[:n] + n = len(times) + + # Get first time slice to obtain the static and metadata + ids = entry.get_slice( + idsname, times[0], CLOSEST_INTERP, occurrence, autoconvert=False + ) + ids_producer = StreamingIDSProducer(ids) + kafka_producer = KafkaProducer( + KafkaSettings(host=kafka_host, topic_name=kafka_topic), + ids_producer.metadata, + ) + + if get: + logging.info("Loading full IDS...") + ids = entry.get(idsname, occurrence, autoconvert=False) + logging.info("IDS loaded.") + + with click.progressbar( + ids_producer.messages_from_batch(ids), + length=n, + label="Streaming time slices", + show_pos=True, + update_min_steps=_PROGRESS_BAR_UPDATE_MINSTEP, + ) as bar: + for i, data in enumerate(bar): + if i == n: + break + kafka_producer.produce(bytes(data)) + # Make bar go to 100% + bar.make_step(n % _PROGRESS_BAR_UPDATE_MINSTEP) + bar.render_progress() + return + + # Send remaining time slices + with click.progressbar( + times, label="Streaming time slices", show_pos=True + ) as bar: + for time in bar: + ids = entry.get_slice( + idsname, + time, + CLOSEST_INTERP, + occurrence, + autoconvert=False, + lazy=True, + ) + kafka_producer.produce(bytes(ids_producer.create_message(ids))) + + @main.command() @click.argument("kafka_host") @click.argument("kafka_topic") @@ -43,11 +156,9 @@ def kafka_to_imasentry( ): """Consume streaming IMAS data from Kafka and store data in an IMAS Data Entry. - N.B. This program requires the optional kafka dependency. - \b Arguments: - KAFKA_HOST Kafka host and port (aka bootstrap.servers). E.g. 'localhost:9092' + KAFKA_HOST Kafka host and port (aka bootstrap.servers). E.g. 'localhost:9092'. KAFKA_TOPIC Name of the kafka topic with streaming IMAS data. IMAS_URI IMAS URI to store the data at, for example 'imas:hdf5?path=./out'. The program will not overwrite existing data (unless the --overwrite diff --git a/src/imas_streams/producer.py b/src/imas_streams/producer.py index 7c2bb93..be08de8 100644 --- a/src/imas_streams/producer.py +++ b/src/imas_streams/producer.py @@ -1,4 +1,5 @@ import copy +from collections.abc import Iterator import imas import numpy as np @@ -9,7 +10,9 @@ from imas_streams.metadata import DynamicData, StreamingIMASMetadata -def _metadata_from_time_slice(time_slice: IDSToplevel, static_paths: list[str]): +def _metadata_from_time_slice( + time_slice: IDSToplevel, static_paths: list[str] +) -> StreamingIMASMetadata: # -- Data sanity checks -- # The IDS must use homogeneous time mode if time_slice.ids_properties.homogeneous_time != IDS_TIME_MODE_HOMOGENEOUS: @@ -191,6 +194,7 @@ def metadata(self) -> StreamingIMASMetadata: return self._metadata def create_message(self, time_slice: IDSToplevel) -> bytearray: + """Create a single IMAS Streams message from the provided time slice.""" buffer = bytearray(self._buffersize) curindex = 0 for dyndata in self._metadata.dynamic_data: @@ -214,3 +218,38 @@ def create_message(self, time_slice: IDSToplevel) -> bytearray: if not (curindex == len(buffer) == self._buffersize): raise RuntimeError("Internal error: incorrect size of data buffer") return buffer + + def messages_from_batch(self, ids: IDSToplevel) -> Iterator[bytearray]: + """Create an IMAS Streams message for each time slice in the provided IDS. + + N.B. This method currently doesn't support streaming dynamic arrays of + structures. + """ + if ids.ids_properties.homogeneous_time != IDS_TIME_MODE_HOMOGENEOUS: + raise ValueError("The provided IDS doesn't use homogeneous time") + + nodes = [] + for dyndata in self._metadata.dynamic_data: + node = ids[dyndata.path] + if ( + dyndata.path != "time" + and not node.metadata.coordinates[-1].is_time_coordinate + ): + raise NotImplementedError( + "messages_from_batch() does not implement streaming data in dynamic" + " arrays of structures. Please use create_message() instead." + ) + nodes.append(node) + + buffer = bytearray(self._buffersize) + for i in range(len(ids.time)): + curindex = 0 + for node in nodes: + arr: np.ndarray = node.value[..., i] + nbytes = arr.nbytes + buffer[curindex : curindex + nbytes] = arr.tobytes() + curindex += nbytes + + if not (curindex == len(buffer) == self._buffersize): + raise RuntimeError("Internal error: incorrect size of data buffer") + yield buffer diff --git a/tests/test_muscle3.py b/tests/test_muscle3.py index 858fc6e..26542b6 100644 --- a/tests/test_muscle3.py +++ b/tests/test_muscle3.py @@ -31,10 +31,6 @@ """ -@pytest.mark.xfail( - tuple(map(int, ymmsl.__version__.split(".")[:3])) < (0, 15, 1), - reason="Test needs YMMSL Entry Points plugins", -) def test_load_ymmsl_config(): config = ymmsl.load_as(Configuration, ymmsl_config) resolve(Reference([]), config)