Как Pytorch Tensor получает индекс определенной стоимости

В списке python можно использовать list.index(somevalue). Как может pytorch сделать это?
Например:

    a=[1,2,3]
    print(a.index(2))

Тогда, 1 будет выводиться. Как может тензор pytorch сделать это без преобразования его в список python?

Ответ 1

Я думаю, что нет прямого перевода из list.index() в функцию pytorch. Однако вы можете добиться аналогичных результатов, используя tensor==number а затем nonzero() функцию. Например:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())

Этот кусок кода возвращается

1

[torch.LongTensor размера 1x1]

Ответ 2

Может быть сделано путем преобразования в numpy следующим образом

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2

Ответ 3

    import torch
    x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
    print(x_data.data[0])
    >>tensor([1.])

Ответ 4

Для тензоров с плавающей точкой я использую это, чтобы получить индекс элемента в тензоре.

print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())

Здесь я хочу получить индекс max_value в тензоре с плавающей запятой, вы также можете указать свое значение таким образом, чтобы получить индекс любых элементов в тензоре.

print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())