From ac64da36bb3d75b8641c4a4f66e319f8a195c43b Mon Sep 17 00:00:00 2001 From: Svetlana Titova Date: Tue, 26 May 2026 16:15:30 -0700 Subject: [PATCH 1/6] =?UTF-8?q?=EF=BB=BFfirst=20checkin;=20remove=20migrat?= =?UTF-8?q?ion=20(it=20will=20be=20in=20emodpy),=20adding=20Node-=20and=20?= =?UTF-8?q?Coordiantor-=20level=20events.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- emod_api/campaign.py | 392 ++++++--- emod_api/demographics/demographics.py | 9 +- emod_api/demographics/demographics_base.py | 3 + emod_api/migration/README.md | 10 - emod_api/migration/__init__.py | 0 emod_api/migration/__main__.py | 22 - emod_api/migration/migration.py | 782 ------------------ emod_api/utils/str_enum.py | 3 + tests/test_campaign_module.py | 363 +++++++-- tests/test_migration.py | 875 --------------------- tests/unittests/test_migration_imports.py | 27 - 11 files changed, 613 insertions(+), 1873 deletions(-) delete mode 100644 emod_api/migration/README.md delete mode 100644 emod_api/migration/__init__.py delete mode 100644 emod_api/migration/__main__.py delete mode 100644 emod_api/migration/migration.py delete mode 100644 tests/test_migration.py delete mode 100644 tests/unittests/test_migration_imports.py diff --git a/emod_api/campaign.py b/emod_api/campaign.py index f164ab3c..fa9011dd 100644 --- a/emod_api/campaign.py +++ b/emod_api/campaign.py @@ -1,52 +1,108 @@ #!/usr/bin/env python -""" -You use this simple campaign builder by importing it, adding valid events via "add", and writing it out with "save". +"""Simple campaign builder for EMOD simulations. + +Import this module, add valid campaign events via ``add``, and write the +campaign file with ``save``. """ import json +import warnings from emod_api import schema_to_class as s2c schema_path = None _schema_json = None campaign_dict = {"Events": [], "Use_Defaults": 1} -pubsub_signals_subbing = [] -pubsub_signals_pubbing = [] -adhocs = [] -custom_coordinator_events = [] -custom_node_events = [] -event_map = {} +individual_events_listened = [] +individual_events_broadcast = [] +node_events_broadcast = [] +node_events_listened = [] +coordinator_events_broadcast = [] +coordinator_events_listened = [] use_old_adhoc_handling = False unsafe = False implicits = list() -trigger_list = None +individual_builtin_events = [] +node_builtin_events = [] +coordinator_builtin_events = [] def reset(): + """Reset all campaign state to defaults. + + Clears accumulated events, signal tracking lists, event mappings, + and the schema cache. + """ campaign_dict["Events"].clear() - pubsub_signals_subbing.clear() - pubsub_signals_pubbing.clear() - adhocs.clear() - custom_coordinator_events.clear() - custom_node_events.clear() + individual_events_listened.clear() + individual_events_broadcast.clear() + node_events_broadcast.clear() + node_events_listened.clear() + coordinator_events_broadcast.clear() + coordinator_events_listened.clear() implicits.clear() + individual_builtin_events.clear() + node_builtin_events.clear() + coordinator_builtin_events.clear() + s2c.clear_schema_cache() - event_map.clear() - s2c.clear_schema_cache() +def _find_builtin_events(schema, reporter_key, events_key): + """Recursively find a builtin events using reporter entry and extract its event list. + Walks the schema looking for ``reporter_key`` as a dict key. When + found, looks up ``events_key`` inside it and returns the + ``"Built-in"`` list if present, otherwise the ``"enum"`` list (used in EMOD-Generic) -def set_schema(schema_path_in): + Args: + schema: The schema JSON object (or sub-object) to search. + reporter_key: The reporter key to find (e.g. + ``"ReportEventRecorder"``). + events_key: The events parameter inside the reporter (e.g. + ``"Report_Event_Recorder_Events"``). + + Returns: + A list of event name strings, or ``None`` if the reporter + or its event list is not found. """ - Set the (path to) the schema file. And reset all campaign variables. This is essentially a - "start_building_campaign" function. + if isinstance(schema, dict): + if reporter_key in schema: + events_entry = schema[reporter_key] + if isinstance(events_entry, dict): + events_param = events_entry.get(events_key) + if isinstance(events_param, dict): + builtin = events_param.get("Built-in") + if isinstance(builtin, list): + return builtin + enum = events_param.get("enum") + if isinstance(enum, list): + return enum + return None + for value in schema.values(): + result = _find_builtin_events(value, reporter_key, events_key) + if result is not None: + return result + elif isinstance(schema, list): + for item in schema: + result = _find_builtin_events(item, reporter_key, events_key) + if result is not None: + return result + return None - Parameters: - schema_path_in (str): The path to a schema.json file - Returns: +def set_schema(schema_path_in): + """Set the schema file path and reset all campaign state. + + This is essentially the "start building a campaign" entry point. + It clears any previously accumulated events and loads the new schema. + Also extracts built-in event lists for individual, node, and + coordinator levels by recursively searching for + ``ReportEventRecorder``, ``ReportEventRecorderNode``, and + ``ReportEventRecorderCoordinator`` in the schema. + Args: + schema_path_in: Path to a ``schema.json`` file. """ reset() global schema_path, _schema_json @@ -55,116 +111,258 @@ def set_schema(schema_path_in): with open(schema_path_in) as schema_file: _schema_json = json.load(schema_file) + found = _find_builtin_events(_schema_json, "ReportEventRecorder", "Report_Event_Recorder_Events") + if found: + individual_builtin_events.extend(found) + + found = _find_builtin_events(_schema_json, "ReportEventRecorderNode", "Report_Node_Event_Recorder_Events") + if found: + node_builtin_events.extend(found) + + found = _find_builtin_events(_schema_json, "ReportEventRecorderCoordinator", "Report_Coordinator_Event_Recorder_Events") + if found: + coordinator_builtin_events.extend(found) + def get_schema(): + """Return the loaded schema JSON dictionary. + + Returns: + The parsed schema dictionary, or ``None`` if ``set_schema`` has + not been called. + """ return _schema_json -def add(event, name=None, first=False): - """ - Add a complete campaign event to the campaign builder. The new event is assumed to be a Python dict, and a - valid event. The new event is not validated here. - Set the first flag to True if this is the first event in a campaign because it functions as an - accumulator and in some situations like sweeps it might have been used recently. +def add(event, note: str = None): + """Add a complete campaign event to the campaign builder. + + The event is assumed to be valid and is not validated here. + + Args: + event: A complete campaign event object. It must support + ``finalize()`` and dict-style key assignment. + note: An optional human-readable note added to the event + inside the output ``campaign.json`` file. """ event.finalize() - if first: - print("Use of 'first' flag is deprecated. Use set_schema to start build a new, empty campaign.") - campaign_dict["Events"].clear() - if "Event_Name" not in event and name is not None: - event["Event_Name"] = name - if "Listening" in event: - pubsub_signals_subbing.extend(event["Listening"]) - event.pop("Listening") - if "Broadcasting" in event: - pubsub_signals_pubbing.extend(event["Broadcasting"]) - event.pop("Broadcasting") + if note is not None: + event["Note"] = note campaign_dict["Events"].append(event) -def get_trigger_list(): - global trigger_list - if get_schema(): - # This needs to be fixed in the schema post-processor: maybe create a new idmTime:EventEnum and replace - # all the occurrences with a reference to that. - try: - trigger_list = get_schema()["idmTypes"]["idmAbstractType:EventCoordinator"]["BroadcastCoordinatorEvent"][ - "Broadcast_Event"]["enum"] - except Exception: - trigger_list = get_schema()["idmTypes"]["idmType:IncidenceCounter"]["Trigger_Condition_List"]["Built-in"] - return trigger_list +def save(filename: str = "campaign.json"): + """Save the accumulated campaign events to a JSON file. -def save(filename="campaign.json"): - """ - Save 'campaign_dict' as file named 'filename'. + Args: + filename: Output file path. + + Returns: + The filename that was written. """ with open(filename, "w") as camp_file: json.dump(campaign_dict, camp_file, sort_keys=True, indent=4) - import copy - ignored_events = copy.deepcopy(set(pubsub_signals_pubbing)) - non_camp_events = set() - if len(pubsub_signals_subbing) > 0: - for event in set(pubsub_signals_subbing): - if event in ignored_events: - ignored_events.remove(event) - if len(non_camp_events) > 0: - for event in set(non_camp_events): - if event in get_adhocs() and not unsafe: - raise RuntimeError(f"ERROR: Report is configured to LISTEN to the following non-existent event: \n" - f"{event} \nPlease fix the error.\n") + return filename -def get_adhocs(): - return event_map +def _validate_custom_events(listened_list, broadcast_list, builtin_list, level): + """Validate that listened-to events are broadcast and vice versa. + + Built-in events are excluded from validation since they are + handled by the simulation engine and do not need to be explicitly + broadcast or listened to in the campaign. + + Args: + listened_list: List of event names being listened to. + broadcast_list: List of event names being broadcast. + builtin_list: List of built-in event names to exclude from + validation. + level: Label for the event level (e.g. ``"coordinator"`` + or ``"node"``) used in error/warning messages. + + Returns: + A deduplicated list of custom (non-built-in) broadcast event + name strings. + + Raises: + ValueError: If any events are listened to but never broadcast. + """ + builtins = set(builtin_list) + + broadcast_matching_builtins = set(broadcast_list) & builtins + if broadcast_matching_builtins: + warnings.warn( + f"The following {level}-level broadcast events mirror built-in {level}-level events, " + f"therefore these events will be broadcast by the simulation as well as the campaign: " + f"{sorted(broadcast_matching_builtins)}") + + listened = set(listened_list) - builtins + broadcast = set(broadcast_list) - builtins + + listened_not_broadcast = listened - broadcast + if listened_not_broadcast: + raise ValueError( + f"The following {level}-level events are listened to but never broadcast. This means that any campaign " + f"interventions that rely on listening to these events will never fire. Please fix the error by either " + f"broadcasting these events in the campaign or removing the interventions that are listening for them:\n" + f"{sorted(listened_not_broadcast)}") + + broadcast_not_listened = broadcast - listened + if broadcast_not_listened: + warnings.warn( + f"The following {level} events are broadcast but nothing in the campaign " + f"is listening to them: {sorted(broadcast_not_listened)}") + + return list(broadcast) def get_custom_coordinator_events(): - return list(set(custom_coordinator_events)) + """Validate and return deduplicated custom coordinator-level events. + + Returns: + A list of unique coordinator event name strings that are broadcast + in the campaign. + + Raises: + ValueError: If any coordinator events are listened to but + never broadcast. + """ + return _validate_custom_events(coordinator_events_listened, coordinator_events_broadcast, coordinator_builtin_events, "coordinator") def get_custom_node_events(): - return list(set(custom_node_events)) + """Validate and return deduplicated custom node-level events. + Returns: + A list of unique node event name strings that are broadcast + in the campaign. -def get_recv_trigger(trigger, old=use_old_adhoc_handling): + Raises: + ValueError: If any node events are listened to but + never broadcast. """ - Get the correct representation of a trigger (also called signal or even event) that is being listened to. + return _validate_custom_events(node_events_listened, node_events_broadcast, node_builtin_events, "node") + + +def get_custom_individual_events(): + """Validate and return deduplicated custom individual-level events. + + Returns: + A list of unique individual event name strings that are broadcast + in the campaign. + + Raises: + ValueError: If any individual events are listened to but + never broadcast. """ - pubsub_signals_subbing.append(trigger) - return get_event(trigger, old) + return _validate_custom_events(individual_events_listened, individual_events_broadcast, individual_builtin_events, "individual") -def get_send_trigger(trigger, old=use_old_adhoc_handling): +def get_recv_trigger(trigger, old=use_old_adhoc_handling): + """Register an individual-level event as listened to. + + Tracks which individual events are used throughout the simulation + so that ``get_custom_individual_events`` can validate that every + listened-to event has a corresponding broadcast. + + Args: + trigger: The individual event name string. + old: Unused. Kept for backwards compatibility. + + Returns: + The event name, unchanged. + """ + if not trigger: + raise ValueError("Event name must not be None or empty.") + individual_events_listened.append(trigger) + return trigger + +def set_listened_node_event(event: str) -> str: + """Register a node-level event as listened to. + + Tracks which node events are used throughout the simulation so + that ``get_custom_node_events`` can validate that every listened-to + event has a corresponding broadcast. + + Args: + event: The node event name string. + + Returns: + The event name, unchanged. """ - Get the correct representation of a trigger (also called signal or even event) that is being broadcast. + if not event: + raise ValueError("Event name must not be None or empty.") + node_events_listened.append(event) + return event + +def set_listened_coordinator_event(event: str) -> str: + """Register a coordinator-level event as listened to. + + Tracks which coordinator events are used throughout the simulation + so that ``get_custom_coordinator_events`` can validate that every + listened-to event has a corresponding broadcast. + + Args: + event: The coordinator event name string. + + Returns: + The event name, unchanged. """ - pubsub_signals_pubbing.append(trigger) - return get_event(trigger, old) + if not event: + raise ValueError("Event name must not be None or empty.") + coordinator_events_listened.append(event) + return event + +def get_send_trigger(trigger, old=use_old_adhoc_handling): + """Register an individual-level event as broadcast. + Args: + trigger: The individual event name string. + old: Unused. Kept for backwards compatibility. -def get_event(event, old=False): + Returns: + The event name, unchanged. """ - Basic placeholder functionality for now. This will map new ad-hoc events to GP_EVENTs and manage that 'cache' - If event in built-ins, return event, else if in adhoc map, return mapped event, else add to adhoc_map and return - mapped event. + if not trigger: + raise ValueError("Event name must not be None or empty.") + individual_events_broadcast.append(trigger) + return trigger + +def set_broadcast_node_event(event: str) -> str: + """Register a node-level event as broadcast. + + Tracks which node events are used throughout the simulation so + that ``get_custom_node_events`` can validate that every broadcast + event has something listening to it. + + Args: + event: The node event name string. + + Returns: + The event name, unchanged. """ - if event is None or event == "": - raise ValueError("campaign.get_event() called with an empty event. Please specify a string.") + if not event: + raise ValueError("Event name must not be None or empty.") + node_events_broadcast.append(event) + return event + +def set_broadcast_coordinator_event(event: str) -> str: + """Register a coordinator-level event as broadcast. - return_event = None - global trigger_list - if trigger_list is None: - trigger_list = get_trigger_list() + Tracks which coordinator events are used throughout the simulation + so that ``get_custom_coordinator_events`` can validate that every + broadcast event has something listening to it. + + Args: + event: The coordinator event name string. + + Returns: + The event name, unchanged. + """ + if not event: + raise ValueError("Event name must not be None or empty.") + coordinator_events_broadcast.append(event) + return event - if event in trigger_list: - return_event = event - elif event in event_map: - return_event = event_map[event] - else: - # get next entry in GP_EVENT_xxx - new_event_name = event if old else f'GP_EVENT_{len(event_map):03d}' - event_map[event] = new_event_name - return_event = event_map[event] - return return_event diff --git a/emod_api/demographics/demographics.py b/emod_api/demographics/demographics.py index 5b6832ca..d23e5c84 100644 --- a/emod_api/demographics/demographics.py +++ b/emod_api/demographics/demographics.py @@ -30,9 +30,14 @@ def __init__(self, nodes: list[Node], idref: str = None, default_node: Node = No """ super().__init__(nodes=nodes, idref=idref, default_node=default_node) - # No current default settings + # set some standard EMOD defaults. set_defaults should always be True unless reading from a demographics file, + # as False allows setting default_node.node_attributes exactly as they are in the file. Loading via + # Demographics.from_file() is deprecated, see below. if set_defaults: - pass + self.default_node.node_attributes.airport = 1 + self.default_node.node_attributes.seaport = 1 + self.default_node.node_attributes.region = 1 + self.default_node.node_attributes.altitude = 0 def to_file(self, path: Union[str, Path] = "demographics.json", indent: int = 4) -> None: """ diff --git a/emod_api/demographics/demographics_base.py b/emod_api/demographics/demographics_base.py index 7af7496d..6bf579e0 100644 --- a/emod_api/demographics/demographics_base.py +++ b/emod_api/demographics/demographics_base.py @@ -305,6 +305,9 @@ def set_age_distribution(self, Args: distribution: The distribution to set. Can either be a BaseDistribution object for a simple distribution or AgeDistribution object for complex. + Note: When using BaseDistribution, the parameter ages are in days. Ex: UniformDistribution(0, 365*50) for + a uniform distribution of ages between 0 and 50 years. When using AgeDistribution, the parameter + ages are in years. node_ids: The node id(s) to apply changes to. None or 0 means the default node. Returns: diff --git a/emod_api/migration/README.md b/emod_api/migration/README.md deleted file mode 100644 index 29ee50aa..00000000 --- a/emod_api/migration/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Migration - -This submodule provides scripts for creating (and reading) migration input files that are directly ingested by the DTK (EMOD) for determining how individuals migrate between nodes over time during a simulation. Going forward, all reading and writing of these migration control files should be done via this submodule. - -It currently consists of a migration script for reading a DTK migration file (not writing yet). Sample usage: - - - ` import emod_api.migration.migration as mig` - - ` MyMigrationFile = mig.MigrationFile(filename=filename)` - - ` node_destination_rates = MyMigrationFile.rates()` - diff --git a/emod_api/migration/__init__.py b/emod_api/migration/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/emod_api/migration/__main__.py b/emod_api/migration/__main__.py deleted file mode 100644 index ee91bd84..00000000 --- a/emod_api/migration/__main__.py +++ /dev/null @@ -1,22 +0,0 @@ -#! /usr/bin/env python3 - -from argparse import ArgumentParser -from pathlib import Path -import sys - -from .migration import to_csv, examine_file - - -if __name__ == "__main__": - parser = ArgumentParser(prog='migration') - parser.add_argument("-c", "--csv", type=Path, default=None, - help="Dump contents of to stdout in CSV format.", metavar='') - parser.add_argument("-e", "--examine", type=Path, default=None, help="Display metadata from on stdout.", - metavar='') - args = parser.parse_args() - - if len(sys.argv) > 1: - to_csv(args.csv) if args.csv else None - examine_file(args.examine) if args.examine else None - else: - parser.print_help() diff --git a/emod_api/migration/migration.py b/emod_api/migration/migration.py deleted file mode 100644 index 16e03de4..00000000 --- a/emod_api/migration/migration.py +++ /dev/null @@ -1,782 +0,0 @@ -from collections import defaultdict -from datetime import datetime -from functools import partial -import json -from numbers import Integral -from os import environ, SEEK_SET -from pathlib import Path -from platform import system -from warnings import warn - -import numpy as np -import csv - -from emod_api.demographics.demographics import Demographics - -# for from_demog_and_param_gravity() -from geographiclib.geodesic import Geodesic - - -class Layer(dict): - - """ - The Layer object represents a mapping from source node (IDs) to destination node (IDs) for a particular - age, gender, age+gender combination, or all users if no age or gender dependence. Users will not generally - interact directly with Layer objects. - """ - - def __init__(self): - - super().__init__() - - return - - @property - def DatavalueCount(self) -> int: - """Get (maximum) number of data values for any node in this layer - - Returns: - Maximum number of data values for any node in this layer - - """ - count = max([len(entry) for entry in self.values()]) if len(self) else 0 - return count - - @property - def NodeCount(self) -> int: - """Get the number of (source) nodes with rates in this layer - - Returns: - Number of (source) nodes with rates in this layer - - """ - return len(self) - - # @property - # def Nodes(self) -> dict: - # return self._nodes - - def __getitem__(self, - key: int) -> dict: - """Allows indexing directly into this object with source node id - - Args: - key: source node id - - Returns: - Dictionary of outbound rates for the given node id - """ - if key not in self: - if isinstance(key, Integral): - super().__setitem__(key, defaultdict(float)) - else: - raise RuntimeError(f"Migration node IDs must be integer values (key = {key}).") - return super().__getitem__(key) - - -_METADATA = "Metadata" -_AUTHOR = "Author" -_DATECREATED = "DateCreated" -_TOOLNAME = "Tool" -_IDREFERENCE = "IdReference" -_MIGRATIONTYPE = "MigrationType" -_NODECOUNT = "NodeCount" -_DATAVALUECOUNT = "DatavalueCount" -_GENDERDATATYPE = "GenderDataType" -_AGESYEARS = "AgesYears" -_INTERPOLATIONTYPE = "InterpolationType" -_NODEOFFSETS = "NodeOffsets" -_EMODAPI = "emod-api" - - -class Migration(object): - - """Represents migration data in a mapping from source node (IDs) to destination node (IDs) with rates for each pairing. - - Migration data may be age dependent, gender dependent, both, or the same for all ages and genders. - A migration file (along with JSON metadata) can be loaded from the static method Migration.from_file() and - inspected and/or modified. - Migration objects can be started from scratch with Migration(), and populated with appropriate source-dest rate data - and saved to a file with the to_file() method. - Given migration = Migration(), syntax is as follows: - - age and gender agnostic: `migration[source_id][dest_id]` - age dependent: `migration[source_id:age]` # age should be >= 0, ages > last bucket value use last bucket value - gender dependent: `migration[source_id:gender]` # gender one of Migration.MALE or Migration.FEMALE - age and gender dependent: `migration[source_id:gender:age]` # gender one of Migration.MALE or Migration.FEMALE - - EMOD/DTK format migration files (and associated metadata files) can be written with migration.to_file(). - EMOD/DTK format migration files (with associated metadata files) can be read with migration.from_file(). - """ - - SAME_FOR_BOTH_GENDERS = 0 - ONE_FOR_EACH_GENDER = 1 - - LINEAR_INTERPOLATION = 0 - PIECEWISE_CONSTANT = 1 - - LOCAL = 1 - AIR = 2 - REGIONAL = 3 - SEA = 4 - FAMILY = 5 - INTERVENTION = 6 - - IDREF_LEGACY = "Legacy" - IDREF_GRUMP30ARCSEC = "Gridded world grump30arcsec" - IDREF_GRUMP2PT5ARCMIN = "Gridded world grump2.5arcmin" - IDREF_GRUMP1DEGREE = "Gridded world grump1degree" - - MALE = 0 - FEMALE = 1 - - MAX_AGE = 125 - - def __init__(self): - - self._agesyears = [] - try: - self._author = _author() - except Exception: - self._author = "Mystery Guest" - self._datecreated = datetime.now() - self._genderdatatype = self.SAME_FOR_BOTH_GENDERS - self._idreference = self.IDREF_LEGACY - self._interpolationtype = self.PIECEWISE_CONSTANT - self._migrationtype = self.LOCAL - self._tool = _EMODAPI - - self._create_layers() - - return - - def _create_layers(self): - - self._layers = [] - for gender in range(0, self._genderdatatype + 1): - for age in range(0, len(self.AgesYears) if self.AgesYears else 1): - self._layers.append(Layer()) - - return - - @property - def AgesYears(self) -> list: - """ - List of ages - ages < first value use first bucket, ages > last value use last bucket. - """ - return self._agesyears - - @AgesYears.setter - def AgesYears(self, ages: list) -> None: - """ - List of ages - ages < first value use first bucket, ages > last value use last bucket. - """ - if sorted(ages) != self.AgesYears: - if self.NodeCount > 0: - warn("Changing age buckets clears existing migration information.", category=UserWarning) - self._agesyears = sorted(ages) - self._create_layers() - return - - @property - def Author(self) -> str: - """str: Author value for metadata for this migration datafile""" - return self._author - - @Author.setter - def Author(self, author: str) -> None: - self._author = author - return - - @property - def DatavalueCount(self) -> int: - """int: Maximum data value count for any layer in this migration datafile""" - count = max([layer.DatavalueCount for layer in self._layers]) - return count - - @property - def DateCreated(self) -> datetime: - """datetime: date/time stamp of this datafile""" - return self._datecreated - - @DateCreated.setter - def DateCreated(self, value) -> None: - if not isinstance(value, datetime): - raise RuntimeError(f"DateCreated must be a datetime value (got {type(value)}).") - self._datecreated = value - return - - @property - def GenderDataType(self) -> int: - """int: gender data type for this datafile - SAME_FOR_BOTH_GENDERS or ONE_FOR_EACH_GENDER""" - return self._genderdatatype - - @GenderDataType.setter - def GenderDataType(self, value: int) -> None: - - # integer value - if value in Migration._GENDER_DATATYPE_ENUMS.keys(): - value = int(value) - # string value - elif value in Migration._GENDER_DATATYPE_LOOKUP.keys(): - value = Migration._GENDER_DATATYPE_LOOKUP[value] - else: - expected = [f"{key}/{value}" for key, value in Migration._GENDER_DATATYPE_LOOKUP.items()] - raise RuntimeError(f"Unknown gender data type, {value}, expected one of {expected}.") - - if (self.NodeCount > 0) and (value != self._genderdatatype): - warn("Changing gender data type clears existing migration information.", category=UserWarning) - - if value != self._genderdatatype: - self._genderdatatype = int(value) - self._create_layers() - return - - @property - def IdReference(self) -> str: - """str: ID reference metadata value""" - return self._idreference - - @IdReference.setter - def IdReference(self, value: str) -> None: - self._idreference = str(value) - return - - @property - def InterpolationType(self) -> int: - """int: interpolation type for this migration data file - LINEAR_INTERPOLATION or PIECEWISE_CONSTANT""" - return self._interpolationtype - - @InterpolationType.setter - def InterpolationType(self, value: int) -> None: - - # integer value - if value in Migration._INTERPOLATION_TYPE_ENUMS.keys(): - self._interpolationtype = int(value) - # string value - elif value in Migration._INTERPOLATION_TYPE_LOOKUP.keys(): - self._interpolationtype = Migration._INTERPOLATION_TYPE_LOOKUP[value] - else: - expected = [f"{key}/{value}" for key, value in Migration._INTERPOLATION_TYPE_LOOKUP.items()] - raise RuntimeError(f"Unknown interpolation type, {value}, expected one of {expected}.") - return - - @property - def MigrationType(self) -> int: - """int: migration type for this migration data file - LOCAL | AIR | REGIONAL | SEA | FAMILY | INTERVENTION""" - return self._migrationtype - - @MigrationType.setter - def MigrationType(self, value: int) -> None: - - # integer value - if value in Migration._MIGRATION_TYPE_ENUMS.keys(): - self._migrationtype = int(value) - elif value in Migration._MIGRATION_TYPE_LOOKUP.keys(): - self._migrationtype = Migration._MIGRATION_TYPE_LOOKUP[value] - else: - expected = [f"{key}/{value}" for key, value in Migration._MIGRATION_TYPE_LOOKUP.items()] - raise RuntimeError(f"Unknown migration type, {value}, expected one of {expected}.") - return - - @property - def Nodes(self) -> list: - node_ids = set() - for layer in self._layers: - node_ids |= set(layer.keys()) - node_ids = sorted(node_ids) - return node_ids - - @property - def NodeCount(self) -> int: - """int: maximum number of source nodes in any layer of this migration data file""" - count = max([layer.NodeCount for layer in self._layers]) - return count - - def get_node_offsets(self, limit: int = 100) -> dict: - nodes = set() - for layer in self._layers: - nodes |= set(key for key in layer.keys()) - count = min(self.DatavalueCount, limit) - # offsets = {} - # for index, node in enumerate(sorted(nodes)): - # offsets[node] = index * 12 * count - offsets = {node: 12 * index * count for index, node in enumerate(sorted(nodes))} - return offsets - - @property - def NodeOffsets(self) -> dict: - """dict: mapping from source node id to offset to destination and rate data in binary data""" - return self.get_node_offsets() - - @property - def Tool(self) -> str: - """str: tool metadata value""" - return self._tool - - @Tool.setter - def Tool(self, value: str) -> None: - self._tool = str(value) - return - - def __getitem__(self, key): - """allows indexing on this object to read/write rate data - Args: - key (slice): source node id:gender:age (gender and age depend on GenderDataType and AgesYears properties) - Returns: - dict for specified node/gender/age - """ - if self.GenderDataType == Migration.SAME_FOR_BOTH_GENDERS: - if not self.AgesYears: - # Case 1 - no gender or age differentiation - key (integer) == node id - return self._layers[0][key] - else: - # Case 3 - age buckets, no gender differentiation - key (tuple or slice) == node id:age - if isinstance(key, tuple): - node_id, age = key - elif isinstance(key, slice): - node_id, age = key.start, key.stop - else: - raise RuntimeError(f"Invalid indexing for migration - {key}") - layer_index = self._index_for_gender_and_age(None, age) - return self._layers[layer_index][node_id] - else: - if not self.AgesYears: - # Case 2 - by gender, no age differentiation - key (tuple or slice) == node id:gender - if isinstance(key, tuple): - node_id, gender = key - elif isinstance(key, slice): - node_id, gender = key.start, key.stop - else: - raise RuntimeError(f"Invalid indexing for migration - {key}") - if gender not in [Migration.SAME_FOR_BOTH_GENDERS, Migration.ONE_FOR_EACH_GENDER]: - raise RuntimeError(f"Invalid gender ({gender}) for migration.") - layer_index = self._index_for_gender_and_age(gender, None) - return self._layers[layer_index][node_id] - else: - # Case 4 - by gender and age - key (slice) == node id:gender:age - if isinstance(key, tuple): - node_id, gender, age = key - elif isinstance(key, slice): - node_id, gender, age = key.start, key.stop, key.step - else: - raise RuntimeError(f"Invalid indexing for migration - {key}") - if gender not in [Migration.SAME_FOR_BOTH_GENDERS, Migration.ONE_FOR_EACH_GENDER]: - raise RuntimeError(f"Invalid gender ({gender}) for migration.") - layer_index = self._index_for_gender_and_age(gender, age) - return self._layers[layer_index][node_id] - - # raise RuntimeError("Invalid state.") - - def _index_for_gender_and_age(self, gender: int, age: float) -> int: - """ - Use age to determine age bucket, 0 if no age differentiation. - Use gender data type to offset by # age buckets if gender data type is one for each gender and gender is female - Ages < first value use first bucket, ages > last value use last bucket. - """ - age_offset = 0 - for age_offset, edge in enumerate(self.AgesYears): - if edge >= age: - break - gender_span = len(self.AgesYears) if self.AgesYears else 1 - gender_offset = gender * gender_span if self.GenderDataType == Migration.ONE_FOR_EACH_GENDER else 0 - index = gender_offset + age_offset - return index - - def __iter__(self): - return iter(self._layers) - - _MIGRATION_TYPE_ENUMS = { - LOCAL: "LOCAL_MIGRATION", - AIR: "AIR_MIGRATION", - REGIONAL: "REGIONAL_MIGRATION", - SEA: "SEA_MIGRATION", - FAMILY: "FAMILY_MIGRATION", - INTERVENTION: "INTERVENTION_MIGRATION" - } - - _GENDER_DATATYPE_ENUMS = { - SAME_FOR_BOTH_GENDERS: "SAME_FOR_BOTH_GENDERS", - ONE_FOR_EACH_GENDER: "ONE_FOR_EACH_GENDER" - } - - _INTERPOLATION_TYPE_ENUMS = { - LINEAR_INTERPOLATION: "LINEAR_INTERPOLATION", - PIECEWISE_CONSTANT: "PIECEWISE_CONSTANT" - } - - def to_file(self, binaryfile: Path, metafile: Path = None, value_limit: int = 100): - """Write current data to given file (and .json metadata file) - - Args: - binaryfile (Path): path to output file (metadata will be written to same path with ".json" appended) - metafile (Path): override standard metadata file naming - value_limit (int): limit on number of destination values to write for each source node (default = 100) - - Returns: - (Path): path to binary file - """ - binaryfile = Path(binaryfile).absolute() - metafile = metafile if metafile else binaryfile.parent / (binaryfile.name + ".json") - - actual_datavalue_count = min(self.DatavalueCount, value_limit) # limited to 100 destinations - - node_ids = set() - for layer in self._layers: - node_ids |= set(layer.keys()) - node_ids = sorted(node_ids) - - offsets = self.get_node_offsets(actual_datavalue_count) - node_offsets_string = ''.join([f"{node:08x}{offsets[node]:08x}" for node in sorted(offsets.keys())]) - - metadata = { - _METADATA: { - _AUTHOR: self.Author, - _DATECREATED: f"{self.DateCreated:%a %b %d %Y %H:%M:%S}", - _TOOLNAME: self.Tool, - _IDREFERENCE: self.IdReference, - _MIGRATIONTYPE: self._MIGRATION_TYPE_ENUMS[self.MigrationType], - _NODECOUNT: self.NodeCount, - _DATAVALUECOUNT: actual_datavalue_count, - # could omit this if SAME_FOR_BOTH_GENDERS since it is the default - _GENDERDATATYPE: self._GENDER_DATATYPE_ENUMS[self.GenderDataType], - # _AGESYEARS: self.AgesYears, # see below - _INTERPOLATIONTYPE: self._INTERPOLATION_TYPE_ENUMS[self.InterpolationType] - }, - _NODEOFFSETS: node_offsets_string - } - if self.AgesYears: - # older versions of Eradication do not handle empty AgesYears lists robustly - metadata[_METADATA][_AGESYEARS] = self.AgesYears - - print(f"Writing metadata to '{metafile}'") - with metafile.open("w") as handle: - json.dump(metadata, handle, indent=4, separators=(",", ": ")) - - def key_func(k, d=None): - return d[k] - - # layers are in age bucket order by gender, e.g. male 0-5, 5-10, 10+, female 0-5, 5-10, 10+ - # see _index_for_gender_and_age() - print(f"Writing binary data to '{binaryfile}'") - with binaryfile.open("wb") as file: - for layer in self: - for node in node_ids: - destinations = np.zeros(actual_datavalue_count, dtype=np.uint32) - rates = np.zeros(actual_datavalue_count, dtype=np.float64) - if node in layer: - - # Sort keys descending on rate and ascending on node ID. - # That way if we are truncating the list, we include the "most important" nodes. - keys = sorted(layer[node].keys()) # sorted ascending on node ID - keys = sorted(keys, key=partial(key_func, d=layer[node]), reverse=True) # descending on rate - - if len(keys) > actual_datavalue_count: - keys = keys[0:actual_datavalue_count] - # save rates in ascending order so small rates are not lost when looking at the cumulative sum - keys = list(reversed(keys)) - destinations[0:len(keys)] = keys - rates[0:len(keys)] = [layer[node][key] for key in keys] - else: - warn(f"No destination nodes found for node {node}", category=UserWarning) - destinations.tofile(file) - rates.tofile(file) - - return binaryfile - - _MIGRATION_TYPE_LOOKUP = { - "LOCAL_MIGRATION": LOCAL, - "AIR_MIGRATION": AIR, - "REGIONAL_MIGRATION": REGIONAL, - "SEA_MIGRATION": SEA, - "FAMILY_MIGRATION": FAMILY, - "INTERVENTION_MIGRATION": INTERVENTION - } - - _GENDER_DATATYPE_LOOKUP = { - "SAME_FOR_BOTH_GENDERS": SAME_FOR_BOTH_GENDERS, - "ONE_FOR_EACH_GENDER": ONE_FOR_EACH_GENDER - } - - _INTERPOLATION_TYPE_LOOKUP = { - "LINEAR_INTERPOLATION": LINEAR_INTERPOLATION, - "PIECEWISE_CONSTANT": PIECEWISE_CONSTANT - } - - -def from_file(binaryfile: Path, - metafile: Path = None) -> Migration: - """Reads migration data file from given binary (and associated JSON metadata file) - - Args: - binaryfile (Path): path to binary file (metadata file is assumed to be at same location with ".json" suffix) - metafile (Path): use given metafile rather than inferring metafile name from the binary file name - - Returns: - Migration object representing binary data in the given file. - """ - binaryfile = Path(binaryfile).absolute() - metafile = metafile if metafile else binaryfile.parent / (binaryfile.name + ".json") - - if not binaryfile.exists(): - raise RuntimeError(f"Cannot find migration binary file '{binaryfile}'") - if not metafile.exists(): - raise RuntimeError(f"Cannot find migration metadata file '{metafile}'.") - with metafile.open("r") as file: - jason = json.load(file) - - # these are the minimum required entries to load a migration file - assert _METADATA in jason, f"Metadata file '{metafile}' does not have a 'Metadata' entry." - metadata = jason[_METADATA] - assert _NODECOUNT in metadata, f"Metadata file '{metafile}' does not have a 'NodeCount' entry." - assert _DATAVALUECOUNT in metadata, f"Metadata file '{metafile}' does not have a 'DatavalueCount' entry." - assert _NODEOFFSETS in jason, f"Metadata file '{metafile}' does not have a 'NodeOffsets' entry." - - migration = Migration() - migration.Author = _value_with_default(metadata, _AUTHOR, _author()) - migration.DateCreated = _try_parse_date(metadata[_DATECREATED]) if _DATECREATED in metadata else datetime.now() - migration.Tool = _value_with_default(metadata, _TOOLNAME, _EMODAPI) - migration.IdReference = _value_with_default(metadata, _IDREFERENCE, Migration.IDREF_LEGACY) - migration.MigrationType = Migration._MIGRATION_TYPE_LOOKUP[_value_with_default(metadata, - _MIGRATIONTYPE, - "LOCAL_MIGRATION")] - migration.GenderDataType = Migration._GENDER_DATATYPE_LOOKUP[_value_with_default(metadata, - _GENDERDATATYPE, - "SAME_FOR_BOTH_GENDERS")] - migration.AgesYears = _value_with_default(metadata, _AGESYEARS, []) - migration.InterpolationType = Migration._INTERPOLATION_TYPE_LOOKUP[_value_with_default(metadata, - _INTERPOLATIONTYPE, - "PIECEWISE_CONSTANT")] - - node_count = metadata[_NODECOUNT] - node_offsets = jason[_NODEOFFSETS] - if len(node_offsets) != 16 * node_count: - raise RuntimeError(f"Length of node offsets string {len(node_offsets)} != 16 * node count {node_count}.") - offsets = _parse_node_offsets(node_offsets, node_count) - datavalue_count = metadata[_DATAVALUECOUNT] - with binaryfile.open("rb") as file: - for gender in range(1 if migration.GenderDataType == Migration.SAME_FOR_BOTH_GENDERS else 2): - for age in migration.AgesYears if migration.AgesYears else [0]: - layer = migration._layers[migration._index_for_gender_and_age(gender, age)] - for node, offset in offsets.items(): - file.seek(offset, SEEK_SET) - destinations = np.fromfile(file, dtype=np.uint32, count=datavalue_count) - rates = np.fromfile(file, dtype=np.float64, count=datavalue_count) - for destination, rate in zip(destinations, rates): - if rate > 0: - layer[node][destination] = rate - - return migration - - -def examine_file(filename): - - def name_for_gender_datatype(e: int) -> str: - return Migration._GENDER_DATATYPE_ENUMS[e] if e in Migration._GENDER_DATATYPE_ENUMS else "unknown" - - def name_for_interpolation(e: int) -> str: - return Migration._INTERPOLATION_TYPE_ENUMS[e] if e in Migration._INTERPOLATION_TYPE_ENUMS else "unknown" - - def name_for_migration_type(e: int) -> str: - return Migration._MIGRATION_TYPE_ENUMS[e] if e in Migration._MIGRATION_TYPE_ENUMS else "unknown" - - migration = from_file(filename) - print(f"Author: {migration.Author}") - print(f"DatavalueCount: {migration.DatavalueCount}") - print(f"DateCreated: {migration.DateCreated:%a %B %d %Y %H:%M}") - print(f"GenderDataType: {migration.GenderDataType} ({name_for_gender_datatype(migration.GenderDataType)})") - print(f"IdReference: {migration.IdReference}") - print(f"InterpolationType: {migration.InterpolationType} ({name_for_interpolation(migration.InterpolationType)})") - print(f"MigrationType: {migration.MigrationType} ({name_for_migration_type(migration.MigrationType)})") - print(f"NodeCount: {migration.NodeCount}") - print(f"NodeOffsets: {migration.NodeOffsets}") - print(f"Tool: {migration.Tool}") - print(f"Nodes: {migration.Nodes}") - - return - - -def _author() -> str: - username = "Unknown" - if system() == "Windows": - username = environ["USERNAME"] - elif "USER" in environ: - username = environ["USER"] - return username - - -def _parse_node_offsets(string: str, count: int) -> dict: - - assert len(string) == 16 * count, f"Length of node offsets string {len(string)} != 16 * node count {count}." - - offsets = {} - for index in range(count): - base = 16 * index - offset = base + 8 - offsets[int(string[base:base + 8], 16)] = int(string[offset:offset + 8], 16) - - return offsets - - -def _try_parse_date(string: str) -> datetime: - - patterns = [ - "%a %b %d %Y %H:%M:%S", - "%a %b %d %H:%M:%S %Y", - "%m/%d/%Y", - "%Y-%m-%d %H:%M:%S.%f" - ] - - for pattern in patterns: - try: - timestamp = datetime.strptime(string, pattern) - return timestamp - except ValueError: - pass - - timestamp = datetime.now() - warn(f"Could not parse date stamp '{string}', using datetime.now() ({timestamp})") - - return timestamp - - -def _value_with_default(dictionary: dict, key: str, default: object) -> object: - return dictionary[key] if key in dictionary else default - - -""" -utility functions emodpy-utils? -""" - - -def from_demog_and_param_gravity(demographics_file_path, gravity_params, id_ref, migration_type=Migration.LOCAL): - demog = Demographics.from_file(demographics_file_path) - return _from_demog_and_param_gravity(demog, gravity_params, id_ref, migration_type) - - -def _from_demog_and_param_gravity(demographics, gravity_params, id_ref, migration_type=Migration.LOCAL): - """ - Create migration files from a gravity model and an input demographics file. - """ - - def _compute_migr_prob(grav_params, home_pop, dest_pop, dist): - """ - Utility function for computing migration probabilities for gravity model. - """ - - # If home/dest node has 0 pop, assume this node is the regional work node-- no local migration allowed - if home_pop == 0 or dest_pop == 0: - return 0. - else: - num_trips = grav_params[0] * home_pop ** grav_params[1] * dest_pop ** grav_params[2] * dist ** grav_params[3] - prob_trip = np.min([1., num_trips / home_pop]) - return prob_trip - - def _compute_migr_dict(node_list, grav_params, **kwargs): - """ - Utility function for computing migration value map. - """ - - excluded_nodes = set(kwargs["exclude_nodes"]) if "exclude_nodes" in kwargs else set() - - mig = Migration() - geodesic = Geodesic.WGS84 - - for source_node in node_list: - - source_id = source_node["NodeID"] - src_lat = source_node["NodeAttributes"]["Latitude"] - src_long = source_node["NodeAttributes"]["Longitude"] - src_pop = source_node["NodeAttributes"]["InitialPopulation"] - - if source_id in excluded_nodes: - continue - - for destination_node in node_list: - - if destination_node == source_node: - continue - - dest_id = destination_node["NodeID"] - - if dest_id in excluded_nodes: - continue - - dst_lat = destination_node["NodeAttributes"]["Latitude"] - dst_long = destination_node["NodeAttributes"]["Longitude"] - dst_pop = destination_node["NodeAttributes"]["InitialPopulation"] - - distance = geodesic.Inverse(src_lat, src_long, dst_lat, dst_long, Geodesic.DISTANCE)['s12'] / 1000 # km - probability = _compute_migr_prob(grav_params, src_pop, dst_pop, distance) - - mig[source_id][dest_id] = probability - - return mig - - # load - nodes = [node.to_dict() for node in demographics.nodes] - migration = _compute_migr_dict(nodes, gravity_params) - migration.IdReference = id_ref - migration.MigrationType = migration_type - - return migration - - -# by gender, by age -_mapping_fns = { - (False, False): lambda m, i, g, a: m[i], - (False, True): lambda m, i, g, a: m[i:a], - (True, False): lambda m, i, g, a: m[i:g], - (True, True): lambda m, i, g, a: m[i:g:a] -} - -# by gender, by age -_display_fns = { - (False, False): lambda i, g, a, d, r: f"{i},{d},{r}", # id only - (False, True): lambda i, g, a, d, r: f"{i},{a},{d},{r}", # id:age - (True, False): lambda i, g, a, d, r: f"{i},{g},{d},{r}", # id:gender - (True, True): lambda i, g, a, d, r: f"{i},{g},{a},{d},{r}" # id:gender:age -} - - -def to_csv(filename: Path): - - migration = from_file(filename) - - mapping = _mapping_fns[(migration.GenderDataType == Migration.ONE_FOR_EACH_GENDER, bool(migration.AgesYears))] - display = _display_fns[(migration.GenderDataType == Migration.ONE_FOR_EACH_GENDER, bool(migration.AgesYears))] - - print(display("node", "gender", "age", "destination", "rate")) - for gender in range(1 if migration.GenderDataType == Migration.SAME_FOR_BOTH_GENDERS else 2): - for age in migration.AgesYears if migration.AgesYears else [0]: - for node in migration.Nodes: - for destination, rate in mapping(migration, node, gender, age).items(): - print(display(node, gender, age, destination, rate)) - - -def from_csv(filename: Path, - id_ref, - mig_type=None) -> Migration: - """Create migration from csv file. The file should have columns 'source' for the source node, 'destination' for the destination node, and 'rate' for the migration rate. - - Args: - filename: csv file - - Returns: - Migration object - """ - migration = Migration() - migration.IdReference = id_ref - if not mig_type: - mig_type = Migration.LOCAL - else: - migration._migrationtype = mig_type - with filename.open("r") as csvfile: - reader = csv.DictReader(csvfile) - csv_data_read = False - for row in reader: - csv_data_read = True - migration[int(row['source'])][int(row['destination'])] = float(row['rate']) - assert csv_data_read, "Csv file %s does not contain migration data." % filename - - return migration diff --git a/emod_api/utils/str_enum.py b/emod_api/utils/str_enum.py index c3d98a87..958ea951 100644 --- a/emod_api/utils/str_enum.py +++ b/emod_api/utils/str_enum.py @@ -4,3 +4,6 @@ class StrEnum(str, Enum): def __str__(self) -> str: return self.value + + def __repr__(self) -> str: + return self.value diff --git a/tests/test_campaign_module.py b/tests/test_campaign_module.py index f5974cb1..b750e6c6 100644 --- a/tests/test_campaign_module.py +++ b/tests/test_campaign_module.py @@ -1,10 +1,30 @@ import unittest import json import os +import tempfile +import warnings + from emod_api import campaign as api_campaign from emod_api import schema_to_class as s2c -from tests import manifest +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) +OUTPUT_FOLDER = os.path.join(CURRENT_DIR, 'output') +if not os.path.isdir(OUTPUT_FOLDER): + os.mkdir(OUTPUT_FOLDER) + +SCHEMA_CANDIDATES = [ + os.path.join(CURRENT_DIR, 'package', 'common', 'schema.json'), + os.path.join(CURRENT_DIR, 'package', 'generic', 'schema.json'), + os.path.join(CURRENT_DIR, 'package', 'malaria', 'schema.json'), + os.path.join(os.path.dirname(CURRENT_DIR), '..', 'emodpy-malaria', + 'tests', 'unittests', 'current_schema', 'schema.json'), +] + +SCHEMA_PATH = None +for candidate in SCHEMA_CANDIDATES: + if os.path.isfile(candidate): + SCHEMA_PATH = candidate + break def generate_sample_campaign_event(my_campaign, schema_path): @@ -14,7 +34,8 @@ def generate_sample_campaign_event(my_campaign, schema_path): broadcast_event = s2c.get_class_with_defaults("BroadcastEvent", schema_json=schema_json) broadcast_event.Broadcast_Event = my_campaign.get_send_trigger("Test_Event", old=True) - coordinator = s2c.get_class_with_defaults("StandardInterventionDistributionEventCoordinator", schema_json=schema_json) + coordinator = s2c.get_class_with_defaults( + "StandardInterventionDistributionEventCoordinator", schema_json=schema_json) coordinator.Intervention_Config = broadcast_event event = s2c.get_class_with_defaults("CampaignEvent", schema_json=schema_json) @@ -22,81 +43,307 @@ def generate_sample_campaign_event(my_campaign, schema_path): return event -class TestCampaign(unittest.TestCase): +@unittest.skipIf(SCHEMA_PATH is None, "No schema.json found") +class TestCampaignWithSchema(unittest.TestCase): + """Tests that require a real schema file.""" + def setUp(self): self.campaign = api_campaign - self.schema_path = manifest.common_schema_path + self.campaign.set_schema(SCHEMA_PATH) - def test_reset(self): - self.campaign.set_schema(self.schema_path) - sample_event = generate_sample_campaign_event(self.campaign, manifest.common_schema_path) + def tearDown(self): + self.campaign.reset() + + def test_reset_clears_all_state(self): + sample_event = generate_sample_campaign_event(self.campaign, SCHEMA_PATH) self.campaign.add(sample_event) + self.campaign.get_recv_trigger("Evt1") + self.campaign.get_send_trigger("Evt2") + self.campaign.set_listened_node_event("NEvt1") + self.campaign.set_broadcast_node_event("NEvt2") + self.campaign.set_listened_coordinator_event("CEvt1") + self.campaign.set_broadcast_coordinator_event("CEvt2") + self.campaign.reset() + self.assertEqual(self.campaign.campaign_dict["Events"], []) + self.assertEqual(self.campaign.individual_events_listened, []) + self.assertEqual(self.campaign.individual_events_broadcast, []) + self.assertEqual(self.campaign.node_events_listened, []) + self.assertEqual(self.campaign.node_events_broadcast, []) + self.assertEqual(self.campaign.coordinator_events_listened, []) + self.assertEqual(self.campaign.coordinator_events_broadcast, []) + self.assertEqual(self.campaign.individual_builtin_events, []) + self.assertEqual(self.campaign.node_builtin_events, []) + self.assertEqual(self.campaign.coordinator_builtin_events, []) - def test_set_schema(self): - self.campaign.set_schema(self.schema_path) - self.assertEqual(self.campaign.schema_path, self.schema_path) - self.assertEqual(len(self.campaign.campaign_dict['Events']), 0) + def test_set_schema_sets_path(self): + self.assertEqual(self.campaign.schema_path, SCHEMA_PATH) - def test_get_schema(self): - self.campaign.set_schema(self.schema_path) + def test_set_schema_populates_individual_builtin_events(self): + self.assertGreater(len(self.campaign.individual_builtin_events), 0) + + def test_get_schema_returns_loaded_json(self): schema = self.campaign.get_schema() - with open(self.schema_path) as schema_file: - expected_schema = json.load(schema_file) - self.assertDictEqual(schema, expected_schema) self.assertIsNotNone(schema) + with open(SCHEMA_PATH) as f: + expected = json.load(f) + self.assertDictEqual(schema, expected) - def test_add(self): - self.campaign.set_schema(self.schema_path) - sample_event = generate_sample_campaign_event(self.campaign, manifest.common_schema_path) - self.campaign.add(sample_event, name="TestEvent") + def test_add_event(self): + sample_event = generate_sample_campaign_event(self.campaign, SCHEMA_PATH) + self.campaign.add(sample_event, note="TestNote") self.assertEqual(len(self.campaign.campaign_dict["Events"]), 1) - self.assertEqual(self.campaign.campaign_dict["Events"][0]["Event_Name"], "TestEvent") - - def test_get_trigger_list(self): - self.campaign.set_schema(self.schema_path) - trigger_list = self.campaign.get_trigger_list() - self.assertIsNotNone(trigger_list) - self.assertNotIn("Test_Event", trigger_list) + self.assertEqual(self.campaign.campaign_dict["Events"][0]["Note"], "TestNote") def test_save(self): - filename = os.path.join(manifest.output_folder, 'test_campaign.json') - self.campaign.set_schema(self.schema_path) - sample_event = generate_sample_campaign_event(self.campaign, manifest.common_schema_path) + filename = os.path.join(OUTPUT_FOLDER, 'test_campaign.json') + sample_event = generate_sample_campaign_event(self.campaign, SCHEMA_PATH) self.campaign.add(sample_event) - saved_filename = self.campaign.save(filename) - self.assertEqual(saved_filename, filename) - with open(filename, "r") as file: - data = json.load(file) - self.assertDictEqual(data, self.campaign.campaign_dict) - - def test_get_adhocs(self): - self.campaign.set_schema(self.schema_path) - sample_event = generate_sample_campaign_event(self.campaign, manifest.common_schema_path) - self.campaign.add(sample_event) - adhocs = self.campaign.get_adhocs() - self.assertDictEqual(adhocs, {'Test_Event': 'Test_Event'}) + saved = self.campaign.save(filename) + self.assertEqual(saved, filename) + with open(filename) as f: + data = json.load(f) + self.assertDictEqual(data, self.campaign.campaign_dict) + + def test_get_custom_individual_events_builtin_excluded(self): + if not self.campaign.individual_builtin_events: + self.skipTest("No individual builtin events in schema") + builtin_event = self.campaign.individual_builtin_events[0] + self.campaign.get_recv_trigger(builtin_event) + result = self.campaign.get_custom_individual_events() + self.assertNotIn(builtin_event, result) + + def test_get_custom_individual_events_broadcast_mirrors_builtin_warns(self): + if not self.campaign.individual_builtin_events: + self.skipTest("No individual builtin events in schema") + builtin_event = self.campaign.individual_builtin_events[0] + self.campaign.get_send_trigger(builtin_event) + with self.assertWarns(UserWarning): + self.campaign.get_custom_individual_events() + + +class TestEventRegistration(unittest.TestCase): + """Tests for event registration functions (no schema needed).""" + + def setUp(self): + self.campaign = api_campaign + self.campaign.reset() + + def tearDown(self): + self.campaign.reset() def test_get_recv_trigger(self): - self.campaign.set_schema(self.schema_path) - trigger = "TestTrigger1" - recv_trigger = self.campaign.get_recv_trigger(trigger, old=True) - self.assertIn(trigger, self.campaign.pubsub_signals_subbing) - self.assertEqual(recv_trigger, trigger) + result = self.campaign.get_recv_trigger("MyListenEvent") + self.assertEqual(result, "MyListenEvent") + self.assertIn("MyListenEvent", self.campaign.individual_events_listened) def test_get_send_trigger(self): - self.campaign.set_schema(self.schema_path) - trigger = "TestTrigger2" - send_trigger = self.campaign.get_send_trigger(trigger, old=True) - self.assertIn(trigger, self.campaign.pubsub_signals_pubbing) - self.assertEqual(send_trigger, trigger) - - def test_get_event(self): - self.campaign.set_schema(self.schema_path) - event = "TestEvent" - mapped_event = self.campaign.get_event(event, old=True) - self.assertEqual(mapped_event, event) + result = self.campaign.get_send_trigger("MyBroadcastEvent") + self.assertEqual(result, "MyBroadcastEvent") + self.assertIn("MyBroadcastEvent", self.campaign.individual_events_broadcast) + + def test_set_listened_node_event(self): + result = self.campaign.set_listened_node_event("NodeListenEvent") + self.assertEqual(result, "NodeListenEvent") + self.assertIn("NodeListenEvent", self.campaign.node_events_listened) + + def test_set_broadcast_node_event(self): + result = self.campaign.set_broadcast_node_event("NodeBroadcastEvent") + self.assertEqual(result, "NodeBroadcastEvent") + self.assertIn("NodeBroadcastEvent", self.campaign.node_events_broadcast) + + def test_set_listened_coordinator_event(self): + result = self.campaign.set_listened_coordinator_event("CoordListenEvent") + self.assertEqual(result, "CoordListenEvent") + self.assertIn("CoordListenEvent", self.campaign.coordinator_events_listened) + + def test_set_broadcast_coordinator_event(self): + result = self.campaign.set_broadcast_coordinator_event("CoordBroadcastEvent") + self.assertEqual(result, "CoordBroadcastEvent") + self.assertIn("CoordBroadcastEvent", self.campaign.coordinator_events_broadcast) + + +class TestValidateCustomEvents(unittest.TestCase): + """Tests for _validate_custom_events and get_custom_* functions.""" + + def setUp(self): + self.campaign = api_campaign + self.campaign.reset() + + def tearDown(self): + self.campaign.reset() + + # --- individual --- + + def test_individual_valid_pair(self): + self.campaign.get_recv_trigger("CustomEvt") + self.campaign.get_send_trigger("CustomEvt") + result = self.campaign.get_custom_individual_events() + self.assertIn("CustomEvt", result) + + def test_individual_listened_not_broadcast_raises(self): + self.campaign.get_recv_trigger("OrphanedEvt") + with self.assertRaises(ValueError): + self.campaign.get_custom_individual_events() + + def test_individual_broadcast_not_listened_warns(self): + self.campaign.get_send_trigger("UnlistenedEvt") + with self.assertWarns(UserWarning): + self.campaign.get_custom_individual_events() + + # --- node --- + + def test_node_valid_pair(self): + self.campaign.set_listened_node_event("NodeEvt") + self.campaign.set_broadcast_node_event("NodeEvt") + result = self.campaign.get_custom_node_events() + self.assertIn("NodeEvt", result) + + def test_node_listened_not_broadcast_raises(self): + self.campaign.set_listened_node_event("OrphanedNodeEvt") + with self.assertRaises(ValueError): + self.campaign.get_custom_node_events() + + def test_node_broadcast_not_listened_warns(self): + self.campaign.set_broadcast_node_event("UnlistenedNodeEvt") + with self.assertWarns(UserWarning): + self.campaign.get_custom_node_events() + + # --- coordinator --- + + def test_coordinator_valid_pair(self): + self.campaign.set_listened_coordinator_event("CoordEvt") + self.campaign.set_broadcast_coordinator_event("CoordEvt") + result = self.campaign.get_custom_coordinator_events() + self.assertIn("CoordEvt", result) + + def test_coordinator_listened_not_broadcast_raises(self): + self.campaign.set_listened_coordinator_event("OrphanedCoordEvt") + with self.assertRaises(ValueError): + self.campaign.get_custom_coordinator_events() + + def test_coordinator_broadcast_not_listened_warns(self): + self.campaign.set_broadcast_coordinator_event("UnlistenedCoordEvt") + with self.assertWarns(UserWarning): + self.campaign.get_custom_coordinator_events() + + # --- builtin filtering --- + + def test_builtin_events_excluded_from_validation(self): + builtins = ["BuiltinA", "BuiltinB"] + result = api_campaign._validate_custom_events( + listened_list=["BuiltinA", "CustomEvt"], + broadcast_list=["CustomEvt"], + builtin_list=builtins, + level="test" + ) + self.assertIn("CustomEvt", result) + self.assertNotIn("BuiltinA", result) + + def test_broadcast_mirrors_builtin_warns(self): + with self.assertWarns(UserWarning): + api_campaign._validate_custom_events( + listened_list=[], + broadcast_list=["BuiltinA"], + builtin_list=["BuiltinA"], + level="test" + ) + + def test_empty_lists_returns_empty(self): + result = api_campaign._validate_custom_events( + listened_list=[], + broadcast_list=[], + builtin_list=[], + level="test" + ) + self.assertEqual(result, []) + + +class TestFindBuiltinEvents(unittest.TestCase): + """Tests for _find_builtin_events schema search.""" + + def test_finds_builtin_key(self): + schema = { + "idmTypes": { + "ReportEventRecorder": { + "Report_Event_Recorder_Events": { + "Built-in": ["Births", "Deaths"], + "enum": ["ShouldNotUseThis"] + } + } + } + } + result = api_campaign._find_builtin_events( + schema, "ReportEventRecorder", "Report_Event_Recorder_Events") + self.assertEqual(result, ["Births", "Deaths"]) + + def test_falls_back_to_enum(self): + schema = { + "idmTypes": { + "ReportEventRecorder": { + "Report_Event_Recorder_Events": { + "enum": ["Births", "Deaths"] + } + } + } + } + result = api_campaign._find_builtin_events( + schema, "ReportEventRecorder", "Report_Event_Recorder_Events") + self.assertEqual(result, ["Births", "Deaths"]) + + def test_returns_none_when_reporter_not_found(self): + schema = {"idmTypes": {"SomethingElse": {}}} + result = api_campaign._find_builtin_events( + schema, "ReportEventRecorder", "Report_Event_Recorder_Events") + self.assertIsNone(result) + + def test_returns_none_when_reporter_has_no_events_key(self): + schema = { + "idmTypes": { + "ReportEventRecorder": { + "SomeOtherParam": 42 + } + } + } + result = api_campaign._find_builtin_events( + schema, "ReportEventRecorder", "Report_Event_Recorder_Events") + self.assertIsNone(result) + + def test_stops_recursing_children_after_reporter_key_match(self): + schema = { + "wrapper": { + "ReportEventRecorder": {"NotTheRightKey": {}}, + "nested_child": { + "ReportEventRecorder": { + "Report_Event_Recorder_Events": { + "Built-in": ["ShouldNotReach"] + } + } + } + } + } + result = api_campaign._find_builtin_events( + schema, "ReportEventRecorder", "Report_Event_Recorder_Events") + self.assertIsNone(result) + + def test_finds_deeply_nested_reporter(self): + schema = { + "level1": { + "level2": { + "level3": { + "ReportEventRecorderNode": { + "Report_Node_Event_Recorder_Events": { + "Built-in": ["NodeEvt1", "NodeEvt2"] + } + } + } + } + } + } + result = api_campaign._find_builtin_events( + schema, "ReportEventRecorderNode", "Report_Node_Event_Recorder_Events") + self.assertEqual(result, ["NodeEvt1", "NodeEvt2"]) if __name__ == '__main__': diff --git a/tests/test_migration.py b/tests/test_migration.py deleted file mode 100644 index 597843a2..00000000 --- a/tests/test_migration.py +++ /dev/null @@ -1,875 +0,0 @@ -from collections import namedtuple -from contextlib import contextmanager -from datetime import datetime -import json -import numpy as np -import os -import math -from os import close, environ -from pathlib import Path -from platform import system -from tempfile import mkstemp -import unittest -from emod_api.migration.migration import Migration, from_file, from_demog_and_param_gravity, to_csv, examine_file, from_csv -import csv -import io -from contextlib import redirect_stdout -from tests import manifest - - -class MigrationTests(unittest.TestCase): - - user = "unknown" - kenya_regional_migration = None - guinea_pig = None - - @classmethod - def setUpClass(cls): - - cls.user = environ["USERNAME"] if system() == "Windows" else environ["USER"] - filename = os.path.join(manifest.migration_folder, "Kenya_Regional_Migration_from_Census.bin") - cls.kenya_regional_migration = from_file(filename) - cls.guinea_pig = Migration() - - def test_defaults(self): - """ - Changing the defaults is a breaking change. - """ - - migration = Migration() - self.assertListEqual(migration.AgesYears, []) - self.assertEqual(migration.Author, self.user) - self.assertEqual(migration.DatavalueCount, 0) - # weekday, month, day, year, hour, minute - might fail if minute rolls over - self.assertEqual(f"{migration.DateCreated:%a %b %d %Y %H:%M}", f"{datetime.now():%a %b %d %Y %H:%M}") - self.assertEqual(migration.GenderDataType, Migration.SAME_FOR_BOTH_GENDERS) - self.assertEqual(migration.IdReference, Migration.IDREF_LEGACY) - self.assertEqual(migration.InterpolationType, Migration.PIECEWISE_CONSTANT) - self.assertEqual(migration.MigrationType, Migration.LOCAL) - self.assertEqual(migration.NodeCount, 0) - self.assertEqual(migration.NodeOffsets, {}) - self.assertEqual(migration.Tool, "emod-api") - - return - - def test_get_agesyears(self): - """AgesYears from migration [metadata] file is readable.""" - self.assertListEqual(self.kenya_regional_migration.AgesYears, - [0.0, 2.5, 7.5, 12.5, 17.5, 22.5, 27.5, 32.5, 37.5, 42.5, 47.5, 52.5, 57.5]) - return - - def test_set_agesyears(self): - """AgesYears is settable and readable.""" - ages = [0, 5, 20, 125] - self.guinea_pig.AgesYears = ages - self.assertListEqual(self.guinea_pig.AgesYears, ages) - return - - def test_get_author(self): - """Author from migration [metadata] file is readable.""" - self.assertEqual(self.kenya_regional_migration.Author, "dbridenbecker") - return - - def test_set_author(self): - """Author is settable and readable.""" - author = "Rumpelstiltskin" - self.guinea_pig.Author = author - self.assertEqual(self.guinea_pig.Author, author) - return - - def test_get_datavaluecount(self): - """DatavalueCount from migration [metadata] file is readable.""" - self.assertEqual(self.kenya_regional_migration.DatavalueCount, 7) - return - - # may not directly set DatavalueCount - derived from internal data - def test_set_datavaluecount(self): - """DatavalueCount is _not_ directly settable (derived from underlying data).""" - with self.assertRaises(AttributeError): - self.guinea_pig.DatavalueCount = 42 - return - - def test_get_datecreated(self): - """DateCreated from migration [metadata] file is readable.""" - # Mon May 2 20:30:12 2016 - self.assertEqual(self.kenya_regional_migration.DateCreated, - datetime(year=2016, month=5, day=2, hour=20, minute=30, second=12)) - return - - def test_set_datecreated(self): - """DateCreated is settable and readable.""" - timestamp = datetime.now() - self.guinea_pig.DateCreated = timestamp - self.assertEqual(self.guinea_pig.DateCreated, timestamp) - return - - def test_set_bad_datecreated(self): - """DateCreated must be a datetime object.""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.DateCreated = "yesterday" - self.assertTrue("DateCreated must be a datetime" in str(context.exception)) - return - - def test_get_genderdatatype(self): - """GenderDataType from migration [metadata] file is readable.""" - self.assertEqual(self.kenya_regional_migration.GenderDataType, Migration.ONE_FOR_EACH_GENDER) - return - - def test_set_genderdatatype(self): - """GenderDataType is settable and readable.""" - self.guinea_pig.GenderDataType = 1 - self.assertEqual(self.guinea_pig.GenderDataType, Migration.ONE_FOR_EACH_GENDER) - self.guinea_pig.GenderDataType = "SAME_FOR_BOTH_GENDERS" - self.assertEqual(self.guinea_pig.GenderDataType, Migration.SAME_FOR_BOTH_GENDERS) - return - - def test_set_bad_genderdatatype_one(self): - """GenderDataType must be in the range Migration.SAME_FOR_BOTH_GENDERS (0) ... Migration.ONE_FOR_EACH_GENDER (1).""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.GenderDataType = -1 - self.assertTrue("Unknown gender data type, -1" in str(context.exception)) - return - - def test_set_bad_genderdatatype_two(self): - """GenderDataType must be in the range Migration.SAME_FOR_BOTH_GENDERS (0) ... Migration.ONE_FOR_EACH_GENDER (1).""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.GenderDataType = 13 - self.assertTrue("Unknown gender data type, 13" in str(context.exception)) - return - - def test_set_bad_genderdatatype_three(self): - """GenderDataType enum must be 'SAME_FOR_BOTH_GENDERS' or 'ONE_FOR_EACH_GENDER'.""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.GenderDataType = "GENDER_NEUTRAL" - self.assertTrue("Unknown gender data type, GENDER_NEUTRAL" in str(context.exception)) - return - - def test_get_idreference(self): - """IdReference from migration [metadata] file is readable.""" - self.assertEqual(self.kenya_regional_migration.IdReference, "0") - return - - def test_set_idreference(self): - """IdReference is settable and readable.""" - reference = "Cool Custom Test IdReference" - self.guinea_pig.IdReference = reference - self.assertEqual(self.guinea_pig.IdReference, reference) - return - - def test_get_interpolationtype(self): - """InterpolationType from migration [metadata] file is readable.""" - self.assertEqual(self.kenya_regional_migration.InterpolationType, Migration.PIECEWISE_CONSTANT) - return - - def test_set_interpolationtype(self): - """InterpolationType is settable from constant or string.""" - # should be starting at PIECEWISE_CONSTANT (default value) - self.guinea_pig.InterpolationType = Migration.LINEAR_INTERPOLATION - self.assertEqual(self.guinea_pig.InterpolationType, Migration.LINEAR_INTERPOLATION) - self.guinea_pig.InterpolationType = "PIECEWISE_CONSTANT" - self.assertEqual(self.guinea_pig.InterpolationType, Migration.PIECEWISE_CONSTANT) - self.guinea_pig.InterpolationType = "LINEAR_INTERPOLATION" - self.assertEqual(self.guinea_pig.InterpolationType, Migration.LINEAR_INTERPOLATION) - self.guinea_pig.InterpolationType = Migration.PIECEWISE_CONSTANT - self.assertEqual(self.guinea_pig.InterpolationType, Migration.PIECEWISE_CONSTANT) - return - - def test_set_bad_interpolationtype_one(self): - """InterpolationType must be LINEAR_INTERPOLATION (0) or PIECEWISE_CONSTANT (1).""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.InterpolationType = -1 - self.assertTrue("Unknown interpolation type, -1" in str(context.exception)) - return - - def test_set_bad_interpolationtype_two(self): - """InterpolationType must be LINEAR_INTERPOLATION (0) or PIECEWISE_CONSTANT (1).""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.InterpolationType = 13 - self.assertTrue("Unknown interpolation type, 13" in str(context.exception)) - return - - def test_set_bad_interpolationtype_three(self): - """InterpolationType must be 'LINEAR_INTERPOLATION' or 'PIECEWISE_CONSTANT'.""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.InterpolationType = "Complex Integration" - self.assertTrue("Unknown interpolation type, Complex Integration" in str(context.exception)) - return - - def test_get_migrationtype(self): - """MigrationType from migration [metadata] file is readable.""" - self.assertEqual(self.kenya_regional_migration.MigrationType, Migration.REGIONAL) - return - - def test_set_migrationtype(self): - """MigrationType is settable from constant or string.""" - self.guinea_pig.MigrationType = Migration.AIR - self.assertEqual(self.guinea_pig.MigrationType, Migration.AIR) - self.guinea_pig.MigrationType = "REGIONAL_MIGRATION" - self.assertEqual(self.guinea_pig.MigrationType, Migration.REGIONAL) - self.guinea_pig.MigrationType = 1 - self.assertEqual(self.guinea_pig.MigrationType, Migration.LOCAL) - self.guinea_pig.MigrationType = Migration.SEA - self.assertEqual(self.guinea_pig.MigrationType, Migration.SEA) - self.guinea_pig.MigrationType = "FAMILY_MIGRATION" - self.assertEqual(self.guinea_pig.MigrationType, Migration.FAMILY) - self.guinea_pig.MigrationType = 6 - self.assertEqual(self.guinea_pig.MigrationType, Migration.INTERVENTION) - return - - def test_set_bad_migrationtype_one(self): - """MigrationType must be one of LOCAL (1), AIR (2), REGIONAL (3), SEA (4), FAMILY (5), or INTERVENTION (6).""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.MigrationType = -1 - self.assertTrue("Unknown migration type, -1" in str(context.exception)) - return - - def test_set_bad_migrationtype_two(self): - """MigrationType must be one of LOCAL (1), AIR (2), REGIONAL (3), SEA (4), FAMILY (5), or INTERVENTION (6).""" - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.MigrationType = 7 - self.assertTrue("Unknown migration type, 7" in str(context.exception)) - return - - def test_set_bad_migrationtype_three(self): - """ - MigrationType must be one of 'LOCAL_MIGRATION', 'AIR_MIGRATION', 'REGIONAL_MIGRATION', 'SEA_MIGRATION', - 'FAMILY_MIGRATION', or 'INTERVENTION_MIGRATION'. - """ - with self.assertRaises(RuntimeError) as context: - self.guinea_pig.MigrationType = "DISPLACED_MIGRATION" - self.assertTrue("Unknown migration type, DISPLACED_MIGRATION" in str(context.exception)) - return - - def test_get_nodecount(self): - """NodeCount from migration file is correct.""" - self.assertEqual(self.kenya_regional_migration.NodeCount, 8) - return - - def test_set_nodecount(self): - """NodeCount cannot be set directly, derives from underlying data.""" - with self.assertRaises(AttributeError): - self.kenya_regional_migration.NodeCount = 42 - return - - def test_get_nodeoffsets(self): - """NodeOffsets from migration file are correct.""" - offsets = { - 1: int("00000000", 16), - 2: int("00000054", 16), - 3: int("000000A8", 16), - 4: int("000000FC", 16), - 5: int("00000150", 16), - 6: int("000001A4", 16), - 7: int("000001F8", 16), - 8: int("0000024C", 16) - } - self.assertDictEqual(self.kenya_regional_migration.NodeOffsets, offsets) - return - - def test_set_nodeoffsets(self): - """NodeOffsets cannot be set directly, derives from underlying data.""" - with self.assertRaises(AttributeError): - self.guinea_pig.NodeOffsets = {0: 0, 1: 12, 2: 24} - return - - def test_get_tool(self): - """Tool from migration [metadata] file is readable.""" - self.assertEqual(self.kenya_regional_migration.Tool, "convert_json_to_bin.py") - return - - def test_set_tool(self): - """Tool is settable and readable.""" - self.guinea_pig.Tool = Path(__file__).name - self.assertEqual(self.guinea_pig.Tool, Path(__file__).name) - return - - def test_set_rates(self): - """Migration rates can be set and read back for each scenario.""" - - # Case 1 - no gender or age differentiation - key is node id - migration = Migration() - migration[20201202][19991231] = 0.125 - self.assertEqual(migration._layers[0][20201202][19991231], 0.125) - - # Case 2 - age buckets w/out gender differentiation - key is node id:age - migration = Migration() - migration.AgesYears = [5, 20] - migration[20201202:10][19991231] = 0.125 - migration[20201202, 5][19690720] = 0.25 - self.assertEqual(migration._layers[1][20201202][19991231], 0.125) - self.assertEqual(migration._layers[0][20201202][19690720], 0.25) - - # Case 3 - by gender w/out age differentiation - key is node id:gender - migration = Migration() - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - migration[20201202:Migration.FEMALE][19991231] = 0.125 - migration[20201202, Migration.MALE][19690720] = 0.25 - self.assertEqual(migration._layers[1][20201202][19991231], 0.125) - self.assertEqual(migration._layers[0][20201202][19690720], 0.25) - - # Case 4 - both gender and age buckets - key is node id:gender:age - migration = Migration() - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - migration.AgesYears = [5, 20] - migration[20201202:Migration.MALE:10][19991231] = 0.125 - migration[20201202, Migration.FEMALE, 25][19690720] = 0.25 - self.assertEqual(migration._layers[1][20201202][19991231], 0.125) - self.assertEqual(migration._layers[3][20201202][19690720], 0.25) - - return - - def test_no_connection(self): - """Node 13 is not in the rate map, should return 0.""" - migration = self._three_square() - self.assertEqual(migration[1][13], 0.0) - - return - - def test_non_integral_node_id(self): - """"Node IDs must be integers.""" - migration = self._three_square() - with self.assertRaises(RuntimeError): - migration[3.14159][5] = 0.125 - return - - def test_non_numeric_node_id(self): - """Node IDs must be integers.""" - migration = self._three_square() - with self.assertRaises(RuntimeError): - migration["three"][5] = 0.125 - return - - def test_warning_on_age_dependency(self): - """Should get warning if changing age dependency after data has been recorded.""" - migration = Migration() - migration[0][1] = 0.0625 - migration[0][2] = 0.0625 - migration[0][3] = 0.03125 - with self.assertWarns(UserWarning): - migration.AgesYears = [0, 5, 10, 15, 20, 125] - return - - def test_warning_on_gender_dependency(self): - """Should get warning if changing gender dependency after data has been recorded.""" - migration = Migration() - migration[0][1] = 0.0625 - migration[0][2] = 0.0625 - migration[0][3] = 0.03125 - with self.assertWarns(UserWarning): - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - return - - def test_nodes_property(self): - """Nodes property should return sorted list of all node IDs.""" - migration = self._three_square() - self.assertListEqual(migration.Nodes, list(range(1, 10))) - return - - def test_age_dependent_indexing_raises(self): - """Age dependent migration must have ID:AGE indexing.""" - migration = Migration() - migration.AgesYears = [0, 5, 10, 15, 20, 125] - with self.assertRaises(RuntimeError): - migration[0][1] = 0.0625 - return - - def test_gender_dependent_indexing_raises_one(self): - """Gender dependent migration must have ID:GENDER indexing.""" - migration = Migration() - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - with self.assertRaises(RuntimeError): - migration[0][1] = 0.0625 - return - - def test_gender_dependent_indexing_raises_two(self): - """Gender dependent migration must have ID:GENDER indexing.""" - migration = Migration() - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - with self.assertRaises(RuntimeError): - migration["zero":Migration.MALE][1] = 0.0625 - return - - def test_gender_dependent_indexing_raises_three(self): - """Gender dependent migration must have ID:GENDER indexing.""" - migration = Migration() - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - with self.assertRaises(RuntimeError): - migration[0:13][1] = 0.0625 - return - - def test_age_and_gender_dependent_indexing_raises_one(self): - """Gender and age dependent migration must have ID:GENDER:AGE indexing.""" - migration = Migration() - migration.AgesYears = [0, 5, 125] - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - with self.assertRaises(RuntimeError): - migration[0][1] = 0.0625 - return - - def test_age_and_gender_dependent_indexing_raises_two(self): - """Gender and age dependent migration must have ID:GENDER:AGE indexing.""" - migration = Migration() - migration.AgesYears = [0, 5, 125] - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - with self.assertRaises(RuntimeError): - migration["zero":Migration.MALE:25][1] = 0.0625 - return - - def test_age_and_gender_dependent_indexing_raises_three(self): - """Gender and age dependent migration must have ID:GENDER:AGE indexing.""" - migration = Migration() - migration.AgesYears = [0, 5, 125] - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - with self.assertRaises(RuntimeError): - migration[0:13:25][1] = 0.0625 - return - - @staticmethod - def _three_square(): - """Create 3x3 grid with migration to N/S/E/W neighbors (not diagonal and not wrapped around).""" - migration = Migration() - Link = namedtuple("Link", ["source", "destination", "rate"]) - rates = [ - Link(1, 2, 0.12), # NW -> N - Link(1, 4, 0.14), # NW -> W - Link(2, 1, 0.21), # N -> NW - Link(2, 3, 0.23), # N -> NE - Link(2, 5, 0.25), # N -> center - Link(3, 2, 0.32), # NE -> N - Link(3, 6, 0.36), # NE -> E - Link(4, 1, 0.41), # W -> NW - Link(4, 5, 0.45), # W -> center - Link(4, 7, 0.47), # W -> SW - Link(5, 2, 0.52), # center -> N - Link(5, 4, 0.54), # center -> W - Link(5, 6, 0.56), # center -> E - Link(5, 8, 0.58), # center -> S - Link(6, 3, 0.63), # E -> NE - Link(6, 5, 0.65), # E -> center - Link(6, 9, 0.69), # E -> SE - Link(7, 5, 0.75), # SW -> W - Link(7, 8, 0.78), # SW -> S - Link(8, 5, 0.85), # S -> center - Link(8, 7, 0.87), # S -> SW - Link(8, 9, 0.89), # S -> SE - Link(9, 6, 0.96), # SE -> E - Link(9, 8, 0.98) # SE -> S - ] - - for link in rates: - migration[link.source][link.destination] = link.rate - migration.MigrationType = Migration.REGIONAL - - return migration - - @staticmethod - @contextmanager - def _temp_filename(prefix: str = "mig-", suffix: str = ".bin"): - """Create temporary directory for migration file, return filename and matching metadata file name.""" - handle, filename = mkstemp(prefix=prefix, suffix=suffix) - close(handle) - filename = Path(filename).absolute() - metafile = filename.parent / (filename.name + ".json") - try: - yield filename, metafile - finally: - filename.unlink() if filename.exists() else None - metafile.unlink() if metafile.exists() else None - - return - - def test_to_file(self): - """Write migration to file, check metadata and spot check binary data.""" - migration = self._three_square() - - with self._temp_filename() as (filename, metafile): - migration.to_file(filename) - - metafile = filename.parent / (filename.name + ".json") - - self.assertTrue(metafile.exists()) - with metafile.open("r") as metafile_handle: - metadata = json.load(metafile_handle) - self.assertEqual(metadata["Metadata"]["Tool"], "emod-api") - self.assertEqual(metadata["Metadata"]["IdReference"], "Legacy") - self.assertEqual(metadata["Metadata"]["MigrationType"], "REGIONAL_MIGRATION") - self.assertEqual(metadata["Metadata"]["NodeCount"], 9) - self.assertEqual(metadata["Metadata"]["DatavalueCount"], 4) - self.assertEqual(metadata["Metadata"]["GenderDataType"], "SAME_FOR_BOTH_GENDERS") - - self.assertTrue(filename.exists()) - - expected_size = 9 * 4 * 12 # #nodes x #links x #bytes (4 + 8) - self.assertEqual(filename.stat().st_size, expected_size) - - with filename.open("rb") as filename_handle: - destinations = np.fromfile(filename_handle, dtype=np.uint32, count=4) - values = np.fromfile(filename_handle, dtype=np.float64, count=4) - self.assertEqual(destinations[0], 2) # first destination is node 2 (sorted in rate ascending order) - self.assertEqual(values[1], 0.14) # rate between nodes 1 and 4 is 0.14 - - return - - def test_to_csv(self): - filename = os.path.join(manifest.migration_folder, "Seattle_30arcsec_local_migration.bin") - - f = io.StringIO() - with redirect_stdout(f): - to_csv(filename) - out = f.getvalue() - - data = io.StringIO(out) - csv_obj = csv.reader(data, dialect='unix') - headers = next(csv_obj, None) - self.assertEqual(len(headers), 3) - num_row = 0 - for csv_row in csv_obj: - for row_val in csv_row: - self.assertGreater(len(row_val), 0) - num_row += 1 - self.assertEqual(num_row, 515) - - def test_examine_file(self): - filename = os.path.join(manifest.migration_folder, "Seattle_30arcsec_local_migration.bin") - output = os.path.join(manifest.output_folder, "seattle_csv.csv") - - expected_output = ["Author:", "DatavalueCount:", "DateCreated:", "GenderDataType:", "IdReference:", - "InterpolationType:", "MigrationType:", "NodeCount:", "NodeOffsets:", "Tool:", "Nodes:"] - - f = io.StringIO() - with redirect_stdout(f): - examine_file(filename) - output = f.getvalue() - - for expected in expected_output: - self.assertTrue(expected in output) - - def test_to_file_age_dependent(self): - """Write migration file with age dependent rates. Check for 'AgesYears' in metadata file.""" - migration = Migration() - migration.AgesYears = [0, 5, 10, 15, 20, 125] - for age in migration.AgesYears: - migration[0:age][1] = 0.125 / age if age else 0 - migration[1:age][0] = 0.125 / age if age else 0 - with self._temp_filename(prefix="age-mig-") as (filename, metafile): - migration.to_file(filename) - with metafile.open("r") as file: - metadata = json.load(file) - self.assertListEqual(metadata["Metadata"]["AgesYears"], migration.AgesYears) - - return - - def test_limited_datavalues(self): - """Test limiting datavalues when writing file to disk.""" - # migration = self._three_square() - - # Create SIZExSIZE grid migration data - SIZE = 5 - migration = Migration() - for source in range(0, SIZE * SIZE): - source_x = source % SIZE - source_y = source // SIZE - for destination in range(0, SIZE * SIZE): - if destination == source: - continue - destination_x = destination % SIZE - destination_y = destination // SIZE - distance = abs(destination_x - source_x) + abs(destination_y - source_y) # Manhattan distance - rate = 0.1 / distance - migration[source + 1][destination + 1] = rate # IDs go from 1..(SIZE*SIZE) - - # Ensure setup - for node in migration.Nodes: - self.assertTrue(len(migration[node]) == (SIZE * SIZE - 1)) # Nodes have no entry for self - - LIMIT = 3 - with self._temp_filename() as (filename, metafile): - migration.to_file(filename, value_limit=LIMIT) - with metafile.open("r") as file: - metadata = json.load(file) - self.assertEqual(metadata["Metadata"]["DatavalueCount"], LIMIT) - actual = from_file(filename) - self.assertEqual(actual.DatavalueCount, LIMIT) - for node in actual.Nodes: - self.assertTrue(len(actual[node]) == LIMIT) - - return - - # TODO - test that when values are truncated in to_file(), see above, the saved values are - # 1. the largest N from the values in the Migration object and - # 2. sorted from smallest to largest - - def test_missing_source_nodes_warning(self): - """Test warning when a source node in one layer has no entries in another layer.""" - migration = Migration() - migration.GenderDataType = Migration.ONE_FOR_EACH_GENDER - source = 4 - for destination in range(9): - if destination != source: - migration[4:Migration.MALE][destination] = 0.0625 - # note, not setting values in Migration.FEMALE layer - with self._temp_filename(prefix="gender-mig-") as (filename, metafile): - with self.assertWarns(UserWarning): - migration.to_file(filename) - - return - - def test_raise_from_file_missing_binary(self): - """Test exception for missing binary file in from_file() call.""" - migration = self._three_square() - with self._temp_filename() as (filename, metafile): - migration.to_file(filename) - filename.unlink() - with self.assertRaises(RuntimeError): - _ = from_file(filename) - - return - - def test_raise_from_file_missing_metadata(self): - """Test exception for missing metadata file in from_file() call.""" - migration = self._three_square() - with self._temp_filename() as (filename, metafile): - migration.to_file(filename) - metafile.unlink() - with self.assertRaises(RuntimeError): - _ = from_file(filename) - - return - - def test_raise_from_file_bad_nodeoffsets(self): - """Test exception for NodeOffsets not matching expected size for NodeCount and DatavalueCount.""" - migration = self._three_square() - with self._temp_filename() as (filename, metafile): - migration.to_file(filename) - with metafile.open("r") as file: - metadata = json.load(file) - metadata["NodeOffsets"] = "0000004200000000" - with metafile.open("w") as file: - json.dump(metadata, file) - with self.assertRaises(RuntimeError): - _ = from_file(filename) - - return - - def test_warn_on_datecreated_parsing(self): - """Test warning when DateCreated field cannot be parsed.""" - migration = self._three_square() - with self._temp_filename() as (filename, metafile): - migration.to_file(filename) - with metafile.open("r") as file: - metadata = json.load(file) - metadata["Metadata"]["DateCreated"] = "Thursday December 10th 2020" - with metafile.open("w") as file: - json.dump(metadata, file) - with self.assertWarns(UserWarning): - _ = from_file(filename) - - return - - def test_from_file(self): - """Test happy path for from_file().""" - local = from_file(os.path.join(manifest.migration_folder, "Seattle_30arcsec_local_migration.bin")) - self.assertEqual(local.Author, "jsteinkraus") - self.assertEqual(local.DateCreated, datetime(year=2011, month=9, day=26, hour=9, minute=59, second=35)) - self.assertEqual(local.DatavalueCount, 8) - self.assertEqual(local.IdReference, "Legacy") - self.assertEqual(local.NodeCount, 124) - self.assertEqual(local.Tool, "createmigrationheader.py") - self.assertEqual(local.MigrationType, 1) - - regional = from_file(os.path.join(manifest.migration_folder, "Seattle_30arcsec_regional_migration.bin")) - self.assertEqual(regional.Author, "jsteinkraus") - self.assertEqual(regional.DateCreated, datetime(year=2011, month=9, day=26, hour=9, minute=59, second=35)) - self.assertEqual(regional.DatavalueCount, 30) - self.assertEqual(regional.IdReference, "Legacy") - self.assertEqual(regional.NodeCount, 124) - self.assertEqual(regional.Tool, "createmigrationheader.py") - self.assertEqual(regional.MigrationType, 3) - - return - - def test_to_and_from_file_with_nonstandard_metadata_filename(self): - """Use options to write and subsequently read from a file with a non-standard metadata filename""" - memory = self._three_square() - with self._temp_filename() as (filename, _): - with self._temp_filename(prefix="meta-", suffix=".json") as (metafile, _): - memory.to_file(filename, metafile=metafile) - self.assertTrue(filename.exists()) - self.assertTrue(metafile.exists()) - disk = from_file(filename, metafile=metafile) - self.assertEqual(disk.Author, memory.Author) - # self.assertEqual(disk.DateCreated, memory.DateCreated) # memory includes microseconds != 0 - self.assertEqual(f"{disk.DateCreated:%a %b %d %Y %H:%M}", f"{memory.DateCreated:%a %b %d %Y %H:%M}") - self.assertEqual(disk.DatavalueCount, memory.DatavalueCount) - self.assertEqual(disk.IdReference, memory.IdReference) - self.assertEqual(disk.NodeCount, memory.NodeCount) - self.assertEqual(disk.Tool, memory.Tool) - for node in memory.Nodes: - self.assertDictEqual(disk[node], memory[node]) - - return - - def test_from_demog_and_param_gravity(self): - demographics_file = os.path.join(manifest.demo_folder, 'Seattle_30arcsec_demographics.json') - migration = from_demog_and_param_gravity(demographics_file, gravity_params=[0.1, 0.2, 0.3, 0.4], - id_ref='from_demog_and_param_gravity_test', - migration_type=Migration.LOCAL) - self.assertEqual(migration.NodeCount, 124) - self.assertEqual(migration.DatavalueCount, 123) - self.assertEqual(migration.IdReference, "from_demog_and_param_gravity_test") - self.assertEqual(migration.MigrationType, Migration.LOCAL) - - def test_from_demog_and_param_gravity_distance(self): - def get_distance(lat1, lon1, lat2, lon2): - r = 6371 - d_lat = deg2rad(lat2 - lat1) - - d_lon = deg2rad(lon2 - lon1) - a = math.sin(d_lat / 2) * math.sin(d_lat / 2) + \ - math.cos(deg2rad(lat1)) * math.cos(deg2rad(lat2)) * \ - math.sin(d_lon / 2) * math.sin(d_lon / 2) - c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) - d = r * c - return d - - def deg2rad(deg): - return deg * (math.pi / 180) - - def verify_distance(migration_rate_list, node_locations): - for rate_val in migration_rate_list[1:-1]: - rate_val = rate_val.split(',') - distance = 1 / float(rate_val[-1]) - src = int(rate_val[0]) - des = int(rate_val[1]) - lat1, lon1 = node_locations[src - 1] - lat2, lon2 = node_locations[des - 1] - calculated_distance = get_distance(lat1, lon1, lat2, lon2) - self.assertAlmostEqual(distance, calculated_distance, delta=2) # km - - id_ref = 'from_demog_and_param_gravity_distance' - demographics_file = os.path.join(manifest.demo_folder, 'gravity_webservice_vs_local_distance_only.json') - locations = [[0, 0], - [0, 1], - [0, 2], - [1, 0], - [1, 1], - [1, 2], - [2, 0], - [2, 1], - [2, 2]] - - # This demographics file is generated with the following code: - # from emod_api.demographics.Node import Node - # import emod_api.demographics.Demographics as Demographics - # - # nodes = [] - # # Add nodes to demographics - # n_nodes = 9 - # for idx in range(n_nodes): - # nodes.append(Node(forced_id=idx + 1, pop=1, lat=locations[idx][0], - # lon=locations[idx][1])) - # - # demog = Demographics.Demographics(nodes=nodes, idref=id_ref) - # demog.SetDefaultProperties() - # demog.generate_file(demographics_file) - - migration_local = from_demog_and_param_gravity(demographics_file, gravity_params=[1, 1, 1, -1], - id_ref=id_ref, migration_type=Migration.REGIONAL) - - migration_local_file = Path(os.path.join(manifest.output_folder, 'gravity_distance.bin')) - migration_local.to_file(migration_local_file) - - f = io.StringIO() - with redirect_stdout(f): - to_csv(migration_local_file) - migration_rate = f.getvalue().split("\n") - - verify_distance(migration_rate, locations) - - def test_from_demog_and_param_gravity_with_reference(self): - demographics_file = os.path.join(manifest.demo_folder, 'Seattle_30arcsec_demographics.json') - - migration = from_demog_and_param_gravity(demographics_file, gravity_params=[0.1, 0.2, 0.3, 0.4], - id_ref='from_demog_and_param_gravity_test', - migration_type=Migration.LOCAL) - - migration_file = Path(os.path.join(manifest.output_folder, 'test_from_demog_and_param_gravity_with_reference.bin')) - migration.to_file(migration_file) - - reference_file = Path(os.path.join(manifest.migration_folder, 'migration_gravity_model_reference.bin')) - self.compare_migration_file_to_reference(migration_file, reference_file, exact_compare=False) - - def test_from_csv(self): - temp = {'source': [1, 2, 5], - 'destination': [2, 3, 4], - 'rate': [0.1, 0.2, 0.3]} - - csv_file = Path(os.path.join(manifest.output_folder, "test_migration.csv")) - with open(csv_file, "w") as fid01: - csv_obj = csv.writer(fid01, dialect='unix', quoting=csv.QUOTE_MINIMAL) - header_vals = list(temp.keys()) - csv_obj.writerow(header_vals) - for row_idx in range(len(temp[header_vals[0]])): - csv_obj.writerow([temp[h_val][row_idx] for h_val in header_vals]) - - migration = from_csv(csv_file, id_ref="testing") - - migration_file = os.path.join(manifest.output_folder, "test_migration.bin") - migration.to_file(migration_file) - migration_from_bin = from_file(migration_file) - - for source, destination, rate in zip(temp['source'], temp['destination'], temp['rate']): - self.assertEqual(migration[source][destination], rate) - self.assertEqual(migration_from_bin[source][destination], rate) - - def test_from_csv_empty_file(self): - with self.assertRaises(AssertionError): - from_csv(Path(os.path.join(manifest.migration_folder, "test_migration_without_content.csv")), id_ref="testing") - - def compare_migration_file_to_reference(self, migration_file, migration_reference_file, exact_compare=True): - self.assertTrue(migration_file.is_file()) - self.assertTrue(migration_reference_file.is_file()) - f = io.StringIO() - with redirect_stdout(f): - to_csv(migration_file) - migration_rate = f.getvalue().split("\n") - f = io.StringIO() - with redirect_stdout(f): - to_csv(migration_reference_file) - migration_rate_reference = f.getvalue().split("\n") - if exact_compare: - self.assertListEqual(migration_rate, migration_rate_reference) - else: - # create numpy array [[src, dst, rate]], first and last row does not contain numbers - migration_rate_from_file = np.array( - [[float(i) for i in r.split(",")] for r in migration_rate[1:-1]]) - reference_rate_from_file = np.array( - [[float(i) for i in r.split(",")] for r in migration_rate_reference[1:-1]]) - - msg = "The migration rates calculated locally and by the webservice are not equal." - np.testing.assert_array_almost_equal(migration_rate_from_file, reference_rate_from_file, - decimal=6, err_msg=msg) - - # compare .json file - migration_json_file = migration_file.parent / (migration_file.name + '.json') - reference_json_file = migration_reference_file.parent / (migration_reference_file.name + '.json') - - self.assertTrue(migration_json_file.is_file()) - self.assertTrue(reference_json_file.is_file()) - with migration_json_file.open('r') as migration_json_f: - migration_json = json.load(migration_json_f) - - with reference_json_file.open('r') as reference_json_f: - reference_json = json.load(reference_json_f) - - migration_json["Metadata"].pop("Author") - migration_json["Metadata"].pop("DateCreated") - - reference_json["Metadata"].pop("Author") - reference_json["Metadata"].pop("DateCreated") - self.maxDiff = None - self.assertDictEqual(migration_json, reference_json) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unittests/test_migration_imports.py b/tests/unittests/test_migration_imports.py deleted file mode 100644 index 8f63d2c4..00000000 --- a/tests/unittests/test_migration_imports.py +++ /dev/null @@ -1,27 +0,0 @@ -import unittest - - -class EmodapiMigrationImportTest(unittest.TestCase): - def setUp(self) -> None: - self.expected_items = None - self.found_items = None - - def verify_expected_items_present(self, namespace): - self.found_items = dir(namespace) - for item in self.expected_items: - self.assertIn( - item, - self.found_items - ) - - def test_migration_migration_import(self): - self.expected_items = [ - 'Layer', - 'Migration', - 'from_file', - 'examine_file', - 'from_demog_and_param_gravity', - 'to_csv' - ] - import emod_api.migration.migration as migration - self.verify_expected_items_present(namespace=migration) From 848b513d4ede14ca85b22baf207cba2e820976e7 Mon Sep 17 00:00:00 2001 From: Svetlana Titova Date: Tue, 26 May 2026 17:36:22 -0700 Subject: [PATCH 2/6] aligning with main --- emod_api/demographics/demographics.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/emod_api/demographics/demographics.py b/emod_api/demographics/demographics.py index d23e5c84..fa86ef7b 100644 --- a/emod_api/demographics/demographics.py +++ b/emod_api/demographics/demographics.py @@ -30,14 +30,10 @@ def __init__(self, nodes: list[Node], idref: str = None, default_node: Node = No """ super().__init__(nodes=nodes, idref=idref, default_node=default_node) - # set some standard EMOD defaults. set_defaults should always be True unless reading from a demographics file, - # as False allows setting default_node.node_attributes exactly as they are in the file. Loading via - # Demographics.from_file() is deprecated, see below. + # No current default settings if set_defaults: - self.default_node.node_attributes.airport = 1 - self.default_node.node_attributes.seaport = 1 - self.default_node.node_attributes.region = 1 - self.default_node.node_attributes.altitude = 0 + pass + def to_file(self, path: Union[str, Path] = "demographics.json", indent: int = 4) -> None: """ From a8e38d1a3ed45296edd74b5493f6a06be5c02951 Mon Sep 17 00:00:00 2001 From: Svetlana Titova Date: Tue, 26 May 2026 17:54:46 -0700 Subject: [PATCH 3/6] adding missing tests --- tests/test_campaign_module.py | 90 +++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tests/test_campaign_module.py b/tests/test_campaign_module.py index b750e6c6..a6e74c16 100644 --- a/tests/test_campaign_module.py +++ b/tests/test_campaign_module.py @@ -6,6 +6,7 @@ from emod_api import campaign as api_campaign from emod_api import schema_to_class as s2c +from emod_api.utils.str_enum import StrEnum CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) OUTPUT_FOLDER = os.path.join(CURRENT_DIR, 'output') @@ -122,6 +123,22 @@ def test_get_custom_individual_events_broadcast_mirrors_builtin_warns(self): with self.assertWarns(UserWarning): self.campaign.get_custom_individual_events() + def test_set_schema_populates_node_builtin_events(self): + if not self.campaign.node_builtin_events: + self.skipTest("No node builtin events in schema") + self.assertGreater(len(self.campaign.node_builtin_events), 0) + + def test_set_schema_populates_coordinator_builtin_events(self): + if not self.campaign.coordinator_builtin_events: + self.skipTest("No coordinator builtin events in schema") + self.assertGreater(len(self.campaign.coordinator_builtin_events), 0) + + def test_add_event_without_note(self): + sample_event = generate_sample_campaign_event(self.campaign, SCHEMA_PATH) + self.campaign.add(sample_event) + self.assertEqual(len(self.campaign.campaign_dict["Events"]), 1) + self.assertNotIn("Note", self.campaign.campaign_dict["Events"][0]) + class TestEventRegistration(unittest.TestCase): """Tests for event registration functions (no schema needed).""" @@ -163,6 +180,54 @@ def test_set_broadcast_coordinator_event(self): self.assertEqual(result, "CoordBroadcastEvent") self.assertIn("CoordBroadcastEvent", self.campaign.coordinator_events_broadcast) + def test_get_recv_trigger_none_raises(self): + with self.assertRaises(ValueError): + self.campaign.get_recv_trigger(None) + + def test_get_recv_trigger_empty_raises(self): + with self.assertRaises(ValueError): + self.campaign.get_recv_trigger("") + + def test_get_send_trigger_none_raises(self): + with self.assertRaises(ValueError): + self.campaign.get_send_trigger(None) + + def test_get_send_trigger_empty_raises(self): + with self.assertRaises(ValueError): + self.campaign.get_send_trigger("") + + def test_set_listened_node_event_none_raises(self): + with self.assertRaises(ValueError): + self.campaign.set_listened_node_event(None) + + def test_set_listened_node_event_empty_raises(self): + with self.assertRaises(ValueError): + self.campaign.set_listened_node_event("") + + def test_set_broadcast_node_event_none_raises(self): + with self.assertRaises(ValueError): + self.campaign.set_broadcast_node_event(None) + + def test_set_broadcast_node_event_empty_raises(self): + with self.assertRaises(ValueError): + self.campaign.set_broadcast_node_event("") + + def test_set_listened_coordinator_event_none_raises(self): + with self.assertRaises(ValueError): + self.campaign.set_listened_coordinator_event(None) + + def test_set_listened_coordinator_event_empty_raises(self): + with self.assertRaises(ValueError): + self.campaign.set_listened_coordinator_event("") + + def test_set_broadcast_coordinator_event_none_raises(self): + with self.assertRaises(ValueError): + self.campaign.set_broadcast_coordinator_event(None) + + def test_set_broadcast_coordinator_event_empty_raises(self): + with self.assertRaises(ValueError): + self.campaign.set_broadcast_coordinator_event("") + class TestValidateCustomEvents(unittest.TestCase): """Tests for _validate_custom_events and get_custom_* functions.""" @@ -259,6 +324,15 @@ def test_empty_lists_returns_empty(self): ) self.assertEqual(result, []) + def test_duplicate_events_are_deduplicated(self): + result = api_campaign._validate_custom_events( + listened_list=["Evt", "Evt", "Evt"], + broadcast_list=["Evt", "Evt"], + builtin_list=[], + level="test" + ) + self.assertEqual(result, ["Evt"]) + class TestFindBuiltinEvents(unittest.TestCase): """Tests for _find_builtin_events schema search.""" @@ -346,5 +420,21 @@ def test_finds_deeply_nested_reporter(self): self.assertEqual(result, ["NodeEvt1", "NodeEvt2"]) +class TestStrEnum(unittest.TestCase): + """Tests for StrEnum __str__ and __repr__.""" + + def setUp(self): + class Color(StrEnum): + RED = "red" + BLUE = "blue" + self.Color = Color + + def test_str_returns_value(self): + self.assertEqual(str(self.Color.RED), "red") + + def test_repr_returns_value(self): + self.assertEqual(repr(self.Color.RED), "red") + + if __name__ == '__main__': unittest.main() From 55c087506440327794f20688e87feb165713f23b Mon Sep 17 00:00:00 2001 From: Svetlana Titova Date: Thu, 11 Jun 2026 17:35:26 -0700 Subject: [PATCH 4/6] adding node properties, removing malaria-specific demographics elements --- emod_api/campaign.py | 4 +- emod_api/demographics/demographics.py | 8 +- emod_api/demographics/demographics_base.py | 161 +++++++--- emod_api/demographics/implicit_functions.py | 11 +- emod_api/demographics/node.py | 41 --- .../demographics/properties_and_attributes.py | 173 ++++++---- emod_api/utils/emod_enum.py | 13 + ...ala_four_node_demographics_for_Thomas.json | 45 +-- ...imple_distribution_implicit_functions.json | 8 +- .../single_node_demographics.json | 299 +++++++++--------- tests/test_demographics.py | 177 ++++++++++- tests/test_node.py | 26 ++ tests/unittests/test_node_properties.py | 118 +++++++ 13 files changed, 726 insertions(+), 358 deletions(-) create mode 100644 emod_api/utils/emod_enum.py create mode 100644 tests/unittests/test_node_properties.py diff --git a/emod_api/campaign.py b/emod_api/campaign.py index fa9011dd..b8de5055 100644 --- a/emod_api/campaign.py +++ b/emod_api/campaign.py @@ -212,8 +212,8 @@ def _validate_custom_events(listened_list, broadcast_list, builtin_list, level): broadcast_not_listened = broadcast - listened if broadcast_not_listened: warnings.warn( - f"The following {level} events are broadcast but nothing in the campaign " - f"is listening to them: {sorted(broadcast_not_listened)}") + f"The following {level} events are broadcast but nothing is listening to them within " + f"the campaign: {sorted(broadcast_not_listened)}") return list(broadcast) diff --git a/emod_api/demographics/demographics.py b/emod_api/demographics/demographics.py index fa86ef7b..31336295 100644 --- a/emod_api/demographics/demographics.py +++ b/emod_api/demographics/demographics.py @@ -7,7 +7,7 @@ from emod_api.demographics.demographics_base import DemographicsBase from emod_api.demographics.node import Node -from emod_api.demographics.properties_and_attributes import NodeAttributes +from emod_api.demographics.properties_and_attributes import NodeAttributes, NodeProperty, NodeProperties from emod_api.demographics.service import service @@ -93,6 +93,12 @@ def from_file(cls, path: str) -> "Demographics": demographics = cls(nodes=nodes, default_node=default_node, idref=idref, set_defaults=False) demographics.metadata = metadata demographics.implicits.extend(implicit_functions) + + node_properties_list = demographics_dict.get("NodeProperties") + if node_properties_list: + for np_dict in node_properties_list: + demographics.node_properties.add(NodeProperty.from_dict(np_dict)) + return demographics @classmethod diff --git a/emod_api/demographics/demographics_base.py b/emod_api/demographics/demographics_base.py index 6bf579e0..76cd0e81 100644 --- a/emod_api/demographics/demographics_base.py +++ b/emod_api/demographics/demographics_base.py @@ -10,9 +10,10 @@ from emod_api.demographics.mortality_distribution import MortalityDistribution from emod_api.demographics.node import Node from emod_api.demographics.demographic_exceptions import InvalidNodeIdException -from emod_api.demographics.properties_and_attributes import IndividualProperty +from emod_api.demographics.properties_and_attributes import IndividualProperty, NodeProperty, NodeProperties from emod_api.demographics.susceptibility_distribution import SusceptibilityDistribution from emod_api.utils.distributions.base_distribution import BaseDistribution +from emod_api.utils.emod_enum import BirthRateDependence class DemographicsBase(BaseInputFile): @@ -38,6 +39,7 @@ def __init__(self, nodes: list[Node], idref: str = None, default_node: Node = No """ super().__init__(idref=idref) self.nodes = nodes + self.node_properties = NodeProperties() self.implicits = list() self.migration_files = list() @@ -268,32 +270,70 @@ def to_dict(self) -> dict: 'Metadata': self.metadata } demographics_dict["Metadata"]["NodeCount"] = len(self.nodes) + if self.node_properties: + demographics_dict["NodeProperties"] = self.node_properties.to_dict() return demographics_dict - def set_birth_rate(self, rate: float, node_ids: list[int] = None): + def set_birth_rate(self, rate: float, node_ids: list[int] = None, birth_rate_dependence: Union[str, BirthRateDependence] = "POPULATION_DEP_RATE"): """ - Sets a specified population-dependent birth rate value on the target node(s). Automatically handles any - necessary config updates. + Sets the BirthRate on the target node(s) and configures how EMOD interprets it via + Birth_Rate_Dependence. Automatically registers the corresponding config implicit. Args: - rate: (float) The birth rate to set in units of births/year/1000-women - node_ids: (list[int]) The node id(s) to apply changes to. None or 0 means the default node. + rate: The birth rate to set on the target node(s). The units of this value depend on the + birth_rate_dependence setting, see below. + node_ids: Node id(s) to apply rate to. ``None`` or ``0`` targets the default node. Please note that the + birth rate dependence setting will be applied to all nodes, regardless of which node(s) the birth + rate is applied to. + birth_rate_dependence: How EMOD uses the BirthRate value. + Accepts a :class:`~emod_api.demographics.implicit_functions.BirthRateDependence` + member or its string value. Defaults to ``POPULATION_DEP_RATE``. + - ``FIXED_BIRTH_RATE`` — 'rate' is used as an absolute daily birth rate with which new individuals are born. + units: number of births per year + - ``POPULATION_DEP_RATE`` — 'rate' is scaled by node population to determine the daily birth rate. + units: number of births per 1000 people per year + max: 1000 (equivalent to 1 birth per year for every person in the population) + - ``DEMOGRAPHIC_DEP_RATE`` — 'rate' is scaled by number of possible mothers (female population in + fertility age range of 15–44 years). + units: number of births per 8 fertile women per year + max: 8 (equivalent to 1 birth per year for every possible mother in the population) + - ``INDIVIDUAL_PREGNANCIES`` — like DEMOGRAPHIC_DEP_RATE, but pregnancies are + assigned individually with a 40-week gestation period. An individual fertile female person becomes + pregnant based on the birth rate and then gives birth 40 weeks later. This setup is required for + using IsPregnant targeting in campaigns. + units: number of pregnancies per 8 fertile women per year + max: 8 (equivalent to 1 pregnancy per year for every possible mother in the population) + + """ + from emod_api.demographics.implicit_functions import ( _set_birth_rate_dependence) + + if not isinstance(birth_rate_dependence, BirthRateDependence): + try: + birth_rate_dependence = BirthRateDependence(birth_rate_dependence) + except ValueError: + raise ValueError( + f"Invalid birth_rate_dependence {birth_rate_dependence!r}. " + f"Valid options: {[e.value for e in BirthRateDependence]}") + + if birth_rate_dependence == BirthRateDependence.POPULATION_DEP_RATE: + if rate > 1000: + raise ValueError(f"Births per 1000 people per year cannot exceed 1000. Provided rate: {rate}") + rate = rate / 365 / 1000 # converting to per day per 1000 people + elif (birth_rate_dependence == BirthRateDependence.DEMOGRAPHIC_DEP_RATE or + birth_rate_dependence == BirthRateDependence.INDIVIDUAL_PREGNANCIES): + if rate > 8: + raise ValueError(f"Births per 8 fertile women per year cannot exceed 8. Provided rate: {rate}") + rate = rate / 365 / 8 # converting to per day per 8 fertile women - Returns: - - """ - from emod_api.demographics.implicit_functions import _set_population_dependent_birth_rate - - rate = rate / 365 / 1000 # converting to births/day/woman, which is what EMOD internally uses. nodes = self.get_nodes_by_id(node_ids=node_ids) for _, node in nodes.items(): node.birth_rate = rate - self.implicits.append(_set_population_dependent_birth_rate) + self.implicits.append(partial(_set_birth_rate_dependence, + birth_rate_dependence=birth_rate_dependence)) # # These distribution setters accept either a simple or complex distribution # - def set_age_distribution(self, distribution: Union[BaseDistribution, AgeDistribution], node_ids: list[int] = None) -> None: @@ -397,44 +437,6 @@ def set_migration_heterogeneity_distribution(self, simple_distribution_implicits=implicits, node_ids=node_ids) - # TODO: This belongs in emodpy-malaria, as that is the one disease that uses this set of parameters. - # Should be moved into a subclass of emodpy Demographics inside emodpy-malaria during a 2.0 conversion of it. - # https://github.com/EMOD-Hub/emodpy-malaria/issues/126 - # def set_innate_immune_distribution(self, - # distribution: BaseDistribution, - # innate_immune_variation_type: str, - # node_ids: list[int] = None) -> None: - # """ - # Sets a innate immune distribution on the demographics object. Automatically handles any necessary config - # updates. - # - # Args: - # distribution: The distribution to set. Must be a BaseDistribution object for a simple distribution. - # innate_immune_variation_type: the variation type to configure in EMOD. Must be either CYTOKINE_KILLING - # or PYROGENIC_THRESHOLD to be compatible with setting a innate immune distribution. - # node_ids: The node id(s) to apply changes to. None or 0 means the default node. - # - # Returns: - # Nothing - # """ - # from emod_api.demographics.implicit_functions import _set_immune_variation_type_cytokine_killing, \ - # _set_immune_variation_type_pyrogenic_threshold - # - # valid_types = [self.CYTOKINE_KILLING, self.PYROGENIC_THRESHOLD] - # if innate_immune_variation_type == self.CYTOKINE_KILLING: - # implicits = [_set_immune_variation_type_cytokine_killing] - # elif innate_immune_variation_type == self.PYROGENIC_THRESHOLD: - # implicits = [_set_immune_variation_type_pyrogenic_threshold] - # else: - # valid_types_str = ', '.join(valid_types) - # raise ValueError(f'innate_immune_variation_type must be one of: {valid_types_str} ... to allow use of a ' - # f'distribution.') - # - # self._set_distribution(distribution=distribution, - # use_case='innate_immune', - # simple_distribution_implicits=implicits, - # node_ids=node_ids) - # # These distribution setters only accept complex distributions # @@ -565,3 +567,58 @@ def add_individual_property(self, raise ValueError(f"Property key '{property}' already present in IndividualProperties list") node.individual_properties.add(individual_property=individual_property, overwrite=overwrite_existing) + + def add_node_property(self, + property: str, + values: list[str], + initial_distribution: list[float] = None, + overwrite_existing: bool = False) -> None: + """ + Adds a new node property to the demographics object. + + Node properties are top-level in the demographics file and define property labels + on nodes that can be used for identifying and targeting subsets of nodes in campaign + elements and reports. For example, nodes may be given a property ('Place') with + values like 'URBAN' or 'RURAL'. + + Each node is randomly assigned a value from the ``initial_distribution`` at + initialization. To override the drawn value for specific nodes, use + ``set_node_property_values``. + + Args: + property: A node property key to add (e.g. ``'Place'``). + values: A list of valid string values for the property (e.g. ``['URBAN', 'RURAL']``). + initial_distribution: The fractional (0 to 1) initial distribution of each value. + Order must match the values argument. Must sum to 1. + overwrite_existing: When True, overwrites an existing node property with the same + key. If False, raises an exception if the property already exists. + + Returns: + None + """ + node_property = NodeProperty(property=property, + values=values, + initial_distribution=initial_distribution) + self.node_properties.add(node_property=node_property, overwrite=overwrite_existing) + + def set_node_property_values(self, + node_ids: list[int], + values: list[str]) -> None: + """ + Set per-node ``NodePropertyValues`` overrides inside ``NodeAttributes``. + + When a node has ``NodePropertyValues`` set, those values override whatever was + drawn from the ``Initial_Distribution`` of the top-level ``NodeProperties``. + + Args: + node_ids: The node ids to apply the overrides to. Must be specific node ids + (not None/0 default node). + values: A list of ``"Property:Value"`` strings (e.g. + ``["Place:RURAL", "InterventionStatus:SPRAYED_B"]``). + + Returns: + None + """ + nodes = self.get_nodes_by_id(node_ids=node_ids) + for _, node in nodes.items(): + node.node_attributes.node_property_values = values diff --git a/emod_api/demographics/implicit_functions.py b/emod_api/demographics/implicit_functions.py index 15ad72bf..b079f72e 100644 --- a/emod_api/demographics/implicit_functions.py +++ b/emod_api/demographics/implicit_functions.py @@ -100,13 +100,10 @@ def _set_fertility_age_year(config): return config -def _set_population_dependent_birth_rate(config): - config.parameters.Birth_Rate_Dependence = "POPULATION_DEP_RATE" +def _set_birth_rate_dependence(config, birth_rate_dependence): + config.parameters.Birth_Rate_Dependence = str(birth_rate_dependence) return config -# Risk - -def _set_enable_demog_risk(config): - config.parameters.Enable_Demographics_Risk = 1 - return config +def _set_population_dependent_birth_rate(config): + return _set_birth_rate_dependence(config, BirthRateDependence.POPULATION_DEP_RATE) diff --git a/emod_api/demographics/node.py b/emod_api/demographics/node.py index b71f1bd0..a4bdd5b6 100644 --- a/emod_api/demographics/node.py +++ b/emod_api/demographics/node.py @@ -371,47 +371,6 @@ def _set_mortality_male_complex_distribution(self, distribution: MortalityDistri """ self.individual_attributes.mortality_distribution_male = distribution - # malaria only - # TODO: Move to emodpy-malaria? - # https://github.com/InstituteforDiseaseModeling/emodpy-malaria-old/issues/707 - def _set_innate_immune_simple_distribution(self, flag: int, value1: float, value2: float): - """ - Properly sets a simple innate immune distribution. For details on the simple distribution flag and value - meanings, see: - https://docs.idmod.org/projects/emod-generic/en/latest/parameter-demographics.html#simple-distributions - - Args: - flag: simple distribution flag determines the type of simple distribution to use - value1: simple distribution type-dependent parameter number 1 - value2: simple distribution type-dependent parameter number 2 - - Returns: - Nothing - """ - self.individual_attributes.innate_immune_distribution_flag = flag - self.individual_attributes.innate_immune_distribution1 = value1 - self.individual_attributes.innate_immune_distribution2 = value2 - - # malaria only - # TODO: Move to emodpy-malaria? - # https://github.com/InstituteforDiseaseModeling/emodpy-malaria-old/issues/707 - def _set_risk_simple_distribution(self, flag: int, value1: float, value2: float): - """ - Properly sets a simple risk distribution. For details on the simple distribution flag and value meanings, see: - https://docs.idmod.org/projects/emod-generic/en/latest/parameter-demographics.html#simple-distributions - - Args: - flag: simple distribution flag determines the type of simple distribution to use - value1: simple distribution type-dependent parameter number 1 - value2: simple distribution type-dependent parameter number 2 - - Returns: - Nothing - """ - self.individual_attributes.risk_distribution_flag = flag - self.individual_attributes.risk_distribution1 = value1 - self.individual_attributes.risk_distribution2 = value2 - # HIV only def _set_fertility_complex_distribution(self, distribution: FertilityDistribution): """ diff --git a/emod_api/demographics/properties_and_attributes.py b/emod_api/demographics/properties_and_attributes.py index 1e0ef5bd..024acc37 100644 --- a/emod_api/demographics/properties_and_attributes.py +++ b/emod_api/demographics/properties_and_attributes.py @@ -5,7 +5,7 @@ from emod_api.demographics.fertility_distribution import FertilityDistribution from emod_api.demographics.implicit_functions import _set_age_simple, _set_age_complex, _set_suscept_simple, \ _set_suscept_complex, _set_init_prev, _set_migration_model_fixed_rate, _set_enable_migration_model_heterogeneity, \ - _set_enable_natural_mortality, _set_mortality_age_gender_year, _set_mortality_age_gender, _set_enable_demog_risk, \ + _set_enable_natural_mortality, _set_mortality_age_gender_year, _set_mortality_age_gender, \ _set_fertility_age_year from emod_api.demographics.mortality_distribution import MortalityDistribution from emod_api.demographics.susceptibility_distribution import SusceptibilityDistribution @@ -208,6 +208,114 @@ def __len__(self): return len(self.individual_properties) +class NodeProperty(Updateable): + def __init__(self, + property: str, + values: list[str], + initial_distribution: list[float] = None): + """ + A node-level property for EMOD simulations. + + Node properties act as labels on nodes that can be used for identifying and targeting + subpopulations of nodes in campaign elements and reports. For example, nodes may be given + a property ('Place') with values like 'Urban' or 'Rural'. + + Note: EMOD requires node property keys and values (property and values args) to be the + same across all nodes. The initial distributions can vary across nodes. + + Args: + property: The node property key (e.g. 'Place'). + values: A list of valid string values for the property (e.g. ['Urban', 'Rural']). + initial_distribution: The fractional (0 to 1) initial distribution of each value. + Order must match the values argument. Must sum to 1. + """ + super().__init__() + if initial_distribution: + for i in initial_distribution: + if i < 0 or i > 1: + raise ValueError("initial_distribution values must be between 0 and 1.") + if sum(initial_distribution) != 1: + raise ValueError("initial_distribution values must sum to 1.") + if len(initial_distribution) != len(values): + raise ValueError("initial_distribution must have the same number of entries as values.") + + self.property = property + self.values = values + self.initial_distribution = initial_distribution + + def to_dict(self) -> dict: + node_property = self.parameter_dict + node_property.update({"Property": self.property, "Values": self.values}) + if self.initial_distribution: + node_property.update({"Initial_Distribution": self.initial_distribution}) + return node_property + + @classmethod + def from_dict(cls, np_dict: dict) -> '__class__': + return cls(property=np_dict["Property"], + values=np_dict["Values"], + initial_distribution=np_dict.get("Initial_Distribution")) + + def __eq__(self, other) -> bool: + return self.to_dict() == other.to_dict() + + +class NodeProperties(Updateable): + """ + A container class for holding NodeProperty objects used by Node objects. + """ + + class DuplicateNodePropertyException(Exception): + pass + + class NoSuchNodePropertyException(Exception): + pass + + def __init__(self, node_properties: list[NodeProperty] = None): + super().__init__() + self.node_properties = [] if node_properties is None else node_properties + + def add(self, node_property: NodeProperty, overwrite=False) -> None: + has_np = self.has_node_property(property_key=node_property.property) + if has_np: + if overwrite: + self.remove_node_property(property_key=node_property.property) + else: + msg = f"Property {node_property.property} already present in NodeProperties" + raise self.DuplicateNodePropertyException(msg) + self.node_properties.append(node_property) + + def add_parameter(self, key, value): + raise NotImplementedError("A parameter cannot be added to NodeProperties.") + + @property + def np_by_name(self): + return {np.property: np for np in self.node_properties} + + def has_node_property(self, property_key: str) -> bool: + return property_key in self.np_by_name.keys() + + def get_node_property(self, property_key: str) -> NodeProperty: + np = self.np_by_name.get(property_key, None) + if np is None: + msg = f"No NodeProperty exists with the property key: {property_key}" + raise self.NoSuchNodePropertyException(msg) + return np + + def remove_node_property(self, property_key: str): + nps_to_keep = [np for np in self.node_properties if np.property != property_key] + self.node_properties = nps_to_keep + + def to_dict(self) -> list[dict]: + return [np.to_dict() for np in self.node_properties] + + def __getitem__(self, index: int): + return self.node_properties[index] + + def __len__(self): + return len(self.node_properties) + + class IndividualAttributes(Updateable): # TODO: consider refactoring to use objects instead of a big list of potential parameters here: # https://github.com/InstituteforDiseaseModeling/emod-api-old/issues/750 @@ -223,18 +331,12 @@ def __init__(self, prevalence_distribution_flag: int = None, prevalence_distribution1: int = None, prevalence_distribution2: int = None, - risk_distribution_flag: int = None, - risk_distribution1: int = None, - risk_distribution2: int = None, migration_heterogeneity_distribution_flag: int = None, migration_heterogeneity_distribution1: int = None, migration_heterogeneity_distribution2: int = None, fertility_distribution: FertilityDistribution = None, mortality_distribution_male: MortalityDistribution = None, mortality_distribution_female: MortalityDistribution = None, - innate_immune_distribution_flag: int = None, - innate_immune_distribution1: int = None, - innate_immune_distribution2: int = None ): """ Defines the initial distribution of attributes for model agents for all disease setups. These are used by Node @@ -354,16 +456,6 @@ class in emodpy.demographics . self.fertility_distribution = fertility_distribution - # risk and innate_immune are only used by malaria - - self.risk_distribution_flag = risk_distribution_flag - self.risk_distribution1 = risk_distribution1 - self.risk_distribution2 = risk_distribution2 - - self.innate_immune_distribution_flag = innate_immune_distribution_flag - self.innate_immune_distribution1 = innate_immune_distribution1 - self.innate_immune_distribution2 = innate_immune_distribution2 - # New names for by-gender mortality distributions to support emodpy Demographics setting of all distributions # using the same code (see properties here). @@ -455,27 +547,6 @@ def to_dict(self) -> dict: value2_key="MigrationHeterogeneityDistribution2") individual_attributes.update(migration_heterogeneity_distribution_dict) - # malaria only - possible to move this to emodpy-malaria in the future if desired. - if self.risk_distribution_flag is not None: - risk_distribution_dict = { - "RiskDistributionFlag": self.risk_distribution_flag, - "RiskDistribution1": self.risk_distribution1, - "RiskDistribution2": self.risk_distribution2 - } - self._ensure_valid_value2_value(distribution_dict=risk_distribution_dict, value2_key="RiskDistribution2") - individual_attributes.update(risk_distribution_dict) - - # malaria only - possible to move this to emodpy-malaria in the future if desired. - if self.innate_immune_distribution_flag is not None: - innate_immune_distribution_dict = { - "InnateImmuneDistributionFlag": self.innate_immune_distribution_flag, - "InnateImmuneDistribution1": self.innate_immune_distribution1, - "InnateImmuneDistribution2": self.innate_immune_distribution2 - } - self._ensure_valid_value2_value(distribution_dict=innate_immune_distribution_dict, - value2_key="InnateImmuneDistribution2") - individual_attributes.update(innate_immune_distribution_dict) - # The following distributions can only be complex, not simple if self.fertility_distribution is not None: @@ -568,24 +639,6 @@ def from_dict(self, individual_attributes: dict) -> Tuple["IndividualAttributes" self.mortality_distribution = MortalityDistribution.from_dict(distribution_dict=distribution_dict) implicit_functions.extend([_set_enable_natural_mortality, _set_mortality_age_gender]) - # malaria only - possible to move this to emodpy-malaria in the future if desired. - self.innate_immune_distribution_flag = individual_attributes.get("InnateImmuneDistributionFlag", None) - self.innate_immune_distribution1 = individual_attributes.get("InnateImmuneDistribution1", None) - self.innate_immune_distribution2 = individual_attributes.get("InnateImmuneDistribution2", None) - if self.innate_immune_distribution_flag is not None: - import warnings - warnings.warn("InnateImmuneDistribution loaded by file. Pyrogenic vs. cytokine-killing vs NONE (ignore) is " - "unknown. Config may need updating to ensure parameter Innate_Immune_Variation_Type is set " - "properly.", - Warning, stacklevel=2) - - # malaria only - possible to move this to emodpy-malaria in the future if desired. - self.risk_distribution_flag = individual_attributes.get("RiskDistributionFlag", None) - self.risk_distribution1 = individual_attributes.get("RiskDistribution1", None) - self.risk_distribution2 = individual_attributes.get("RiskDistribution2", None) - if self.risk_distribution_flag is not None: - implicit_functions.append(_set_enable_demog_risk) - distribution_dict = individual_attributes.get("FertilityDistribution", None) if distribution_dict is None: self.fertility_distribution = None @@ -611,6 +664,7 @@ def __init__(self, larval_habitat_multiplier: Optional[list[float]] = None, initial_vectors_per_species: Union[dict, int, None] = None, infectivity_multiplier: float = None, + node_property_values: list[str] = None, extra_attributes: dict = None): """ Defines node-specific attributes for all disease setups, utilized by Node objects. @@ -635,6 +689,8 @@ def __init__(self, initial_vectors_per_species ((dict or int), optional): The initial number of vectors per species in the node. infectivity_multiplier (float, optional): TODO: unknown + node_property_values (list[str], optional): Per-node overrides for node property values. + Each entry is a ``"Property:Value"`` string (e.g. ``"Place:RURAL"``). extra_attributes (dict, optional): An arbitrary dict of attribute key/values to add to the node. """ super().__init__() @@ -651,6 +707,7 @@ def __init__(self, self.metadata = metadata self.name = name self.infectivity_multiplier = infectivity_multiplier + self.node_property_values = node_property_values self.extra_attributes = extra_attributes def from_dict(self, node_attributes: dict): @@ -667,6 +724,7 @@ def from_dict(self, node_attributes: dict): self.metadata = node_attributes.get("Metadata") self.name = node_attributes.get("FacilityName") self.infectivity_multiplier = node_attributes.get("InfectivityMultiplier") + self.node_property_values = node_attributes.get("NodePropertyValues") # Legacy keys key_list = ["Airport", "Region", "Seaport"] @@ -720,6 +778,9 @@ def to_dict(self) -> dict: if self.infectivity_multiplier is not None: node_attributes.update({"InfectivityMultiplier": self.infectivity_multiplier}) + if self.node_property_values is not None: + node_attributes.update({"NodePropertyValues": self.node_property_values}) + if self.extra_attributes is not None: node_attributes.update(self.extra_attributes) diff --git a/emod_api/utils/emod_enum.py b/emod_api/utils/emod_enum.py new file mode 100644 index 00000000..f34563c5 --- /dev/null +++ b/emod_api/utils/emod_enum.py @@ -0,0 +1,13 @@ +from emod_api.utils.str_enum import StrEnum + +class BirthRateDependence(StrEnum): + """How BirthRate from the demographics file is interpreted by EMOD. + + Only modes that USE the BirthRate value are included here. + to use ``INDIVIDUAL_PREGNANCIES_BY_AGE_AND_YEAR`` + use FertilityDistribution instead + """ + FIXED_BIRTH_RATE = "FIXED_BIRTH_RATE" + POPULATION_DEP_RATE = "POPULATION_DEP_RATE" + DEMOGRAPHIC_DEP_RATE = "DEMOGRAPHIC_DEP_RATE" + INDIVIDUAL_PREGNANCIES = "INDIVIDUAL_PREGNANCIES" \ No newline at end of file diff --git a/tests/data/demographics/Namawala_four_node_demographics_for_Thomas.json b/tests/data/demographics/Namawala_four_node_demographics_for_Thomas.json index 244ce9e4..76e08f6b 100644 --- a/tests/data/demographics/Namawala_four_node_demographics_for_Thomas.json +++ b/tests/data/demographics/Namawala_four_node_demographics_for_Thomas.json @@ -45,9 +45,6 @@ "PrevalenceDistribution1": 0, "PrevalenceDistribution2": 0, "PrevalenceDistributionFlag": 0, - "RiskDistribution1": 1, - "RiskDistribution2": 0, - "RiskDistributionFlag": 0, "SusceptibilityDistribution1": 1, "SusceptibilityDistribution2": 0, "SusceptibilityDistributionFlag": 0 @@ -78,8 +75,8 @@ "Airport": 0, "Altitude": 0, "BirthRate": 0.0002, - "InitialPopulation": 1000, "FacilityName": "default_node", + "InitialPopulation": 1000, "InitialVectorsPerSpecies": { "arabiensis": 500, "funestus": 500, @@ -97,7 +94,7 @@ "Region": 1, "Seaport": 0 }, - "NodeID": 0 + "NodeID": 0 }, "Metadata": { "Author": "jsteinkraus", @@ -113,9 +110,6 @@ "AgeDistribution1": 0.000118, "AgeDistribution2": 0, "AgeDistributionFlag": 3, - "SusceptibilityDistribution1": 1, - "SusceptibilityDistribution2": 0, - "SusceptibilityDistributionFlag": 0, "MigrationHeterogeneityDistribution1": 1, "MigrationHeterogeneityDistribution2": 0, "MigrationHeterogeneityDistributionFlag": 0, @@ -157,9 +151,9 @@ "PrevalenceDistribution1": 0.1, "PrevalenceDistribution2": 0.2, "PrevalenceDistributionFlag": 1, - "RiskDistribution1": 1, - "RiskDistribution2": 0, - "RiskDistributionFlag": 0 + "SusceptibilityDistribution1": 1, + "SusceptibilityDistribution2": 0, + "SusceptibilityDistributionFlag": 0 }, "IndividualProperties": [ { @@ -208,9 +202,6 @@ "AgeDistribution1": 0.000118, "AgeDistribution2": 0, "AgeDistributionFlag": 3, - "SusceptibilityDistribution1": 1, - "SusceptibilityDistribution2": 0, - "SusceptibilityDistributionFlag": 0, "MigrationHeterogeneityDistribution1": 1, "MigrationHeterogeneityDistribution2": 0, "MigrationHeterogeneityDistributionFlag": 0, @@ -252,9 +243,9 @@ "PrevalenceDistribution1": 0.1, "PrevalenceDistribution2": 0.2, "PrevalenceDistributionFlag": 1, - "RiskDistribution1": 1, - "RiskDistribution2": 0, - "RiskDistributionFlag": 0 + "SusceptibilityDistribution1": 1, + "SusceptibilityDistribution2": 0, + "SusceptibilityDistributionFlag": 0 }, "IndividualProperties": [ { @@ -301,9 +292,6 @@ "AgeDistribution1": 0.003, "AgeDistribution2": 0, "AgeDistributionFlag": 3, - "SusceptibilityDistribution1": 1, - "SusceptibilityDistribution2": 0, - "SusceptibilityDistributionFlag": 0, "MigrationHeterogeneityDistribution1": 1, "MigrationHeterogeneityDistribution2": 0, "MigrationHeterogeneityDistributionFlag": 0, @@ -345,9 +333,9 @@ "PrevalenceDistribution1": 0.1, "PrevalenceDistribution2": 0.2, "PrevalenceDistributionFlag": 1, - "RiskDistribution1": 1, - "RiskDistribution2": 0, - "RiskDistributionFlag": 0 + "SusceptibilityDistribution1": 1, + "SusceptibilityDistribution2": 0, + "SusceptibilityDistributionFlag": 0 }, "NodeAttributes": { "Airport": 0, @@ -372,9 +360,6 @@ "AgeDistribution1": 0.000118, "AgeDistribution2": 0, "AgeDistributionFlag": 3, - "SusceptibilityDistribution1": 1, - "SusceptibilityDistribution2": 0, - "SusceptibilityDistributionFlag": 0, "MigrationHeterogeneityDistribution1": 1, "MigrationHeterogeneityDistribution2": 0, "MigrationHeterogeneityDistributionFlag": 0, @@ -416,9 +401,9 @@ "PrevalenceDistribution1": 0.1, "PrevalenceDistribution2": 0.2, "PrevalenceDistributionFlag": 1, - "RiskDistribution1": 1, - "RiskDistribution2": 0, - "RiskDistributionFlag": 0 + "SusceptibilityDistribution1": 1, + "SusceptibilityDistribution2": 0, + "SusceptibilityDistributionFlag": 0 }, "NodeAttributes": { "Airport": 0, @@ -439,4 +424,4 @@ "NodeID": 340461479 } ] -} \ No newline at end of file +} diff --git a/tests/data/demographics/demographics_test_from_file_sets_necessary_simple_distribution_implicit_functions.json b/tests/data/demographics/demographics_test_from_file_sets_necessary_simple_distribution_implicit_functions.json index 23c9a010..83bd246b 100644 --- a/tests/data/demographics/demographics_test_from_file_sets_necessary_simple_distribution_implicit_functions.json +++ b/tests/data/demographics/demographics_test_from_file_sets_necessary_simple_distribution_implicit_functions.json @@ -4,18 +4,12 @@ "AgeDistribution1": 0.1, "AgeDistribution2": 0, "AgeDistributionFlag": 3, - "InnateImmuneDistribution1": 0.1, - "InnateImmuneDistribution2": 0, - "InnateImmuneDistributionFlag": 3, "MigrationHeterogeneityDistribution1": 0.1, "MigrationHeterogeneityDistribution2": 0, "MigrationHeterogeneityDistributionFlag": 3, "PrevalenceDistribution1": 0.1, "PrevalenceDistribution2": 0, "PrevalenceDistributionFlag": 3, - "RiskDistribution1": 0.1, - "RiskDistribution2": 0, - "RiskDistributionFlag": 3, "SusceptibilityDistribution1": 0.1, "SusceptibilityDistribution2": 0, "SusceptibilityDistributionFlag": 3 @@ -40,4 +34,4 @@ "Tool": "emod-api" }, "Nodes": [] -} \ No newline at end of file +} diff --git a/tests/data/demographics/single_node_demographics.json b/tests/data/demographics/single_node_demographics.json index 1676c256..30edd3a9 100644 --- a/tests/data/demographics/single_node_demographics.json +++ b/tests/data/demographics/single_node_demographics.json @@ -1,13 +1,10 @@ { "Defaults": { "IndividualAttributes": { - "InnateImmuneDistribution1": 1, - "InnateImmuneDistribution2": 0, - "InnateImmuneDistributionFlag": 0, "MigrationHeterogeneityDistribution1": 1, "MigrationHeterogeneityDistribution2": 0, "MigrationHeterogeneityDistributionFlag": 0, - "MortalityDistributionMale": { + "MortalityDistributionFemale": { "AxisNames": [ "age", "year" @@ -29,14 +26,14 @@ "ResultUnits": "annual death rate for an individual", "ResultValues": [ [ - 23.0 + 24.0 ], [ - 23.0 + 24.0 ] ] }, - "MortalityDistributionFemale": { + "MortalityDistributionMale": { "AxisNames": [ "age", "year" @@ -58,33 +55,30 @@ "ResultUnits": "annual death rate for an individual", "ResultValues": [ [ - 24.0 + 23.0 ], [ - 24.0 + 23.0 ] ] }, "PrevalenceDistribution1": 0.13, "PrevalenceDistribution2": 0.15, - "PrevalenceDistributionFlag": 1, - "RiskDistribution1": 1, - "RiskDistribution2": 0, - "RiskDistributionFlag": 0 + "PrevalenceDistributionFlag": 1 }, "NodeAttributes": { "Airport": 0, "Altitude": 0, - "Region": 1, - "Seaport": 0, - "FacilityName": "default_node", "BirthRate": 36.30287175670945, + "FacilityName": "default_node", "Metadata": { - "AbovePoverty": 0.5, - "BirthRateSource": "World Bank", - "World Bank Year": "2016", - "Urban": 0 - } + "AbovePoverty": 0.5, + "BirthRateSource": "World Bank", + "Urban": 0, + "World Bank Year": "2016" + }, + "Region": 1, + "Seaport": 0 }, "NodeID": 0 }, @@ -101,145 +95,116 @@ "IndividualAttributes": { "AgeDistribution": { "DistributionValues": [ - 0.0, - 0.09283128862223852, - 0.1772211025454774, - 0.25393702279895114, - 0.32367683741088754, - 0.3870071913378208, - 0.4446451552016249, - 0.49704189320403297, - 0.5446739889406104, - 0.5879746921806008, - 0.6273377246229421, - 0.6631206791912141, - 0.6956497619307801, - 0.7251893347323798, - 0.7520742768551991, - 0.7765142764013582, - 0.7987315336335542, - 0.8189284949356203, - 0.8372888655321308, - 0.8539796470598181, - 0.8691525528112302, - 0.8829308140781711, - 0.8954708922721066, - 0.9068706500664176, - 0.9172337771432669, - 0.9266544386796152, - 0.9352183517697102, - 0.9430035165849779, - 0.9500807446461047, - 0.9565075587977739, - 0.9623567260094198, - 0.9676739687563408, - 0.9725076946297527, - 0.9769018698125737, - 0.980896462811316, - 0.9845277584117492, - 0.9878288232462242, - 0.9908265109620692, - 0.9935548072877166, - 0.9960350071420613, - 0.9982896317086243, - 1.00000000000000 + 0.0, + 0.09283128862223852, + 0.1772211025454774, + 0.25393702279895114, + 0.32367683741088754, + 0.3870071913378208, + 0.4446451552016249, + 0.49704189320403297, + 0.5446739889406104, + 0.5879746921806008, + 0.6273377246229421, + 0.6631206791912141, + 0.6956497619307801, + 0.7251893347323798, + 0.7520742768551991, + 0.7765142764013582, + 0.7987315336335542, + 0.8189284949356203, + 0.8372888655321308, + 0.8539796470598181, + 0.8691525528112302, + 0.8829308140781711, + 0.8954708922721066, + 0.9068706500664176, + 0.9172337771432669, + 0.9266544386796152, + 0.9352183517697102, + 0.9430035165849779, + 0.9500807446461047, + 0.9565075587977739, + 0.9623567260094198, + 0.9676739687563408, + 0.9725076946297527, + 0.9769018698125737, + 0.980896462811316, + 0.9845277584117492, + 0.9878288232462242, + 0.9908265109620692, + 0.9935548072877166, + 0.9960350071420613, + 0.9982896317086243, + 1.0 ], "ResultScaleFactor": 365.0, "ResultValues": [ - 0.0, - 2.4493150684931506, - 4.898630136986301, - 7.347945205479452, - 9.797260273972602, - 12.243835616438357, - 14.693150684931506, - 17.14246575342466, - 19.59178082191781, - 22.041095890410958, - 24.49041095890411, - 26.93972602739726, - 29.389041095890413, - 31.835616438356166, - 34.28493150684932, - 36.73424657534247, - 39.18356164383562, - 41.632876712328766, - 44.082191780821915, - 46.53150684931507, - 48.98082191780822, - 51.42739726027397, - 53.87671232876713, - 56.326027397260276, - 58.775342465753425, - 61.224657534246575, - 63.673972602739724, - 66.12328767123287, - 68.57260273972602, - 71.01917808219179, - 73.46849315068494, - 75.91780821917808, - 78.36712328767123, - 80.81643835616438, - 83.26575342465753, - 85.71506849315068, - 88.16438356164383, - 90.6109589041096, - 93.06027397260274, - 95.5095890410959, - 97.95890410958904, - 100.40821917808219 + 0.0, + 2.4493150684931506, + 4.898630136986301, + 7.347945205479452, + 9.797260273972602, + 12.243835616438357, + 14.693150684931506, + 17.14246575342466, + 19.59178082191781, + 22.041095890410958, + 24.49041095890411, + 26.93972602739726, + 29.389041095890413, + 31.835616438356166, + 34.28493150684932, + 36.73424657534247, + 39.18356164383562, + 41.632876712328766, + 44.082191780821915, + 46.53150684931507, + 48.98082191780822, + 51.42739726027397, + 53.87671232876713, + 56.326027397260276, + 58.775342465753425, + 61.224657534246575, + 63.673972602739724, + 66.12328767123287, + 68.57260273972602, + 71.01917808219179, + 73.46849315068494, + 75.91780821917808, + 78.36712328767123, + 80.81643835616438, + 83.26575342465753, + 85.71506849315068, + 88.16438356164383, + 90.6109589041096, + 93.06027397260274, + 95.5095890410959, + 97.95890410958904, + 100.40821917808219 ] }, - "MortalityDistributionMale": { - "AxisNames": [ - "age", - "year" - ], - "AxisScaleFactors": [ - 365.0, - 1 - ], - "PopulationGroups": [ - [ - 0, - 1 - ], - [ - 2020 - ] - ], - "ResultScaleFactor": 0.0027397260273972603, - "ResultUnits": "annual death rate for an individual", - "ResultValues": [ + "MortalityDistributionFemale": { + "AxisNames": [ + "age", + "year" + ], + "AxisScaleFactors": [ + 365.0, + 1 + ], + "PopulationGroups": [ [ - 38.92 + 0, + 1 ], [ - 38.92 + 2020 ] - ] - }, - "MortalityDistributionFemale": { - "AxisNames": [ - "age", - "year" - ], - "AxisScaleFactors": [ - 365.0, - 1 - ], - "PopulationGroups": [ - [ - 0, - 1 - ], - [ - 2020 - ] - ], - "ResultScaleFactor": 0.0027397260273972603, - "ResultUnits": "annual death rate for an individual", - "ResultValues": [ + ], + "ResultScaleFactor": 0.0027397260273972603, + "ResultUnits": "annual death rate for an individual", + "ResultValues": [ [ 39.92 ], @@ -247,8 +212,36 @@ 39.92 ] ] + }, + "MortalityDistributionMale": { + "AxisNames": [ + "age", + "year" + ], + "AxisScaleFactors": [ + 365.0, + 1 + ], + "PopulationGroups": [ + [ + 0, + 1 + ], + [ + 2020 + ] + ], + "ResultScaleFactor": 0.0027397260273972603, + "ResultUnits": "annual death rate for an individual", + "ResultValues": [ + [ + 38.92 + ], + [ + 38.92 + ] + ] } - }, "NodeAttributes": { "BirthRate": 0.10663013698630137, @@ -266,4 +259,4 @@ "NodeID": 1 } ] -} \ No newline at end of file +} diff --git a/tests/test_demographics.py b/tests/test_demographics.py index 96a22828..f9d7e7ef 100644 --- a/tests/test_demographics.py +++ b/tests/test_demographics.py @@ -18,7 +18,8 @@ from emod_api.demographics.node import Node from emod_api.demographics.overlay_node import OverlayNode from emod_api.demographics.properties_and_attributes import (IndividualAttributes, IndividualProperty, - IndividualProperties, NodeAttributes) + IndividualProperties, NodeAttributes, + NodeProperty, NodeProperties) from emod_api.demographics.susceptibility_distribution import SusceptibilityDistribution from emod_api.utils.distributions.exponential_distribution import ExponentialDistribution from emod_api.utils.distributions.gaussian_distribution import GaussianDistribution @@ -273,7 +274,6 @@ def test_demo_node(self): def test_from_file_sets_necessary_simple_distribution_implicit_functions(self): from emod_api.demographics.implicit_functions import _set_age_simple - from emod_api.demographics.implicit_functions import _set_enable_demog_risk from emod_api.demographics.implicit_functions import _set_enable_migration_model_heterogeneity from emod_api.demographics.implicit_functions import _set_init_prev from emod_api.demographics.implicit_functions import _set_migration_model_fixed_rate @@ -281,8 +281,9 @@ def test_from_file_sets_necessary_simple_distribution_implicit_functions(self): input_filepath = Path(manifest.demo_folder, "demographics_test_from_file_sets_necessary_simple_distribution_implicit_functions.json") + # Note: _set_enable_demog_risk was removed — risk distribution is now handled by emodpy-malaria's MalariaNode expected_implicits = [_set_age_simple, _set_suscept_simple, _set_init_prev, _set_migration_model_fixed_rate, - _set_enable_migration_model_heterogeneity, _set_enable_demog_risk] + _set_enable_migration_model_heterogeneity] # # Use this if needing to regenerate the input_filepath # default_node = Node(lat=0, lon=0, pop=1000, forced_id=0) @@ -628,6 +629,7 @@ def test_all_members_to_dict(self): larval_habitat_multiplier=[{"Larval": 123}], initial_vectors_per_species=123, infectivity_multiplier=0.5, + node_property_values=["Place:RURAL"], extra_attributes={"Test_Parameter_1": 123}) node_attributes.add_parameter("Test_Parameter_2", 123) @@ -1199,20 +1201,82 @@ def setUp(self): self.distribution = ExponentialDistribution(mean=0.0001) def test_set_birth_rate(self): - # ok, this isn't a simple distribution, but it needs to be tested and the code is adjacent - # to the following simple distribution tests in demographics.py - from emod_api.demographics.implicit_functions import _set_population_dependent_birth_rate - rate = 50 # births/year/1000 women self.demographics.set_birth_rate(rate=rate) - expected = 50 / 365 / 1000 # birth rate is auto-converted to what EMOD uses: births/day/woman + expected = 50 / 365 / 1000 expected_delta = expected * 1e-6 self.assertAlmostEqual(self.demographics.default_node.birth_rate, expected, delta=expected_delta) self.assertAlmostEqual(self.demographics.default_node.node_attributes.birth_rate, expected, delta=expected_delta) self.assertEqual(len(self.demographics.implicits), 1) - self.assertIn(_set_population_dependent_birth_rate, self.demographics.implicits) + # default should set POPULATION_DEP_RATE + class MockConfig: + class parameters: + Birth_Rate_Dependence = None + self.demographics.implicits[0](MockConfig()) + self.assertEqual(MockConfig.parameters.Birth_Rate_Dependence, "POPULATION_DEP_RATE") + + def test_set_birth_rate_individual_pregnancies(self): + from emod_api.utils.emod_enum import BirthRateDependence + + self.demographics.set_birth_rate( + rate=4, + birth_rate_dependence=BirthRateDependence.INDIVIDUAL_PREGNANCIES, + ) + + expected = 4 / 365 / 8 # INDIVIDUAL_PREGNANCIES converts per 8 fertile women per year + expected_delta = expected * 1e-6 + self.assertAlmostEqual(self.demographics.default_node.birth_rate, expected, delta=expected_delta) + + self.assertEqual(len(self.demographics.implicits), 1) + class MockConfig: + class parameters: + Birth_Rate_Dependence = None + self.demographics.implicits[0](MockConfig()) + self.assertEqual(MockConfig.parameters.Birth_Rate_Dependence, "INDIVIDUAL_PREGNANCIES") + + def test_set_birth_rate_demographic_dep(self): + self.demographics.set_birth_rate(rate=2, birth_rate_dependence="DEMOGRAPHIC_DEP_RATE") + + expected = 2 / 365 / 8 + expected_delta = expected * 1e-6 + self.assertAlmostEqual(self.demographics.default_node.birth_rate, expected, delta=expected_delta) + + self.assertEqual(len(self.demographics.implicits), 1) + class MockConfig: + class parameters: + Birth_Rate_Dependence = None + self.demographics.implicits[0](MockConfig()) + self.assertEqual(MockConfig.parameters.Birth_Rate_Dependence, "DEMOGRAPHIC_DEP_RATE") + + def test_set_birth_rate_fixed(self): + self.demographics.set_birth_rate(rate=12, birth_rate_dependence="FIXED_BIRTH_RATE") + + self.assertEqual(self.demographics.default_node.birth_rate, 12) + + self.assertEqual(len(self.demographics.implicits), 1) + class MockConfig: + class parameters: + Birth_Rate_Dependence = None + self.demographics.implicits[0](MockConfig()) + self.assertEqual(MockConfig.parameters.Birth_Rate_Dependence, "FIXED_BIRTH_RATE") + + def test_set_birth_rate_population_dep_max_exceeded(self): + with self.assertRaises(ValueError): + self.demographics.set_birth_rate(rate=1001) + + def test_set_birth_rate_demographic_dep_max_exceeded(self): + with self.assertRaises(ValueError): + self.demographics.set_birth_rate(rate=9, birth_rate_dependence="DEMOGRAPHIC_DEP_RATE") + + def test_set_birth_rate_individual_pregnancies_max_exceeded(self): + with self.assertRaises(ValueError): + self.demographics.set_birth_rate(rate=9, birth_rate_dependence="INDIVIDUAL_PREGNANCIES") + + def test_set_birth_rate_invalid_dependence(self): + with self.assertRaises(ValueError): + self.demographics.set_birth_rate(rate=12, birth_rate_dependence="INVALID") def test_set_age_distribution_simple(self): @@ -1372,3 +1436,98 @@ def test_simple_and_complex_susceptibility_distribution_specification_throws_an_ self.assertRaises(demog_ex.ConflictingDistributionsException, self.demographics.to_dict) + + +class NodePropertiesDemographicsTest(unittest.TestCase): + + def setUp(self): + nodes = [Node(lat=0, lon=0, pop=1000, forced_id=1), + Node(lat=1, lon=1, pop=2000, forced_id=2)] + self.demographics = Demographics(nodes=nodes) + + def test_add_node_property(self): + self.demographics.add_node_property( + property="Place", values=["RURAL", "URBAN"], initial_distribution=[0.6, 0.4]) + d = self.demographics.to_dict() + self.assertIn("NodeProperties", d) + self.assertEqual(len(d["NodeProperties"]), 1) + self.assertEqual(d["NodeProperties"][0]["Property"], "Place") + self.assertEqual(d["NodeProperties"][0]["Values"], ["RURAL", "URBAN"]) + self.assertEqual(d["NodeProperties"][0]["Initial_Distribution"], [0.6, 0.4]) + + def test_add_multiple_node_properties(self): + self.demographics.add_node_property( + property="Place", values=["RURAL", "URBAN"], initial_distribution=[0.6, 0.4]) + self.demographics.add_node_property( + property="Risk", values=["HIGH", "MED", "LOW"], initial_distribution=[0.1, 0.2, 0.7]) + d = self.demographics.to_dict() + self.assertEqual(len(d["NodeProperties"]), 2) + props = {np["Property"] for np in d["NodeProperties"]} + self.assertEqual(props, {"Place", "Risk"}) + + def test_add_duplicate_node_property_raises(self): + self.demographics.add_node_property( + property="Place", values=["RURAL", "URBAN"], initial_distribution=[0.6, 0.4]) + with self.assertRaises(NodeProperties.DuplicateNodePropertyException): + self.demographics.add_node_property( + property="Place", values=["A", "B"], initial_distribution=[0.5, 0.5]) + + def test_add_node_property_overwrite(self): + self.demographics.add_node_property( + property="Place", values=["RURAL", "URBAN"], initial_distribution=[0.6, 0.4]) + self.demographics.add_node_property( + property="Place", values=["CITY", "VILLAGE"], initial_distribution=[0.3, 0.7], + overwrite_existing=True) + d = self.demographics.to_dict() + self.assertEqual(len(d["NodeProperties"]), 1) + self.assertEqual(d["NodeProperties"][0]["Values"], ["CITY", "VILLAGE"]) + + def test_no_node_properties_omitted_from_dict(self): + d = self.demographics.to_dict() + self.assertNotIn("NodeProperties", d) + + def test_set_node_property_values(self): + self.demographics.add_node_property( + property="Place", values=["RURAL", "URBAN"], initial_distribution=[0.6, 0.4]) + self.demographics.set_node_property_values( + node_ids=[2], values=["Place:RURAL"]) + d = self.demographics.to_dict() + node2 = [n for n in d["Nodes"] if n["NodeID"] == 2][0] + self.assertEqual(node2["NodeAttributes"]["NodePropertyValues"], ["Place:RURAL"]) + node1 = [n for n in d["Nodes"] if n["NodeID"] == 1][0] + self.assertNotIn("NodePropertyValues", node1["NodeAttributes"]) + + def test_set_node_property_values_multiple(self): + self.demographics.add_node_property( + property="Place", values=["RURAL", "URBAN"], initial_distribution=[0.6, 0.4]) + self.demographics.add_node_property( + property="InterventionStatus", values=["NONE", "SPRAYED"], + initial_distribution=[0.5, 0.5]) + self.demographics.set_node_property_values( + node_ids=[1], values=["Place:RURAL", "InterventionStatus:SPRAYED"]) + d = self.demographics.to_dict() + node1 = [n for n in d["Nodes"] if n["NodeID"] == 1][0] + self.assertEqual(node1["NodeAttributes"]["NodePropertyValues"], + ["Place:RURAL", "InterventionStatus:SPRAYED"]) + + def test_roundtrip_with_reference_file(self): + ref_path = os.path.join( + "C:\\", "github", "emod1", "regression", "generic", + "57_Generic_NodeProperties", "demographics_node_properties.json") + if not os.path.isfile(ref_path): + self.skipTest("Reference demographics file not found") + import warnings + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + demog = Demographics.from_file(ref_path) + except Exception as e: + self.skipTest(f"Cannot load reference file (likely schema mismatch): {e}") + d = demog.to_dict() + self.assertIn("NodeProperties", d) + self.assertEqual(len(d["NodeProperties"]), 3) + props = {np["Property"] for np in d["NodeProperties"]} + self.assertEqual(props, {"Place", "Risk", "InterventionStatus"}) + node100 = [n for n in d["Nodes"] if n["NodeID"] == 100][0] + self.assertEqual(node100["NodeAttributes"]["NodePropertyValues"], + ["Place:RURAL", "InterventionStatus:SPRAYED_B"]) diff --git a/tests/test_node.py b/tests/test_node.py index 6799205e..e488d2e3 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -151,6 +151,32 @@ def test_node_property_birth_rate(self): node.birth_rate = 0.5 self.assertEqual(node.birth_rate, 0.5) + def test_node_property_values_in_to_dict(self): + na = NodeAttributes(node_property_values=["Place:RURAL", "Risk:HIGH"]) + node = Node(lat=0, lon=0, pop=100, node_attributes=na) + d = node.to_dict() + self.assertEqual(d["NodeAttributes"]["NodePropertyValues"], ["Place:RURAL", "Risk:HIGH"]) + + def test_node_property_values_not_in_dict_when_none(self): + node = Node(lat=0, lon=0, pop=100) + d = node.to_dict() + self.assertNotIn("NodePropertyValues", d["NodeAttributes"]) + + def test_node_property_values_set_directly(self): + node = Node(lat=0, lon=0, pop=100) + node.node_attributes.node_property_values = ["Place:URBAN"] + d = node.to_dict() + self.assertEqual(d["NodeAttributes"]["NodePropertyValues"], ["Place:URBAN"]) + + def test_node_property_values_from_data_roundtrip(self): + na = NodeAttributes(node_property_values=["Place:RURAL", "InterventionStatus:SPRAYED_B"]) + node = Node(lat=1, lon=2, pop=100, forced_id=1, node_attributes=na) + d = node.to_dict() + restored_node, _ = Node.from_data(d) + restored_d = restored_node.to_dict() + self.assertEqual(d["NodeAttributes"]["NodePropertyValues"], + restored_d["NodeAttributes"]["NodePropertyValues"]) + if __name__ == '__main__': unittest.main() diff --git a/tests/unittests/test_node_properties.py b/tests/unittests/test_node_properties.py new file mode 100644 index 00000000..2a6d53ff --- /dev/null +++ b/tests/unittests/test_node_properties.py @@ -0,0 +1,118 @@ +import unittest + +from emod_api.demographics.properties_and_attributes import NodeProperties, NodeProperty + + +class NodePropertyTest(unittest.TestCase): + + def test_to_dict(self): + np = NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.3, 0.7]) + d = np.to_dict() + self.assertEqual(d["Property"], "Place") + self.assertEqual(d["Values"], ["Urban", "Rural"]) + self.assertEqual(d["Initial_Distribution"], [0.3, 0.7]) + + def test_to_dict_without_distribution(self): + np = NodeProperty(property='Place', values=['Urban', 'Rural']) + d = np.to_dict() + self.assertEqual(d["Property"], "Place") + self.assertEqual(d["Values"], ["Urban", "Rural"]) + self.assertNotIn("Initial_Distribution", d) + + def test_from_dict(self): + original = NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.4, 0.6]) + restored = NodeProperty.from_dict(original.to_dict()) + self.assertEqual(original, restored) + + def test_from_dict_without_distribution(self): + original = NodeProperty(property='Place', values=['Urban', 'Rural']) + restored = NodeProperty.from_dict(original.to_dict()) + self.assertEqual(original, restored) + + def test_equality(self): + np1 = NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.5, 0.5]) + np2 = NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.5, 0.5]) + self.assertEqual(np1, np2) + + def test_inequality(self): + np1 = NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.5, 0.5]) + np2 = NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.3, 0.7]) + self.assertNotEqual(np1, np2) + + def test_invalid_distribution_out_of_range(self): + with self.assertRaises(ValueError): + NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[1.1, -0.1]) + + def test_invalid_distribution_wrong_sum(self): + with self.assertRaises(ValueError): + NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.3, 0.3]) + + def test_invalid_distribution_wrong_length(self): + with self.assertRaises(ValueError): + NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.3, 0.3, 0.4]) + + +class NodePropertiesTest(unittest.TestCase): + + def setUp(self): + self.np = NodeProperty(property='Place', values=['Urban', 'Rural'], initial_distribution=[0.4, 0.6]) + self.nps = NodeProperties() + self.nps.node_properties = [self.np] + + self.new_np = NodeProperty(property='Risk', values=['High', 'Low'], initial_distribution=[0.2, 0.8]) + self.new_np_v2 = NodeProperty(property='Risk', values=['High', 'Low'], initial_distribution=[0.5, 0.5]) + + def test_has_node_property(self): + self.assertTrue(self.nps.has_node_property(property_key='Place')) + self.assertFalse(self.nps.has_node_property(property_key='Risk')) + + def test_get_node_property(self): + np = self.nps.get_node_property(property_key='Place') + self.assertEqual(np, self.np) + + def test_get_nonexistent_raises(self): + with self.assertRaises(NodeProperties.NoSuchNodePropertyException): + self.nps.get_node_property(property_key='Risk') + + def test_remove_node_property(self): + self.assertEqual(len(self.nps), 1) + self.nps.remove_node_property(property_key='Place') + self.assertEqual(len(self.nps), 0) + + def test_remove_nonexistent_is_no_op(self): + self.assertEqual(len(self.nps), 1) + self.nps.remove_node_property(property_key='Risk') + self.assertEqual(len(self.nps), 1) + + def test_add_new(self): + self.nps.add(node_property=self.new_np) + self.assertEqual(len(self.nps), 2) + self.assertEqual(self.nps.get_node_property('Risk'), self.new_np) + + def test_add_duplicate_raises(self): + with self.assertRaises(NodeProperties.DuplicateNodePropertyException): + self.nps.add(node_property=self.np, overwrite=False) + + def test_add_with_overwrite(self): + self.nps.add(node_property=self.new_np) + self.assertEqual(len(self.nps), 2) + self.nps.add(node_property=self.new_np_v2, overwrite=True) + self.assertEqual(len(self.nps), 2) + self.assertEqual(self.nps.get_node_property('Risk'), self.new_np_v2) + + def test_add_parameter_raises(self): + with self.assertRaises(NotImplementedError): + self.nps.add_parameter("key", "value") + + def test_to_dict(self): + result = self.nps.to_dict() + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["Property"], "Place") + + def test_len(self): + self.assertEqual(len(self.nps), 1) + self.nps.add(node_property=self.new_np) + self.assertEqual(len(self.nps), 2) + + def test_getitem(self): + self.assertEqual(self.nps[0], self.np) From 7df2f6bdaa2252e9f184a78e9687c040e721f6a3 Mon Sep 17 00:00:00 2001 From: Svetlana Titova Date: Thu, 11 Jun 2026 18:25:22 -0700 Subject: [PATCH 5/6] updating reference file --- .../demographics_node_properties.json | 875 ++++++++++++++++++ tests/test_demographics.py | 6 +- 2 files changed, 877 insertions(+), 4 deletions(-) create mode 100644 tests/data/demographics/demographics_node_properties.json diff --git a/tests/data/demographics/demographics_node_properties.json b/tests/data/demographics/demographics_node_properties.json new file mode 100644 index 00000000..573196e3 --- /dev/null +++ b/tests/data/demographics/demographics_node_properties.json @@ -0,0 +1,875 @@ +{ + "Metadata": { + "DateCreated": "Fri Dec 9, 2016", + "Author": "dbridenbecker", + "IdReference": "0", + "NodeCount": 100 + }, + "Defaults" : { + "NodeAttributes": { + "Altitude": 0, + "Airport": 0, + "Region": 1, + "Seaport": 0, + "BirthRate": 0.0001 + }, + "IndividualAttributes": { + "AgeDistributionFlag": 2, + "AgeDistribution1": 10950, + "AgeDistribution2": 7300, + "PrevalenceDistributionFlag": 0, + "PrevalenceDistribution1": 0.05, + "PrevalenceDistribution2": 0.0, + "SusceptibilityDistributionFlag": 0, + "SusceptibilityDistribution1": 1, + "SusceptibilityDistribution2": 0, + "RiskDistributionFlag": 0, + "RiskDistribution1": 1, + "RiskDistribution2": 0, + "MigrationHeterogeneityDistributionFlag": 0, + "MigrationHeterogeneityDistribution1": 1, + "MigrationHeterogeneityDistribution2": 0, + "MortalityDistribution": { + "NumDistributionAxes": 2, + "AxisNames": [ "gender", "age" ], + "AxisUnits": [ "male=0,female=1", "years" ], + "AxisScaleFactors": [ 1, 365 ], + "NumPopulationGroups": [ 2, 3 ], + "PopulationGroups": [ + [ 0, 1 ], + [ 0, 100, 2000 ] + ], + "ResultUnits": "annual deaths per 1000 individuals", + "ResultScaleFactor": 0.00000273972602739726027397260273973, + "ResultValues": [ + [ 0, 20.0000035, 400.00007 ], + [ 0, 20.0000035, 400.00007 ] + ] + } + } + }, + "NodeProperties": [ + { + "Property": "Place", + "Values": [ "RURAL", "URBAN"], + "Initial_Distribution": [ 0.6, 0.4] + }, + { + "Property": "Risk", + "Values": [ "HIGH", "MED", "LOW"], + "Initial_Distribution": [ 0.1, 0.2, 0.7 ] + }, + { + "Property": "InterventionStatus", + "Values": [ "NONE", "SPRAYED_A", "SPRAYED_B", "FENCE_AND_TRAP" ], + "Initial_Distribution": [ 0.4, 0.2, 0.3, 0.1 ] + } + ], + "Nodes": [ + { + "NodeID": 1, + "NodeAttributes": { + "Latitude": 47.7601, + "Longitude": -122.2038, + "InitialPopulation": 10 + } + }, + { + "NodeID": 2, + "NodeAttributes": { + "Latitude": 47.7602, + "Longitude": -122.2038, + "InitialPopulation": 20 + } + }, + { + "NodeID": 3, + "NodeAttributes": { + "Latitude": 47.7603, + "Longitude": -122.2038, + "InitialPopulation": 30 + } + }, + { + "NodeID": 4, + "NodeAttributes": { + "Latitude": 47.7604, + "Longitude": -122.2038, + "InitialPopulation": 40 + } + }, + { + "NodeID": 5, + "NodeAttributes": { + "Latitude": 47.7605, + "Longitude": -122.2038, + "InitialPopulation": 50 + } + }, + { + "NodeID": 6, + "NodeAttributes": { + "Latitude": 47.7606, + "Longitude": -122.2038, + "InitialPopulation": 60 + } + }, + { + "NodeID": 7, + "NodeAttributes": { + "Latitude": 47.7607, + "Longitude": -122.2038, + "InitialPopulation": 70 + } + }, + { + "NodeID": 8, + "NodeAttributes": { + "Latitude": 47.7608, + "Longitude": -122.2038, + "InitialPopulation": 80 + } + }, + { + "NodeID": 9, + "NodeAttributes": { + "Latitude": 47.7609, + "Longitude": -122.2038, + "InitialPopulation": 90 + } + }, + { + "NodeID": 10, + "NodeAttributes": { + "Latitude": 47.7610, + "Longitude": -122.2038, + "InitialPopulation": 100 + } + }, + { + "NodeID": 11, + "NodeAttributes": { + "Latitude": 47.7611, + "Longitude": -122.2038, + "InitialPopulation": 110 + } + }, + { + "NodeID": 12, + "NodeAttributes": { + "Latitude": 47.7612, + "Longitude": -122.2038, + "InitialPopulation": 120 + } + }, + { + "NodeID": 13, + "NodeAttributes": { + "Latitude": 47.7613, + "Longitude": -122.2038, + "InitialPopulation": 130 + } + }, + { + "NodeID": 14, + "NodeAttributes": { + "Latitude": 47.7614, + "Longitude": -122.2038, + "InitialPopulation": 140 + } + }, + { + "NodeID": 15, + "NodeAttributes": { + "Latitude": 47.7615, + "Longitude": -122.2038, + "InitialPopulation": 150 + } + }, + { + "NodeID": 16, + "NodeAttributes": { + "Latitude": 47.7616, + "Longitude": -122.2038, + "InitialPopulation": 160 + } + }, + { + "NodeID": 17, + "NodeAttributes": { + "Latitude": 47.7617, + "Longitude": -122.2038, + "InitialPopulation": 170 + } + }, + { + "NodeID": 18, + "NodeAttributes": { + "Latitude": 47.7618, + "Longitude": -122.2038, + "InitialPopulation": 180 + } + }, + { + "NodeID": 19, + "NodeAttributes": { + "Latitude": 47.7619, + "Longitude": -122.2038, + "InitialPopulation": 190 + } + }, + { + "NodeID": 20, + "NodeAttributes": { + "Latitude": 47.7620, + "Longitude": -122.2038, + "InitialPopulation": 200 + } + }, + { + "NodeID": 21, + "NodeAttributes": { + "Latitude": 47.7621, + "Longitude": -122.2038, + "InitialPopulation": 210 + } + }, + { + "NodeID": 22, + "NodeAttributes": { + "Latitude": 47.7622, + "Longitude": -122.2038, + "InitialPopulation": 220 + } + }, + { + "NodeID": 23, + "NodeAttributes": { + "Latitude": 47.7623, + "Longitude": -122.2038, + "InitialPopulation": 230 + } + }, + { + "NodeID": 24, + "NodeAttributes": { + "Latitude": 47.7624, + "Longitude": -122.2038, + "InitialPopulation": 240 + } + }, + { + "NodeID": 25, + "NodeAttributes": { + "Latitude": 47.7625, + "Longitude": -122.2038, + "InitialPopulation": 250 + } + }, + { + "NodeID": 26, + "NodeAttributes": { + "Latitude": 47.7626, + "Longitude": -122.2038, + "InitialPopulation": 260 + } + }, + { + "NodeID": 27, + "NodeAttributes": { + "Latitude": 47.7627, + "Longitude": -122.2038, + "InitialPopulation": 270 + } + }, + { + "NodeID": 28, + "NodeAttributes": { + "Latitude": 47.7628, + "Longitude": -122.2038, + "InitialPopulation": 280 + } + }, + { + "NodeID": 29, + "NodeAttributes": { + "Latitude": 47.7629, + "Longitude": -122.2038, + "InitialPopulation": 290 + } + }, + { + "NodeID": 30, + "NodeAttributes": { + "Latitude": 47.7630, + "Longitude": -122.2038, + "InitialPopulation": 300 + } + }, + { + "NodeID": 31, + "NodeAttributes": { + "Latitude": 47.7631, + "Longitude": -122.2038, + "InitialPopulation": 310 + } + }, + { + "NodeID": 32, + "NodeAttributes": { + "Latitude": 47.7632, + "Longitude": -122.2038, + "InitialPopulation": 320 + } + }, + { + "NodeID": 33, + "NodeAttributes": { + "Latitude": 47.7633, + "Longitude": -122.2038, + "InitialPopulation": 330 + } + }, + { + "NodeID": 34, + "NodeAttributes": { + "Latitude": 47.7634, + "Longitude": -122.2038, + "InitialPopulation": 340 + } + }, + { + "NodeID": 35, + "NodeAttributes": { + "Latitude": 47.7635, + "Longitude": -122.2038, + "InitialPopulation": 350 + } + }, + { + "NodeID": 36, + "NodeAttributes": { + "Latitude": 47.7636, + "Longitude": -122.2038, + "InitialPopulation": 360 + } + }, + { + "NodeID": 37, + "NodeAttributes": { + "Latitude": 47.7637, + "Longitude": -122.2038, + "InitialPopulation": 370 + } + }, + { + "NodeID": 38, + "NodeAttributes": { + "Latitude": 47.7638, + "Longitude": -122.2038, + "InitialPopulation": 380 + } + }, + { + "NodeID": 39, + "NodeAttributes": { + "Latitude": 47.7639, + "Longitude": -122.2038, + "InitialPopulation": 390 + } + }, + { + "NodeID": 40, + "NodeAttributes": { + "Latitude": 47.7640, + "Longitude": -122.2038, + "InitialPopulation": 400 + } + }, + { + "NodeID": 41, + "NodeAttributes": { + "Latitude": 47.7641, + "Longitude": -122.2038, + "InitialPopulation": 410 + } + }, + { + "NodeID": 42, + "NodeAttributes": { + "Latitude": 47.7642, + "Longitude": -122.2038, + "InitialPopulation": 420 + } + }, + { + "NodeID": 43, + "NodeAttributes": { + "Latitude": 47.7643, + "Longitude": -122.2038, + "InitialPopulation": 430 + } + }, + { + "NodeID": 44, + "NodeAttributes": { + "Latitude": 47.7644, + "Longitude": -122.2038, + "InitialPopulation": 440 + } + }, + { + "NodeID": 45, + "NodeAttributes": { + "Latitude": 47.7645, + "Longitude": -122.2038, + "InitialPopulation": 450 + } + }, + { + "NodeID": 46, + "NodeAttributes": { + "Latitude": 47.7646, + "Longitude": -122.2038, + "InitialPopulation": 460 + } + }, + { + "NodeID": 47, + "NodeAttributes": { + "Latitude": 47.7647, + "Longitude": -122.2038, + "InitialPopulation": 470 + } + }, + { + "NodeID": 48, + "NodeAttributes": { + "Latitude": 47.7648, + "Longitude": -122.2038, + "InitialPopulation": 480 + } + }, + { + "NodeID": 49, + "NodeAttributes": { + "Latitude": 47.7649, + "Longitude": -122.2038, + "InitialPopulation": 490 + } + }, + { + "NodeID": 50, + "NodeAttributes": { + "Latitude": 47.7650, + "Longitude": -122.2038, + "InitialPopulation": 500 + } + }, + { + "NodeID": 51, + "NodeAttributes": { + "Latitude": 47.7651, + "Longitude": -122.2038, + "InitialPopulation": 510 + } + }, + { + "NodeID": 52, + "NodeAttributes": { + "Latitude": 47.7652, + "Longitude": -122.2038, + "InitialPopulation": 520 + } + }, + { + "NodeID": 53, + "NodeAttributes": { + "Latitude": 47.7653, + "Longitude": -122.2038, + "InitialPopulation": 530 + } + }, + { + "NodeID": 54, + "NodeAttributes": { + "Latitude": 47.7654, + "Longitude": -122.2038, + "InitialPopulation": 540 + } + }, + { + "NodeID": 55, + "NodeAttributes": { + "Latitude": 47.7655, + "Longitude": -122.2038, + "InitialPopulation": 550 + } + }, + { + "NodeID": 56, + "NodeAttributes": { + "Latitude": 47.7656, + "Longitude": -122.2038, + "InitialPopulation": 560 + } + }, + { + "NodeID": 57, + "NodeAttributes": { + "Latitude": 47.7657, + "Longitude": -122.2038, + "InitialPopulation": 570 + } + }, + { + "NodeID": 58, + "NodeAttributes": { + "Latitude": 47.7658, + "Longitude": -122.2038, + "InitialPopulation": 580 + } + }, + { + "NodeID": 59, + "NodeAttributes": { + "Latitude": 47.7659, + "Longitude": -122.2038, + "InitialPopulation": 590 + } + }, + { + "NodeID": 60, + "NodeAttributes": { + "Latitude": 47.7660, + "Longitude": -122.2038, + "InitialPopulation": 600 + } + }, + { + "NodeID": 61, + "NodeAttributes": { + "Latitude": 47.7661, + "Longitude": -122.2038, + "InitialPopulation": 610 + } + }, + { + "NodeID": 62, + "NodeAttributes": { + "Latitude": 47.7662, + "Longitude": -122.2038, + "InitialPopulation": 620 + } + }, + { + "NodeID": 63, + "NodeAttributes": { + "Latitude": 47.7663, + "Longitude": -122.2038, + "InitialPopulation": 630 + } + }, + { + "NodeID": 64, + "NodeAttributes": { + "Latitude": 47.7664, + "Longitude": -122.2038, + "InitialPopulation": 640 + } + }, + { + "NodeID": 65, + "NodeAttributes": { + "Latitude": 47.7665, + "Longitude": -122.2038, + "InitialPopulation": 650 + } + }, + { + "NodeID": 66, + "NodeAttributes": { + "Latitude": 47.7666, + "Longitude": -122.2038, + "InitialPopulation": 660 + } + }, + { + "NodeID": 67, + "NodeAttributes": { + "Latitude": 47.7667, + "Longitude": -122.2038, + "InitialPopulation": 670 + } + }, + { + "NodeID": 68, + "NodeAttributes": { + "Latitude": 47.7668, + "Longitude": -122.2038, + "InitialPopulation": 680 + } + }, + { + "NodeID": 69, + "NodeAttributes": { + "Latitude": 47.7669, + "Longitude": -122.2038, + "InitialPopulation": 690 + } + }, + { + "NodeID": 70, + "NodeAttributes": { + "Latitude": 47.7670, + "Longitude": -122.2038, + "InitialPopulation": 700 + } + }, + { + "NodeID": 71, + "NodeAttributes": { + "Latitude": 47.7671, + "Longitude": -122.2038, + "InitialPopulation": 710 + } + }, + { + "NodeID": 72, + "NodeAttributes": { + "Latitude": 47.7672, + "Longitude": -122.2038, + "InitialPopulation": 720 + } + }, + { + "NodeID": 73, + "NodeAttributes": { + "Latitude": 47.7673, + "Longitude": -122.2038, + "InitialPopulation": 730 + } + }, + { + "NodeID": 74, + "NodeAttributes": { + "Latitude": 47.7674, + "Longitude": -122.2038, + "InitialPopulation": 740 + } + }, + { + "NodeID": 75, + "NodeAttributes": { + "Latitude": 47.7675, + "Longitude": -122.2038, + "InitialPopulation": 750 + } + }, + { + "NodeID": 76, + "NodeAttributes": { + "Latitude": 47.7676, + "Longitude": -122.2038, + "InitialPopulation": 760 + } + }, + { + "NodeID": 77, + "NodeAttributes": { + "Latitude": 47.7677, + "Longitude": -122.2038, + "InitialPopulation": 770 + } + }, + { + "NodeID": 78, + "NodeAttributes": { + "Latitude": 47.7678, + "Longitude": -122.2038, + "InitialPopulation": 780 + } + }, + { + "NodeID": 79, + "NodeAttributes": { + "Latitude": 47.7679, + "Longitude": -122.2038, + "InitialPopulation": 790 + } + }, + { + "NodeID": 80, + "NodeAttributes": { + "Latitude": 47.7680, + "Longitude": -122.2038, + "InitialPopulation": 800 + } + }, + { + "NodeID": 81, + "NodeAttributes": { + "Latitude": 47.7681, + "Longitude": -122.2038, + "InitialPopulation": 810 + } + }, + { + "NodeID": 82, + "NodeAttributes": { + "Latitude": 47.7682, + "Longitude": -122.2038, + "InitialPopulation": 820 + } + }, + { + "NodeID": 83, + "NodeAttributes": { + "Latitude": 47.7683, + "Longitude": -122.2038, + "InitialPopulation": 830 + } + }, + { + "NodeID": 84, + "NodeAttributes": { + "Latitude": 47.7684, + "Longitude": -122.2038, + "InitialPopulation": 840 + } + }, + { + "NodeID": 85, + "NodeAttributes": { + "Latitude": 47.7685, + "Longitude": -122.2038, + "InitialPopulation": 850 + } + }, + { + "NodeID": 86, + "NodeAttributes": { + "Latitude": 47.7686, + "Longitude": -122.2038, + "InitialPopulation": 860 + } + }, + { + "NodeID": 87, + "NodeAttributes": { + "Latitude": 47.7687, + "Longitude": -122.2038, + "InitialPopulation": 870 + } + }, + { + "NodeID": 88, + "NodeAttributes": { + "Latitude": 47.7688, + "Longitude": -122.2038, + "InitialPopulation": 880 + } + }, + { + "NodeID": 89, + "NodeAttributes": { + "Latitude": 47.7689, + "Longitude": -122.2038, + "InitialPopulation": 890 + } + }, + { + "NodeID": 90, + "NodeAttributes": { + "Latitude": 47.7690, + "Longitude": -122.2038, + "InitialPopulation": 900 + } + }, + { + "NodeID": 91, + "NodeAttributes": { + "Latitude": 47.7691, + "Longitude": -122.2038, + "InitialPopulation": 910 + } + }, + { + "NodeID": 92, + "NodeAttributes": { + "Latitude": 47.7692, + "Longitude": -122.2038, + "InitialPopulation": 920 + } + }, + { + "NodeID": 93, + "NodeAttributes": { + "Latitude": 47.7693, + "Longitude": -122.2038, + "InitialPopulation": 930 + } + }, + { + "NodeID": 94, + "NodeAttributes": { + "Latitude": 47.7694, + "Longitude": -122.2038, + "InitialPopulation": 940 + } + }, + { + "NodeID": 95, + "NodeAttributes": { + "Latitude": 47.7695, + "Longitude": -122.2038, + "InitialPopulation": 950 + } + }, + { + "NodeID": 96, + "NodeAttributes": { + "Latitude": 47.7696, + "Longitude": -122.2038, + "InitialPopulation": 960 + } + }, + { + "NodeID": 97, + "NodeAttributes": { + "Latitude": 47.7697, + "Longitude": -122.2038, + "InitialPopulation": 970 + } + }, + { + "NodeID": 98, + "NodeAttributes": { + "Latitude": 47.7698, + "Longitude": -122.2038, + "InitialPopulation": 980 + } + }, + { + "NodeID": 99, + "NodeAttributes": { + "Latitude": 47.7699, + "Longitude": -122.2038, + "InitialPopulation": 990 + } + }, + { + "NodeID": 100, + "NodeAttributes": { + "Latitude": 47.7700, + "Longitude": -122.2038, + "InitialPopulation": 1000, + "NodePropertyValues" : + [ + "Place:RURAL", + "InterventionStatus:SPRAYED_B" + ] + } + } + ] +} diff --git a/tests/test_demographics.py b/tests/test_demographics.py index f9d7e7ef..1f8a77e5 100644 --- a/tests/test_demographics.py +++ b/tests/test_demographics.py @@ -1512,10 +1512,8 @@ def test_set_node_property_values_multiple(self): def test_roundtrip_with_reference_file(self): ref_path = os.path.join( - "C:\\", "github", "emod1", "regression", "generic", - "57_Generic_NodeProperties", "demographics_node_properties.json") - if not os.path.isfile(ref_path): - self.skipTest("Reference demographics file not found") + os.path.dirname(__file__), "data", "demographics", + "demographics_node_properties.json") import warnings try: with warnings.catch_warnings(): From 701adc3a5240e675e5bbf1fb82f74ce4f2c96a2a Mon Sep 17 00:00:00 2001 From: Svetlana Titova Date: Thu, 11 Jun 2026 19:37:01 -0700 Subject: [PATCH 6/6] fixing linting issues --- emod_api/campaign.py | 7 +++++-- emod_api/demographics/demographics.py | 3 +-- emod_api/demographics/demographics_base.py | 8 ++++---- emod_api/demographics/implicit_functions.py | 3 +++ emod_api/utils/emod_enum.py | 3 ++- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/emod_api/campaign.py b/emod_api/campaign.py index b8de5055..86bd6dee 100644 --- a/emod_api/campaign.py +++ b/emod_api/campaign.py @@ -151,7 +151,6 @@ def add(event, note: str = None): campaign_dict["Events"].append(event) - def save(filename: str = "campaign.json"): """Save the accumulated campaign events to a JSON file. @@ -279,6 +278,7 @@ def get_recv_trigger(trigger, old=use_old_adhoc_handling): individual_events_listened.append(trigger) return trigger + def set_listened_node_event(event: str) -> str: """Register a node-level event as listened to. @@ -297,6 +297,7 @@ def set_listened_node_event(event: str) -> str: node_events_listened.append(event) return event + def set_listened_coordinator_event(event: str) -> str: """Register a coordinator-level event as listened to. @@ -315,6 +316,7 @@ def set_listened_coordinator_event(event: str) -> str: coordinator_events_listened.append(event) return event + def get_send_trigger(trigger, old=use_old_adhoc_handling): """Register an individual-level event as broadcast. @@ -330,6 +332,7 @@ def get_send_trigger(trigger, old=use_old_adhoc_handling): individual_events_broadcast.append(trigger) return trigger + def set_broadcast_node_event(event: str) -> str: """Register a node-level event as broadcast. @@ -348,6 +351,7 @@ def set_broadcast_node_event(event: str) -> str: node_events_broadcast.append(event) return event + def set_broadcast_coordinator_event(event: str) -> str: """Register a coordinator-level event as broadcast. @@ -365,4 +369,3 @@ def set_broadcast_coordinator_event(event: str) -> str: raise ValueError("Event name must not be None or empty.") coordinator_events_broadcast.append(event) return event - diff --git a/emod_api/demographics/demographics.py b/emod_api/demographics/demographics.py index 31336295..247d9a4e 100644 --- a/emod_api/demographics/demographics.py +++ b/emod_api/demographics/demographics.py @@ -7,7 +7,7 @@ from emod_api.demographics.demographics_base import DemographicsBase from emod_api.demographics.node import Node -from emod_api.demographics.properties_and_attributes import NodeAttributes, NodeProperty, NodeProperties +from emod_api.demographics.properties_and_attributes import NodeAttributes, NodeProperty, NodeProperties # noqa: F401 from emod_api.demographics.service import service @@ -34,7 +34,6 @@ def __init__(self, nodes: list[Node], idref: str = None, default_node: Node = No if set_defaults: pass - def to_file(self, path: Union[str, Path] = "demographics.json", indent: int = 4) -> None: """ Write the Demographics object to an EMOD demograhpics json file. diff --git a/emod_api/demographics/demographics_base.py b/emod_api/demographics/demographics_base.py index 76cd0e81..9fbd7a07 100644 --- a/emod_api/demographics/demographics_base.py +++ b/emod_api/demographics/demographics_base.py @@ -274,7 +274,7 @@ def to_dict(self) -> dict: demographics_dict["NodeProperties"] = self.node_properties.to_dict() return demographics_dict - def set_birth_rate(self, rate: float, node_ids: list[int] = None, birth_rate_dependence: Union[str, BirthRateDependence] = "POPULATION_DEP_RATE"): + def set_birth_rate(self, rate: float, node_ids: list[int] = None, birth_rate_dependence: Union[str, BirthRateDependence] = "POPULATION_DEP_RATE"): """ Sets the BirthRate on the target node(s) and configures how EMOD interprets it via Birth_Rate_Dependence. Automatically registers the corresponding config implicit. @@ -305,7 +305,7 @@ def set_birth_rate(self, rate: float, node_ids: list[int] = None, birth_rate_dep max: 8 (equivalent to 1 pregnancy per year for every possible mother in the population) """ - from emod_api.demographics.implicit_functions import ( _set_birth_rate_dependence) + from emod_api.demographics.implicit_functions import _set_birth_rate_dependence if not isinstance(birth_rate_dependence, BirthRateDependence): try: @@ -319,8 +319,8 @@ def set_birth_rate(self, rate: float, node_ids: list[int] = None, birth_rate_dep if rate > 1000: raise ValueError(f"Births per 1000 people per year cannot exceed 1000. Provided rate: {rate}") rate = rate / 365 / 1000 # converting to per day per 1000 people - elif (birth_rate_dependence == BirthRateDependence.DEMOGRAPHIC_DEP_RATE or - birth_rate_dependence == BirthRateDependence.INDIVIDUAL_PREGNANCIES): + elif birth_rate_dependence in (BirthRateDependence.DEMOGRAPHIC_DEP_RATE, + BirthRateDependence.INDIVIDUAL_PREGNANCIES): if rate > 8: raise ValueError(f"Births per 8 fertile women per year cannot exceed 8. Provided rate: {rate}") rate = rate / 365 / 8 # converting to per day per 8 fertile women diff --git a/emod_api/demographics/implicit_functions.py b/emod_api/demographics/implicit_functions.py index b079f72e..f6b02a95 100644 --- a/emod_api/demographics/implicit_functions.py +++ b/emod_api/demographics/implicit_functions.py @@ -1,3 +1,6 @@ +from emod_api.utils.emod_enum import BirthRateDependence + + # Migration def _set_migration_model_fixed_rate(config): diff --git a/emod_api/utils/emod_enum.py b/emod_api/utils/emod_enum.py index f34563c5..3a3c46c5 100644 --- a/emod_api/utils/emod_enum.py +++ b/emod_api/utils/emod_enum.py @@ -1,5 +1,6 @@ from emod_api.utils.str_enum import StrEnum + class BirthRateDependence(StrEnum): """How BirthRate from the demographics file is interpreted by EMOD. @@ -10,4 +11,4 @@ class BirthRateDependence(StrEnum): FIXED_BIRTH_RATE = "FIXED_BIRTH_RATE" POPULATION_DEP_RATE = "POPULATION_DEP_RATE" DEMOGRAPHIC_DEP_RATE = "DEMOGRAPHIC_DEP_RATE" - INDIVIDUAL_PREGNANCIES = "INDIVIDUAL_PREGNANCIES" \ No newline at end of file + INDIVIDUAL_PREGNANCIES = "INDIVIDUAL_PREGNANCIES"