Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions trx/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@
fetch_data(get_testing_files_dict(), keys=["DSI.zip", "trx_from_scratch.zip"])


def _normalize_dtype_dict(dtype_dict):
"""Normalize dtype dict to use explicit little-endian byte order.

On little-endian systems, numpy may use '=' (native) or '<' (explicit)
interchangeably. This normalizes all dtypes to '<' for consistent comparison.
"""
normalized = {}
for key, value in dtype_dict.items():
if isinstance(value, dict):
normalized[key] = _normalize_dtype_dict(value)
elif isinstance(value, np.dtype):
# Normalize to little-endian for multi-byte types
if value.byteorder == "=" and value.itemsize > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that I understand this. If we are on a big endian system, couldn't this still be set to "=" (native)? And this would change it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal here is to make sure we are in little endian whatever the machine (big endian system or little endian system).

for now, we do not want to stay in native.

If we are on a big endian system, couldn't this still be set to "=" (native)?

yes, but internally, we switch it to little endian

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, looks good to me! I wonder if @frheault could give it another look, to make sure that I have not missed anything.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good, I tested using some tracking scripts with 0 seeds and it now works great for saving empty tractogram (lazy saving too) and loading it.

normalized[key] = value.newbyteorder("<")
else:
normalized[key] = value
else:
normalized[key] = value
return normalized


# Tests for standalone CLI commands (trx_* commands)
class TestStandaloneCommands:
"""Tests for standalone CLI commands."""
Expand Down Expand Up @@ -391,7 +412,13 @@ def test_execution_manipulate_trx_datatype(self):
"groups": {"g_AF_L": np.dtype("int32"), "g_AF_R": np.dtype("int32")},
}

assert DeepDiff(trx.get_dtype_dict(), expected_dtype) == {}
assert (
DeepDiff(
trx.get_dtype_dict(),
_normalize_dtype_dict(expected_dtype),
)
== {}
)
trx.close()

generated_dtype = {
Expand All @@ -416,5 +443,11 @@ def test_execution_manipulate_trx_datatype(self):
out_gen_path = os.path.join(tmp_dir, "generated.trx")
manipulate_trx_datatype(expected_trx, out_gen_path, generated_dtype)
trx = tmm.load(out_gen_path)
assert DeepDiff(trx.get_dtype_dict(), generated_dtype) == {}
assert (
DeepDiff(
trx.get_dtype_dict(),
_normalize_dtype_dict(generated_dtype),
)
== {}
)
trx.close()
41 changes: 38 additions & 3 deletions trx/tests/test_memmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-

import os
import tempfile
import zipfile

from nibabel.streamlines import LazyTractogram
from nibabel.streamlines.tests.test_tractogram import make_dummy_streamline
Expand Down Expand Up @@ -341,7 +343,37 @@ def test_copy_fixed_arrays_from():


def test_initialize_empty_trx():
pass
"""Test creating, saving, and loading an empty TRX file."""
trx = tmm.TrxFile()
assert trx.header["NB_STREAMLINES"] == 0
assert trx.header["NB_VERTICES"] == 0
assert len(trx.streamlines) == 0

with tempfile.TemporaryDirectory() as tmp_dir:
out_path = os.path.join(tmp_dir, "empty.trx")
tmm.save(trx, out_path)

assert os.path.exists(out_path)
file_size = os.path.getsize(out_path)
assert file_size < 500 # Should be very small, just header.json in zip

with zipfile.ZipFile(out_path, "r") as zf:
filenames = [info.filename for info in zf.filelist]
assert "header.json" in filenames
positions_files = [f for f in filenames if f.startswith("positions")]
offsets_files = [f for f in filenames if f.startswith("offsets")]
assert len(positions_files) == 0
assert len(offsets_files) == 0

loaded_trx = tmm.load(out_path)
assert loaded_trx.header["NB_STREAMLINES"] == 0
assert loaded_trx.header["NB_VERTICES"] == 0
assert len(loaded_trx.streamlines) == 0
assert len(loaded_trx.groups) == 0
assert len(loaded_trx.data_per_streamline) == 0
assert len(loaded_trx.data_per_vertex) == 0
assert len(loaded_trx.data_per_group) == 0
loaded_trx.close()


def test_create_trx_from_pointer():
Expand Down Expand Up @@ -421,8 +453,11 @@ def test__ensure_little_endian(dtype, test_value):
# Result should be little-endian (or native if system is little-endian)
assert result.dtype.byteorder in ("<", "=", "|")

# Values should be preserved
assert result[0] == test_value
# Values should be preserved (use isclose for float types due to precision)
if np.issubdtype(dtype, np.floating):
assert np.isclose(result[0], test_value, rtol=1e-6)
else:
assert result[0] == test_value


def test__ensure_little_endian_big_endian_input():
Expand Down
38 changes: 22 additions & 16 deletions trx/trx_file_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,23 +935,25 @@ def deepcopy(self) -> Type["TrxFile"]: # noqa: C901
json.dump(tmp_header, out_json)
out_json.close()

positions_filename = _generate_filename_from_data(
to_dump, os.path.join(tmp_dir.name, "positions")
)
_ensure_little_endian(to_dump).tofile(positions_filename)

if not self._copy_safe:
to_dump = _append_last_offsets(
self.streamlines.copy()._offsets, self.header["NB_VERTICES"]
# Only write positions and offsets if TRX is not empty
if tmp_header["NB_STREAMLINES"] > 0 and tmp_header["NB_VERTICES"] > 0:
positions_filename = _generate_filename_from_data(
to_dump, os.path.join(tmp_dir.name, "positions")
)
else:
to_dump = _append_last_offsets(
self.streamlines._offsets, self.header["NB_VERTICES"]
_ensure_little_endian(to_dump).tofile(positions_filename)

if not self._copy_safe:
to_dump = _append_last_offsets(
self.streamlines.copy()._offsets, self.header["NB_VERTICES"]
)
else:
to_dump = _append_last_offsets(
self.streamlines._offsets, self.header["NB_VERTICES"]
)
offsets_filename = _generate_filename_from_data(
self.streamlines._offsets, os.path.join(tmp_dir.name, "offsets")
)
offsets_filename = _generate_filename_from_data(
self.streamlines._offsets, os.path.join(tmp_dir.name, "offsets")
)
_ensure_little_endian(to_dump).tofile(offsets_filename)
_ensure_little_endian(to_dump).tofile(offsets_filename)

if len(self.data_per_vertex.keys()) > 0:
os.mkdir(os.path.join(tmp_dir.name, "dpv/"))
Expand Down Expand Up @@ -1243,9 +1245,13 @@ def _create_trx_from_pointer( # noqa: C901
TrxFile
A TrxFile constructed from the pointer provided.
"""
# TODO support empty positions, using optional tag?
trx = TrxFile()
trx.header = header

# Handle empty TRX files early - no positions/offsets to load
if header["NB_STREAMLINES"] == 0 and header["NB_VERTICES"] == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a or ? Since streamlines with 0 or 1 points are invalid anyway, maybe this could catch some corner case and force empty TRX?

return trx

positions, offsets = None, None
for elem_filename in dict_pointer_size.keys():
if root_zip:
Expand Down