Python/Numpy is 10x slower than MATLAB. Advice on how to speed up?

65 views Asked by At

I have a script that requires multiple loops. The Python version takes ~88 seconds while the MATLAB version takes ~9 seconds. I am already using vectorization to remove one of the loops.

MATLAB:

ntotal = 2000;
nj = 4000;
r = zeros(ntotal,nj,3);
ncorr = 200;
temp2 = zeros(ncorr,nj);
final = zeros(3,ncorr);


tic
for k = 1:3
    for j = 1:nj
        for n0 = 1:ncorr
            temp1=(r(1:ncorr-n0,j,k)-r(n0:ncorr-1,j,k)).^2;
            temp2(n0,j) = mean(temp1);
        end
    end
    final(k,:) = mean(temp2,2);
end
toc

Python:

import time
import numpy as np

ntotal = 2000
nj = 4000
r = np.zeros((ntotal,nj,3))
ncorr = 200
temp2 = np.zeros((ncorr,nj))
final = np.zeros((3,ncorr))

t0 = time.time()

for k in range(3):
    for j in range(nj):
        for n0 in range(ncorr):
            temp1 = (r[0:ntotal-n0,j,k]-r[n0:ntotal,j,k]) ** 2
            temp2[n0,j] = np.mean(temp1)
    final[k,:] = np.mean(temp2,axis=1)

t1 = time.time()
print(t1-t0)

When I increased nj to 10000, the timings became 21 (MATLAB) vs 248 seconds (Python). Any advice on how to make the Python script run faster? Thanks.

3

There are 3 answers

1
Maksym Machkovskiy On

Try to use cython, it's a staticly typed, compiled superset of python, it really makes the code faster, particullary with nested loops. I think you could have a better result than with MATLAB.

1
hpaulj On

I don't think you need to explicitly loop on j and k. Here's a tentative rewrite:

In [85]: temp2 = np.zeros((ncorr,nj,3))
    ...: for n0 in range(ncorr):
    ...:     temp1 = (r[0:ntotal-n0,:,:]-r[n0:ntotal,:,:]) ** 2
    ...:     temp2[n0,:] = np.mean(temp1, axis=0)
    ...: final  = temp2.mean(axis=1)

In [86]: final.shape
Out[86]: (200, 3)

This appears to be much faster, however I did not let your version run to completion; it was taking too long on my small computer. Also since your r is all zeros, there isn't much point in verifying the final mean values.

It may be possible to rework the n0 loop, but since that is only 200, while nj and ntotal are 4000 and 2000, I wouldn't expect as significant a speed gain.

1
Nick ODell On

Given the nature of the problem, I would suggest using an appropriate memory ordering, and using Numba.

To begin with, I put your code into a function, so that I had something to compare against.

def function_orig(ntotal, nj, r, ncorr, temp2):
    final = np.zeros((3,ncorr))
    for k in range(3):
        for j in range(nj):
            for n0 in range(ncorr):
                temp1 = (r[0:ntotal-n0,j,k]-r[n0:ntotal,j,k]) ** 2
                temp2[n0,j] = np.mean(temp1)
        final[k,:] = np.mean(temp2,axis=1)
    return final

(I couldn't tell if temp2 is intended to be a parameter here. It probably ought to be a local variable, but I wasn't sure.)

First thing to notice is that the expression r[0:ntotal-n0,j,k] selects many elements on the first axis, and a single element on the remaining axis. This is something which is efficient on column-major arrays, but inefficient on row-major arrays. NumPy uses row-major arrays by default. Matlab uses column-major by default. More information on memory orderings.

If you switch to column-major for this one array, it's 2x faster.

def function_colmajor(ntotal, nj, r, ncorr, temp2):
    r = np.asfortranarray(r)
    final = np.zeros((3,ncorr))
    for k in range(3):
        for j in range(nj):
            for n0 in range(ncorr):
                temp1 = (r[0:ntotal-n0,j,k]-r[n0:ntotal,j,k]) ** 2
                temp2[n0,j] = np.mean(temp1)
        final[k,:] = np.mean(temp2,axis=1)
    return final

Next, I applied a JIT compiler which is frequently helpful for accellerating NumPy-heavy loops, Numba. I also turned on fastmath=True. See here for an explanation of fastmath: https://numba.pydata.org/numba-doc/latest/user/performance-tips.html#fastmath

import numba as nb


@nb.njit(fastmath=True)
def function_numba(ntotal, nj, r, ncorr, temp2):
    r = np.asfortranarray(r)
    final = np.zeros((3,ncorr))
    for k in range(3):
        for j in range(nj):
            for n0 in range(ncorr):
                temp1 = (r[0:ntotal-n0,j,k]-r[n0:ntotal,j,k]) ** 2
                temp2[n0,j] = np.mean(temp1)
        final[k,:] = mean_along_axis_1(temp2)
    return final


@nb.njit(fastmath=True)
def mean_along_axis_1(array):
    assert len(array.shape) == 2, "array not 2d"
    output = np.zeros(array.shape[0])
    for i in range(array.shape[0]):
        output[i] = array[i].mean()
    return output

Timings on my hardware:

Original
1min 21s ± 1.46 s per loop (mean ± std. dev. of 3 runs, 1 loop each)
Column major
41.4 s ± 217 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)
Numbaized
2.51 s ± 16.2 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)