kwarray.util_torch
¶
Torch specific extensions
Module Contents¶
Functions¶
|
Embedding labels to one-hot form. |
|
Return value of a particular column for each row in data. |
Attributes¶
- kwarray.util_torch.torch¶
- kwarray.util_torch._is_in_onnx_export()¶
- kwarray.util_torch.one_hot_embedding(labels, num_classes, dim=1)¶
Embedding labels to one-hot form.
- Parameters
labels – (LongTensor) class labels, sized [N,].
num_classes – (int) number of classes.
dim (int) – dimension which will be created, if negative
- Returns
encoded labels, sized [N,#classes].
- Return type
Tensor
References
https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4
Example
>>> # each element in target has to have 0 <= value < C >>> # xdoctest: +REQUIRES(module:torch) >>> labels = torch.LongTensor([0, 0, 1, 4, 2, 3]) >>> num_classes = max(labels) + 1 >>> t = one_hot_embedding(labels, num_classes) >>> assert all(row[y] == 1 for row, y in zip(t.numpy(), labels.numpy())) >>> import ubelt as ub >>> print(ub.repr2(t.numpy().tolist())) [ [1.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], ] >>> t2 = one_hot_embedding(labels.numpy(), num_classes) >>> assert np.all(t2 == t.numpy()) >>> if torch.cuda.is_available(): >>> t3 = one_hot_embedding(labels.to(0), num_classes) >>> assert np.all(t3.cpu().numpy() == t.numpy())
Example
>>> # xdoctest: +REQUIRES(module:torch) >>> nC = num_classes = 3 >>> labels = (torch.rand(10, 11, 12) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11, 12) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11, 12) >>> assert one_hot_embedding(labels, nC, dim=2).shape == (10, 11, 3, 12) >>> assert one_hot_embedding(labels, nC, dim=3).shape == (10, 11, 12, 3) >>> labels = (torch.rand(10, 11) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11) >>> labels = (torch.rand(10) * nC).long() >>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10) >>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3)
- kwarray.util_torch.one_hot_lookup(data, indices)¶
Return value of a particular column for each row in data.
Each item in labels corresonds to a row in
data
. Returns the index specified at each row.- Parameters
data (ArrayLike) – N x C float array of values
indices (ArrayLike) – N integer array between 0 and C. This is an column index for each row in
data
.
- Returns
the selected probability for each row
- Return type
ArrayLike
Notes
This is functionally equivalent to
[row[c] for row, c in zip(data, indices)]
except that it is works with pure matrix operations.Todo
- [ ] Allow the user to specify which dimension indices should be
zipped over. By default it should be dim=0
- [ ] Allow the user to specify which dimension indices should select
from. By default it should be dim=1.
Example
>>> from kwarray.util_torch import * # NOQA >>> data = np.array([ >>> [0, 1, 2], >>> [3, 4, 5], >>> [6, 7, 8], >>> [9, 10, 11], >>> ]) >>> indices = np.array([0, 1, 2, 1]) >>> res = one_hot_lookup(data, indices) >>> print('res = {!r}'.format(res)) res = array([ 0, 4, 8, 10]) >>> alt = np.array([row[c] for row, c in zip(data, indices)]) >>> assert np.all(alt == res)
Example
>>> # xdoctest: +REQUIRES(module:torch) >>> import torch >>> data = torch.from_numpy(np.array([ >>> [0, 1, 2], >>> [3, 4, 5], >>> [6, 7, 8], >>> [9, 10, 11], >>> ])) >>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long() >>> res = one_hot_lookup(data, indices) >>> print('res = {!r}'.format(res)) res = tensor([ 0, 4, 8, 10]...) >>> alt = torch.LongTensor([row[c] for row, c in zip(data, indices)]) >>> assert torch.all(alt == res)
- Ignore:
>>> # xdoctest: +REQUIRES(module:torch, module:onnx, module:onnx_tf) >>> # Test if this converts to ONNX >>> from kwarray.util_torch import * # NOQA >>> import torch.onnx >>> import io >>> import onnx >>> import onnx_tf.backend >>> import numpy as np >>> data = torch.from_numpy(np.array([ >>> [0, 1, 2], >>> [3, 4, 5], >>> [6, 7, 8], >>> [9, 10, 11], >>> ])) >>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long() >>> class TFConvertWrapper(torch.nn.Module): >>> def forward(self, data, indices): >>> return one_hot_lookup(data, indices) >>> ### >>> # Test the ONNX export >>> wrapped = TFConvertWrapper() >>> onnx_file = io.BytesIO() >>> torch.onnx.export( >>> wrapped, tuple([data, indices]), >>> input_names=['data', 'indices'], >>> output_names=['out'], >>> f=onnx_file, >>> opset_version=11, >>> verbose=1, >>> ) >>> onnx_file.seek(0) >>> onnx_model = onnx.load(onnx_file) >>> onnx_tf_model = onnx_tf.backend.prepare(onnx_model) >>> # Test that the resulting graph tensors are concretely sized. >>> import tensorflow as tf >>> onnx_gd = onnx_tf_model.graph.as_graph_def() >>> output_tensors = tf.import_graph_def( >>> onnx_gd, >>> input_map={}, >>> return_elements=[onnx_tf_model.tensor_dict[ol].name for ol in onnx_tf_model.outputs] >>> ) >>> assert all(isinstance(d.value, int) for t in output_tensors for d in t.shape) >>> tf_outputs = onnx_tf_model.run([data, indices]) >>> pt_outputs = wrapped(data, indices) >>> print('tf_outputs = {!r}'.format(tf_outputs)) >>> print('pt_outputs = {!r}'.format(pt_outputs)) >>> ### >>> # Test if data is more than 2D >>> shape = (4, 3, 8) >>> data = torch.arange(int(np.prod(shape))).view(*shape).float() >>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long() >>> onnx_file = io.BytesIO() >>> torch.onnx.export( >>> wrapped, tuple([data, indices]), >>> input_names=['data', 'indices'], >>> output_names=['out'], >>> f=onnx_file, >>> opset_version=11, >>> verbose=1, >>> ) >>> onnx_file.seek(0) >>> onnx_model = onnx.load(onnx_file) >>> onnx_tf_model = onnx_tf.backend.prepare(onnx_model) >>> # Test that the resulting graph tensors are concretely sized. >>> import tensorflow as tf >>> onnx_gd = onnx_tf_model.graph.as_graph_def() >>> output_tensors = tf.import_graph_def( >>> onnx_gd, >>> input_map={}, >>> return_elements=[onnx_tf_model.tensor_dict[ol].name for ol in onnx_tf_model.outputs] >>> ) >>> assert all(isinstance(d.value, int) for t in output_tensors for d in t.shape) >>> tf_outputs = onnx_tf_model.run([data, indices]) >>> pt_outputs = wrapped(data, indices) >>> print('tf_outputs = {!r}'.format(tf_outputs)) >>> print('pt_outputs = {!r}'.format(pt_outputs))