Фильтровать строки массива numpy?

Я хочу применить функцию к каждой строке массива numpy. Если эта функция будет равна true, я сохраню строку, иначе я ее отброшу. Например, моя функция может быть:

def f(row): if sum(row)>10: return True else: return False 

Мне было интересно, есть ли что-то похожее:

 np.apply_over_axes() 

который применяет функцию к каждой строке массива numpy и возвращает результат. Я надеялся на что-то вроде:

 np.filter_over_axes() 

который применил бы функцию к каждой строке массива numpy и только возвращал строки, для которых функция возвращала true. Есть ли что-нибудь подобное? Или я должен просто использовать цикл for?

One Solution collect form web for “Фильтровать строки массива numpy?”

В идеале вы сможете реализовать векторизованную версию своей функции и использовать ее для булевской индексации . Для подавляющего большинства проблем это правильное решение. Numpy предоставляет довольно много функций, которые могут действовать по различным осям, а также все основные операции и сравнения, поэтому большинство полезных условий должны быть векционируемыми.

 import numpy as np x = np.random.randn(20, 3) x_new = x[np.sum(x, axis=1) > .5] 

Если вы абсолютно уверены, что не можете сделать выше, я бы предложил использовать понимание списка (или np.apply_along_axis ) для создания массива bools для индексирования.

 def myfunc(row): return sum(row) > .5 bool_arr = np.array([myfunc(row) for row in x]) x_new = x[bool_arr] 

Это позволит сделать работу относительно чистым способом, но будет значительно медленнее, чем векторная версия. Пример:

 x = np.random.randn(5000, 200) %timeit x[np.sum(x, axis=1) > .5] # 100 loops, best of 3: 5.71 ms per loop %timeit x[np.array([myfunc(row) for row in x])] # 1 loops, best of 3: 217 ms per loop 
  • Каков наиболее эффективный способ преобразования набора результатов MySQL в массив NumPy?
  • Как получить наивысший элемент по абсолютной величине в матрице numpy?
  • python - матрица RGB изображения
  • Объект 'numpy.ndarray' не имеет атрибута 'read'
  • Чтение и сохранение произвольных целых чисел длины байта из файла
  • повторное время с использованием scipy.signal.resample
  • Пиксельные соседи в массиве 2d (изображение) с использованием Python
  • сопоставить каждый элемент с выражением
  • NumPy: как фильтровать матричные линии
  • Python: обработка большого набора данных. Scipy или Rpy? И как?
  • Как определить, установлен ли numpy
  • Python - лучший язык программирования в мире.