-
Notifications
You must be signed in to change notification settings - Fork 154
Description
Description
Currently, IfElse inplace optimization is limited regarding inplace operations.
Because of how PyTensor checks for valid inplace operations in graphs, it cannot declare that an output is a view of both branches (e.g., view_map = {0: [1, 2]}). The dependency checker does not support "one output: view of multiple inputs" logic, even though mutually exclusive execution ensures this aliasing would never happen simultaneously.
Right now, IfElse is restricted to either be as_view=False (copying everything) or as_view=True (which currently only aliases the "True" branch by default).
Proposed improvements
-
Fine-grained inplace specifications
The inplace logic could be made more granular than "all true branches" or "none".- For instance, branches containing constants can never be destroyed / returned directly, whereas valid intermediate tensors could be.
- It should be possible to define for each output exactly which branch input is allowed to be returned as a view. This would allow IfElse to return a view of the "False" branch if the "True" branch cannot be aliased (or vice versa), rather than defaulting strictly to the "True" branch.
-
"Likely" branch hint
Alternatively, a likely hint flag could be added to the Op. This would allow users or the optimizer to specify which branch is executed most frequently, making it the priority target for inplace optimization (viewing).
Code example
The following function require a deepcopy, because we don't return aliases of constants
import pytensor
import pytensor.tensor as pt
x = pt.vector("x", shape=(10,))
out = pytensor.ifelse(x.sum() > 0, [x.zeros_like()], [x * 2])
pytensor.function([x], out).dprint(print_memory_map=True)DeepCopyOp [id A] 4
└─ if{inplace} [id B] v={0: [1]} 3
├─ Gt [id C] 2
│ ├─ Sum{axes=None} [id D] 1
│ │ └─ x [id E]
│ └─ 0 [id F]
├─ [0. 0. 0. ... 0. 0. 0.] [id G]
└─ Mul [id H] 0
├─ [2.] [id I]
└─ x [id E]
Whereas the following returns the intermediate x * 2 directly (if that branch is triggered)
out = pytensor.ifelse(x.sum() < 0, [x * 2], [x.zeros_like()])
pytensor.function([x], out).dprint(print_memory_map=True)if{inplace} [id A] v={0: [1]} 3
├─ Lt [id B] 2
│ ├─ Sum{axes=None} [id C] 1
│ │ └─ x [id D]
│ └─ 0 [id E]
├─ Mul [id F] 0
│ ├─ [2.] [id G]
│ └─ x [id D]
└─ [0. 0. 0. ... 0. 0. 0.] [id H]
Note in both cases the second branch (in case the condition is false) is copied, be it a constant or not.
In this case we could rewrite by negating the condition and swapping the order of inputs, but in the general case with multiple inputs, it may not be obvious that one dominates the other, whereas pairwise we can make better decisions.
Also, as I mentioned the likely flag could allow users to tell us what side they expect, and we could use it to reduce the chance / frequency of copies.