Skip to content
This repository was archived by the owner on Dec 2, 2023. It is now read-only.
This repository was archived by the owner on Dec 2, 2023. It is now read-only.

Support Numpy Duck Arrays #74

@mrocklin

Description

@mrocklin

Tangent provides source-to-source automatic differentiation of functions containing Numpy syntax

In [1]: import numpy as np

In [2]: def f(x):
   ...:     return np.sum(np.exp(x)) + 1

In [3]: x = np.arange(5)
In [4]: f(x)
Out[4]: 86.7910248837216

In [5]: import tangent
In [6]: df = tangent.grad(f)
In [7]: df(x)
Out[7]: array([ 1.        ,  2.71828183,  7.3890561 , 20.08553692, 54.59815003])

It currently has a pluggable mechanism to support both numpy arrays and tensorflow arrays explicitly. However, it would be nice if it also supported other numpy-like arrays using duck typing. Currently this appears not to be the case.

In [8]: import dask.array as da
In [9]: x = da.arange(5, chunks=(2,))
In [10]: f(x)
Out[10]: dask.array<add, shape=(), dtype=float64, chunksize=()>

In [11]: _.compute()
Out[11]: 86.7910248837216

In [12]: df(x)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-12-31ac6e885892> in <module>()
----> 1 df(x)

/tmp/tmp3sxcen8j/tangent_b64e.py in dfdx(x, b_return)
      3     np_sum_np_exp_x = np.sum(np_exp_x)
      4     _return = np_sum_np_exp_x + 1
----> 5     assert tangent.shapes_match(_return, b_return
      6         ), 'Shape mismatch between return value (%s) and seed derivative (%s)' % (
      7         numpy.shape(_return), numpy.shape(b_return))

~/workspace/tangent/tangent/utils.py in shapes_match(a, b)
    627     return match
    628   else:
--> 629     shape_checker = shape_checkers[(type(a), type(b))]
    630     return shape_checker(a, b)
    631 

KeyError: (<class 'dask.array.core.Array'>, <class 'float'>)

It would be convenient if tangent could be used for other objects that "quack like a numpy.ndarray" for which there are a few today (numpy, sparse, dask.array, cupy).

cc @njsmith @shoyer @ericmjl @hameerabbasi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions