Source code for astropy.wcs.wcsapi.sliced_low_level_wcs
import numbers
import numpy as np
from astropy.wcs.wcsapi import BaseLowLevelWCS
__all__ = ['sanitize_slices', 'SlicedLowLevelWCS']
[docs]def sanitize_slices(slices, ndim):
"""
Given a set of input
"""
if not isinstance(slices, (tuple, list)): # We just have a single int
slices = (slices,)
slices = list(slices)
if Ellipsis in slices:
if slices.count(Ellipsis) > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
# Replace the Ellipsis with the correct number of slice(None)s
e_ind = slices.index(Ellipsis)
slices.remove(Ellipsis)
n_e = ndim - len(slices)
for i in range(n_e):
ind = e_ind + i
slices.insert(ind, slice(None))
for i in range(ndim):
if i < len(slices):
slc = slices[i]
if isinstance(slc, slice):
if slc.step and slc.step != 1:
raise ValueError("Slicing WCS with a step is not supported.")
elif not isinstance(slc, numbers.Integral):
raise ValueError("Only integer or range slices are accepted.")
else:
slices.append(slice(None))
return slices
[docs]class SlicedLowLevelWCS(BaseLowLevelWCS):
def __init__(self, wcs, slices):
self._wcs = wcs
self._slices_array = sanitize_slices(slices, self._wcs.pixel_n_dim)
self._slices_pixel = self._slices_array[::-1]
# figure out which pixel dimensions have been kept, then use axis correlation
# matrix to figure out which world dims are kept
self._pixel_keep = np.nonzero([not isinstance(self._slices_pixel[ip], numbers.Integral)
for ip in range(self._wcs.pixel_n_dim)])[0]
# axis_correlation_matrix[world, pixel]
self._world_keep = np.nonzero(
self._wcs.axis_correlation_matrix[:, self._pixel_keep].any(axis=1))[0]
@property
def pixel_n_dim(self):
return len(self._pixel_keep)
@property
def world_n_dim(self):
return len(self._world_keep)
@property
def world_axis_physical_types(self):
return [self._wcs.world_axis_physical_types[i] for i in self._world_keep]
@property
def world_axis_units(self):
return [self._wcs.world_axis_units[i] for i in self._world_keep]
[docs] def pixel_to_world_values(self, *pixel_arrays):
pixel_arrays_new = []
ipix_curr = -1
for ipix in range(self._wcs.pixel_n_dim):
if isinstance(self._slices_pixel[ipix], int):
pixel_arrays_new.append(self._slices_pixel[ipix])
else:
ipix_curr += 1
if self._slices_pixel[ipix].start is not None:
pixel_arrays_new.append(pixel_arrays[ipix_curr] + self._slices_pixel[ipix].start)
else:
pixel_arrays_new.append(pixel_arrays[ipix_curr])
world_arrays = self._wcs.pixel_to_world_values(*pixel_arrays_new)
return [world_arrays[iw] for iw in self._world_keep]
[docs] def array_index_to_world_values(self, *index_arrays):
return self.pixel_to_world_values(*index_arrays[::-1])
[docs] def world_to_pixel_values(self, *world_arrays):
world_arrays_new = []
iworld_curr = -1
for iworld in range(self._wcs.world_n_dim):
if iworld in self._world_keep:
iworld_curr += 1
world_arrays_new.append(world_arrays[iworld_curr])
else:
world_arrays_new.append(1.)
pixel_arrays = list(self._wcs.world_to_pixel_values(*world_arrays_new))
for ipixel in range(self._wcs.pixel_n_dim):
if isinstance(self._slices_pixel[ipixel], slice) and self._slices_pixel[ipixel].start is not None:
pixel_arrays[ipixel] -= self._slices_pixel[ipixel].start
return [pixel_arrays[ip] for ip in self._pixel_keep]
[docs] def world_to_array_index_values(self, *world_arrays):
pixel_arrays = self.world_to_pixel_values(*world_arrays, 0)[::-1]
array_indices = tuple(np.asarray(np.floor(pixel + 0.5), dtype=np.int) for pixel in pixel_arrays)
return array_indices
@property
def world_axis_object_components(self):
return [self._wcs.world_axis_object_components[idx] for idx in self._world_keep]
@property
def world_axis_object_classes(self):
keys_keep = [item[0] for item in self.world_axis_object_components]
return dict([item for item in self._wcs.world_axis_object_classes.items() if item[0] in keys_keep])
@property
def array_shape(self):
if self._wcs.array_shape:
return np.broadcast_to(0, self._wcs.array_shape)[tuple(self._slices_array)].shape
@property
def pixel_shape(self):
if self.array_shape:
return self.array_shape[::-1]
@property
def pixel_bounds(self):
if self._wcs.pixel_bounds is None:
return None
bounds = []
for idx in self._pixel_keep:
if self._slices_pixel[idx].start is None:
bounds.append(self._wcs.pixel_bounds[idx])
else:
imin, imax = self._wcs.pixel_bounds[idx]
start = self._slices_pixel[idx].start
bounds.append((imin - start, imax - start))
return bounds
@property
def axis_correlation_matrix(self):
return self._wcs.axis_correlation_matrix[self._world_keep][:,self._pixel_keep]