編寫自定義數(shù)組容器

NumPy 的分派機制(在numpy版本v1.16中引入)是編寫與numpy API兼容并提供numpy功能的自定義實現(xiàn)的自定義N維數(shù)組容器的推薦方法。
應(yīng)用包括 dask 數(shù)組(分布在多個節(jié)點上的N維數(shù)組)
cupy

數(shù)組(GPU上的N維數(shù)組)。

為了獲得編寫自定義數(shù)組容器的感覺,我們將從一個簡單的示例開始,該示例具有相當(dāng)狹窄的實用程序,但說明了所涉及的概念。

>>> import numpy as np
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self):
...         return self._i * np.eye(self._N)
...

我們的自定義數(shù)組可以實例化,如下所示:

>>> arr = DiagonalArray(5, 1)
>>> arr
DiagonalArray(N=5, value=1)

我們可以使用 numpy.array

numpy.asarray

, 轉(zhuǎn)換為numpy數(shù)組,這將調(diào)用它的 __array__ 方法來獲得標(biāo)準 numpy.ndarray。

>>> np.asarray(arr)
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

如果我們使用 numpy 函數(shù)對 arr 進行操作,numpy 將再次使用 __array__接口將其轉(zhuǎn)換為數(shù)組,然后以通常的方式應(yīng)用該函數(shù)。

>>> np.multiply(arr, 2)
array([[2., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [0., 0., 2., 0., 0.],
       [0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 2.]])

注意,返回類型是標(biāo)準 numpy.ndarray。

>>> type(arr)
numpy.ndarray

我們?nèi)绾瓮ㄟ^此函數(shù)傳遞我們的自定義數(shù)組類型?Numpy允許類指示它希望通過交互 __array_ufunc____array_function__ 以自定義方式處理計算。
讓我們一次拿一個,從 __array_ufunc__ 開始。
此方法涵蓋 Universal functions (ufunc),
這是一類函數(shù),包括例如 numpy.multiply
numpy.sin。

_array_ufunc_ 獲得:

  • ufunc, 一個類似 numpy.multiply 的函數(shù)
  • method,一個字符串,區(qū)分 numpy.multiply(...)。
    以及numpy.multiy.outer、numpy.multiy.accumate等變體。對于常見情況,numpy.multiply(...)method='__call__'。
  • inputs, 可能是不同類型的混合
  • kwargs, 傳遞給函數(shù)的關(guān)鍵字參數(shù)

對于這個例子,我們將只處理方法 '__call__

>>> from numbers import Number
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self):
...         return self._i * np.eye(self._N)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...

現(xiàn)在讓我們的自定義數(shù)組類型通過numpy的函數(shù)。

>>> arr = DiagonalArray(5, 1)
>>> np.multiply(arr, 3)
DiagonalArray(N=5, value=3)
>>> np.add(arr, 3)
DiagonalArray(N=5, value=4)
>>> np.sin(arr)
DiagonalArray(N=5, value=0.8414709848078965)

此時 arr + 3 不起作用。

>>> arr + 3
TypeError: unsupported operand type(s) for *: 'DiagonalArray' and 'int'

為了支持它,我們需要定義Python接口 __add____lt__ 等,以便調(diào)度到相應(yīng)的ufunc。 我們可以通過繼承mixin NDArrayOperatorsMixin






來方便地實現(xiàn)這一點。

>>> import numpy.lib.mixins
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self):
...         return self._i * np.eye(self._N)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...
>>> arr = DiagonalArray(5, 1)
>>> arr + 3
DiagonalArray(N=5, value=4)
>>> arr > 0
DiagonalArray(N=5, value=True)

現(xiàn)在讓我們來解決 __array_function__。 我們將創(chuàng)建將 numpy 函數(shù)映射到我們的自定義變體的 dict。

>>> HANDLED_FUNCTIONS = {}
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self):
...         return self._i * np.eye(self._N)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 # In this case we accept only scalar numbers or DiagonalArrays.
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != self._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = self._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...    def __array_function__(self, func, types, args, kwargs):
...        if func not in HANDLED_FUNCTIONS:
...            return NotImplemented
...        # Note: this allows subclasses that don't override
...        # __array_function__ to handle DiagonalArray objects.
...        if not all(issubclass(t, self.__class__) for t in types):
...            return NotImplemented
...        return HANDLED_FUNCTIONS[func](*args, **kwargs)
...

一個便捷的模式是定義一個可用于向 HANDLED_FUNCTIONS 添加函數(shù)的裝飾器 實現(xiàn)。

>>> def implements(np_function):
...    "Register an __array_function__ implementation for DiagonalArray objects."
...    def decorator(func):
...        HANDLED_FUNCTIONS[np_function] = func
...        return func
...    return decorator
...

現(xiàn)在我們?yōu)?DiagonalArray 編寫numpy函數(shù)的實現(xiàn)。
為了完整性,為了支持使用 arr.sum()
添加一個調(diào)用 numpy.sum(self) 的方法 sum,對于 mean 來說也是一樣的。

>>> @implements(np.sum)
... def sum(a):
...     "Implementation of np.sum for DiagonalArray objects"
...     return arr._i * arr._N
...
>>> @implements(np.mean)
... def sum(a):
...     "Implementation of np.mean for DiagonalArray objects"
...     return arr._i / arr._N
...
>>> arr = DiagonalArray(5, 1)
>>> np.sum(arr)
5
>>> np.mean(arr)
0.2

如果用戶嘗試使用 HANDLED_FUNCTIONS 中未包含的任何numpy函數(shù),
則numpy將引發(fā) TypeError,表示不支持此操作。
例如,連接兩個 DiagonalArrays 不會產(chǎn)生另一個對角線數(shù)組,因此不支持它。

>>> np.concatenate([arr, arr])
TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]

另外,我們的 summean 實現(xiàn)不接受numpy實現(xiàn)的可選參數(shù)。

>>> np.sum(arr, axis=0)
TypeError: sum() got an unexpected keyword argument 'axis'

用戶總是可以選擇使用 numpy.asarray 轉(zhuǎn)換為普通的 numpy.asarray,并使用標(biāo)準的numpy。

>>> np.concatenate([np.asarray(arr), np.asarray(arr)])
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

有關(guān)自定義數(shù)組容器的更完整工作示例,請參閱dask源代碼

cupy源代碼

另外可以看一下 NEP 18

作者:柯廣的網(wǎng)絡(luò)日志 ? 編寫自定義數(shù)組容器


微信公眾號:Java大數(shù)據(jù)與數(shù)據(jù)倉庫