Оптимизируйте функцию «маска» в Matlab

Для сравнительного сравнения я рассматриваю простую функцию:

function dealiasing2d(where_dealiased, data) [n1, n0, nk] = size(data); for i0=1:n0 for i1=1:n1 if where_dealiased(i1, i0) data(i1, i0, :) = 0.; end end end 

Это может быть полезно при псевдоспектральном моделировании (где data представляют собой 3D-массив комплексных чисел), но в основном он применяет маску к набору изображений, помещая в нули некоторые элементы, для которых where_dealiased истинно.

Я сравниваю производительность разных языков (и реализаций, компиляторов, …) в этом простом случае. Для Matlab я использую функцию timeit . Поскольку я не хочу сравнивать свое незнание в Matlab, я бы хотел оптимизировать эту функцию с помощью этого языка. Каким будет самый быстрый способ сделать это в Matlab?

Простое решение, которое я использую сейчас:

 function dealiasing2d(where_dealiased, data) [n1, n0, nk] = size(data); N = n0*n1; ind_zeros = find(reshape(where_dealiased, 1, [])); for ik=1:nk data(ind_zeros + N*(ik-1)) = 0; end 

Я подозреваю, что это неправильный способ сделать это, так как эквивалентное решение Numpy примерно в 10 раз быстрее.

 import numpy as np def dealiasing(where, data): nk = data.shape[0] N = reduce(lambda x, y: x*y, data.shape[1:]) inds, = np.nonzero(where.flat) for ik in xrange(nk): data.flat[inds + N*ik] = 0. 

Наконец, если кто-то скажет мне что-то вроде «Когда вы хотите быть очень быстрым с определенной функцией в Matlab, вы должны скомпилировать его так: […]», я бы включил такое решение в эталон.


Редактировать:

После двух ответов я сравнивал эти предложения и, похоже, нет заметного улучшения производительности. Это странно, потому что простое решение Python-Numpy действительно (на порядок) намного быстрее, поэтому я все еще ищу лучшее решение с Matlab …

Если я правильно понимаю, это можно сделать легко и быстро с помощью bsxfun :

 data = bsxfun(@times, data, ~where_dealiased); 

Это устанавливает в 0 все компоненты третьего измерения элементов, для которых where_dealiased является true (он умножает их на 0 ), и оставляет остальное как они были (он умножает их на 1 ).

Конечно, это предполагает [size(data,1) size(data,2]==size(where_dealiased) .


Ваше решение с линейной индексацией , вероятно, очень быстро. Чтобы сэкономить некоторое время там, вы можете удалить reshape , потому что find уже возвращает линейные индексы:

 ind_zeros = find(where_dealiased); 

Подход №1: Логическое индексирование С repmat

 data(repmat(where_dealiased,1,1,size(data,3))) = 0; 

Подход №2: Линейное индексирование с помощью bsxfun(@plus

 [m,n,r] = size(data); idx = bsxfun(@plus,find(where_dealiased),[0:r-1]*m*n); %// linear indices data(idx) = 0; 

Это должно быть быстро, если у вас есть несколько ненулевых элементов в where_dealiased .

Никакой оптимизации без бенчмарка! Итак, вот некоторые предлагаемые решения и измерения производительности. Код инициализации:

 N = 2000; nk = 10; where = false([N, N]); where(1:100, 1:100) = 1; data = (5.+j)*ones([N, N, nk]); 

и я выполняю функции с функцией timeit следующим образом:

 timeit(@() dealiasing2d(where, data)) 

Для сравнения, когда я делаю точно то же самое с функцией Numpy, заданной в вопросе, она работает в 0,0167 с.

Начальные функции Matlab с 2 циклами выполняются примерно через 0,34 с, а эквивалентная функция Numpy (с 2 циклами) работает медленнее и работает через 0,42 с. Это может быть потому, что Matlab использует компиляцию JIT.

Луис Мендо упоминает, что я могу удалить reshape потому что find уже возвращает линейные индексы. Мне нравится, так как код намного чище, но reshape в любом случае очень дешево, поэтому оно действительно не улучшает производительность функции:

 function dealiasing2d(where, data) [n1, n0, nk] = size(data); N = n0*n1; ind_zeros = find(where); for ik=1:nk data(ind_zeros + N*(ik-1)) = 0; end 

Эта функция занимает 0.23 с, что быстрее, чем решение с 2 циклами, но очень медленное по сравнению с решением Numpy (~ 14 раз медленнее!). Вот почему я написал свой вопрос.

Луис Мендо также предлагает решение, основанное на функции bsxfun , которая дает:

 function dealiasing2d_bsxfun(where, data) data = bsxfun(@times, data, ~where); 

Это решение включает в себя N*N*nk умножений (на 1 или 0), что явно слишком много, так как нам просто нужно положить в значения массива значения 100*100*nk . Однако эти умножения могут быть векторизованы так, что они «довольно быстрые» по сравнению с другими решениями Matlab: 0,23 с, то же самое, что и первое решение с использованием find !

Оба решения, предлагаемые Divakar, включают в себя создание большого массива размера N*N*nk . Нет петли Matlab, поэтому мы можем надеяться на лучшее исполнение, но …

 function dealiasing2d_bsxfun2(where, data) [n1, n0, nk] = size(data); idx = bsxfun(@plus, find(where), [0:nk-1]*n1*n0); data(idx) = 0; 

занимает 0,23 с (все тот же период времени, что и другие функции!) и

 function dealiasing2d(where, data) data(repmat(where,[1,1,size(data,3)])) = 0; 

составляет 0,30 с (~ 20% больше, чем другие решения Matlab).

В заключение, кажется, что в этом случае есть что-то, что ограничивает производительность Matlab. Также может быть, что в Matlab есть лучшее решение или что я делаю что-то не так с эталоном … Было бы здорово, если бы кто-то из Matlab и Python-Numpy мог предоставить другие тайминги.


Редактировать:

Еще несколько данных относительно комментария Дивакара:

При N = 500; nk = 500:

 Method | time (s) | normalized ----------------|----------|------------ Numpy | 0.05 | 1.0 Numpy loop | 0.05 | 1.0 Matlab bsxfun | 0.70 | 14.0 Matlab find | 0.75 | 15.0 Matlab bsxfun2 | 0.76 | 15.2 Matlab loop | 0.77 | 15.4 Matlab repmat | 0.96 | 19.2 

При N = 500; nk = 100:

 Method | time (s) | normalized ----------------|----------|------------ Numpy | 0.01 | 1.0 Numpy loop | 0.03 | 3.0 Matlab bsxfun | 0.14 | 12.7 Matlab find | 0.15 | 13.6 Matlab bsxfun2 | 0.16 | 14.5 Matlab loop | 0.16 | 14.5 Matlab repmat | 0.20 | 18.2 

Для N = 2000; nk = 10:

 Method | time (s) | normalized | ----------------|----------|------------| Numpy | 0.02 | 1.0 | Matlab find | 0.23 | 13.8 | Matlab bsxfun2 | 0.23 | 13.8 | Matlab bsxfun | 0.24 | 14.4 | Matlab repmat | 0.30 | 18.0 | Matlab loop | 0.34 | 20.4 | Numpy loop | 0.42 | 25.1 | 

Я действительно удивляюсь, почему Matlab кажется настолько медленным по сравнению с Numpy …

Interesting Posts