kwarray.util_torch module

Torch specific extensions

kwarray.util_torch._is_in_onnx_export()[source]
kwarray.util_torch.one_hot_embedding(labels, num_classes, dim=1)[source]

Embedding labels to one-hot form.

Parameters:
  • labels – (LongTensor) class labels, sized [N,].

  • num_classes – (int) number of classes.

  • dim (int) – dimension which will be created, if negative

Returns:

encoded labels, sized [N,#classes].

Return type:

Tensor

References

https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4

Example

>>> # each element in target has to have 0 <= value < C
>>> # xdoctest: +REQUIRES(module:torch)
>>> import torch
>>> labels = torch.LongTensor([0, 0, 1, 4, 2, 3])
>>> num_classes = max(labels) + 1
>>> t = one_hot_embedding(labels, num_classes)
>>> assert all(row[y] == 1 for row, y in zip(t.numpy(), labels.numpy()))
>>> import ubelt as ub
>>> print(ub.urepr(t.numpy().tolist()))
[
    [1.0, 0.0, 0.0, 0.0, 0.0],
    [1.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 1.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 1.0],
    [0.0, 0.0, 1.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 1.0, 0.0],
]
>>> t2 = one_hot_embedding(labels.numpy(), num_classes)
>>> assert np.all(t2 == t.numpy())
>>> from kwarray.util_torch import _torch_available_devices
>>> devices = _torch_available_devices()
>>> if devices:
>>>     device = devices[0]
>>>     try:
>>>         t3 = one_hot_embedding(labels.to(device), num_classes)
>>>     except RuntimeError:
>>>         pass
>>>     assert np.all(t3.cpu().numpy() == t.numpy())

Example

>>> # xdoctest: +REQUIRES(module:torch)
>>> import torch
>>> nC = num_classes = 3
>>> labels = (torch.rand(10, 11, 12) * nC).long()
>>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11, 12)
>>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11, 12)
>>> assert one_hot_embedding(labels, nC, dim=2).shape == (10, 11, 3, 12)
>>> assert one_hot_embedding(labels, nC, dim=3).shape == (10, 11, 12, 3)
>>> labels = (torch.rand(10, 11) * nC).long()
>>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10, 11)
>>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3, 11)
>>> labels = (torch.rand(10) * nC).long()
>>> assert one_hot_embedding(labels, nC, dim=0).shape == (3, 10)
>>> assert one_hot_embedding(labels, nC, dim=1).shape == (10, 3)
kwarray.util_torch.one_hot_lookup(data, indices)[source]

Return value of a particular column for each row in data.

Each item in labels corresonds to a row in data. Returns the index specified at each row.

Parameters:
  • data (ArrayLike) – N x C float array of values

  • indices (ArrayLike) – N integer array between 0 and C. This is an column index for each row in data.

Returns:

the selected probability for each row

Return type:

ArrayLike

Note

This is functionally equivalent to [row[c] for row, c in zip(data, indices)] except that it is works with pure matrix operations.

Todo

  • [ ] Allow the user to specify which dimension indices should be

    zipped over. By default it should be dim=0

  • [ ] Allow the user to specify which dimension indices should select

    from. By default it should be dim=1.

Example

>>> from kwarray.util_torch import *  # NOQA
>>> data = np.array([
>>>     [0, 1, 2],
>>>     [3, 4, 5],
>>>     [6, 7, 8],
>>>     [9, 10, 11],
>>> ])
>>> indices = np.array([0, 1, 2, 1])
>>> res = one_hot_lookup(data, indices)
>>> print('res = {!r}'.format(res))
res = array([ 0,  4,  8, 10])
>>> alt = np.array([row[c] for row, c in zip(data, indices)])
>>> assert np.all(alt == res)

Example

>>> # xdoctest: +REQUIRES(module:torch)
>>> import torch
>>> data = torch.from_numpy(np.array([
>>>     [0, 1, 2],
>>>     [3, 4, 5],
>>>     [6, 7, 8],
>>>     [9, 10, 11],
>>> ]))
>>> indices = torch.from_numpy(np.array([0, 1, 2, 1])).long()
>>> res = one_hot_lookup(data, indices)
>>> print('res = {!r}'.format(res))
res = tensor([ 0,  4,  8, 10]...)
>>> alt = torch.LongTensor([row[c] for row, c in zip(data, indices)])
>>> assert torch.all(alt == res)
kwarray.util_torch._torch_available_devices()[source]

An attempt to determine what devices this version of torch can use

Try and check that cuda is availble AND we have a good kernel image