kwarray.util_torch

Torch specific extensions

Module Contents

Functions

_is_in_onnx_export()

one_hot_embedding(labels, num_classes, dim=1)

Embedding labels to one-hot form.

one_hot_lookup(data, indices)

Return value of a particular column for each row in data.

Attributes

torch

kwarray.util_torch.torch
kwarray.util_torch._is_in_onnx_export()
kwarray.util_torch.one_hot_embedding(labels, num_classes, dim=1)

Embedding labels to one-hot form.

Parameters
  • labels – (LongTensor) class labels, sized [N,].

  • num_classes – (int) number of classes.

  • dim (int) – dimension which will be created, if negative

Returns

encoded labels, sized [N,#classes].

Return type

Tensor

References

https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4

Example

>>> # each element in target has to have 0 <= value < C
>>> # xdoctest: +REQUIRES(module:torch)
>>> labels = torch.LongTensor([0, 0, 1, 4, 2, 3])
>>> num_classes = max(labels) + 1
>>> t = one_hot_embedding(labels, num_classes)
>>> assert all(row[y] == 1 for row, y in zip(t.numpy(), labels.numpy()))
>>> import ubelt as ub
>>> print(ub.repr2(t.numpy().tolist()))
[
    [1.0, 0.0, 0.0, 0.0, 0.0],
    [1.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 1.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 1.0],
    [0.0, 0.0, 1.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 1.0, 0.0],
]
>>> t2 = one_hot_embedding(labels.numpy(), num_classes)
>>> assert np.all(t2 == t.numpy())
>>> if torch.cuda.is_available():
>>>     t3 = one_hot_embedding(labels.to(0), num_classes)
>>>     assert np.all(t3.cpu().numpy() == t.numpy())

Example

>>> # xdoctest: +REQUIRES(module:torch)
>>> nC = num_classes = 3
>>> labels = (torch.rand(10, 11, 12) * nC).long()
>>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11, 12)
>>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11, 12)
>>> assert one_hot_embedding(labels, nC, dim=2).shape == (10, 11, 3, 12)
>>> assert one_hot_embedding(labels, nC, dim=3).shape == (10, 11, 12, 3)
>>> labels = (torch.rand(10, 11) * nC).long()
>>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11)
>>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11)
>>> labels = (torch.rand(10) * nC).long()
>>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10)
>>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3)
kwarray.util_torch.one_hot_lookup(data, indices)

Return value of a particular column for each row in data.

Each item in labels corresonds to a row in data. Returns the index specified at each row.

Parameters
  • data (ArrayLike) – N x C float array of values

  • indices (ArrayLike) – N integer array between 0 and C. This is an column index for each row in data.

Returns

the selected probability for each row

Return type

ArrayLike

Notes

This is functionally equivalent to [row[c] for row, c in zip(data, indices)] except that it is works with pure matrix operations.

Todo

  • [ ] Allow the user to specify which dimension indices should be

    zipped over. By default it should be dim=0

  • [ ] Allow the user to specify which dimension indices should select

    from. By default it should be dim=1.

Example

>>> from kwarray.util_torch import *  # NOQA
>>> data = np.array([
>>>     [0, 1, 2],
>>>     [3, 4, 5],
>>>     [6, 7, 8],
>>>     [9, 10, 11],
>>> ])
>>> indices = np.array([0, 1, 2, 1])
>>> res = one_hot_lookup(data, indices)
>>> print('res = {!r}'.format(res))
res = array([ 0,  4,  8, 10])
>>> alt = np.array([row[c] for row, c in zip(data, indices)])
>>> assert np.all(alt == res)

Example

>>> # xdoctest: +REQUIRES(module:torch)
>>> import torch
>>> data = torch.from_numpy(np.array([
>>>     [0, 1, 2],
>>>     [3, 4, 5],
>>>     [6, 7, 8],
>>>     [9, 10, 11],
>>> ]))
>>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long()
>>> res = one_hot_lookup(data, indices)
>>> print('res = {!r}'.format(res))
res = tensor([ 0,  4,  8, 10]...)
>>> alt = torch.LongTensor([row[c] for row, c in zip(data, indices)])
>>> assert torch.all(alt == res)
Ignore:
>>> # xdoctest: +REQUIRES(module:torch, module:onnx, module:onnx_tf)
>>> # Test if this converts to ONNX
>>> from kwarray.util_torch import *  # NOQA
>>> import torch.onnx
>>> import io
>>> import onnx
>>> import onnx_tf.backend
>>> import numpy as np
>>> data = torch.from_numpy(np.array([
>>>     [0, 1, 2],
>>>     [3, 4, 5],
>>>     [6, 7, 8],
>>>     [9, 10, 11],
>>> ]))
>>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long()
>>> class TFConvertWrapper(torch.nn.Module):
>>>     def forward(self, data, indices):
>>>         return one_hot_lookup(data, indices)
>>> ###
>>> # Test the ONNX export
>>> wrapped = TFConvertWrapper()
>>> onnx_file = io.BytesIO()
>>> torch.onnx.export(
>>>     wrapped, tuple([data, indices]),
>>>     input_names=['data', 'indices'],
>>>     output_names=['out'],
>>>     f=onnx_file,
>>>     opset_version=11,
>>>     verbose=1,
>>> )
>>> onnx_file.seek(0)
>>> onnx_model = onnx.load(onnx_file)
>>> onnx_tf_model = onnx_tf.backend.prepare(onnx_model)
>>> # Test that the resulting graph tensors are concretely sized.
>>> import tensorflow as tf
>>> onnx_gd = onnx_tf_model.graph.as_graph_def()
>>> output_tensors = tf.import_graph_def(
>>>     onnx_gd,
>>>     input_map={},
>>>     return_elements=[onnx_tf_model.tensor_dict[ol].name for ol in onnx_tf_model.outputs]
>>> )
>>> assert all(isinstance(d.value, int) for t in output_tensors for d in t.shape)
>>> tf_outputs = onnx_tf_model.run([data, indices])
>>> pt_outputs = wrapped(data, indices)
>>> print('tf_outputs = {!r}'.format(tf_outputs))
>>> print('pt_outputs = {!r}'.format(pt_outputs))
>>> ###
>>> # Test if data is more than 2D
>>> shape = (4, 3, 8)
>>> data = torch.arange(int(np.prod(shape))).view(*shape).float()
>>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long()
>>> onnx_file = io.BytesIO()
>>> torch.onnx.export(
>>>     wrapped, tuple([data, indices]),
>>>     input_names=['data', 'indices'],
>>>     output_names=['out'],
>>>     f=onnx_file,
>>>     opset_version=11,
>>>     verbose=1,
>>> )
>>> onnx_file.seek(0)
>>> onnx_model = onnx.load(onnx_file)
>>> onnx_tf_model = onnx_tf.backend.prepare(onnx_model)
>>> # Test that the resulting graph tensors are concretely sized.
>>> import tensorflow as tf
>>> onnx_gd = onnx_tf_model.graph.as_graph_def()
>>> output_tensors = tf.import_graph_def(
>>>     onnx_gd,
>>>     input_map={},
>>>     return_elements=[onnx_tf_model.tensor_dict[ol].name for ol in onnx_tf_model.outputs]
>>> )
>>> assert all(isinstance(d.value, int) for t in output_tensors for d in t.shape)
>>> tf_outputs = onnx_tf_model.run([data, indices])
>>> pt_outputs = wrapped(data, indices)
>>> print('tf_outputs = {!r}'.format(tf_outputs))
>>> print('pt_outputs = {!r}'.format(pt_outputs))