Monday, January 16, 2017

transposing a non-square float matrix with SSE

The transpose of a matrix can be useful.  During a matrix multiply a transpose of the right hand matrix will allow the result to be expressed as a series of dot products, with the inputs of the dot product sequential in memory.

But often the matrix will be non-square, and so cannot be transposed in place.  But let's look at transposing with a source and destination.

We will transpose in 4x4 blocks.  Momentarily ignoring what is evenly divisible by 4:
    for (size_t y = nHeight&~3; y >0;)
    {
        y -= 4;
        for (size_t x = nWidth&~3; x>0;)
        {
            x -= 4;
            __m128 row0;
            __m128 row1;
            __m128 row2;
            __m128 row3;
            row0 = _mm_loadu_ps(&pMatrixIn[(y + 0)*nWidth + x]);
            row1 = _mm_loadu_ps(&pMatrixIn[(y + 1)*nWidth + x]);
            row2 = _mm_loadu_ps(&pMatrixIn[(y + 2)*nWidth + x]);
            row3 = _mm_loadu_ps(&pMatrixIn[(y + 3)*nWidth + x]);
            _MM_TRANSPOSE4_PS(row0, row1, row2, row3);
            _mm_storeu_ps(&pMatrixOut[(x + 0)*nHeight + y], row0);
            _mm_storeu_ps(&pMatrixOut[(x + 1)*nHeight + y], row1);
            _mm_storeu_ps(&pMatrixOut[(x + 2)*nHeight + y], row2);
            _mm_storeu_ps(&pMatrixOut[(x + 3)*nHeight + y], row3);
        }
    }


This seems pretty tight...but that's a lot of multiplies.  We can factor out most of the multiplies so that it's just pointer arithmetic and adds.  On x64 the central loop can be accomplished without multiplies and without register spills.

Most of the matrix will be transposed by this set of loops. We will have to handle the rest in two other rectangular strips (making sure not to double-count the corner rectangle).

 void TransposeNonSquareMatrix_SSE(const float *const pMatrixIn, float *const pMatrixOut, const size_t nWidth, const size_t nHeight)
{
    for (size_t y = nHeight; y & 3;)
    {
        y--;
        for (size_t x = nWidth; x-- > 0;)
        {
            pMatrixOut[x*nHeight + y] = pMatrixIn[y*nWidth + x];
        }
    }
    if (nWidth & 3)
    {
        for (size_t y = nHeight&~3; y-- > 0;)
        {
            for (size_t x = nWidth; x & 3;)
            {
                x--;
                pMatrixOut[x*nHeight + y] = pMatrixIn[y*nWidth + x];
            }
        }
    }
    const size_t nHeightMax = nHeight&~3;
    const size_t nWidthMax = nWidth&~3;
    for (size_t y = nHeightMax; y >0;)
    {
        y -= 4;
        const float *const pIn0 = pMatrixIn + y*nWidth;
        const float *const pIn1 = pIn0 + nWidth;
        const float *const pIn2 = pIn1 + nWidth;
        const float * const pIn3 = pIn2 + nWidth;
        float * pOut = &pMatrixOut[nWidthMax*nHeight + y];
        for (size_t x = nWidthMax; x>0;)
        {
            x -= 4;
            __m128 row0;
            __m128 row1;
            __m128 row2;
            __m128 row3;
            row0 = _mm_loadu_ps(&pIn0[x]);
            row1 = _mm_loadu_ps(&pIn1[x]);
            row2 = _mm_loadu_ps(&pIn2[x]);
            row3 = _mm_loadu_ps(&pIn3[x]);
            _MM_TRANSPOSE4_PS(row0, row1, row2, row3);
            _mm_storeu_ps((pOut -= nHeight), row3);
            _mm_storeu_ps((pOut -= nHeight), row2);
            _mm_storeu_ps((pOut -= nHeight), row1);
            _mm_storeu_ps((pOut -= nHeight), row0);
        }
    }
}


The assembly compilation of the central loop is pretty tight:
         {
            x -= 4;
            __m128 row0;
            __m128 row1;
            __m128 row2;
            __m128 row3;
            row0 = _mm_loadu_ps(&pIn0[x]);
            row1 = _mm_loadu_ps(&pIn1[x]);
000000013F401210  movups      xmm0,xmmword ptr [r10-10h] 
            row2 = _mm_loadu_ps(&pIn2[x]);
            row3 = _mm_loadu_ps(&pIn3[x]);
            _MM_TRANSPOSE4_PS(row0, row1, row2, row3);
            _mm_storeu_ps((pOut -= nHeight), row3);
000000013F401215  sub         rdx,rbp 
000000013F401218  lea         r10,[r10-10h] 
000000013F40121C  movups      xmm4,xmmword ptr [r10+r15] 
000000013F401221  movups      xmm3,xmmword ptr [r10+r11] 
000000013F401226  movups      xmm1,xmmword ptr [r10+rcx] 
000000013F40122B  movaps      xmm5,xmm4 
000000013F40122E  movaps      xmm2,xmm3 
000000013F401231  shufps      xmm4,xmm0,0EEh 
000000013F401235  shufps      xmm5,xmm0,44h 
            row2 = _mm_loadu_ps(&pIn2[x]);
            row3 = _mm_loadu_ps(&pIn3[x]);
            _MM_TRANSPOSE4_PS(row0, row1, row2, row3);
            _mm_storeu_ps((pOut -= nHeight), row3);
000000013F401239  movaps      xmm0,xmm4 
000000013F40123C  shufps      xmm3,xmm1,0EEh 
000000013F401240  shufps      xmm0,xmm3,0DDh 
000000013F401244  movups      xmmword ptr [rdx],xmm0 
            _mm_storeu_ps((pOut -= nHeight), row2);
000000013F401247  sub         rdx,rbp 
000000013F40124A  shufps      xmm2,xmm1,44h 
000000013F40124E  movaps      xmm0,xmm5 
000000013F401251  shufps      xmm4,xmm3,88h 
000000013F401255  shufps      xmm0,xmm2,0DDh 
000000013F401259  movups      xmmword ptr [rdx],xmm4 
            _mm_storeu_ps((pOut -= nHeight), row1);
000000013F40125C  sub         rdx,rbp 
000000013F40125F  shufps      xmm5,xmm2,88h 
000000013F401263  movups      xmmword ptr [rdx],xmm0 
            _mm_storeu_ps((pOut -= nHeight), row0);
000000013F401266  sub         rdx,rbp 
000000013F401269  movups      xmmword ptr [rdx],xmm5 
000000013F40126C  sub         rax,1 
000000013F401270  jne         TransposeNonSquareMatrix_SSE+1A0h (013F401210h) 

The compiler has split the loop variable into two, rax and r10.  The final shuffles from the transpose macro are interleaved between the unaligned stores.

No comments:

Post a Comment