From 50700c381b6beecca28c58bbb5080c698154755b Mon Sep 17 00:00:00 2001 From: icmeyer Date: Wed, 5 Jun 2024 14:19:47 -0400 Subject: [PATCH] added n_records argument to read_ntuple, added tests --- tests/test_ntuple.py | 15 +++++++++++++++ topas2numpy/ntuple.py | 8 +++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_ntuple.py b/tests/test_ntuple.py index 7d0fee8..ff2bea1 100755 --- a/tests/test_ntuple.py +++ b/tests/test_ntuple.py @@ -135,6 +135,21 @@ class TestBinaryOtherNtuple(unittest.TestCase, CommonOtherTests): def setUp(self): self.result = read_ntuple(binary_other_path) +class TestAsciiNtupleNrecords(unittest.TestCase, CommonTests): + def setUp(self): + self.result = read_ntuple(ascii_path, n_records=50) + self.column_names = column_names + + def test_size(self): + self.assertEqual(self.result.size, 50) + +class TestBinaryNtupleNrecords(unittest.TestCase, CommonTests): + def setUp(self): + self.result = read_ntuple(binary_path, n_records=50) + self.column_names = column_names + + def test_size(self): + self.assertEqual(self.result.size, 50) if __name__ == '__main__': import sys diff --git a/topas2numpy/ntuple.py b/topas2numpy/ntuple.py index 981de87..4443e74 100644 --- a/topas2numpy/ntuple.py +++ b/topas2numpy/ntuple.py @@ -34,7 +34,7 @@ ] -def read_ntuple(filepath): +def read_ntuple(filepath, n_records=-1): root, ext = os.path.splitext(filepath) ntuple_path = root + '.phsp' header_path = root + '.header' @@ -42,12 +42,14 @@ def read_ntuple(filepath): file_format, col_names = _sniff_format(header_path) if file_format == 'ascii': + max_rows_arg = None if n_records == -1 else n_records # preserve column names => cannot be viewed as a np.recarray # http://docs.scipy.org/doc/numpy-1.10.1/user/basics.io.genfromtxt.html#validating-names - return np.genfromtxt(ntuple_path, names=col_names, deletechars=set(), replace_space='') + return np.genfromtxt(ntuple_path, names=col_names, deletechars=set(), + replace_space='', max_rows=max_rows_arg) elif file_format == 'binary': - return np.fromfile(ntuple_path, dtype=np.dtype(col_names)) + return np.fromfile(ntuple_path, dtype=np.dtype(col_names), count=n_records) else: raise IOError('Unrecognized file format: "%s"' % filepath)