Source code for kwarray.util_torch

"""
Torch specific extensions
"""
import numpy as np
import sys


[docs] def _is_in_onnx_export(): torch = sys.modules.get('torch', None) if torch is None: return False try: # Does not exist for older torch versions return torch.onnx.is_in_onnx_export() except AttributeError: return False
[docs] def one_hot_embedding(labels, num_classes, dim=1): """ Embedding labels to one-hot form. Args: labels: (LongTensor) class labels, sized [N,]. num_classes: (int) number of classes. dim (int): dimension which will be created, if negative Returns: Tensor: encoded labels, sized [N,#classes]. References: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4 Example: >>> # each element in target has to have 0 <= value < C >>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> labels = torch.LongTensor([0, 0, 1, 4, 2, 3]) >>> num_classes = max(labels) + 1 >>> t = one_hot_embedding(labels, num_classes) >>> assert all(row[y] == 1 for row, y in zip(t.numpy(), labels.numpy())) >>> import ubelt as ub >>> print(ub.urepr(t.numpy().tolist())) [ [1.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], ] >>> t2 = one_hot_embedding(labels.numpy(), num_classes) >>> assert np.all(t2 == t.numpy()) >>> from kwarray.util_torch import _torch_available_devices >>> devices = _torch_available_devices() >>> if devices: >>> device = devices[0] >>> try: >>> t3 = one_hot_embedding(labels.to(device), num_classes) >>> except RuntimeError: >>> pass >>> assert np.all(t3.cpu().numpy() == t.numpy()) Example: >>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> nC = num_classes = 3 >>> labels = (torch.rand(10, 11, 12) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11, 12) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11, 12) >>> assert one_hot_embedding(labels, nC, dim=2).shape == (10, 11, 3, 12) >>> assert one_hot_embedding(labels, nC, dim=3).shape == (10, 11, 12, 3) >>> labels = (torch.rand(10, 11) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11) >>> labels = (torch.rand(10) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3) """ torch = sys.modules.get('torch', None) if torch is not None and torch.is_tensor(labels): in_dims = labels.ndimension() if dim < 0: dim = in_dims - dim + 1 if dim == 1 and in_dims == 1: # normal case where everything is already flat y = torch.eye(int(num_classes), device=labels.device) y_onehot = y[labels] else: # non-flat case (note that this would handle the normal case, but # why do extra work?) y = torch.eye(int(num_classes), device=labels.device) flat_y_onehot = y[labels.view(-1)] y_onehot = flat_y_onehot.view(*list(labels.shape) + [num_classes]) if dim != in_dims: dim_order = list(range(in_dims)) dim_order.insert(dim, in_dims) y_onehot = y_onehot.permute(*dim_order) else: if dim < 0: dim = labels.ndim - dim + 1 flag = (dim != 1 or labels.ndim == 2) if flag: orig_shape = labels.shape labels = labels.reshape(-1) # raise NotImplementedError('not implemented for this case') y = np.eye(int(num_classes)) y_onehot = y[labels] if flag: new_shape = list(orig_shape) + [num_classes] y_onehot = y_onehot.reshape(*new_shape) old_axes = list(range(len(orig_shape))) new_axes = old_axes new_axes.insert(dim, len(orig_shape)) y_onehot = y_onehot.transpose(*new_axes) return y_onehot
[docs] def one_hot_lookup(data, indices): """ Return value of a particular column for each row in data. Each item in labels corresonds to a row in ``data``. Returns the index specified at each row. Args: data (ArrayLike): N x C float array of values indices (ArrayLike): N integer array between 0 and C. This is an column index for each row in ``data``. Returns: ArrayLike: the selected probability for each row Note: This is functionally equivalent to ``[row[c] for row, c in zip(data, indices)]`` except that it is works with pure matrix operations. TODO: - [ ] Allow the user to specify which dimension indices should be zipped over. By default it should be dim=0 - [ ] Allow the user to specify which dimension indices should select from. By default it should be dim=1. Example: >>> from kwarray.util_torch import * # NOQA >>> data = np.array([ >>> [0, 1, 2], >>> [3, 4, 5], >>> [6, 7, 8], >>> [9, 10, 11], >>> ]) >>> indices = np.array([0, 1, 2, 1]) >>> res = one_hot_lookup(data, indices) >>> print('res = {!r}'.format(res)) res = array([ 0, 4, 8, 10]) >>> alt = np.array([row[c] for row, c in zip(data, indices)]) >>> assert np.all(alt == res) Example: >>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> data = torch.from_numpy(np.array([ >>> [0, 1, 2], >>> [3, 4, 5], >>> [6, 7, 8], >>> [9, 10, 11], >>> ])) >>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long() >>> res = one_hot_lookup(data, indices) >>> print('res = {!r}'.format(res)) res = tensor([ 0, 4, 8, 10]...) >>> alt = torch.LongTensor([row[c] for row, c in zip(data, indices)]) >>> assert torch.all(alt == res) Ignore: >>> # xdoctest: +REQUIRES(module:torch, module:onnx, module:onnx_tf) >>> # Test if this converts to ONNX >>> from kwarray.util_torch import * # NOQA >>> import torch.onnx >>> import io >>> import onnx >>> import onnx_tf.backend >>> import numpy as np >>> data = torch.from_numpy(np.array([ >>> [0, 1, 2], >>> [3, 4, 5], >>> [6, 7, 8], >>> [9, 10, 11], >>> ])) >>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long() >>> class TFConvertWrapper(torch.nn.Module): >>> def forward(self, data, indices): >>> return one_hot_lookup(data, indices) >>> ### >>> # Test the ONNX export >>> wrapped = TFConvertWrapper() >>> onnx_file = io.BytesIO() >>> torch.onnx.export( >>> wrapped, tuple([data, indices]), >>> input_names=['data', 'indices'], >>> output_names=['out'], >>> f=onnx_file, >>> opset_version=11, >>> verbose=1, >>> ) >>> onnx_file.seek(0) >>> onnx_model = onnx.load(onnx_file) >>> onnx_tf_model = onnx_tf.backend.prepare(onnx_model) >>> # Test that the resulting graph tensors are concretely sized. >>> import tensorflow as tf >>> onnx_gd = onnx_tf_model.graph.as_graph_def() >>> output_tensors = tf.import_graph_def( >>> onnx_gd, >>> input_map={}, >>> return_elements=[onnx_tf_model.tensor_dict[ol].name for ol in onnx_tf_model.outputs] >>> ) >>> assert all(isinstance(d.value, int) for t in output_tensors for d in t.shape) >>> tf_outputs = onnx_tf_model.run([data, indices]) >>> pt_outputs = wrapped(data, indices) >>> print('tf_outputs = {!r}'.format(tf_outputs)) >>> print('pt_outputs = {!r}'.format(pt_outputs)) >>> ### >>> # Test if data is more than 2D >>> shape = (4, 3, 8) >>> data = torch.arange(int(np.prod(shape))).view(*shape).float() >>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long() >>> onnx_file = io.BytesIO() >>> torch.onnx.export( >>> wrapped, tuple([data, indices]), >>> input_names=['data', 'indices'], >>> output_names=['out'], >>> f=onnx_file, >>> opset_version=11, >>> verbose=1, >>> ) >>> onnx_file.seek(0) >>> onnx_model = onnx.load(onnx_file) >>> onnx_tf_model = onnx_tf.backend.prepare(onnx_model) >>> # Test that the resulting graph tensors are concretely sized. >>> import tensorflow as tf >>> onnx_gd = onnx_tf_model.graph.as_graph_def() >>> output_tensors = tf.import_graph_def( >>> onnx_gd, >>> input_map={}, >>> return_elements=[onnx_tf_model.tensor_dict[ol].name for ol in onnx_tf_model.outputs] >>> ) >>> assert all(isinstance(d.value, int) for t in output_tensors for d in t.shape) >>> tf_outputs = onnx_tf_model.run([data, indices]) >>> pt_outputs = wrapped(data, indices) >>> print('tf_outputs = {!r}'.format(tf_outputs)) >>> print('pt_outputs = {!r}'.format(pt_outputs)) """ torch = sys.modules.get('torch', None) if torch is not None and torch.is_tensor(indices): if _is_in_onnx_export(): # Don't use eye for ONNX ASSUME_OPTSET = 10 device = indices.device n = data.shape[1] # Have to construct eye manually to satisfy onnx # Manually construct diag indices row_idxs = torch.arange(n, device=device) eye_idxs = row_idxs + (row_idxs * n) if ASSUME_OPTSET >= 11: # With opset 11 we use the "put" operation to directly # populate the diagonal elements. eye = torch.zeros((n, n), dtype=data.dtype, device=device) flat_eye = eye.view(n * n) diag_elem = torch.ones(n, dtype=data.dtype, device=device) flat_eye[eye_idxs] = diag_elem elif ASSUME_OPTSET >= 10: # With opset 10 we cannot use "put", so we have to get spicey # Construct the flat indexes of an NxN matrix flat_idxs = torch.arange(n * n) # Broadcast and check if these flat indexes are equal to the # target indexes, then sum over the broadcast dimension flat_eye = (flat_idxs[:, None] == eye_idxs[None, :]).to(data.dtype).sum(dim=1) else: raise AssertionError('ASSUME_OPTSET = {}'.format(ASSUME_OPTSET)) eye = flat_eye.view(n, n) # Do the normal lookup in the eye matrix to get the OHE ohe = eye[indices] # need to pad OHE with extra dimensions for broadcasting extra_dims = len(data.shape) - 2 if extra_dims > 0: ohe = ohe[(Ellipsis,) + (None,) * extra_dims] # Have to use multiply trick to satisfy onnx out = (data * ohe).sum(dim=1) else: ohe = torch.eye(data.shape[1], dtype=torch.bool, device=indices.device)[indices] out = data[ohe] else: # ohe = kwarray.one_hot_embedding(indices, data.shape[1]).astype(bool) # Constructing the OHE with a small dtype offers a sizable speed advantage ohe = np.eye(data.shape[1], dtype=bool)[indices] out = data[ohe] return out
[docs] def _torch_available_devices(): """ An attempt to determine what devices this version of torch can use Try and check that cuda is availble AND we have a good kernel image """ torch = sys.modules.get('torch', None) available_devices = [] if torch is not None: if torch.cuda.is_available(): arch_versions = [] for arch in torch.cuda.get_arch_list(): arch_major = int(arch.split('_')[1][0]) arch_minor = int(arch.split('_')[1][1]) arch_ver = (arch_major, arch_minor) arch_versions.append(arch_ver) arch_versions = set(arch_versions) for idx in range(torch.cuda.device_count()): device = torch.device(idx) devprop = torch.cuda.get_device_properties(device) dev_ver = (devprop.major, devprop.minor) if dev_ver in arch_versions: available_devices.append(device) return available_devices
if __name__ == '__main__': """ CommandLine: python -m kwarray.util_torch all """ import xdoctest xdoctest.doctest_module(__file__)