How to overlay images with alpha blending using AVX512 instructions?

217 views Asked by At

I have two images A and B that are stored as byte arrays of ARGB data:

Image A: [a0, r0, g0, b0, a1, r1, g1, b1, ...]
Image B: [a0, r0, g0, b0, a1, r1, g1, b1, ...]

I would like to overlay image B on top of A using the alpha blending formula.

How can I achieve this with AVX512 instructions that operate on multiple pixels at a time?

I don't mind using 256 instead of 255 in the calculations if that makes things simpler.

Edit:

I tried implementing this based on another stackoverflow answer. However, it seems to be slower than the non-AVX512 code that runs one pixel at a time. What am I doing wrong?

I tried without using lazy_static! (because I think it uses a locking data structure) and passed the constants into the function but it was still slower. Is this just not a good problem to solve with AVX512? It seems like it should be.

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn overlay_row_avx512(this_chunk: &mut [u8], image_chunk: &[u8]) {
    use std::arch::x86_64::*;

    let this_ptr = this_chunk.as_mut_ptr() as *mut i8;
    let image_ptr = image_chunk.as_ptr() as *const i8;

    let this_argb = _mm512_loadu_epi8(this_ptr);
    let image_argb = _mm512_loadu_epi8(image_ptr);

    // Pick out the upper 256-bits and calculate inv_alpha.
    let this_upper = _mm512_shuffle_epi8(this_argb, *UPPER_TO_U16);
    let image_upper = _mm512_shuffle_epi8(image_argb, *UPPER_TO_U16);
    let alpha_upper = _mm512_shuffle_epi8(image_argb, *UPPER_ALPHA_TO_U16);
    let inv_alpha_upper = _mm512_subs_epu8(*U8_MAX_VALUE, alpha_upper);

    // Apply the blend function and store the result in blended_upper_u8.
    let this_blended_upper = _mm512_mullo_epi16(this_upper, inv_alpha_upper);
    let image_blended_upper = _mm512_mullo_epi16(image_upper, alpha_upper); // TODO: premultiply alpha
    let blended_upper = _mm512_add_epi16(this_blended_upper, image_blended_upper);
    let blended_upper_u8 = _mm512_shuffle_epi8(blended_upper, *UPPER_U16_TO_U8);

    // Repeat for the lower 256-bits.
    let this_lower = _mm512_shuffle_epi8(this_argb, *LOWER_TO_U16);
    let image_lower = _mm512_shuffle_epi8(image_argb, *LOWER_TO_U16);
    let alpha_lower = _mm512_shuffle_epi8(image_argb, *LOWER_ALPHA_TO_U16);
    let inv_alpha_lower = _mm512_subs_epu8(*U8_MAX_VALUE, alpha_lower);

    let this_blended_lower = _mm512_mullo_epi16(this_lower, inv_alpha_lower);
    let image_blended_lower = _mm512_mullo_epi16(image_lower, alpha_lower); // TODO: premultiply alpha
    let blended_lower = _mm512_add_epi16(this_blended_lower, image_blended_lower);
    let blended_lower_u8 = _mm512_add_epi16(blended_lower, *LOWER_U16_TO_U8);

    // OR together the upper and lower 256-bits.
    let blended = _mm512_or_si512(blended_upper_u8, blended_lower_u8);

    _mm512_storeu_epi8(this_ptr, blended);
}

lazy_static! {
    static ref U8_MAX_VALUE: __m512i = unsafe { _mm512_set1_epi8(-1) };

    static ref UPPER_TO_U16: __m512i = unsafe {
        _mm512_set_epi8(
            X, 63, X, 62, X, 61, X, 60, X, 59, X, 58, X, 57, X, 56,
            X, 55, X, 54, X, 53, X, 52, X, 51, X, 50, X, 49, X, 48,
            X, 47, X, 46, X, 45, X, 44, X, 43, X, 42, X, 41, X, 40,
            X, 39, X, 38, X, 37, X, 36, X, 35, X, 34, X, 33, X, 32,
        )
    };

    static ref LOWER_TO_U16: __m512i = unsafe {
        _mm512_set_epi8(
            X, 31, X, 30, X, 29, X, 28, X, 27, X, 26, X, 25, X, 24,
            X, 23, X, 22, X, 21, X, 20, X, 19, X, 18, X, 17, X, 16,
            X, 15, X, 14, X, 13, X, 12, X, 11, X, 10, X,  9, X,  8,
            X,  7, X,  6, X,  5, X,  4, X,  3, X,  2, X,  1, X,  0,
        )
    };

    static ref UPPER_ALPHA_TO_U16: __m512i = unsafe {
        _mm512_set_epi8(
            X, 63, X, 63, X, 63, X, 63, X, 59, X, 59, X, 59, X, 59,
            X, 55, X, 55, X, 55, X, 55, X, 51, X, 51, X, 51, X, 51,
            X, 47, X, 47, X, 47, X, 47, X, 43, X, 43, X, 43, X, 43,
            X, 39, X, 39, X, 39, X, 39, X, 35, X, 35, X, 35, X, 35,
        )
    };

    static ref LOWER_ALPHA_TO_U16: __m512i = unsafe {
        _mm512_set_epi8(
            X, 31, X, 31, X, 31, X, 31, X, 27, X, 27, X, 27, X, 27,
            X, 23, X, 23, X, 23, X, 23, X, 19, X, 19, X, 19, X, 19,
            X, 15, X, 15, X, 15, X, 15, X, 11, X, 11, X, 11, X, 11,
            X,  7, X,  7, X,  7, X,  7, X,  3, X,  3, X,  3, X,  3,
        )
    };

    // Pick out the upper 8-bits of each 16-bit u16.
    // This effectively divides by 256.
    static ref UPPER_U16_TO_U8: __m512i = unsafe {
        _mm512_set_epi8(
            63, X, 62, X, 61, X, 60, X, 59, X, 58, X, 57, X, 56, X,
            55, X, 54, X, 53, X, 52, X, 51, X, 50, X, 49, X, 48, X,
            47, X, 46, X, 45, X, 44, X, 43, X, 42, X, 41, X, 40, X,
            39, X, 38, X, 37, X, 36, X, 35, X, 34, X, 33, X, 32, X,
        )
    };

    static ref LOWER_U16_TO_U8: __m512i = unsafe {
        _mm512_set_epi8(
            31, X, 30, X, 29, X, 28, X, 27, X, 26, X, 25, X, 24, X,
            23, X, 22, X, 21, X, 20, X, 19, X, 18, X, 17, X, 16, X,
            15, X, 14, X, 13, X, 12, X, 11, X, 10, X,  9, X,  8, X,
             7, X,  6, X,  5, X,  4, X,  3, X,  2, X,  1, X,  0, X,
        )
    };
}

For comparison, here's my code that runs one pixel at a time:

// A chunk is just 4 bytes in this case rather than 64 bytes.
fn overlay_row_without_simd(this_chunk: &mut [u8], image_chunk: &[u8]) {
    let alpha = image_chunk[0] as u32;
    let inv_alpha = 255 - alpha;

    this_chunk[1] = ((this_chunk[1] as u32 * inv_alpha + image_chunk[1] as u32 * alpha) / 255) as u8;
    this_chunk[2] = ((this_chunk[2] as u32 * inv_alpha + image_chunk[2] as u32 * alpha) / 255) as u8;
    this_chunk[3] = ((this_chunk[3] as u32 * inv_alpha + image_chunk[3] as u32 * alpha) / 255) as u8;
}
1

There are 1 answers

0
Chris On

I managed to figure out what was going on!

Basically, the implementation in my question is wrong and ends up generating a video with lots of random-looking pixels. In my case, I'm piping the output of each frame to ffmpeg and it was having a really difficult time compressing the frames because of all the random colors and this was the reason my program ran twice as slowly - not because of anything to do with the AVX512 code.

I spent a long time figuring out how to do this properly in AVX512 and fixing my code. I then benchmarked the function directly and found that it runs 5.38x faster than the one pixel code. I intentionally wrote it so that I only rely on the avx512f and avx512bw features for better CPU compatibility. It might be possible to save a couple of instructions by using _mm512_permutexvar_epi8 but that requires avx512vbmi.

My working implementation is here and process 32 pixels at a time:

unsafe fn overlay_chunk(this_chunk: &mut [u8], image_chunk: &[u8], c: &AVX512Constants) {
    let this_ptr = this_chunk.as_mut_ptr() as *mut i8;
    let image_ptr = image_chunk.as_ptr() as *const i8;

    let this_argb = _mm256_loadu_epi8(this_ptr);
    let image_argb = _mm256_loadu_epi8(image_ptr);

    // Extend each 8-bit integer into a 16-bit integer (zero filled).
    let this_u16 = _mm512_cvtepu8_epi16(this_argb);
    let image_u16 = _mm512_cvtepu8_epi16(image_argb);

    // Copy the alpha channel over each rgb channel.
    let image_alpha = _mm512_shuffle_epi8(image_u16, c.copy_alpha_to_rgb);

    // Calculate (255 - alpha) and set each u16 alpha value to 256.
    // We shift right by 8 bits later and 256 >> 8 equals 1.
    let image_inv_alpha = _mm512_sub_epi8(c.inv_alpha_minuend, image_alpha);

    // Apply the alpha blending formula (https://graphics.fandom.com/wiki/Alpha_blending).
    let this_blended = _mm512_mullo_epi16(this_u16, image_inv_alpha);
    let image_blended = _mm512_mullo_epi16(image_u16, image_alpha); // TODO: premultiply alpha

    let blended = _mm512_add_epi16(this_blended, image_blended);

    // Shift the u16 values right by 8 bits which divides by 256. We should
    // divide by 255 but this is faster and is close enough. The alpha value
    // of this_argb is preserved because of the 1 bits in the minuend.
    let divided = _mm512_srli_epi16(blended, 8);

    // Convert back to 8-bit integers, discarding the high bits that are zero.
    let divided_u8 = _mm512_cvtepi16_epi8(divided);

    _mm256_storeu_epi8(this_ptr, divided_u8);
}

struct AVX512Constants {
    copy_alpha_to_rgb: __m512i,
    inv_alpha_minuend: __m512i,
}

const X: i8 = -1;

impl AVX512Constants {
    fn new() -> Self {
        unsafe {
            Self {
                copy_alpha_to_rgb: _mm512_set_epi8(
                  X, 56, X, 56, X, 56, X, X, X, 48, X, 48, X, 48, X, X,
                  X, 40, X, 40, X, 40, X, X, X, 32, X, 32, X, 32, X, X,
                  X, 24, X, 24, X, 24, X, X, X, 16, X, 16, X, 16, X, X,
                  X, 8,  X, 8,  X, 8,  X, X, X, 0,  X, 0,  X, 0,  X, X, // right to left
                                                            //    v  v
                                                            // high  low
                ),
                inv_alpha_minuend: _mm512_set_epi8(
                    0, -1, 0, -1, 0, -1, 1, 0, 0, -1, 0, -1, 0, -1, 1, 0,
                    0, -1, 0, -1, 0, -1, 1, 0, 0, -1, 0, -1, 0, -1, 1, 0,
                    0, -1, 0, -1, 0, -1, 1, 0, 0, -1, 0, -1, 0, -1, 1, 0,
                    0, -1, 0, -1, 0, -1, 1, 0, 0, -1, 0, -1, 0, -1, 1, 0, // right to left
                                                              //    v  v
                                                              // high  low
                ),
            }
        }
    }
}