Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion phaser/utils/_jax_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import jax # pyright: ignore[reportMissingImports]
import jax.numpy as jnp # pyright: ignore[reportMissingImports]

from phaser.utils.image import _InterpBoundaryMode


Device: t.TypeAlias = t.Any

Expand All @@ -14,7 +16,7 @@ def to_nd(arr: jax.Array, n: int) -> jax.Array:
if arr.ndim > n:
arr = arr.reshape(-1, *arr.shape[arr.ndim - n + 1:])
elif arr.ndim < n:
arr = jax.lax.expand_dims(arr, [0] * (n - arr.ndim))
arr = jax.lax.expand_dims(arr, tuple(range((n - arr.ndim))))

return arr

Expand Down Expand Up @@ -115,6 +117,46 @@ def affine_transform(
)(to_nd(input, n_axes + 1)).reshape((*input.shape[:-n_axes], *output_shape))


# convert scipy boundary mode to numpy.pad mode
_INTERP_TO_PAD: t.Dict[_InterpBoundaryMode, str] = {
'reflect': 'symmetric',
'mirror': 'reflect',
'nearest': 'edge',
'grid-mirror': 'reflect',
'grid-wrap': 'wrap',
'grid-constant': 'constant',
}


def convolve1d(
arr: jnp.ndarray, weights: jnp.ndarray, axis: int = -1, *,
mode: _InterpBoundaryMode, cval: float = 0.
) -> jnp.ndarray:
r = len(weights) // 2
pad_mode = _INTERP_TO_PAD.get(mode, mode)
pad_kwargs = {'constant_values': cval} if pad_mode == 'constant' else {}

# transpose, pad
arr = jnp.moveaxis(arr, axis, -1)
out_shape_t = arr.shape
arr = jnp.pad(
arr,
((0, 0),) * (arr.ndim-1) + ((len(weights) - r - 1, r),),
mode=pad_mode, **pad_kwargs
)

# convolve
arr = jax.lax.conv_general_dilated(
arr.reshape(-1, 1, arr.shape[-1]),
to_3d(jnp.flip(weights)).astype(arr.dtype),
window_strides=(1,), padding='VALID',
dimension_numbers=('NCW', 'OIW', 'NCW'),
)

# unflatten, untranspose
return jnp.moveaxis(arr.reshape(out_shape_t), -1, axis)


def get_devices() -> t.Tuple[Device, ...]:
devices = []

Expand Down
123 changes: 92 additions & 31 deletions phaser/utils/_torch_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,27 +199,89 @@ def split(
return torch.split(arr, arr.shape[axis] // sections, axis)


def _pad_idxs_edge(idxs: torch.Tensor, left: int, right: int, size: int) -> t.Tuple[torch.Tensor, torch.Tensor]:
return (idxs.new_zeros(left), idxs.new_full((right,), size - 1))


def _pad_idxs_wrap(idxs: torch.Tensor, left: int, right: int, size: int) -> t.Tuple[torch.Tensor, torch.Tensor]:
left_idx = torch.arange(-left, 0, dtype=idxs.dtype, device=idxs.device) % size
right_idx = torch.arange(size, size + right, dtype=idxs.dtype, device=idxs.device) % size
return (left_idx, right_idx)


def _pad_idxs_reflect(idxs: torch.Tensor, left: int, right: int, size: int) -> t.Tuple[torch.Tensor, torch.Tensor]:
if size == 1:
return (idxs.new_zeros(left), idxs.new_zeros(right))
period = 2 * (size - 1)
def fold(i: torch.Tensor) -> torch.Tensor:
i = i % period
return (size - 1) - ((i - (size - 1)).abs())
left_idx = fold(torch.arange(left, 0, -1, dtype=idxs.dtype, device=idxs.device))
right_idx = fold(torch.arange(size, size + right, dtype=idxs.dtype, device=idxs.device))
return (left_idx, right_idx)


def _pad_idxs_symmetric(idxs: torch.Tensor, left: int, right: int, size: int) -> t.Tuple[torch.Tensor, torch.Tensor]:
period = 2 * size
def fold(i: torch.Tensor) -> torch.Tensor:
i = i % period
return torch.where(i < size, i, period - 1 - i)
left_idx = fold(torch.arange(-left, 0, dtype=idxs.dtype, device=idxs.device) % period)
right_idx = fold(torch.arange(size, size + right, dtype=idxs.dtype, device=idxs.device))
return (left_idx, right_idx)


_FAST_PAD_MODES: t.FrozenSet[str] = frozenset(('constant', 'edge', 'reflect', 'wrap'))
_PAD_MODES: t.FrozenSet[str] = _FAST_PAD_MODES | frozenset(('symmetric',))
_MAKE_PAD_IDXS: t.Dict[_PadMode, t.Callable[[torch.Tensor, int, int, int], t.Tuple[torch.Tensor, torch.Tensor]]] = {
'edge': _pad_idxs_edge,
'wrap': _pad_idxs_wrap,
'reflect': _pad_idxs_reflect,
'symmetric': _pad_idxs_symmetric,
}


def pad(
arr: torch.Tensor, pad_width: t.Union[int, t.Tuple[int, int], t.Sequence[t.Tuple[int, int]]], /, *,
mode: _PadMode = 'constant', cval: float = 0.
) -> torch.Tensor:
if mode not in ('constant', 'edge', 'reflect', 'wrap'):
if mode not in _PAD_MODES:
raise ValueError(f"Unsupported padding mode '{mode}'")

pad = (pad_width, pad_width) if isinstance(pad_width, int) else pad_width

if isinstance(pad[0], int):
pad = (pad,)
pad = (t.cast(t.Tuple[int, int], pad),)

if len(pad) == 1:
pad = tuple(pad) * arr.ndim
elif len(pad) != arr.ndim:
raise ValueError(f"Invalid `pad_width` '{pad_width}'.")

pad = tuple(itertools.chain.from_iterable(t.cast(t.Sequence[t.Tuple[int, int]], reversed(pad))))
# check for fast path (F.pad)
# checks supported mode, dim <= 3, pad lengths all less than array size
# constant padding has no restrictions
if mode == 'constant' or (
mode in _FAST_PAD_MODES
and arr.ndim <= 3
and all(p <= s - 1 if isinstance(p, int) else all(p1 <= s - 1 for p1 in p) for (p, s) in zip(pad, arr.shape))
):
pad = tuple(itertools.chain.from_iterable(t.cast(t.Sequence[t.Tuple[int, int]], reversed(pad))))
kwargs = {'value': cval} if mode == 'constant' else {}
return _MockTensor(F.pad(arr.reshape(1, *arr.shape), pad, mode=_PAD_MODE_MAP[mode], **kwargs)[0])

kwargs = {'value': cval} if mode == 'constant' else {}
return _MockTensor(F.pad(arr, pad, mode=_PAD_MODE_MAP[mode], **kwargs))
# slow path
for dim, (p, size) in enumerate(zip(pad, arr.shape)):
(left, right) = (p, p) if isinstance(p, int) else p
if left == 0 and right == 0:
continue

idxs = torch.arange(size, dtype=torch.int64, device=arr.device)
(left_idx, right_idx) = _MAKE_PAD_IDXS[mode](idxs, left, right, size)
idxs = torch.cat([left_idx, idxs, right_idx]).to(arr.device)
arr = arr.index_select(dim, idxs)

return _MockTensor(arr)


def unwrap(arr: torch.Tensor, discont: t.Optional[float] = None, axis: int = -1, *,
Expand Down Expand Up @@ -406,42 +468,41 @@ def _map_coordinates_constant(
return result.type(arr.dtype)


_INTERP_TO_TORCH_PAD: t.Dict[_InterpBoundaryMode, str] = {
'nearest': 'replicate',
'wrap': 'circular',
'grid-wrap': 'circular',
'constant': 'constant',
'grid-constant': 'constant',
# convert scipy boundary mode to numpy.pad mode
_INTERP_TO_PAD: t.Dict[_InterpBoundaryMode, str] = {
'reflect': 'symmetric',
'mirror': 'reflect',
'nearest': 'edge',
'grid-mirror': 'reflect',
'grid-wrap': 'wrap',
'grid-constant': 'constant',
}


def _convolve1d(
def convolve1d(
arr: torch.Tensor, weights: torch.Tensor, axis: int, *,
mode: _InterpBoundaryMode, cval: float = 0.
) -> torch.Tensor:
pad_mode = _INTERP_TO_TORCH_PAD.get(mode)
if pad_mode is None:
raise ValueError(f"Pad mode '{mode}' not implemented for torch backend")

# reorder to last axis
reorder = axis != arr.ndim - 1
if reorder:
arr = torch.moveaxis(arr, axis, -1)
leading_shape = arr.shape[:-1]
arr = arr.reshape((-1, arr.shape[-1]))
r = len(weights) // 2

# torch's conv1d is actually a correlation
weights = weights.flip(0)

# TODO: this will fail for some pads where weights is large, investigate further
arr = F.pad(arr, (len(weights) - r - 1, r), mode=pad_mode, value=cval)
pad_mode = t.cast(_PadMode, _INTERP_TO_PAD.get(mode, mode))

# reorder to last axis, pad
arr = torch.moveaxis(arr, axis, -1)
out_shape_t = arr.shape
# pad
arr = pad(
arr,
((0, 0),) * (arr.ndim-1) + ((len(weights) - r - 1, r),),
mode=pad_mode, cval=cval
)

# convolve
arr = F.conv1d(
arr[:, None, :], weights[None, None, :]
)[:, 0].reshape((*leading_shape, -1))
arr.reshape((-1, 1, arr.shape[-1])),
weights.flip(0).to(arr.dtype)[None, None, :]
).reshape(out_shape_t)

return torch.moveaxis(arr, -1, axis) if reorder else arr
return torch.moveaxis(arr, -1, axis)


def get_devices() -> t.Tuple[torch.device, ...]:
Expand Down
42 changes: 10 additions & 32 deletions phaser/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,6 @@ def square_pixel_transfer(shape: t.Tuple[int, int], *, xp: t.Any = None) -> NDAr
return xp.sinc(ky) * xp.sinc(kx)


# convert scipy boundary mode to numpy.pad mode
_INTERP_TO_PAD: t.Dict[_InterpBoundaryMode, str] = {
'reflect': 'symmetric',
'mirror': 'reflect',
'nearest': 'edge',
'grid-mirror': 'reflect',
'grid-wrap': 'wrap',
'grid-constant': 'constant',
}


def _canonicalize_axis(axis: int, num_dims: int) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = axis.__index__()
Expand All @@ -312,34 +301,23 @@ def convolve1d(
axis = _canonicalize_axis(axis, arr.ndim)

if xp_is_torch(xp):
from ._torch_kernels import _convolve1d, _MockTensor
from ._torch_kernels import convolve1d, _MockTensor

return t.cast(NDArray[NumT], _convolve1d(
return t.cast(NDArray[NumT], convolve1d(
t.cast(_MockTensor, arr),
t.cast(_MockTensor, weights),
axis=axis, mode=mode, cval=cval
))

scipy = get_scipy_module(arr, weights)

if xp_is_jax(xp):
r = len(weights) // 2
pad_mode = _INTERP_TO_PAD.get(mode, mode)
pad_kwargs = {'constant_values': cval} if pad_mode == 'constant' else {}

pad = tuple(
(len(weights) - r - 1, r) if i == axis else (0, 0)
for i in range(arr.ndim)
)
weights = weights[tuple(
slice(None) if i == axis else None
for i in range(arr.ndim)
)]
# TODO: use jax.lax.conv_general_dilated directly
return scipy.signal.convolve(
xp.pad(arr, pad, mode=pad_mode, **pad_kwargs), # type: ignore
weights, mode='valid', method='direct'
).astype(arr.dtype)
from ._jax_kernels import convolve1d

return t.cast(NDArray[NumT], convolve1d(
arr, weights, axis, # type: ignore
mode=mode, cval=cval
))

scipy = get_scipy_module(arr, weights)

return scipy.ndimage.convolve1d(
arr, weights, axis, mode=mode, cval=cval
Expand Down
12 changes: 6 additions & 6 deletions phaser/utils/num.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def abs2(x: ArrayLike) -> NDArray[numpy.floating]:
return x.real**2 + x.imag**2 # type: ignore


_PadMode: t.TypeAlias = t.Literal['constant', 'edge', 'reflect', 'wrap']
_PadMode: t.TypeAlias = t.Literal['constant', 'edge', 'reflect', 'wrap', 'symmetric']


@t.overload
Expand All @@ -778,12 +778,12 @@ def pad(
xp = get_array_module(arr)

if xp_is_torch(xp):
pass
#from ._torch_kernels import pad
#return pad(arr, pad_width, mode=mode, cval=cval) # type: ignore

return xp.pad(arr, pad_width, mode=mode, constant_values=cval)
from ._torch_kernels import pad
return pad(arr, pad_width, mode=mode, cval=cval) # type: ignore

if mode == 'constant':
return xp.pad(arr, pad_width, mode=mode, constant_values=cval)
return xp.pad(arr, pad_width, mode=mode)


@t.overload
Expand Down
11 changes: 8 additions & 3 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def test_resample(
([[[1, 2], [3, 4]], [[2, 3], [4, 5]], [[3, 4], [5, 6]]], [1, 2, 3], 0),
([[[1, 2], [3, 4]], [[2, 3], [4, 5]], [[3, 4], [5, 6]]], [1, 2, 3], -1),
([1+1.j, 2+2.j, 3+3.j], [1-1.j, 2-1.j], 0),
# casting of weights
([1+1.j, 2+2.j, 3+3.j], [1.0, 2.0], 0),
([[[1, 2], [3, 4]], [[2, 3], [4, 5]], [[3, 4], [5, 6]]], [1, 2, 3], 1),
([1, 2, 3, 4, 5], [2], 0),
# kernel longer than array
([1, 2, 3], [1, 2, 3, 4, 5, 6, 7], 0),
# length-1 along conv axis
([[[1, 2], [3, 4]]], [1, 2, 3], 0),
])
@pytest.mark.parametrize(('mode', 'cval'), [
('constant', 1.0), ('nearest', 0.0), ('mirror', 0.0),
Expand All @@ -93,9 +101,6 @@ def test_convolve1d(
arr, weights, axis, mode, cval,
backend: BackendName,
):
if mode == 'reflect' and backend == 'torch':
pytest.xfail("'reflect' not supported on torch")

arr = numpy.asarray(arr)
weights = numpy.asarray(weights)

Expand Down
Loading
Loading