diff --git a/phaser/utils/_jax_kernels.py b/phaser/utils/_jax_kernels.py index 3c19f91..94f5be5 100644 --- a/phaser/utils/_jax_kernels.py +++ b/phaser/utils/_jax_kernels.py @@ -6,6 +6,8 @@ import jax # pyright: ignore[reportMissingImports] import jax.numpy as jnp # pyright: ignore[reportMissingImports] +from phaser.utils.image import _InterpBoundaryMode + Device: t.TypeAlias = t.Any @@ -14,7 +16,7 @@ def to_nd(arr: jax.Array, n: int) -> jax.Array: if arr.ndim > n: arr = arr.reshape(-1, *arr.shape[arr.ndim - n + 1:]) elif arr.ndim < n: - arr = jax.lax.expand_dims(arr, [0] * (n - arr.ndim)) + arr = jax.lax.expand_dims(arr, tuple(range((n - arr.ndim)))) return arr @@ -115,6 +117,46 @@ def affine_transform( )(to_nd(input, n_axes + 1)).reshape((*input.shape[:-n_axes], *output_shape)) +# convert scipy boundary mode to numpy.pad mode +_INTERP_TO_PAD: t.Dict[_InterpBoundaryMode, str] = { + 'reflect': 'symmetric', + 'mirror': 'reflect', + 'nearest': 'edge', + 'grid-mirror': 'reflect', + 'grid-wrap': 'wrap', + 'grid-constant': 'constant', +} + + +def convolve1d( + arr: jnp.ndarray, weights: jnp.ndarray, axis: int = -1, *, + mode: _InterpBoundaryMode, cval: float = 0. +) -> jnp.ndarray: + r = len(weights) // 2 + pad_mode = _INTERP_TO_PAD.get(mode, mode) + pad_kwargs = {'constant_values': cval} if pad_mode == 'constant' else {} + + # transpose, pad + arr = jnp.moveaxis(arr, axis, -1) + out_shape_t = arr.shape + arr = jnp.pad( + arr, + ((0, 0),) * (arr.ndim-1) + ((len(weights) - r - 1, r),), + mode=pad_mode, **pad_kwargs + ) + + # convolve + arr = jax.lax.conv_general_dilated( + arr.reshape(-1, 1, arr.shape[-1]), + to_3d(jnp.flip(weights)).astype(arr.dtype), + window_strides=(1,), padding='VALID', + dimension_numbers=('NCW', 'OIW', 'NCW'), + ) + + # unflatten, untranspose + return jnp.moveaxis(arr.reshape(out_shape_t), -1, axis) + + def get_devices() -> t.Tuple[Device, ...]: devices = [] diff --git a/phaser/utils/_torch_kernels.py b/phaser/utils/_torch_kernels.py index ba21d3f..3c5930a 100644 --- a/phaser/utils/_torch_kernels.py +++ b/phaser/utils/_torch_kernels.py @@ -199,27 +199,89 @@ def split( return torch.split(arr, arr.shape[axis] // sections, axis) +def _pad_idxs_edge(idxs: torch.Tensor, left: int, right: int, size: int) -> t.Tuple[torch.Tensor, torch.Tensor]: + return (idxs.new_zeros(left), idxs.new_full((right,), size - 1)) + + +def _pad_idxs_wrap(idxs: torch.Tensor, left: int, right: int, size: int) -> t.Tuple[torch.Tensor, torch.Tensor]: + left_idx = torch.arange(-left, 0, dtype=idxs.dtype, device=idxs.device) % size + right_idx = torch.arange(size, size + right, dtype=idxs.dtype, device=idxs.device) % size + return (left_idx, right_idx) + + +def _pad_idxs_reflect(idxs: torch.Tensor, left: int, right: int, size: int) -> t.Tuple[torch.Tensor, torch.Tensor]: + if size == 1: + return (idxs.new_zeros(left), idxs.new_zeros(right)) + period = 2 * (size - 1) + def fold(i: torch.Tensor) -> torch.Tensor: + i = i % period + return (size - 1) - ((i - (size - 1)).abs()) + left_idx = fold(torch.arange(left, 0, -1, dtype=idxs.dtype, device=idxs.device)) + right_idx = fold(torch.arange(size, size + right, dtype=idxs.dtype, device=idxs.device)) + return (left_idx, right_idx) + + +def _pad_idxs_symmetric(idxs: torch.Tensor, left: int, right: int, size: int) -> t.Tuple[torch.Tensor, torch.Tensor]: + period = 2 * size + def fold(i: torch.Tensor) -> torch.Tensor: + i = i % period + return torch.where(i < size, i, period - 1 - i) + left_idx = fold(torch.arange(-left, 0, dtype=idxs.dtype, device=idxs.device) % period) + right_idx = fold(torch.arange(size, size + right, dtype=idxs.dtype, device=idxs.device)) + return (left_idx, right_idx) + + +_FAST_PAD_MODES: t.FrozenSet[str] = frozenset(('constant', 'edge', 'reflect', 'wrap')) +_PAD_MODES: t.FrozenSet[str] = _FAST_PAD_MODES | frozenset(('symmetric',)) +_MAKE_PAD_IDXS: t.Dict[_PadMode, t.Callable[[torch.Tensor, int, int, int], t.Tuple[torch.Tensor, torch.Tensor]]] = { + 'edge': _pad_idxs_edge, + 'wrap': _pad_idxs_wrap, + 'reflect': _pad_idxs_reflect, + 'symmetric': _pad_idxs_symmetric, +} + + def pad( arr: torch.Tensor, pad_width: t.Union[int, t.Tuple[int, int], t.Sequence[t.Tuple[int, int]]], /, *, mode: _PadMode = 'constant', cval: float = 0. ) -> torch.Tensor: - if mode not in ('constant', 'edge', 'reflect', 'wrap'): + if mode not in _PAD_MODES: raise ValueError(f"Unsupported padding mode '{mode}'") pad = (pad_width, pad_width) if isinstance(pad_width, int) else pad_width if isinstance(pad[0], int): - pad = (pad,) + pad = (t.cast(t.Tuple[int, int], pad),) if len(pad) == 1: pad = tuple(pad) * arr.ndim elif len(pad) != arr.ndim: raise ValueError(f"Invalid `pad_width` '{pad_width}'.") - pad = tuple(itertools.chain.from_iterable(t.cast(t.Sequence[t.Tuple[int, int]], reversed(pad)))) + # check for fast path (F.pad) + # checks supported mode, dim <= 3, pad lengths all less than array size + # constant padding has no restrictions + if mode == 'constant' or ( + mode in _FAST_PAD_MODES + and arr.ndim <= 3 + and all(p <= s - 1 if isinstance(p, int) else all(p1 <= s - 1 for p1 in p) for (p, s) in zip(pad, arr.shape)) + ): + pad = tuple(itertools.chain.from_iterable(t.cast(t.Sequence[t.Tuple[int, int]], reversed(pad)))) + kwargs = {'value': cval} if mode == 'constant' else {} + return _MockTensor(F.pad(arr.reshape(1, *arr.shape), pad, mode=_PAD_MODE_MAP[mode], **kwargs)[0]) - kwargs = {'value': cval} if mode == 'constant' else {} - return _MockTensor(F.pad(arr, pad, mode=_PAD_MODE_MAP[mode], **kwargs)) + # slow path + for dim, (p, size) in enumerate(zip(pad, arr.shape)): + (left, right) = (p, p) if isinstance(p, int) else p + if left == 0 and right == 0: + continue + + idxs = torch.arange(size, dtype=torch.int64, device=arr.device) + (left_idx, right_idx) = _MAKE_PAD_IDXS[mode](idxs, left, right, size) + idxs = torch.cat([left_idx, idxs, right_idx]).to(arr.device) + arr = arr.index_select(dim, idxs) + + return _MockTensor(arr) def unwrap(arr: torch.Tensor, discont: t.Optional[float] = None, axis: int = -1, *, @@ -406,42 +468,41 @@ def _map_coordinates_constant( return result.type(arr.dtype) -_INTERP_TO_TORCH_PAD: t.Dict[_InterpBoundaryMode, str] = { - 'nearest': 'replicate', - 'wrap': 'circular', - 'grid-wrap': 'circular', - 'constant': 'constant', - 'grid-constant': 'constant', +# convert scipy boundary mode to numpy.pad mode +_INTERP_TO_PAD: t.Dict[_InterpBoundaryMode, str] = { + 'reflect': 'symmetric', 'mirror': 'reflect', + 'nearest': 'edge', + 'grid-mirror': 'reflect', + 'grid-wrap': 'wrap', + 'grid-constant': 'constant', } -def _convolve1d( +def convolve1d( arr: torch.Tensor, weights: torch.Tensor, axis: int, *, mode: _InterpBoundaryMode, cval: float = 0. ) -> torch.Tensor: - pad_mode = _INTERP_TO_TORCH_PAD.get(mode) - if pad_mode is None: - raise ValueError(f"Pad mode '{mode}' not implemented for torch backend") - - # reorder to last axis - reorder = axis != arr.ndim - 1 - if reorder: - arr = torch.moveaxis(arr, axis, -1) - leading_shape = arr.shape[:-1] - arr = arr.reshape((-1, arr.shape[-1])) r = len(weights) // 2 - - # torch's conv1d is actually a correlation - weights = weights.flip(0) - - # TODO: this will fail for some pads where weights is large, investigate further - arr = F.pad(arr, (len(weights) - r - 1, r), mode=pad_mode, value=cval) + pad_mode = t.cast(_PadMode, _INTERP_TO_PAD.get(mode, mode)) + + # reorder to last axis, pad + arr = torch.moveaxis(arr, axis, -1) + out_shape_t = arr.shape + # pad + arr = pad( + arr, + ((0, 0),) * (arr.ndim-1) + ((len(weights) - r - 1, r),), + mode=pad_mode, cval=cval + ) + + # convolve arr = F.conv1d( - arr[:, None, :], weights[None, None, :] - )[:, 0].reshape((*leading_shape, -1)) + arr.reshape((-1, 1, arr.shape[-1])), + weights.flip(0).to(arr.dtype)[None, None, :] + ).reshape(out_shape_t) - return torch.moveaxis(arr, -1, axis) if reorder else arr + return torch.moveaxis(arr, -1, axis) def get_devices() -> t.Tuple[torch.device, ...]: diff --git a/phaser/utils/image.py b/phaser/utils/image.py index b63d25a..20ecc5a 100644 --- a/phaser/utils/image.py +++ b/phaser/utils/image.py @@ -278,17 +278,6 @@ def square_pixel_transfer(shape: t.Tuple[int, int], *, xp: t.Any = None) -> NDAr return xp.sinc(ky) * xp.sinc(kx) -# convert scipy boundary mode to numpy.pad mode -_INTERP_TO_PAD: t.Dict[_InterpBoundaryMode, str] = { - 'reflect': 'symmetric', - 'mirror': 'reflect', - 'nearest': 'edge', - 'grid-mirror': 'reflect', - 'grid-wrap': 'wrap', - 'grid-constant': 'constant', -} - - def _canonicalize_axis(axis: int, num_dims: int) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" axis = axis.__index__() @@ -312,34 +301,23 @@ def convolve1d( axis = _canonicalize_axis(axis, arr.ndim) if xp_is_torch(xp): - from ._torch_kernels import _convolve1d, _MockTensor + from ._torch_kernels import convolve1d, _MockTensor - return t.cast(NDArray[NumT], _convolve1d( + return t.cast(NDArray[NumT], convolve1d( t.cast(_MockTensor, arr), t.cast(_MockTensor, weights), axis=axis, mode=mode, cval=cval )) - scipy = get_scipy_module(arr, weights) - if xp_is_jax(xp): - r = len(weights) // 2 - pad_mode = _INTERP_TO_PAD.get(mode, mode) - pad_kwargs = {'constant_values': cval} if pad_mode == 'constant' else {} - - pad = tuple( - (len(weights) - r - 1, r) if i == axis else (0, 0) - for i in range(arr.ndim) - ) - weights = weights[tuple( - slice(None) if i == axis else None - for i in range(arr.ndim) - )] - # TODO: use jax.lax.conv_general_dilated directly - return scipy.signal.convolve( - xp.pad(arr, pad, mode=pad_mode, **pad_kwargs), # type: ignore - weights, mode='valid', method='direct' - ).astype(arr.dtype) + from ._jax_kernels import convolve1d + + return t.cast(NDArray[NumT], convolve1d( + arr, weights, axis, # type: ignore + mode=mode, cval=cval + )) + + scipy = get_scipy_module(arr, weights) return scipy.ndimage.convolve1d( arr, weights, axis, mode=mode, cval=cval diff --git a/phaser/utils/num.py b/phaser/utils/num.py index 3ca6639..a74f58c 100644 --- a/phaser/utils/num.py +++ b/phaser/utils/num.py @@ -754,7 +754,7 @@ def abs2(x: ArrayLike) -> NDArray[numpy.floating]: return x.real**2 + x.imag**2 # type: ignore -_PadMode: t.TypeAlias = t.Literal['constant', 'edge', 'reflect', 'wrap'] +_PadMode: t.TypeAlias = t.Literal['constant', 'edge', 'reflect', 'wrap', 'symmetric'] @t.overload @@ -778,12 +778,12 @@ def pad( xp = get_array_module(arr) if xp_is_torch(xp): - pass - #from ._torch_kernels import pad - #return pad(arr, pad_width, mode=mode, cval=cval) # type: ignore - - return xp.pad(arr, pad_width, mode=mode, constant_values=cval) + from ._torch_kernels import pad + return pad(arr, pad_width, mode=mode, cval=cval) # type: ignore + if mode == 'constant': + return xp.pad(arr, pad_width, mode=mode, constant_values=cval) + return xp.pad(arr, pad_width, mode=mode) @t.overload diff --git a/tests/test_image.py b/tests/test_image.py index d3e5199..d78b84e 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -84,6 +84,14 @@ def test_resample( ([[[1, 2], [3, 4]], [[2, 3], [4, 5]], [[3, 4], [5, 6]]], [1, 2, 3], 0), ([[[1, 2], [3, 4]], [[2, 3], [4, 5]], [[3, 4], [5, 6]]], [1, 2, 3], -1), ([1+1.j, 2+2.j, 3+3.j], [1-1.j, 2-1.j], 0), + # casting of weights + ([1+1.j, 2+2.j, 3+3.j], [1.0, 2.0], 0), + ([[[1, 2], [3, 4]], [[2, 3], [4, 5]], [[3, 4], [5, 6]]], [1, 2, 3], 1), + ([1, 2, 3, 4, 5], [2], 0), + # kernel longer than array + ([1, 2, 3], [1, 2, 3, 4, 5, 6, 7], 0), + # length-1 along conv axis + ([[[1, 2], [3, 4]]], [1, 2, 3], 0), ]) @pytest.mark.parametrize(('mode', 'cval'), [ ('constant', 1.0), ('nearest', 0.0), ('mirror', 0.0), @@ -93,9 +101,6 @@ def test_convolve1d( arr, weights, axis, mode, cval, backend: BackendName, ): - if mode == 'reflect' and backend == 'torch': - pytest.xfail("'reflect' not supported on torch") - arr = numpy.asarray(arr) weights = numpy.asarray(weights) diff --git a/tests/test_num.py b/tests/test_num.py index b60756c..343643d 100644 --- a/tests/test_num.py +++ b/tests/test_num.py @@ -1,5 +1,7 @@ +import typing as t import numpy +from numpy.typing import ArrayLike from numpy.testing import assert_array_almost_equal, assert_array_equal import pytest @@ -12,6 +14,7 @@ fft2, ifft2, abs2, to_numpy, as_array, ufunc_outer, + pad, _PadMode ) @@ -224,3 +227,83 @@ def test_ufunc_outer(backend: BackendName): expected = numpy.multiply.outer(xs, ys) actual = to_numpy(ufunc_outer(xp.multiply, xp.array(xs), xp.array(ys))) assert_array_equal(expected, actual) + + +@with_backends('numpy', 'jax', 'cupy', 'torch') +@pytest.mark.parametrize(('mode', 'pad_width', 'expected'), [ + ('constant', 0, [1, 2, 3, 4, 5]), + ('constant', 2, [3, 3, 1, 2, 3, 4, 5, 3, 3]), + ('constant', (2, 1), [3, 3, 1, 2, 3, 4, 5, 3]), + ('constant', (0, 3), [1, 2, 3, 4, 5, 3, 3, 3]), + ('constant', (3, 0), [3, 3, 3, 1, 2, 3, 4, 5]), + ('edge', 2, [1, 1, 1, 2, 3, 4, 5, 5, 5]), + ('edge', (2, 1), [1, 1, 1, 2, 3, 4, 5, 5]), + ('edge', (0, 3), [1, 2, 3, 4, 5, 5, 5, 5]), + ('edge', (3, 0), [1, 1, 1, 1, 2, 3, 4, 5]), + ('edge', 6, [1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5, 5, 5, 5, 5, 5, 5]), + ('reflect', 2, [3, 2, 1, 2, 3, 4, 5, 4, 3]), + ('reflect', (2, 1), [3, 2, 1, 2, 3, 4, 5, 4]), + ('reflect', (1, 3), [2, 1, 2, 3, 4, 5, 4, 3, 2]), + ('reflect', (0, 3), [1, 2, 3, 4, 5, 4, 3, 2]), + ('reflect', 6, [3, 4, 5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]), + ('symmetric', 2, [2, 1, 1, 2, 3, 4, 5, 5, 4]), + ('symmetric', (2, 1), [2, 1, 1, 2, 3, 4, 5, 5]), + ('symmetric', (1, 3), [1, 1, 2, 3, 4, 5, 5, 4, 3]), + ('symmetric', (0, 3), [1, 2, 3, 4, 5, 5, 4, 3]), + ('symmetric', 6, [5, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 1]), + ('wrap', 2, [4, 5, 1, 2, 3, 4, 5, 1, 2]), + ('wrap', (2, 1), [4, 5, 1, 2, 3, 4, 5, 1]), + ('wrap', (1, 3), [5, 1, 2, 3, 4, 5, 1, 2, 3]), + ('wrap', (0, 3), [1, 2, 3, 4, 5, 1, 2, 3]), + ('wrap', 6, [5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]), +]) +def test_pad_1d( + mode: _PadMode, + pad_width: t.Union[int, t.Tuple[int, int], t.Tuple[t.Tuple[int, int]]], + expected: ArrayLike, + backend: BackendName, +): + xp = get_backend_module(backend) + in_arr = xp.array([1, 2, 3, 4, 5]) + + assert_array_equal( + to_numpy(pad(in_arr, pad_width, mode=mode, cval=3)), + numpy.array(expected), + ) + + +@with_backends('jax', 'cupy', 'torch') +@pytest.mark.parametrize('mode', ('constant', 'edge', 'reflect', 'symmetric', 'wrap')) +@pytest.mark.parametrize(('shape', 'pad_width'), [ + ((4, 6), 2), + ((4, 6), (2, 3)), + ((4, 6), ((1, 2), (3, 1))), + ((4, 6), ((0, 3), (2, 0))), + ((3, 3), 4), + ((3, 5), ((4, 4), (6, 6))), + ((2, 4, 6), 2), + ((2, 4, 6), ((1, 1), (2, 3), (0, 2))), + ((2, 3, 4), 5), + ((2, 3, 4, 5), 2), + ((2, 3, 4, 5), ((1, 2), (0, 3), (2, 1), (1, 1))), + ((5,), 3), + ((1, 4), 2), + ((1, 5), ((0, 0), (2, 2))), +]) +def test_pad_nd( + mode: _PadMode, + shape: t.Sequence[int], + pad_width: t.Union[int, t.Tuple[int, int], t.Tuple[t.Tuple[int, int]]], + backend: BackendName, +): + xp = get_backend_module(backend) + + kwargs = {'constant_values': 3} if mode == 'constant' else {} + + rng = numpy.random.default_rng() + in_arr = rng.integers(1024, size=shape) + + assert_array_equal( + to_numpy(pad(xp.array(in_arr), pad_width, mode=mode, cval=3)), + numpy.pad(in_arr, pad_width, mode=mode, **kwargs) # type: ignore + ) \ No newline at end of file