diff --git a/trx/tests/test_cli.py b/trx/tests/test_cli.py index 201ff03..ddb70bd 100644 --- a/trx/tests/test_cli.py +++ b/trx/tests/test_cli.py @@ -30,6 +30,27 @@ fetch_data(get_testing_files_dict(), keys=["DSI.zip", "trx_from_scratch.zip"]) +def _normalize_dtype_dict(dtype_dict): + """Normalize dtype dict to use explicit little-endian byte order. + + On little-endian systems, numpy may use '=' (native) or '<' (explicit) + interchangeably. This normalizes all dtypes to '<' for consistent comparison. + """ + normalized = {} + for key, value in dtype_dict.items(): + if isinstance(value, dict): + normalized[key] = _normalize_dtype_dict(value) + elif isinstance(value, np.dtype): + # Normalize to little-endian for multi-byte types + if value.byteorder == "=" and value.itemsize > 1: + normalized[key] = value.newbyteorder("<") + else: + normalized[key] = value + else: + normalized[key] = value + return normalized + + # Tests for standalone CLI commands (trx_* commands) class TestStandaloneCommands: """Tests for standalone CLI commands.""" @@ -391,7 +412,13 @@ def test_execution_manipulate_trx_datatype(self): "groups": {"g_AF_L": np.dtype("int32"), "g_AF_R": np.dtype("int32")}, } - assert DeepDiff(trx.get_dtype_dict(), expected_dtype) == {} + assert ( + DeepDiff( + trx.get_dtype_dict(), + _normalize_dtype_dict(expected_dtype), + ) + == {} + ) trx.close() generated_dtype = { @@ -416,5 +443,11 @@ def test_execution_manipulate_trx_datatype(self): out_gen_path = os.path.join(tmp_dir, "generated.trx") manipulate_trx_datatype(expected_trx, out_gen_path, generated_dtype) trx = tmm.load(out_gen_path) - assert DeepDiff(trx.get_dtype_dict(), generated_dtype) == {} + assert ( + DeepDiff( + trx.get_dtype_dict(), + _normalize_dtype_dict(generated_dtype), + ) + == {} + ) trx.close() diff --git a/trx/tests/test_memmap.py b/trx/tests/test_memmap.py index 6028353..ed1f581 100644 --- a/trx/tests/test_memmap.py +++ b/trx/tests/test_memmap.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- import os +import tempfile +import zipfile from nibabel.streamlines import LazyTractogram from nibabel.streamlines.tests.test_tractogram import make_dummy_streamline @@ -341,7 +343,37 @@ def test_copy_fixed_arrays_from(): def test_initialize_empty_trx(): - pass + """Test creating, saving, and loading an empty TRX file.""" + trx = tmm.TrxFile() + assert trx.header["NB_STREAMLINES"] == 0 + assert trx.header["NB_VERTICES"] == 0 + assert len(trx.streamlines) == 0 + + with tempfile.TemporaryDirectory() as tmp_dir: + out_path = os.path.join(tmp_dir, "empty.trx") + tmm.save(trx, out_path) + + assert os.path.exists(out_path) + file_size = os.path.getsize(out_path) + assert file_size < 500 # Should be very small, just header.json in zip + + with zipfile.ZipFile(out_path, "r") as zf: + filenames = [info.filename for info in zf.filelist] + assert "header.json" in filenames + positions_files = [f for f in filenames if f.startswith("positions")] + offsets_files = [f for f in filenames if f.startswith("offsets")] + assert len(positions_files) == 0 + assert len(offsets_files) == 0 + + loaded_trx = tmm.load(out_path) + assert loaded_trx.header["NB_STREAMLINES"] == 0 + assert loaded_trx.header["NB_VERTICES"] == 0 + assert len(loaded_trx.streamlines) == 0 + assert len(loaded_trx.groups) == 0 + assert len(loaded_trx.data_per_streamline) == 0 + assert len(loaded_trx.data_per_vertex) == 0 + assert len(loaded_trx.data_per_group) == 0 + loaded_trx.close() def test_create_trx_from_pointer(): @@ -421,8 +453,11 @@ def test__ensure_little_endian(dtype, test_value): # Result should be little-endian (or native if system is little-endian) assert result.dtype.byteorder in ("<", "=", "|") - # Values should be preserved - assert result[0] == test_value + # Values should be preserved (use isclose for float types due to precision) + if np.issubdtype(dtype, np.floating): + assert np.isclose(result[0], test_value, rtol=1e-6) + else: + assert result[0] == test_value def test__ensure_little_endian_big_endian_input(): diff --git a/trx/trx_file_memmap.py b/trx/trx_file_memmap.py index 1619c1c..26b578c 100644 --- a/trx/trx_file_memmap.py +++ b/trx/trx_file_memmap.py @@ -935,23 +935,25 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901 json.dump(tmp_header, out_json) out_json.close() - positions_filename = _generate_filename_from_data( - to_dump, os.path.join(tmp_dir.name, "positions") - ) - _ensure_little_endian(to_dump).tofile(positions_filename) - - if not self._copy_safe: - to_dump = _append_last_offsets( - self.streamlines.copy()._offsets, self.header["NB_VERTICES"] + # Only write positions and offsets if TRX is not empty + if tmp_header["NB_STREAMLINES"] > 0 and tmp_header["NB_VERTICES"] > 0: + positions_filename = _generate_filename_from_data( + to_dump, os.path.join(tmp_dir.name, "positions") ) - else: - to_dump = _append_last_offsets( - self.streamlines._offsets, self.header["NB_VERTICES"] + _ensure_little_endian(to_dump).tofile(positions_filename) + + if not self._copy_safe: + to_dump = _append_last_offsets( + self.streamlines.copy()._offsets, self.header["NB_VERTICES"] + ) + else: + to_dump = _append_last_offsets( + self.streamlines._offsets, self.header["NB_VERTICES"] + ) + offsets_filename = _generate_filename_from_data( + self.streamlines._offsets, os.path.join(tmp_dir.name, "offsets") ) - offsets_filename = _generate_filename_from_data( - self.streamlines._offsets, os.path.join(tmp_dir.name, "offsets") - ) - _ensure_little_endian(to_dump).tofile(offsets_filename) + _ensure_little_endian(to_dump).tofile(offsets_filename) if len(self.data_per_vertex.keys()) > 0: os.mkdir(os.path.join(tmp_dir.name, "dpv/")) @@ -1243,9 +1245,13 @@ def _create_trx_from_pointer( # noqa: C901 TrxFile A TrxFile constructed from the pointer provided. """ - # TODO support empty positions, using optional tag? trx = TrxFile() trx.header = header + + # Handle empty TRX files early - no positions/offsets to load + if header["NB_STREAMLINES"] == 0 and header["NB_VERTICES"] == 0: + return trx + positions, offsets = None, None for elem_filename in dict_pointer_size.keys(): if root_zip: