-
Notifications
You must be signed in to change notification settings - Fork 37
Expand file tree
/
Copy pathmultitensor_systems.py
More file actions
212 lines (186 loc) · 7.54 KB
/
multitensor_systems.py
File metadata and controls
212 lines (186 loc) · 7.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import numpy as np
import torch
np.random.seed(0)
torch.manual_seed(0)
NUM_DIMENSIONS = 5 # We have 5 dimensions: examples, colors, directions, x, y
class MultiTensorSystem:
"""
A system for handling multi-dimensional configurations of 'examples',
'colors', 'directions', and (x, y) positions. This class can generate
and iterate through valid dimension combinations.
"""
def __init__(self, n_examples, n_colors, n_x, n_y, task):
"""
Args:
n_examples (int): Number of examples.
n_colors (int): Number of colors.
n_x (int): Size of the X dimension.
n_y (int): Size of the Y dimension.
task: ARC task that the multitensor system is xreated for
"""
self.n_examples = n_examples
self.n_colors = n_colors
self.n_directions = 8
self.n_x = n_x
self.n_y = n_y
self.task = task
self.dim_lengths = [self.n_examples, self.n_colors,
self.n_directions, self.n_x, self.n_y]
def dims_valid(self, dims):
"""
Checks whether a given dimension combination is valid.
Validity rules:
1. If any of x/y is set (dims[3] or dims[4]), then examples (dims[0]) must also be set.
2. Sum of dims[1:] cannot be zero (i.e., at least color, direction, or x/y must be set).
Args:
dims (list[int]): A list of 0/1 flags indicating which dimensions are included.
Returns:
bool: Whether the dimension combination is valid.
"""
# If x or y is set, then examples must also be set.
if (dims[3] or dims[4]) and not dims[0]:
return False
# At least one of [color, direction, x, y] must be set.
if sum(dims[1:]) == 0:
return False
return True
def shape(self, dims, extra_dim=None):
"""
Creates a shape tuple for PyTorch or NumPy based on which dimensions are used.
Args:
dims (list[int]): A list of 0/1 flags for each dimension.
extra_dim (int, optional): An additional dimension to be appended at the end.
Returns:
list[int]: The computed shape.
"""
shape = []
for dim_index, length in enumerate(self.dim_lengths):
if dims[dim_index]:
shape.append(length)
if extra_dim is not None:
shape.append(extra_dim)
return shape
def _generate_dims_combinations(self):
"""Generate all possible 5-bit dimension combinations (from 0..31)."""
for i in range(2 ** NUM_DIMENSIONS):
# For each of the 5 bits in i, compute dims array
dims = [(i >> bit) & 1 for bit in range(NUM_DIMENSIONS)]
yield dims
def __iter__(self):
"""
Yields valid dims.
"""
for dims in self._generate_dims_combinations():
if self.dims_valid(dims):
yield dims
def _make_multitensor(self, default, index):
"""
Recursively creates a nested list (tree-like) of shape [2 x 2 x 2 x 2 x 2]
(depth = NUM_DIMENSIONS) if `index < NUM_DIMENSIONS`.
Once index == NUM_DIMENSIONS, returns `default`.
Args:
default (Any): The value to return at the leaf of the recursion.
index (int): Current depth.
Returns:
list or default: A nested list structure or the default object if at depth.
"""
if index == NUM_DIMENSIONS:
return default
return [self._make_multitensor(default, index+1) for _ in range(2)]
def make_multitensor(self, default=None):
"""
Create a multitensor with a default object to place at every index.
Args:
default (Any): The default value to place at all leaves. Default: None
Returns:
MultiTensor: A multitensor with the default object at every index.
"""
return MultiTensor(self._make_multitensor(default, 0), self)
class MultiTensor:
"""
Wrapper for a nested data structure that can be indexed by a 5-element dims array.
"""
def __init__(self, data, multitensor_system):
"""
Args:
data (nested list): The nested list holding the actual data.
multitensor_system (MultiTensorSystem): The system this MultiTensor belongs to.
"""
self.data = data
self.multitensor_system = multitensor_system
def __getitem__(self, dims):
"""
Retrieve the data at a specific 5-dimensional index.
Args:
dims (list[int]): 5-element array (0 or 1) indicating path in nested lists.
Returns:
Any: The data stored at that nested location.
"""
d = self.data
for dim_val in dims:
d = d[dim_val]
return d
def __setitem__(self, dims, value):
"""
Set the data at a specific 5-dimensional index.
Args:
dims (list[int]): 5-element array (0 or 1) indicating path in nested lists.
value (Any): The value to store.
"""
d = self.data
for dim_val in dims[:-1]:
d = d[dim_val]
d[dims[-1]] = value
def multify(fn):
"""
Decorator that applies a function to all valid dimension combinations
if any arguments are MultiTensor instances.
"""
def wrapper(*args, **kwargs):
# Check if we should perform multi-mode or not
multitensor_system = None
multi_mode = False
# Identify if any arg or kwarg is a MultiTensor
for arg in args:
if isinstance(arg, MultiTensor):
multi_mode = True
multitensor_system = arg.multitensor_system
if not multi_mode:
for value in kwargs.values():
if isinstance(value, MultiTensor):
multi_mode = True
multitensor_system = value.multitensor_system
break
# If none of the args/kwargs are MultiTensor, just call the function directly
if not multi_mode:
return fn(None, *args, **kwargs)
# We do have MultiTensor arguments, so let's build a new MultiTensor result
# of the same shape and fill it by iterating over valid dimension combos.
def iterate_and_assign(multitensor_system, result_data):
"""Helper to iterate over dims and assign function outputs."""
for dims in multitensor_system:
# Build per-dims argument list
new_args = []
for arg in args:
if isinstance(arg, MultiTensor):
new_args.append(arg[dims])
else:
new_args.append(arg)
# Build per-dims kwargs
new_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, MultiTensor):
new_kwargs[key] = value[dims]
else:
new_kwargs[key] = value
# Call the user function on these "scalar" values
output = fn(dims, *new_args, **new_kwargs)
# Assign back to the result MultiTensor
# This goes step by step into result_data
result_data[dims] = output
# Create an empty nested list structure
result_data = multitensor_system.make_multitensor()
iterate_and_assign(multitensor_system, result_data)
# Return a MultiTensor wrapping the nested result
return result_data
return wrapper