Skip to content
This repository was archived by the owner on Dec 14, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion daceml/autodiff/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dace.sdfg import utils as sdfg_utils
from dace.transformation.passes import analysis

AccessSets = Dict[SDFGState, Tuple[Set[str], Set[str]]]
AccessSets = Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]]


def dependency_analysis(sdfg: SDFG) -> Dict[str, Set[str]]:
Expand All @@ -37,6 +37,7 @@ def dependency_analysis(sdfg: SDFG) -> Dict[str, Set[str]]:

def inverse_reachability(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]:
reachability = analysis.StateReachability().apply_pass(sdfg, {})
reachability = reachability[sdfg.sdfg_id]
inverse_reachability = collections.defaultdict(set)
for pred, successors in reachability.items():
for successor in successors:
Expand All @@ -62,6 +63,8 @@ def is_previously_written(sdfg: SDFG,

if access_sets is None:
access_sets = analysis.AccessSets().apply_pass(sdfg, {})
# We're only interested in the top level SDFG
access_sets = access_sets[sdfg.sdfg_id]

reachable = inverse_reachability(sdfg)

Expand Down
8 changes: 4 additions & 4 deletions daceml/autodiff/implementations/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def backward(
access_output_grad = nstate.add_read(result.given_grad_names["Output"])

def create_access_node(connector: str) -> nd.AccessNode:
nsdfg.add_datadesc(
connector,
butils.forward_in_desc_with_name(forward_node, context,
connector))
forward_desc = butils.forward_in_desc_with_name(
forward_node, context, connector)
desc = copy.deepcopy(forward_desc)
nsdfg.add_datadesc(connector, desc)
return nstate.add_read(connector)

# the forward inputs we will require
Expand Down
3 changes: 3 additions & 0 deletions daceml/autodiff/library/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ def determine_forward_state(
if access_sets is None:
access_sets = analysis.AccessSets().apply_pass(sdfg, {})

# We're only interested in the top level sdfg
access_sets = access_sets[sdfg.sdfg_id]

candidate_states = []
for cand in sdfg.states():
_, write_set = access_sets[cand]
Expand Down
2 changes: 1 addition & 1 deletion daceml/onnx/op_implementations/cudnn_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import dace
from dace import SDFGState, nodes as nd, SDFG, dtypes, data as dt
from dace.codegen.targets.common import sym2cpp
from dace.codegen.common import sym2cpp

from daceml.onnx import environments
from daceml.onnx.converters import clean_onnx_name
Expand Down
2 changes: 1 addition & 1 deletion daceml/torch/dispatchers/cpp_torch_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dace.codegen.codeobject import CodeObject
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.targets.common import sym2cpp
from dace.codegen.common import sym2cpp

from daceml.autodiff import BackwardResult
from daceml.torch.environments import PyTorch
Expand Down
14 changes: 9 additions & 5 deletions daceml/torch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,11 @@ def forward(self, *actual_inputs):

def __sdfg__(self, *args):
if self.sdfg is None:
raise ValueError("""
Using a PyTorch model in a DaceProgram requires that the model is initialized first.
Either call this model using some inputs, or pass 'dummy_inputs' to the constructor.
""")
raise ValueError(
"Using a PyTorch model in a DaceProgram requires"
" that the model is initialized first. Either call this model"
" using some inputs, or pass 'dummy_inputs' to the constructor."
)

for name, param in self._exported_parameters.items():
onnx_name = clean_onnx_name(name)
Expand Down Expand Up @@ -440,7 +441,10 @@ def _add_gradient_buffers(self) -> List[str]:
def __sdfg_signature__(self):
if self.dace_model is None:
raise ValueError(
"Can't determine signature before SDFG is generated.")
"Using a PyTorch model in a DaceProgram requires"
" that the model is initialized first. Either call this model"
" using some inputs, or pass 'dummy_inputs' to the constructor."
)
inputs = [clean_onnx_name(name) for name in self.dace_model.inputs]
grad_buffers = self._add_gradient_buffers()
inputs.extend(grad_buffers)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
package_data={'': (['*.cpp'] + runtime_files)},
install_requires=[
'dace@git+https://github.com/spcl/dace.git@backport-0.13-fixes',
'dace@git+https://github.com/spcl/dace.git',
'onnx == 1.7.0', # we support opset v12
'torch',
'protobuf == 3.19',
Expand Down