diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..9855d94 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --tb=short diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..14eec5b --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package for vhlab-library-python diff --git a/tests/test_cdm.py b/tests/test_cdm.py new file mode 100644 index 0000000..5af37a0 --- /dev/null +++ b/tests/test_cdm.py @@ -0,0 +1,498 @@ +""" +Comprehensive tests for the vhlib.CDM module. +""" + +import pytest +import os +import tempfile +import shutil + + +class TestCellname2Nameref: + """Tests for cellname2nameref function.""" + + def test_basic_cellname(self): + """Test parsing a standard cell name.""" + from vhlib.CDM import cellname2nameref + + nameref, index, datestr = cellname2nameref('cell_ctx_0003_001_2003_05_27') + + assert nameref['name'] == 'ctx' + assert nameref['ref'] == 3 + assert index == 1 + assert datestr == '2003_05_27' + + def test_different_name(self): + """Test parsing cell name with different region name.""" + from vhlib.CDM import cellname2nameref + + nameref, index, datestr = cellname2nameref('cell_v1_0015_042_2020_12_01') + + assert nameref['name'] == 'v1' + assert nameref['ref'] == 15 + assert index == 42 + assert datestr == '2020_12_01' + + def test_large_numbers(self): + """Test parsing cell name with large ref and index numbers.""" + from vhlib.CDM import cellname2nameref + + nameref, index, datestr = cellname2nameref('cell_area_9999_0999_1999_01_31') + + assert nameref['ref'] == 9999 + assert index == 999 + + def test_invalid_cellname_too_short(self): + """Test that invalid cell names raise ValueError.""" + from vhlib.CDM import cellname2nameref + + with pytest.raises(ValueError): + cellname2nameref('cell_ctx_0003') + + def test_invalid_cellname_empty(self): + """Test that empty cell name raises ValueError.""" + from vhlib.CDM import cellname2nameref + + with pytest.raises(ValueError): + cellname2nameref('') + + +class TestCellname2Date: + """Tests for cellname2date function.""" + + def test_basic_date_extraction(self): + """Test extracting date from cell name.""" + from vhlib.CDM import cellname2date + + date = cellname2date('cell_ctx_0003_001_2003_05_27') + + assert date == '2003-05-27' + + def test_different_date(self): + """Test extracting different date.""" + from vhlib.CDM import cellname2date + + date = cellname2date('cell_v1_0015_042_2020_12_01') + + assert date == '2020-12-01' + + def test_january_date(self): + """Test date with January.""" + from vhlib.CDM import cellname2date + + date = cellname2date('cell_abc_0001_001_2025_01_15') + + assert date == '2025-01-15' + + +class TestNameref2Cellname: + """Tests for nameref2cellname function.""" + + def test_basic_conversion(self): + """Test basic conversion from nameref to cellname.""" + from vhlib.CDM import nameref2cellname + + # Using a string path as ds + cellname = nameref2cellname('/path/to/2003-05-27', 'ctx', 3, 1) + + assert cellname == 'cell_ctx_003_001_2003_05_27' + + def test_with_trailing_slash(self): + """Test path with trailing slash.""" + from vhlib.CDM import nameref2cellname + + cellname = nameref2cellname('/path/to/2003-05-27/', 'ctx', 3, 1) + + assert cellname == 'cell_ctx_003_001_2003_05_27' + + def test_formatting_with_zero_padding(self): + """Test that ref and index are zero-padded.""" + from vhlib.CDM import nameref2cellname + + cellname = nameref2cellname('/path/2020-01-15', 'v1', 1, 1) + + assert cellname == 'cell_v1_001_001_2020_01_15' + + def test_large_numbers(self): + """Test with larger ref and index numbers.""" + from vhlib.CDM import nameref2cellname + + cellname = nameref2cellname('/path/2020-01-15', 'area', 999, 123) + + assert cellname == 'cell_area_999_123_2020_01_15' + + def test_invalid_date_format(self): + """Test that invalid date format raises ValueError.""" + from vhlib.CDM import nameref2cellname + + with pytest.raises(ValueError): + nameref2cellname('/path/to/20030527', 'ctx', 3, 1) + + +class TestRoundTrip: + """Tests for round-trip conversion between cellname and nameref.""" + + def test_roundtrip_conversion(self): + """Test that cellname -> nameref -> cellname preserves data.""" + from vhlib.CDM import cellname2nameref, nameref2cellname + + original = 'cell_ctx_003_001_2003_05_27' + nameref, index, datestr = cellname2nameref(original) + + # Create path with date in expected format + path = f'/path/to/{datestr.replace("_", "-")}' + reconstructed = nameref2cellname(path, nameref['name'], nameref['ref'], index) + + assert reconstructed == original + + +class TestTrainingHelp: + """Tests for training help functions.""" + + def test_trainingtype_prints_docstring(self, capsys): + """Test that trainingtype prints its docstring.""" + from vhlib.CDM import trainingtype + + trainingtype() + captured = capsys.readouterr() + + assert 'trainingtype.txt' in captured.out + assert 'Bidirectional' in captured.out + + def test_trainingangle_prints_docstring(self, capsys): + """Test that trainingangle prints its docstring.""" + from vhlib.CDM import trainingangle + + trainingangle() + captured = capsys.readouterr() + + assert 'trainingangle.txt' in captured.out + + def test_trainingstim_prints_docstring(self, capsys): + """Test that trainingstim prints its docstring.""" + from vhlib.CDM import trainingstim + + trainingstim() + captured = capsys.readouterr() + + assert 'trainingstim.txt' in captured.out + + def test_trainingtemporalfrequency_prints_docstring(self, capsys): + """Test that trainingtemporalfrequency prints its docstring.""" + from vhlib.CDM import trainingtemporalfrequency + + trainingtemporalfrequency() + captured = capsys.readouterr() + + assert 'trainingtemporalfrequency.txt' in captured.out + + +class TestUnitquality: + """Tests for unitquality help function.""" + + def test_unitquality_prints_docstring(self, capsys): + """Test that unitquality prints its docstring.""" + from vhlib.CDM import unitquality + + unitquality() + captured = capsys.readouterr() + + assert 'unitquality.txt' in captured.out + assert 'channel' in captured.out + + +class TestHelpFiles: + """Tests for help file functions.""" + + def test_unitquality_channelshift_prints_docstring(self, capsys): + """Test that unitquality_channelshift prints docstring.""" + from vhlib.CDM import unitquality_channelshift + + unitquality_channelshift() + captured = capsys.readouterr() + + assert 'unitquality_channelshift.txt' in captured.out + + def test_testdirinfo_prints_docstring(self, capsys): + """Test that testdirinfo prints docstring.""" + from vhlib.CDM import testdirinfo + + testdirinfo() + captured = capsys.readouterr() + + assert 'testdirinfo.txt' in captured.out + + +class TestAssociateVariablesTxt: + """Tests for associate_variables_txt help function.""" + + def test_associate_variables_txt_prints_docstring(self, capsys): + """Test that associate_variables_txt prints docstring.""" + from vhlib.CDM import associate_variables_txt + + associate_variables_txt() + captured = capsys.readouterr() + + assert 'associate_variables.txt' in captured.out + + +class TestFilterByIndex: + """Tests for filter_by_index function.""" + + def test_filter_within_range(self): + """Test filtering cells within index range.""" + from vhlib.CDM import filter_by_index + + cells = ['cell_a', 'cell_b', 'cell_c', 'cell_d'] + cellnames = [ + 'cell_ctx_0001_001_2020_01_01', + 'cell_ctx_0001_005_2020_01_01', + 'cell_ctx_0001_010_2020_01_01', + 'cell_ctx_0001_015_2020_01_01' + ] + + filtered_cells, filtered_names, indices = filter_by_index(cells, cellnames, 1, 10) + + assert len(filtered_cells) == 3 + assert len(filtered_names) == 3 + assert indices == [0, 1, 2] + + def test_filter_exact_match(self): + """Test filtering with exact min=max.""" + from vhlib.CDM import filter_by_index + + cells = ['a', 'b', 'c'] + cellnames = [ + 'cell_ctx_0001_005_2020_01_01', + 'cell_ctx_0001_010_2020_01_01', + 'cell_ctx_0001_015_2020_01_01' + ] + + filtered_cells, filtered_names, indices = filter_by_index(cells, cellnames, 10, 10) + + assert len(filtered_cells) == 1 + assert filtered_cells == ['b'] + assert indices == [1] + + def test_filter_no_matches(self): + """Test filtering with no matches.""" + from vhlib.CDM import filter_by_index + + cells = ['a', 'b'] + cellnames = [ + 'cell_ctx_0001_001_2020_01_01', + 'cell_ctx_0001_002_2020_01_01' + ] + + filtered_cells, filtered_names, indices = filter_by_index(cells, cellnames, 100, 200) + + assert len(filtered_cells) == 0 + assert indices == [] + + +class TestFilterByReference: + """Tests for filter_by_reference function.""" + + def test_filter_within_range(self): + """Test filtering cells within reference range.""" + from vhlib.CDM import filter_by_reference + + cells = ['cell_a', 'cell_b', 'cell_c', 'cell_d'] + cellnames = [ + 'cell_ctx_0001_001_2020_01_01', + 'cell_ctx_0005_001_2020_01_01', + 'cell_ctx_0010_001_2020_01_01', + 'cell_ctx_0015_001_2020_01_01' + ] + + filtered_cells, filtered_names, indices = filter_by_reference(cells, cellnames, 1, 10) + + assert len(filtered_cells) == 3 + assert len(filtered_names) == 3 + assert indices == [0, 1, 2] + + def test_filter_exact_ref(self): + """Test filtering with exact reference.""" + from vhlib.CDM import filter_by_reference + + cells = ['a', 'b', 'c'] + cellnames = [ + 'cell_ctx_0005_001_2020_01_01', + 'cell_ctx_0010_001_2020_01_01', + 'cell_ctx_0015_001_2020_01_01' + ] + + filtered_cells, filtered_names, indices = filter_by_reference(cells, cellnames, 10, 10) + + assert len(filtered_cells) == 1 + assert filtered_cells == ['b'] + + +class TestReadTrainingType: + """Tests for read_trainingtype function.""" + + def setup_method(self): + """Create temporary directory for tests.""" + self.test_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up temporary directory.""" + shutil.rmtree(self.test_dir) + + def test_read_bidirectional(self): + """Test reading bidirectional training type.""" + from vhlib.CDM import read_trainingtype + + # Create trainingtype.txt + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('Bidirectional\n') + + assoc = read_trainingtype(self.test_dir) + + assert len(assoc) == 1 + assert assoc[0]['type'] == 'Training Type' + assert assoc[0]['data'] == 'bidirectional' + + def test_read_unidirectional(self): + """Test reading unidirectional training type.""" + from vhlib.CDM import read_trainingtype + + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('uni\n') + + assoc = read_trainingtype(self.test_dir) + + assert assoc[0]['data'] == 'unidirectional' + + def test_read_flash(self): + """Test reading flash training type.""" + from vhlib.CDM import read_trainingtype + + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('Flash\n') + + assoc = read_trainingtype(self.test_dir) + + assert assoc[0]['data'] == 'flash' + + def test_read_none(self): + """Test reading none training type.""" + from vhlib.CDM import read_trainingtype + + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('None\n') + + assoc = read_trainingtype(self.test_dir) + + assert assoc[0]['data'] == 'none' + + def test_read_counterphase(self): + """Test reading counterphase training type.""" + from vhlib.CDM import read_trainingtype + + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('CP\n') + + assoc = read_trainingtype(self.test_dir) + + assert assoc[0]['data'] == 'counterphase' + + def test_read_training_angle(self): + """Test reading training angle.""" + from vhlib.CDM import read_trainingtype + + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('Bi\n') + with open(os.path.join(self.test_dir, 'trainingangle.txt'), 'w') as f: + f.write('45.0 135.0\n') + + assoc = read_trainingtype(self.test_dir) + + assert len(assoc) == 2 + angle_assoc = [a for a in assoc if a['type'] == 'Training Angle'][0] + assert angle_assoc['data'] == [45.0, 135.0] + + def test_read_training_tf(self): + """Test reading training temporal frequency.""" + from vhlib.CDM import read_trainingtype + + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('Bi\n') + with open(os.path.join(self.test_dir, 'trainingtemporalfrequency.txt'), 'w') as f: + f.write('4.0 8.0\n') + + assoc = read_trainingtype(self.test_dir) + + tf_assoc = [a for a in assoc if a['type'] == 'Training TF'][0] + assert tf_assoc['data'] == [4.0, 8.0] + + def test_read_training_stim(self): + """Test reading training stim.""" + from vhlib.CDM import read_trainingtype + + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('scrambled\n') + with open(os.path.join(self.test_dir, 'trainingstim.txt'), 'w') as f: + f.write('b5\n') + + assoc = read_trainingtype(self.test_dir) + + stim_assoc = [a for a in assoc if a['type'] == 'Training Stim'][0] + assert stim_assoc['data'] == 'B5' + + def test_no_training_type_file(self): + """Test when training type file doesn't exist.""" + from vhlib.CDM import read_trainingtype + + # Don't create any files + assoc = read_trainingtype(self.test_dir) + + assert len(assoc) == 0 + + def test_error_if_no_training_type_requested(self): + """Test error raised when requested and file missing.""" + from vhlib.CDM import read_trainingtype + + with pytest.raises(FileNotFoundError): + read_trainingtype(self.test_dir, ErrorIfNoTrainingType=True) + + def test_unknown_training_type_raises_error(self): + """Test that unknown training type raises ValueError.""" + from vhlib.CDM import read_trainingtype + + with open(os.path.join(self.test_dir, 'trainingtype.txt'), 'w') as f: + f.write('UnknownType\n') + + with pytest.raises(ValueError): + read_trainingtype(self.test_dir) + + +class TestRepeatedMeasurementAssociates: + """Tests for repeated_measurement_associates function.""" + + def test_find_repeated_associates(self): + """Test finding repeated measurement associates.""" + from vhlib.CDM import repeated_measurement_associates + + # Create a cell with associates + cell = { + 'associates': [ + {'type': 'SP F0 TFOP0 TF Response curve', 'owner': 'test', 'data': 1, 'desc': 'test'}, + {'type': 'SP F0 TFOP1 TF Response curve', 'owner': 'test', 'data': 2, 'desc': 'test'}, + {'type': 'SP F0 TFOP3 TF Response curve', 'owner': 'test', 'data': 3, 'desc': 'test'}, + ] + } + + n = repeated_measurement_associates(cell, 'SP F0 TFOP%d TF Response curve', 5) + + assert n == [0, 1, 3] + + def test_no_repeated_associates(self): + """Test when no repeated associates found.""" + from vhlib.CDM import repeated_measurement_associates + + cell = {'associates': []} + + n = repeated_measurement_associates(cell, 'Test%d', 10) + + assert n == [] diff --git a/tests/test_md.py b/tests/test_md.py new file mode 100644 index 0000000..b8b7ea4 --- /dev/null +++ b/tests/test_md.py @@ -0,0 +1,576 @@ +""" +Comprehensive tests for the vhlib.md module. +""" + +import pytest +import numpy as np + + +class TestMeasuredData: + """Tests for MeasuredData class.""" + + def test_create_basic(self): + """Test creating a basic MeasuredData object.""" + from vhlib.md import MeasuredData + + intervals = [[0, 1], [2, 3], [4, 5]] + md = MeasuredData(intervals, 'Long description', 'Brief') + + assert md.intervals == intervals + assert md.description_long == 'Long description' + assert md.description_brief == 'Brief' + assert len(md.associates) == 0 + + def test_create_with_numpy_intervals(self): + """Test creating MeasuredData with numpy array intervals.""" + from vhlib.md import MeasuredData + + intervals = np.array([[0, 1], [2, 3]]) + md = MeasuredData(intervals) + + assert md.intervals.shape == (2, 2) + + def test_create_with_empty_intervals(self): + """Test creating MeasuredData with empty intervals.""" + from vhlib.md import MeasuredData + + md = MeasuredData([]) + + assert len(md.intervals) == 0 + + def test_invalid_intervals_shape(self): + """Test that invalid intervals shape raises error.""" + from vhlib.md import MeasuredData + + with pytest.raises(ValueError): + MeasuredData([[1, 2, 3]]) # 3 columns instead of 2 + + +class TestMeasuredDataAssociate: + """Tests for MeasuredData.associate method.""" + + def test_add_associate(self): + """Test adding an associate.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('test_type', 'test_owner', {'value': 42}, 'test description') + + assert md.numassociates() == 1 + assoc = md.getassociate(0) + assert assoc['type'] == 'test_type' + assert assoc['owner'] == 'test_owner' + assert assoc['data'] == {'value': 42} + assert assoc['desc'] == 'test description' + + def test_add_associate_dict(self): + """Test adding an associate using dict format.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + assoc_dict = { + 'type': 'dict_type', + 'owner': 'dict_owner', + 'data': [1, 2, 3], + 'desc': 'dict description' + } + md.associate(assoc_dict) + + assert md.numassociates() == 1 + assoc = md.getassociate(0) + assert assoc['type'] == 'dict_type' + + def test_replace_existing_associate(self): + """Test that adding duplicate associate replaces it.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type1', 'owner1', 'data2', 'desc1') + + assert md.numassociates() == 1 + assoc = md.getassociate(0) + assert assoc['data'] == 'data2' + + def test_add_multiple_associates(self): + """Test adding multiple different associates.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type2', 'owner2', 'data2', 'desc2') + md.associate('type3', 'owner3', 'data3', 'desc3') + + assert md.numassociates() == 3 + + def test_associate_returns_self(self): + """Test that associate returns self for chaining.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + result = md.associate('type', 'owner', 'data', 'desc') + + assert result is md + + def test_associate_invalid_type(self): + """Test that non-string type raises error.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + + with pytest.raises(ValueError): + md.associate(123, 'owner', 'data', 'desc') + + def test_associate_invalid_owner(self): + """Test that non-string owner raises error.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + + with pytest.raises(ValueError): + md.associate('type', 123, 'data', 'desc') + + def test_associate_invalid_description(self): + """Test that non-string description raises error.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + + with pytest.raises(ValueError): + md.associate('type', 'owner', 'data', 123) + + +class TestMeasuredDataFindassociate: + """Tests for MeasuredData.findassociate method.""" + + def test_find_by_type(self): + """Test finding associate by type.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type2', 'owner2', 'data2', 'desc2') + + matches, indices = md.findassociate('type1', '', '') + + assert len(matches) == 1 + assert matches[0]['type'] == 'type1' + assert indices == [0] + + def test_find_by_owner(self): + """Test finding associate by owner.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type2', 'owner1', 'data2', 'desc2') + md.associate('type3', 'owner2', 'data3', 'desc3') + + matches, indices = md.findassociate('', 'owner1', '') + + assert len(matches) == 2 + assert indices == [0, 1] + + def test_find_by_description(self): + """Test finding associate by description.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type2', 'owner2', 'data2', 'desc1') + + matches, indices = md.findassociate('', '', 'desc1') + + assert len(matches) == 2 + + def test_find_by_multiple_criteria(self): + """Test finding associate by multiple criteria.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type1', 'owner2', 'data2', 'desc1') + md.associate('type2', 'owner1', 'data3', 'desc1') + + matches, indices = md.findassociate('type1', 'owner1', 'desc1') + + assert len(matches) == 1 + assert indices == [0] + + def test_find_all_with_empty_criteria(self): + """Test finding all associates with empty criteria.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type2', 'owner2', 'data2', 'desc2') + + matches, indices = md.findassociate('', '', '') + + assert len(matches) == 2 + assert indices == [0, 1] + + def test_find_no_matches(self): + """Test finding with no matches.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + + matches, indices = md.findassociate('nonexistent', '', '') + + assert len(matches) == 0 + assert indices == [] + + +class TestMeasuredDataDisassociate: + """Tests for MeasuredData.disassociate method.""" + + def test_disassociate_single(self): + """Test removing single associate.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type2', 'owner2', 'data2', 'desc2') + + md.disassociate(0) + + assert md.numassociates() == 1 + assert md.getassociate(0)['type'] == 'type2' + + def test_disassociate_multiple(self): + """Test removing multiple associates.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type2', 'owner2', 'data2', 'desc2') + md.associate('type3', 'owner3', 'data3', 'desc3') + + md.disassociate([0, 2]) + + assert md.numassociates() == 1 + assert md.getassociate(0)['type'] == 'type2' + + def test_disassociate_returns_self(self): + """Test that disassociate returns self.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + + result = md.disassociate(0) + + assert result is md + + def test_disassociate_invalid_index(self): + """Test that invalid index is handled gracefully.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + + # Should not raise, just ignore invalid index + md.disassociate(100) + + assert md.numassociates() == 1 + + +class TestMeasuredDataAssociates2struct: + """Tests for MeasuredData.associates2struct method.""" + + def test_basic_conversion(self): + """Test converting associates to struct.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('Type One', 'owner', 'data1', 'desc') + md.associate('Type Two', 'owner', 'data2', 'desc') + + s = md.associates2struct() + + assert s['Type_One'] == 'data1' + assert s['Type_Two'] == 'data2' + + def test_empty_associates(self): + """Test converting empty associates.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + + s = md.associates2struct() + + assert s == {} + + +class TestMeasuredDataGetassociate: + """Tests for MeasuredData.getassociate method.""" + + def test_get_single(self): + """Test getting single associate by index.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + + assoc = md.getassociate(0) + + assert assoc['type'] == 'type1' + + def test_get_multiple(self): + """Test getting multiple associates by indices.""" + from vhlib.md import MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + md.associate('type2', 'owner2', 'data2', 'desc2') + md.associate('type3', 'owner3', 'data3', 'desc3') + + assocs = md.getassociate([0, 2]) + + assert len(assocs) == 2 + assert assocs[0]['type'] == 'type1' + assert assocs[1]['type'] == 'type3' + + +class TestModuleLevelFunctions: + """Tests for module-level functions in vhlib.md.""" + + def test_findassociate_with_dict(self): + """Test findassociate with dict input.""" + from vhlib.md import findassociate + + cell = { + 'associates': [ + {'type': 'type1', 'owner': 'owner1', 'data': 'data1', 'desc': 'desc1'}, + {'type': 'type2', 'owner': 'owner2', 'data': 'data2', 'desc': 'desc2'}, + ] + } + + matches, indices = findassociate(cell, 'type1', '', '') + + assert len(matches) == 1 + assert indices == [0] + + def test_findassociate_with_measureddata(self): + """Test findassociate with MeasuredData input.""" + from vhlib.md import findassociate, MeasuredData + + md = MeasuredData([[0, 1]]) + md.associate('type1', 'owner1', 'data1', 'desc1') + + matches, indices = findassociate(md, 'type1', '', '') + + assert len(matches) == 1 + + def test_findassociate_empty_associates(self): + """Test findassociate with empty associates list.""" + from vhlib.md import findassociate + + cell = {'associates': []} + + matches, indices = findassociate(cell, 'type1', '', '') + + assert matches == [] + assert indices == [] + + def test_findassociate_invalid_input(self): + """Test findassociate with invalid input.""" + from vhlib.md import findassociate + + with pytest.raises(ValueError): + findassociate("invalid", 'type', '', '') + + def test_associate_with_dict(self): + """Test associate with dict input.""" + from vhlib.md import associate + + cell = {'associates': []} + + result = associate(cell, 'type1', 'owner1', 'data1', 'desc1') + + assert len(result['associates']) == 1 + assert result['associates'][0]['type'] == 'type1' + + def test_associate_with_dict_struct(self): + """Test associate with dict struct input.""" + from vhlib.md import associate + + cell = {'associates': []} + assoc_struct = { + 'type': 'type1', + 'owner': 'owner1', + 'data': 'data1', + 'desc': 'desc1' + } + + result = associate(cell, assoc_struct) + + assert len(result['associates']) == 1 + + def test_associate_replaces_existing(self): + """Test that associate replaces existing matching associate.""" + from vhlib.md import associate + + cell = { + 'associates': [ + {'type': 'type1', 'owner': 'owner1', 'data': 'old_data', 'desc': 'desc1'} + ] + } + + result = associate(cell, 'type1', 'owner1', 'new_data', 'desc1') + + assert len(result['associates']) == 1 + assert result['associates'][0]['data'] == 'new_data' + + def test_disassociate_with_dict(self): + """Test disassociate with dict input.""" + from vhlib.md import disassociate + + cell = { + 'associates': [ + {'type': 'type1', 'owner': 'owner1', 'data': 'data1', 'desc': 'desc1'}, + {'type': 'type2', 'owner': 'owner2', 'data': 'data2', 'desc': 'desc2'} + ] + } + + result = disassociate(cell, 0) + + assert len(result['associates']) == 1 + assert result['associates'][0]['type'] == 'type2' + + def test_disassociate_multiple_indices(self): + """Test disassociate with multiple indices.""" + from vhlib.md import disassociate + + cell = { + 'associates': [ + {'type': 'type1', 'owner': 'owner1', 'data': 'data1', 'desc': 'desc1'}, + {'type': 'type2', 'owner': 'owner2', 'data': 'data2', 'desc': 'desc2'}, + {'type': 'type3', 'owner': 'owner3', 'data': 'data3', 'desc': 'desc3'} + ] + } + + result = disassociate(cell, [0, 2]) + + assert len(result['associates']) == 1 + assert result['associates'][0]['type'] == 'type2' + + def test_associate_all(self): + """Test associate_all function.""" + from vhlib.md import associate_all + + cells = [ + {'associates': []}, + {'associates': []}, + ] + assoclist = [ + {'type': 'type1', 'owner': 'owner1', 'data': 'data1', 'desc': 'desc1'}, + {'type': 'type2', 'owner': 'owner2', 'data': 'data2', 'desc': 'desc2'}, + ] + + result = associate_all(cells, assoclist) + + assert len(result) == 2 + assert len(result[0]['associates']) == 2 + assert len(result[1]['associates']) == 2 + + def test_associate_all_single_cell(self): + """Test associate_all with single cell (not list).""" + from vhlib.md import associate_all + + cell = {'associates': []} + assoclist = [ + {'type': 'type1', 'owner': 'owner1', 'data': 'data1', 'desc': 'desc1'}, + ] + + result = associate_all(cell, assoclist) + + assert len(result['associates']) == 1 + + +class TestSpikeTriggeredAverage: + """Tests for spiketriggeredaverage function.""" + + def test_basic_sta(self): + """Test basic spike-triggered average calculation.""" + from vhlib.md import spiketriggeredaverage + + # Create test signal + signal_t = np.arange(0, 1, 0.001) # 1 second, 1kHz + signal = np.sin(2 * np.pi * 10 * signal_t) # 10 Hz sine wave + + # Spike times - ensure they have enough room for the window (50ms before and after) + spiketimes = [0.1, 0.2, 0.3, 0.4, 0.5] + + sta, t_sta, count = spiketriggeredaverage( + spiketimes, signal, signal_t, [-0.05, 0.05]) + + assert count == 5 + assert len(sta) == len(t_sta) + assert t_sta[0] < 0 # Starts before spike + assert t_sta[-1] > 0 # Ends after spike + + def test_sta_no_spikes_in_range(self): + """Test STA with no spikes in valid range.""" + from vhlib.md import spiketriggeredaverage + + signal_t = np.arange(0, 1, 0.001) + signal = np.ones_like(signal_t) + + # Spikes outside signal range + spiketimes = [2.0, 3.0] + + sta, t_sta, count = spiketriggeredaverage( + spiketimes, signal, signal_t, [-0.05, 0.05]) + + assert count == 0 + + def test_sta_single_spike(self): + """Test STA with single spike.""" + from vhlib.md import spiketriggeredaverage + + signal_t = np.arange(0, 1, 0.001) + signal = np.zeros_like(signal_t) + signal[500] = 1.0 # Impulse at t=0.5 + + spiketimes = [0.5] + + sta, t_sta, count = spiketriggeredaverage( + spiketimes, signal, signal_t, [-0.01, 0.01]) + + assert count == 1 + # STA should have peak near center + center_idx = len(sta) // 2 + assert sta[center_idx] > 0 + + def test_sta_short_signal(self): + """Test STA with very short signal.""" + from vhlib.md import spiketriggeredaverage + + signal_t = np.array([0]) # Single sample + signal = np.array([1.0]) + + spiketimes = [0.0] + + sta, t_sta, count = spiketriggeredaverage( + spiketimes, signal, signal_t, [-0.01, 0.01]) + + # Should return None or empty due to insufficient data + assert sta is None or count == 0 + + def test_sta_window_larger_than_signal(self): + """Test STA when window is larger than available signal.""" + from vhlib.md import spiketriggeredaverage + + signal_t = np.arange(0, 0.1, 0.001) # 100ms + signal = np.ones_like(signal_t) + + spiketimes = [0.05] + + sta, t_sta, count = spiketriggeredaverage( + spiketimes, signal, signal_t, [-0.1, 0.1]) # 200ms window + + # Spike at 0.05 can't have 100ms before it in 100ms signal + assert count == 0 diff --git a/tests/test_stimdecode.py b/tests/test_stimdecode.py new file mode 100644 index 0000000..9a7d4a4 --- /dev/null +++ b/tests/test_stimdecode.py @@ -0,0 +1,400 @@ +""" +Comprehensive tests for the vhlib.StimDecode module. +""" + +import pytest +import os +import tempfile +import shutil +import numpy as np + + +class TestReadStimtimesTxt: + """Tests for read_stimtimes_txt function.""" + + def setup_method(self): + """Create temporary directory for tests.""" + self.test_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up temporary directory.""" + shutil.rmtree(self.test_dir) + + def test_read_basic_stimtimes(self): + """Test reading basic stimtimes.txt file.""" + from vhlib.StimDecode import read_stimtimes_txt + + # Create stimtimes.txt + content = """1 0.00000 0.10000 0.20000 0.30000 +2 1.00000 1.10000 1.20000 1.30000 +3 2.00000 2.10000 2.20000 2.30000 +""" + with open(os.path.join(self.test_dir, 'stimtimes.txt'), 'w') as f: + f.write(content) + + stimids, stimtimes, frametimes = read_stimtimes_txt(self.test_dir) + + assert len(stimids) == 3 + assert stimids[0] == 1 + assert stimids[1] == 2 + assert stimids[2] == 3 + + assert len(stimtimes) == 3 + np.testing.assert_almost_equal(stimtimes[0], 0.0, decimal=5) + np.testing.assert_almost_equal(stimtimes[1], 1.0, decimal=5) + np.testing.assert_almost_equal(stimtimes[2], 2.0, decimal=5) + + assert len(frametimes) == 3 + np.testing.assert_array_almost_equal(frametimes[0], [0.1, 0.2, 0.3], decimal=5) + + def test_read_single_line(self): + """Test reading single line stimtimes.txt.""" + from vhlib.StimDecode import read_stimtimes_txt + + content = "5 10.50000 10.60000 10.70000\n" + with open(os.path.join(self.test_dir, 'stimtimes.txt'), 'w') as f: + f.write(content) + + stimids, stimtimes, frametimes = read_stimtimes_txt(self.test_dir) + + assert len(stimids) == 1 + assert stimids[0] == 5 + np.testing.assert_almost_equal(stimtimes[0], 10.5, decimal=5) + + def test_read_no_frametimes(self): + """Test reading stimtimes with no frame times.""" + from vhlib.StimDecode import read_stimtimes_txt + + content = "1 0.50000\n2 1.50000\n" + with open(os.path.join(self.test_dir, 'stimtimes.txt'), 'w') as f: + f.write(content) + + stimids, stimtimes, frametimes = read_stimtimes_txt(self.test_dir) + + assert len(stimids) == 2 + assert len(frametimes[0]) == 0 # No frame times + assert len(frametimes[1]) == 0 + + def test_file_not_found(self): + """Test that missing file raises IOError.""" + from vhlib.StimDecode import read_stimtimes_txt + + with pytest.raises(IOError): + read_stimtimes_txt(self.test_dir, 'nonexistent.txt') + + def test_custom_filename(self): + """Test reading from custom filename.""" + from vhlib.StimDecode import read_stimtimes_txt + + content = "1 0.00000\n" + with open(os.path.join(self.test_dir, 'custom_stim.txt'), 'w') as f: + f.write(content) + + stimids, stimtimes, frametimes = read_stimtimes_txt(self.test_dir, 'custom_stim.txt') + + assert len(stimids) == 1 + + +class TestWriteStimtimesTxt: + """Tests for write_stimtimes_txt function.""" + + def setup_method(self): + """Create temporary directory for tests.""" + self.test_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up temporary directory.""" + shutil.rmtree(self.test_dir) + + def test_write_basic_stimtimes(self): + """Test writing basic stimtimes.txt file.""" + from vhlib.StimDecode import write_stimtimes_txt, read_stimtimes_txt + + stimids = [1, 2, 3] + stimtimes = [0.0, 1.0, 2.0] + frametimes = [ + np.array([0.1, 0.2, 0.3]), + np.array([1.1, 1.2, 1.3]), + np.array([2.1, 2.2, 2.3]) + ] + + write_stimtimes_txt(self.test_dir, stimids, stimtimes, frametimes, filename='test_stim.txt') + + # Verify file was created + assert os.path.exists(os.path.join(self.test_dir, 'test_stim.txt')) + + # Read back and verify + r_stimids, r_stimtimes, r_frametimes = read_stimtimes_txt(self.test_dir, 'test_stim.txt') + + np.testing.assert_array_equal(r_stimids, stimids) + np.testing.assert_array_almost_equal(r_stimtimes, stimtimes, decimal=5) + + def test_write_without_frametimes(self): + """Test writing stimtimes without frame times.""" + from vhlib.StimDecode import write_stimtimes_txt + + stimids = [1, 2] + stimtimes = [0.5, 1.5] + + write_stimtimes_txt(self.test_dir, stimids, stimtimes, filename='no_frames.txt') + + assert os.path.exists(os.path.join(self.test_dir, 'no_frames.txt')) + + def test_file_already_exists_error(self): + """Test that writing to existing file raises IOError.""" + from vhlib.StimDecode import write_stimtimes_txt + + # Create file first + with open(os.path.join(self.test_dir, 'existing.txt'), 'w') as f: + f.write('test') + + with pytest.raises(IOError): + write_stimtimes_txt(self.test_dir, [1], [0.0], filename='existing.txt') + + +class TestGetStimdirectoryTime: + """Tests for getstimdirectorytime function.""" + + def setup_method(self): + """Create temporary directory for tests.""" + self.test_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up temporary directory.""" + shutil.rmtree(self.test_dir) + + def test_read_time_from_filetime(self): + """Test reading time from filetime.txt.""" + from vhlib.StimDecode import getstimdirectorytime + + # Create required files + with open(os.path.join(self.test_dir, 'stims.mat'), 'w') as f: + f.write('') # Empty mat file placeholder + with open(os.path.join(self.test_dir, 'spike2data.smr'), 'w') as f: + f.write('') # Empty smr file placeholder + with open(os.path.join(self.test_dir, 'filetime.txt'), 'w') as f: + f.write('36000.0') # 10:00 AM in seconds + + time_val = getstimdirectorytime(self.test_dir) + + assert time_val == 36000.0 + + def test_early_morning_adjustment(self): + """Test early morning time adjustment.""" + from vhlib.StimDecode import getstimdirectorytime + + # Create required files + with open(os.path.join(self.test_dir, 'stims.mat'), 'w') as f: + f.write('') + with open(os.path.join(self.test_dir, 'spike2data.smr'), 'w') as f: + f.write('') + with open(os.path.join(self.test_dir, 'filetime.txt'), 'w') as f: + f.write('7200.0') # 2:00 AM in seconds + + time_val = getstimdirectorytime(self.test_dir) + + # Should add 24 hours since it's before 5 AM cutoff + assert time_val == 7200.0 + 86400.0 + + def test_no_early_morning_warning(self): + """Test disabling early morning warning.""" + from vhlib.StimDecode import getstimdirectorytime + + with open(os.path.join(self.test_dir, 'stims.mat'), 'w') as f: + f.write('') + with open(os.path.join(self.test_dir, 'spike2data.smr'), 'w') as f: + f.write('') + with open(os.path.join(self.test_dir, 'filetime.txt'), 'w') as f: + f.write('7200.0') + + # Should not raise, just return adjusted time + time_val = getstimdirectorytime(self.test_dir, WarnOnEarlyMorning=False) + + assert time_val == 7200.0 + 86400.0 + + def test_missing_files_error(self): + """Test error when required files are missing.""" + from vhlib.StimDecode import getstimdirectorytime + + with pytest.raises(FileNotFoundError): + getstimdirectorytime(self.test_dir) + + def test_missing_files_no_error(self): + """Test no error when ErrorIfEmpty=False.""" + from vhlib.StimDecode import getstimdirectorytime + + time_val = getstimdirectorytime(self.test_dir, ErrorIfEmpty=False) + + assert np.isnan(time_val) + + +class TestVhinterconnectDecode: + """Tests for vhinterconnect_decode function.""" + + def test_basic_decode(self): + """Test basic decoding of interconnect signals.""" + from vhlib.StimDecode import vhinterconnect_decode + + # Create test data + time = np.array([0.0, 0.001, 0.002, 0.003, 0.004, 0.005]) + # Bit 0 (StimTrigger): 0,0,1,1,1,0 -> transition at index 2 + # Signal values: bit 0 set at indices 2,3,4 + input_sig = np.array([0, 0, 1, 1, 1, 0], dtype=np.uint16) + + out = vhinterconnect_decode(time, input_sig) + + assert 'StimTrigger' in out + assert 'StimTriggerSamples' in out + + def test_custom_polarity(self): + """Test decoding with custom polarity.""" + from vhlib.StimDecode import vhinterconnect_decode + + time = np.array([0.0, 0.001, 0.002, 0.003]) + input_sig = np.array([0, 0, 1, 1], dtype=np.uint16) + + # Custom polarity with some NaN values (should use defaults) + polarity = np.array([1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]) + + out = vhinterconnect_decode(time, input_sig, polarity=polarity) + + assert 'StimTrigger' in out + + def test_invalid_polarity_length(self): + """Test that invalid polarity length raises error.""" + from vhlib.StimDecode import vhinterconnect_decode + + time = np.array([0.0, 0.001]) + input_sig = np.array([0, 1], dtype=np.uint16) + polarity = np.array([1, 1, 1]) # Wrong length + + with pytest.raises(ValueError): + vhinterconnect_decode(time, input_sig, polarity=polarity) + + def test_stim_code_extraction(self): + """Test extraction of stimulus codes from upper bits.""" + from vhlib.StimDecode import vhinterconnect_decode + + time = np.array([0.0, 0.001, 0.002, 0.003]) + # Create signal with stim code in upper byte + # Upper byte = 5, lower bit 0 set for trigger + # 5 << 8 = 1280, plus 1 for trigger bit = 1281 + input_sig = np.array([0, 0, 1281, 1281], dtype=np.uint16) + + out = vhinterconnect_decode(time, input_sig) + + if len(out.get('StimTriggerSamples', [])) > 0: + assert 'StimCode' in out + + +class TestStimscriptgraph: + """Tests for stimscriptgraph function.""" + + def test_not_implemented(self): + """Test that stimscriptgraph raises NotImplementedError.""" + from vhlib.StimDecode import stimscriptgraph + + with pytest.raises(NotImplementedError): + stimscriptgraph('/some/path') + + +class TestVhlabcorrectmti: + """Tests for vhlabcorrectmti function.""" + + def test_not_implemented(self): + """Test that vhlabcorrectmti raises NotImplementedError.""" + from vhlib.StimDecode import vhlabcorrectmti + + with pytest.raises(NotImplementedError): + vhlabcorrectmti({}, 'file.txt') + + +class TestWriteInterconnectTextfiles: + """Tests for write_interconnect_textfiles function.""" + + def setup_method(self): + """Create temporary directory for tests.""" + self.test_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up temporary directory.""" + shutil.rmtree(self.test_dir) + + def test_write_basic_textfiles(self): + """Test writing interconnect text files.""" + from vhlib.StimDecode import write_interconnect_textfiles + + out = { + 'StimTrigger': np.array([0.5, 1.5, 2.5]), + 'StimCode': np.array([1, 2, 3]), + 'FrameTriggerRaw': np.array([0.6, 0.7, 1.6, 1.7, 2.6, 2.7]), + 'TwoPhotonFrameTrigger': np.array([0.55, 1.55, 2.55]), + 'StimulusMonitorVerticalRefresh': np.array([0.51, 0.52, 1.51, 1.52]) + } + + write_interconnect_textfiles(self.test_dir, out) + + # Check that files were created + assert os.path.exists(os.path.join(self.test_dir, 'stimtimes.txt')) + assert os.path.exists(os.path.join(self.test_dir, 'stimontimes.txt')) + assert os.path.exists(os.path.join(self.test_dir, 'twophotontimes.txt')) + assert os.path.exists(os.path.join(self.test_dir, 'verticalblanking.txt')) + assert os.path.exists(os.path.join(self.test_dir, 'Intan_decoding_finished.txt')) + + def test_removes_existing_files(self): + """Test that existing files are removed before writing.""" + from vhlib.StimDecode import write_interconnect_textfiles + + # Create existing files + for fname in ['stimtimes.txt', 'stimontimes.txt']: + with open(os.path.join(self.test_dir, fname), 'w') as f: + f.write('old content') + + out = { + 'StimTrigger': np.array([0.5]), + 'StimCode': np.array([1]), + 'FrameTriggerRaw': np.array([0.6]) + } + + write_interconnect_textfiles(self.test_dir, out) + + # Files should be recreated + assert os.path.exists(os.path.join(self.test_dir, 'stimtimes.txt')) + + +class TestReadWriteRoundTrip: + """Tests for round-trip read/write operations.""" + + def setup_method(self): + """Create temporary directory for tests.""" + self.test_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up temporary directory.""" + shutil.rmtree(self.test_dir) + + def test_stimtimes_roundtrip(self): + """Test that writing and reading stimtimes preserves data.""" + from vhlib.StimDecode import write_stimtimes_txt, read_stimtimes_txt + + original_stimids = [1, 5, 10, 15] + original_stimtimes = [0.0, 1.5, 3.0, 4.5] + original_frametimes = [ + np.array([0.1, 0.2]), + np.array([1.6, 1.7]), + np.array([3.1, 3.2]), + np.array([4.6, 4.7]) + ] + + write_stimtimes_txt(self.test_dir, original_stimids, original_stimtimes, + original_frametimes, filename='roundtrip.txt') + + read_stimids, read_stimtimes, read_frametimes = read_stimtimes_txt( + self.test_dir, 'roundtrip.txt') + + np.testing.assert_array_equal(read_stimids, original_stimids) + np.testing.assert_array_almost_equal(read_stimtimes, original_stimtimes, decimal=5) + + for i in range(len(original_frametimes)): + np.testing.assert_array_almost_equal( + read_frametimes[i], original_frametimes[i], decimal=5)