diff --git a/dataframely/columns/array.py b/dataframely/columns/array.py index e1004493..5d582a3d 100644 --- a/dataframely/columns/array.py +++ b/dataframely/columns/array.py @@ -126,9 +126,26 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: n_elements = n * math.prod(self.shape) all_elements = self.inner.sample(generator, n_elements) + # For complex inner types (List, Array), reshape doesn't work properly + # because it tries to reshape based on primitive element count, not the + # count of complex elements. We need to manually chunk and construct. + inner_dtype = self.inner.dtype + if isinstance(inner_dtype, (pl.List, pl.Array)): + # Chunk the elements into groups of shape size + chunk_size = math.prod(self.shape) + chunks = [] + for i in range(n): + start = i * chunk_size + chunk = all_elements.slice(start, chunk_size).to_list() + chunks.append(chunk) + result = pl.Series(chunks, dtype=self.dtype) + else: + # For scalar and struct types, reshape works correctly + result = all_elements.reshape((n, *self.shape)) + # Finally, apply a null mask return generator._apply_null_mask( - all_elements.reshape((n, *self.shape)), + result, null_probability=self._null_probability, ) diff --git a/tests/columns/test_sample.py b/tests/columns/test_sample.py index 2c51d1c1..245e0f38 100644 --- a/tests/columns/test_sample.py +++ b/tests/columns/test_sample.py @@ -199,6 +199,24 @@ def test_sample_struct(generator: Generator) -> None: assert len(samples) == 10_000 +@pytest.mark.parametrize(("arr_size", "n_samples"), [(1, 1), (2, 1), (3, 2), (2, 10)]) +def test_sample_array_list(arr_size: int, n_samples: int, generator: Generator) -> None: + """Test sampling for Array(List(...)) columns.""" + column = dy.Array(dy.List(dy.Bool()), arr_size) + samples = sample_and_validate(column, generator, n=n_samples) + assert len(samples) == n_samples + + +def test_sample_nested_array(generator: Generator) -> None: + """Test sampling for Array(Array(...)) columns.""" + column = dy.Array(dy.Array(dy.Int64(), 2), 3) + samples = sample_and_validate(column, generator, n=10) + assert len(samples) == 10 + # Check that the shape is correct (accounting for nulls) + non_null_lengths = samples.arr.len().drop_nulls() + assert all(non_null_lengths == 3) + + # --------------------------------------- UTILS -------------------------------------- #