Найти, где массив numpy равен любому значению списка значений

У меня есть массив целых чисел и вы хотите найти, где этот массив равен любому значению в списке из нескольких значений. Это легко сделать, обрабатывая каждое значение по отдельности или используя несколько "или" операторов в цикле, но я чувствую, что должен быть лучший/более быстрый способ сделать это. Я на самом деле занимаюсь массивами размером 4000х2000, но вот упрощенная редакция проблемы:

fake=arange(9).reshape((3,3))
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
want=(fake==0)+(fake==2)+(fake==6)+(fake==8)
print want 
array([[ True, False,  True],
       [False, False, False],
       [ True, False,  True]], dtype=bool)

Я хотел бы получить способ получить want из одной команды с участием fake и списка значений [0,2,6,8]. Я мог бы написать команду сам, но я предполагаю, что есть пакет, который включает это уже уже, что будет значительно быстрее, чем если бы я просто написал функцию с циклом в python.

Спасибо, -Adam

Ответ 1

Функция numpy.in1d, похоже, делает то, что вы хотите. Единственные проблемы в том, что он работает только на 1d массивах, поэтому вы должны использовать его следующим образом:

In [9]: np.in1d(fake, [0,2,6,8]).reshape(fake.shape)
Out[9]: 
array([[ True, False,  True],
       [False, False, False],
       [ True, False,  True]], dtype=bool)

Я не знаю, почему это ограничивается только 1d массивами. Рассматривая исходный код сначала он сглаживает два массива, после чего он делает некоторые умные трюки сортировки. Но ничто не остановило бы его от окончательного развязывания результата в конце, как я должен был сделать вручную здесь.

Ответ 2

@Ваш ответ - тот, который вы, вероятно, ищете. Но вот еще один способ сделать это, используя трюк numpy vectorize:

import numpy as np
S = set([0,2,6,8])

@np.vectorize
def contained(x):
    return x in S

contained(fake)
=> array([[ True, False,  True],
          [False, False, False],
          [ True, False,  True]], dtype=bool)

Кон этим решением является то, что contained() вызывается для каждого элемента (т.е. в python-пространстве), что делает это намного медленнее, чем решение с чисто-numpy.