AVX 512 matrix multiplication with column-wise traversal on B

94 views Asked by At

I wrote a matrix multiplication over floating point values with AVX512 intrinsics -

for (int i=200; i<400; i++) {
        for (int k=1200; k<1400; k++) {
            tmp=val[440000+ (k-1200)*200 + (i-200)];
            sv=_mm512_set1_ps(tmp);
            for (int j=0; j<512/16; j++) {
                cv = _mm512_loadu_ps(&y[i*512 + 16*j]);
                bv = _mm512_loadu_ps(&x[k*512 + 16*j]);
                cv = _mm512_fmadd_ps(sv, bv, cv);
                _mm512_storeu_ps(&y[512*i+16*j], cv);
            }
        }
    }

If you notice the line tmp=val[440000+ (k-1200)*200 + (i-200)];, you will notice that val is traversed in a column major order, which is why k is multiplied by k's range.

i, k and j have the usual meanings where i and k are the row and column dimensions of the first matrix (val) and k and j are the row and column dimensions of the second matrix (x).

I am currently compiling the program with the flags - gcc -O3 -march=native -mavx -mprefer-vector-width=512 -ffast-math.

I have also tried unrolling and jamming, which resulted in the following:

for (int i=200; i<400; i++) {
    for (int k=1400; k<1600; k+=2) {
        sv[0] = _mm512_set1_ps(val[520000 + (k+-1400) * 200 + (i-200)]);
        sv[1] = _mm512_set1_ps(val[520000 + (k+-1399) * 200 + (i-200)]);
        for (int j=0; j<512/16; j++){
            cv = _mm512_loadu_ps(&y[i*512 + 16*j]);
            bv[0] = _mm512_loadu_ps(&x[(k+0)*512 + 16*j]);
            cv = _mm512_fmadd_ps(sv[0], bv[0], cv);
            bv[1] = _mm512_loadu_ps(&x[(k+1)*512 + 16*j]);
            cv = _mm512_fmadd_ps(sv[1], bv[1], cv);
            _mm512_storeu_ps(&y[512*i+16*j], cv);
        }
    }
}

Is there a way to further optimize these loops? (Short of parallelizing the loops)

Could I trivially dispatch to cblas_sgemm?

Other things I have tried are -

  1. Unrolling the j loop
  2. Tiling the j loop
  3. Transposing the val array, and making it val[440000 + (i-200)*200 + (k-1200)]
0

There are 0 answers