"""
Defines the :class:`SlidingWindow` and :class:`Sticher` classes.
The :class:`SlidingWindow` generates a grid of slices over an
:func:`numpy.ndarray`, which can then be used to compute on subsets of the
data. The :class:`Stitcher` can then take these results and recombine them into
a final result that matches the larger array.
"""
import ubelt as ub
import numpy as np
import itertools as it
[docs]
class SlidingWindow(ub.NiceRepr):
"""
Slide a window of a certain shape over an array with a larger shape.
This can be used for iterating over a grid of sub-regions of 2d-images,
3d-volumes, or any n-dimensional array.
Yields slices of shape `window` that can be used to index into an array
with shape `shape` via numpy / torch fancy indexing. This allows for fast
fast iteration over subregions of a larger image. Because we generate a
grid-basis using only shapes, the larger image does not need to be in
memory as long as its width/height/depth/etc...
Args:
shape (Tuple[int, ...]): shape of source array to slide across.
window (Tuple[int, ...]): shape of window that will be slid over the
larger image.
overlap (float, default=0): a number between 0 and 1 indicating the
fraction of overlap that parts will have. Specifying this is
mutually exclusive with `stride`. Must be `0 <= overlap < 1`.
stride (int, default=None): the number of cells (pixels) moved on each
step of the window. Mutually exclusive with overlap.
keepbound (bool, default=False): if True, a non-uniform stride will be
taken to ensure that the right / bottom of the image is returned as
a slice if needed. Such a slice will not obey the overlap
constraints. (Defaults to False)
allow_overshoot (bool, default=False): if False, we will raise an
error if the window doesn't slide perfectly over the input shape.
Attributes:
basis_shape - shape of the grid corresponding to the number of strides
the sliding window will take.
basis_slices - slices that will be taken in every dimension
Yields:
Tuple[slice, ...]: slices used for numpy indexing, the number of slices
in the tuple
Note:
For each dimension, we generate a basis (which defines a grid), and we
slide over that basis.
TODO:
- [ ] have an option that is allowed to go outside of the window bounds
on the right and bottom when the slider overshoots.
Example:
>>> from kwarray.util_slider import * # NOQA
>>> shape = (10, 10)
>>> window = (5, 5)
>>> self = SlidingWindow(shape, window)
>>> for i, index in enumerate(self):
>>> print('i={}, index={}'.format(i, index))
i=0, index=(slice(0, 5, None), slice(0, 5, None))
i=1, index=(slice(0, 5, None), slice(5, 10, None))
i=2, index=(slice(5, 10, None), slice(0, 5, None))
i=3, index=(slice(5, 10, None), slice(5, 10, None))
Example:
>>> from kwarray.util_slider import * # NOQA
>>> shape = (16, 16)
>>> window = (4, 4)
>>> self = SlidingWindow(shape, window, overlap=(.5, .25))
>>> print('self.stride = {!r}'.format(self.stride))
self.stride = [2, 3]
>>> list(ub.chunks(self.grid, 5))
[[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4)],
[(1, 0), (1, 1), (1, 2), (1, 3), (1, 4)],
[(2, 0), (2, 1), (2, 2), (2, 3), (2, 4)],
[(3, 0), (3, 1), (3, 2), (3, 3), (3, 4)],
[(4, 0), (4, 1), (4, 2), (4, 3), (4, 4)],
[(5, 0), (5, 1), (5, 2), (5, 3), (5, 4)],
[(6, 0), (6, 1), (6, 2), (6, 3), (6, 4)]]
Example:
>>> # Test shapes that dont fit
>>> # When the window is bigger than the shape, the left-aligned slices
>>> # are returend.
>>> self = SlidingWindow((3, 3), (12, 12), allow_overshoot=True, keepbound=True)
>>> print(list(self))
[(slice(0, 12, None), slice(0, 12, None))]
>>> print(list(SlidingWindow((3, 3), None, allow_overshoot=True, keepbound=True)))
[(slice(0, 3, None), slice(0, 3, None))]
>>> print(list(SlidingWindow((3, 3), (None, 2), allow_overshoot=True, keepbound=True)))
[(slice(0, 3, None), slice(0, 2, None)), (slice(0, 3, None), slice(1, 3, None))]
"""
def __init__(self, shape, window, overlap=None, stride=None,
keepbound=False, allow_overshoot=False):
stride, overlap, window = self._compute_stride(
overlap, stride, shape, window)
stide_kw = [dict(margin=d, stop=D, step=s, keepbound=keepbound,
check=not keepbound and not allow_overshoot)
for d, D, s in zip(window, shape, stride)]
undershot_shape = []
overshoots = []
for kw in stide_kw:
final_pos = (kw['stop'] - kw['margin'])
n_steps = final_pos // kw['step']
overshoot = final_pos % kw['step']
undershot_shape.append(n_steps + 1)
overshoots.append(overshoot)
self._final_step = overshoots
if not allow_overshoot and any(overshoots):
raise ValueError('overshoot={} stide_kw={}'.format(overshoots,
stide_kw))
# make a slice generator for each dimension
self.stride = stride
self.overlap = overlap
self.window = window
self.input_shape = shape
# The undershot basis shape, only contains indices that correspond
# perfectly to the input. It may crop a bit of the ends. If this is
# equal to basis_shape, then the self perfectly fits the input.
self.undershot_shape = undershot_shape
# NOTE: if we have overshot, then basis shape will not perfectly
# align to the original image. This shape will be a bit bigger.
self.basis_slices = [list(_slices1d(**kw)) for kw in stide_kw]
self.basis_shape = [len(b) for b in self.basis_slices]
self.n_total = np.prod(self.basis_shape)
def __nice__(self):
return 'bshape={}, shape={}, window={}, stride={}'.format(
tuple(self.basis_shape),
tuple(self.input_shape),
self.window,
tuple(self.stride)
)
[docs]
def _compute_stride(self, overlap, stride, shape, window):
"""
Ensures that stride hasoverlap the correct shape. If stride is not
provided, compute stride from desired overlap.
"""
if window is None:
window = shape
if isinstance(stride, np.ndarray):
stride = tuple(stride)
# TODO: some auto overlap?
if isinstance(overlap, np.ndarray):
overlap = tuple(overlap)
if len(window) != len(shape):
raise ValueError('incompatible dims: {} {}'.format(len(window),
len(shape)))
if any(d is None for d in window):
window = [D if d is None else d for d, D in zip(window, shape)]
if overlap is None and stride is None:
overlap = 0
if not (overlap is None) ^ (stride is None):
raise ValueError('specify overlap({}) XOR stride ({})'.format(
overlap, stride))
if stride is None:
if not isinstance(overlap, (list, tuple)):
overlap = [overlap] * len(window)
if any(frac < 0 or frac >= 1 for frac in overlap):
raise ValueError((
'part overlap was {}, but fractional overlaps must be '
'in the range [0, 1)').format(overlap))
stride = [int(round(d - d * frac))
for frac, d in zip(overlap, window)]
else:
if not isinstance(stride, (list, tuple)):
stride = [stride] * len(window)
# Recompute fractional overlap after integer stride is computed
overlap = [(d - s) / d for s, d in zip(stride, window)]
assert len(stride) == len(shape), 'incompatible dims'
if not all(stride):
raise ValueError(
'Step must be positive everywhere. Got={}'.format(stride))
return stride, overlap, window
def __len__(self):
return self.n_total
[docs]
def _iter_basis_frac(self):
for slices in self:
frac = [sl.start / D for sl, D in zip(slices, self.source.shape)]
yield frac
def __iter__(self):
for slices in it.product(*self.basis_slices):
yield slices
def __getitem__(self, index):
"""
Get a specific item by its flat (raveled) index
Example:
>>> from kwarray.util_slider import * # NOQA
>>> window = (10, 10)
>>> shape = (20, 20)
>>> self = SlidingWindow(shape, window, stride=5)
>>> itered_items = list(self)
>>> assert len(itered_items) == len(self)
>>> indexed_items = [self[i] for i in range(len(self))]
>>> assert itered_items[0] == self[0]
>>> assert itered_items[-1] == self[-1]
>>> assert itered_items == indexed_items
"""
if index < 0:
index = len(self) + index
# Find the nd location in the grid
basis_idx = np.unravel_index(index, self.basis_shape)
# Take the slice for each of the n dimensions
slices = tuple([bdim[i]
for bdim, i in zip(self.basis_slices, basis_idx)])
return slices
@property
def grid(self):
"""
Generate indices into the "basis" slice for each dimension.
This enumerates the nd indices of the grid.
Yields:
Tuple[int, ...]
"""
# Generates basis for "sliding window" slices to break a large image
# into smaller pieces. Use it.product to slide across the coordinates.
basis_indices = map(range, self.basis_shape)
for basis_idxs in it.product(*basis_indices):
yield basis_idxs
@property
def slices(self):
"""
Generate slices for each window (equivalent to iter(self))
Example:
>>> shape = (220, 220)
>>> window = (10, 10)
>>> self = SlidingWindow(shape, window, stride=5)
>>> list(self)[41:45]
[(slice(0, 10, None), slice(205, 215, None)),
(slice(0, 10, None), slice(210, 220, None)),
(slice(5, 15, None), slice(0, 10, None)),
(slice(5, 15, None), slice(5, 15, None))]
>>> print('self.overlap = {!r}'.format(self.overlap))
self.overlap = [0.5, 0.5]
"""
return iter(self)
@property
def centers(self):
"""
Generate centers of each window
Yields:
Tuple[float, ...]: the center coordinate of the slice
Example:
>>> shape = (4, 4)
>>> window = (3, 3)
>>> self = SlidingWindow(shape, window, stride=1)
>>> list(zip(self.centers, self.slices))
[((1.0, 1.0), (slice(0, 3, None), slice(0, 3, None))),
((1.0, 2.0), (slice(0, 3, None), slice(1, 4, None))),
((2.0, 1.0), (slice(1, 4, None), slice(0, 3, None))),
((2.0, 2.0), (slice(1, 4, None), slice(1, 4, None)))]
>>> shape = (3, 3)
>>> window = (2, 2)
>>> self = SlidingWindow(shape, window, stride=1)
>>> list(zip(self.centers, self.slices))
[((0.5, 0.5), (slice(0, 2, None), slice(0, 2, None))),
((0.5, 1.5), (slice(0, 2, None), slice(1, 3, None))),
((1.5, 0.5), (slice(1, 3, None), slice(0, 2, None))),
((1.5, 1.5), (slice(1, 3, None), slice(1, 3, None)))]
"""
for slices in self:
center = tuple(sl_.start + (sl_.stop - sl_.start - 1) / 2
for sl_ in slices)
yield center
__devnote__ = '''
TODO:
- [ ] Look at the old "add_fast" code in the netharn version and see if
it is worth porting. This code is kept in the dev folder in
../dev/_dev_slider.py
'''
[docs]
class Stitcher(ub.NiceRepr):
"""
Stitches multiple possibly overlapping slices into a larger array.
This is used to invert the SlidingWindow. For semenatic segmentation the
patches are probability chips. Overlapping chips are averaged together.
SeeAlso:
:class:`kwarray.RunningStats` - similarly performs running means, but
can also track other statistics.
Example:
>>> from kwarray.util_slider import * # NOQA
>>> import sys
>>> # Build a high resolution image and slice it into chips
>>> highres = np.random.rand(5, 200, 200).astype(np.float32)
>>> target_shape = (1, 50, 50)
>>> slider = SlidingWindow(highres.shape, target_shape, overlap=(0, .5, .5))
>>> # Show how Sticher can be used to reconstruct the original image
>>> stitcher = Stitcher(slider.input_shape)
>>> for sl in list(slider):
... chip = highres[sl]
... stitcher.add(sl, chip)
>>> assert stitcher.weights.max() == 4, 'some parts should be processed 4 times'
>>> recon = stitcher.finalize()
Example:
>>> from kwarray.util_slider import * # NOQA
>>> import sys
>>> # Demo stitching 3 patterns where one has nans
>>> pat1 = np.full((32, 32), fill_value=0.2)
>>> pat2 = np.full((32, 32), fill_value=0.4)
>>> pat3 = np.full((32, 32), fill_value=0.8)
>>> pat1[:, 16:] = 0.6
>>> pat2[16:, :] = np.nan
>>> # Test with nan_policy=omit
>>> stitcher = Stitcher(shape=(32, 64), nan_policy='omit')
>>> stitcher[0:32, 0:32](pat1)
>>> stitcher[0:32, 16:48](pat2)
>>> stitcher[0:32, 33:64](pat3[:, 1:])
>>> final1 = stitcher.finalize()
>>> # Test without nan_policy=propogate
>>> stitcher = Stitcher(shape=(32, 64), nan_policy='propogate')
>>> stitcher[0:32, 0:32](pat1)
>>> stitcher[0:32, 16:48](pat2)
>>> stitcher[0:32, 33:64](pat3[:, 1:])
>>> final2 = stitcher.finalize()
>>> # Checks
>>> assert np.isnan(final1).sum() == 16, 'only should contain nan where no data was stiched'
>>> assert np.isnan(final2).sum() == 512, 'should contain nan wherever a nan was stitched'
>>> # xdoctest: +REQUIRES(--show)
>>> # xdoctest: +REQUIRES(module:kwplot)
>>> import kwplot
>>> import kwimage
>>> kwplot.autompl()
>>> kwplot.imshow(pat1, title='pat1', pnum=(3, 3, 1))
>>> kwplot.imshow(kwimage.nodata_checkerboard(pat2, square_shape=1), title='pat2 (has nans)', pnum=(3, 3, 2))
>>> kwplot.imshow(pat3, title='pat3', pnum=(3, 3, 3))
>>> kwplot.imshow(kwimage.nodata_checkerboard(final1, square_shape=1), title='stitched (nan_policy=omit)', pnum=(3, 1, 2))
>>> kwplot.imshow(kwimage.nodata_checkerboard(final2, square_shape=1), title='stitched (nan_policy=propogate)', pnum=(3, 1, 3))
Example:
>>> # Example of weighted stitching
>>> # xdoctest: +REQUIRES(module:kwimage)
>>> from kwarray.util_slider import * # NOQA
>>> import kwimage
>>> import kwarray
>>> import sys
>>> data = kwimage.Mask.demo().data.astype(np.float32)
>>> data_dims = data.shape
>>> window_dims = (8, 8)
>>> # We are going to slide a window over the data, do some processing
>>> # and then stitch it all back together. There are a few ways we
>>> # can do it. Lets demo the params.
>>> basis = {
>>> # Vary the overlap of the slider
>>> 'overlap': (0, 0.5, .9),
>>> # Vary if we are using weighted stitching or not
>>> 'weighted': ['none', 'gauss'],
>>> 'keepbound': [True, False]
>>> }
>>> results = []
>>> gauss_weights = kwimage.gaussian_patch(window_dims)
>>> gauss_weights = kwimage.normalize(gauss_weights)
>>> for params in ub.named_product(basis):
>>> if params['weighted'] == 'none':
>>> weights = None
>>> elif params['weighted'] == 'gauss':
>>> weights = gauss_weights
>>> # Build the slider and stitcher
>>> slider = kwarray.SlidingWindow(
>>> data_dims, window_dims, overlap=params['overlap'],
>>> allow_overshoot=True,
>>> keepbound=params['keepbound'])
>>> stitcher = kwarray.Stitcher(data_dims)
>>> # Loop over the regions
>>> for sl in list(slider):
>>> chip = data[sl]
>>> # This is our dummy function for thie example.
>>> predicted = np.ones_like(chip) * chip.sum() / chip.size
>>> stitcher.add(sl, predicted, weight=weights)
>>> final = stitcher.finalize()
>>> results.append({
>>> 'final': final,
>>> 'params': params,
>>> })
>>> # xdoctest: +REQUIRES(--show)
>>> # xdoctest: +REQUIRES(module:kwplot)
>>> import kwplot
>>> kwplot.autompl()
>>> pnum_ = kwplot.PlotNums(nCols=3, nSubplots=len(results) + 2)
>>> kwplot.imshow(data, pnum=pnum_(), title='input image')
>>> kwplot.imshow(gauss_weights, pnum=pnum_(), title='Gaussian weights')
>>> pnum_()
>>> for result in results:
>>> param_key = ub.urepr(result['params'], compact=1)
>>> final = result['final']
>>> canvas = kwarray.normalize(final)
>>> canvas = kwimage.fill_nans_with_checkers(canvas)
>>> kwplot.imshow(canvas, pnum=pnum_(), title=param_key)
"""
def __init__(self, shape, device='numpy', dtype='float32', nan_policy='propogate'):
"""
Args:
shape (tuple): dimensions of the large image that will be created from
the smaller pixels or patches.
device (str | int | torch.device):
default is 'numpy', but if given as a torch device, then
underlying operations will be done with torch tensors instead.
dtype (str):
the datatype to use in the underlying accumulator.
nan_policy (str):
if omit, check for nans and convert any to zero weight items in
stitching.
"""
self.nan_policy = nan_policy
self.shape = shape
self.device = device
if device == 'numpy':
self.sums = np.zeros(shape, dtype=dtype)
self.weights = np.zeros(shape, dtype=dtype)
self.sumview = self.sums.ravel()
self.weightview = self.weights.ravel()
else:
import torch
self.sums = torch.zeros(shape, device=device)
self.weights = torch.zeros(shape, device=device)
self.sumview = self.sums.view(-1)
self.weightview = self.weights.view(-1)
if self.nan_policy in {'omit', 'raise'}:
if device == 'numpy':
self._isnan = np.isnan
self._any = np.any
else:
self._isnan = torch.isnan
self._any = torch.any
elif self.nan_policy != 'propogate':
raise ValueError(self.nan_policy)
def __nice__(self):
return str(self.sums.shape)
[docs]
def add(self, indices, patch, weight=None):
"""
Incorporate a new (possibly overlapping) patch or pixel using a
weighted sum.
Args:
indices (slice | tuple | None):
typically a Tuple[slice] of pixels or a single pixel, but this
can be any numpy fancy index.
patch (ndarray): data to patch into the bigger image.
weight (float | ndarray): weight of this patch (default to 1.0)
"""
if self.nan_policy == 'omit':
mask = self._isnan(patch)
if self._any(mask):
# Detect nans, set weight and value to zero
if weight is None:
weight = (~mask).astype(self.weights.dtype)
else:
weight = weight * (~mask).astype(self.weights.dtype)
patch = patch.copy()
patch[mask] = 0
elif self.nan_policy == 'raise':
mask = self._isnan(patch)
if self._any(mask):
raise ValueError('nan_policy is raise')
if weight is None:
self.sums[indices] += patch
self.weights[indices] += 1.0
else:
self.sums[indices] += (patch * weight)
self.weights[indices] += weight
def __getitem__(self, indices):
"""
Convinience function to use slice notation directly.
"""
from functools import partial
return partial(self.add, indices)
[docs]
def average(self):
"""
Averages out contributions from overlapping adds using weighted average
Returns:
ndarray: out - the stitched image
"""
out = self.sums / self.weights
return out
[docs]
def finalize(self, indices=None):
"""
Averages out contributions from overlapping adds
Args:
indices (None | slice | tuple): if None, finalize the entire
block, otherwise only finalize a subregion.
Returns:
ndarray: final - the stitched image
"""
if indices is None:
final = self.sums / self.weights
else:
final = self.sums[indices] / self.weights[indices]
return final
[docs]
def _slices1d(margin, stop, step=None, start=0, keepbound=False, check=True):
"""
Helper to generates slices in a single dimension.
Args:
margin (int): the length of the slice (window)
stop (int): the length of the image dimension
step (int, default=None): the length of each step / distance between
slices
start (int, default=0): starting point (in most cases set this to 0)
keepbound (bool): if True, a non-uniform step will be taken to ensure
that the right / bottom of the image is returned as a slice if
needed. Such a slice will not obey the overlap constraints.
(Defaults to False)
check (bool): if True an error will be raised if the window does not
cover the entire extent from start to stop, even if keepbound is
True.
Yields:
slice : slice in one dimension of size (margin)
Example:
>>> stop, margin, step = 2000, 360, 360
>>> keepbound = True
>>> strides = list(_slices1d(margin, stop, step, keepbound, check=False))
>>> assert all([(s.stop - s.start) == margin for s in strides])
Example:
>>> stop, margin, step = 200, 46, 7
>>> keepbound = True
>>> strides = list(_slices1d(margin, stop, step, keepbound=False, check=True))
>>> starts = np.array([s.start for s in strides])
>>> stops = np.array([s.stop for s in strides])
>>> widths = stops - starts
>>> assert np.all(np.diff(starts) == step)
>>> assert np.all(widths == margin)
Example:
>>> import pytest
>>> stop, margin, step = 200, 36, 7
>>> with pytest.raises(ValueError):
... list(_slices1d(margin, stop, step))
"""
if step is None:
step = margin
if check:
# see how far off the end we would fall if we didnt check bounds
perfect_final_pos = (stop - start - margin)
overshoot = perfect_final_pos % step
if overshoot > 0:
raise ValueError(
('margin={} and step={} overshoot endpoint={} '
'by {} units when starting from={}').format(
margin, step, stop, overshoot, start))
pos = start
# probably could be more efficient with numpy here
while True:
endpos = pos + margin
yield slice(pos, endpos)
# Stop once we reached the end
if endpos == stop:
break
pos += step
if pos + margin > stop:
if keepbound:
# Ensure the boundary is always used even if steps
# would overshoot Could do some other strategy here
pos = stop - margin
if pos < 0:
break
else:
break