Friday, December 16, 2016

another dot product - signed short

There's another dot product... yes, yet another dot product.  It's almost as though the instruction set was designed around them.

The integer multiply instructions on SSE2 are a little sparse.  A full four lane 32bit*32bit=32bit multiply can be only be constructed by multiple _mm_mul_epu32 instructionsIt isn't until the SSE4.1 support of the instruction _mm_mullo_epi32 that this seemingly fundamental operation was directly supported. 

A single 32bit * 32bit = 64bit multiply can be composed of four 16bit*16bit = 32bit multiplies.  The _mm_mul_epu32 has two of these.  The equivalent of these eight multiplies are exposed in a couple different ways. 

The _mm_madd_epi16 instruction seems directly designed for dot product.  It is documented as "Multiply packed signed 16-bit integers in a and b, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers, and pack the results in dst."
This requires internally eight 16bit*16bit = 32bit multiplies.

The code for a dot product using this instruction is looking familiar:

int dotProduct_Ansi(const signed short *pA, const signed short *pB, UINT cElements)
{
    int iRet = 0;
    UINT cElementsRemaining = cElements;
    while (cElementsRemaining > 0)
    {
        cElementsRemaining--;
        iRet += ((int)pA[cElementsRemaining]) * ((int)pB[cElementsRemaining]);
    }
    return iRet;
}
__forceinline int horizontalSum_SSE2(const __m128i &mABCD)
{
    __m128i mCD = _mm_srli_si128(mABCD, 8);
    __m128i mApCBpD = _mm_add_epi32(mABCD, mCD);
    __m128i mBpD = _mm_srli_si128(mApCBpD, 4);
    __m128i mApBpCpD = _mm_add_epi32(mApCBpD, mBpD);
    return _mm_cvtsi128_si32(mApBpCpD);
}
int dotProduct_SSE2(const signed short *pA, const signed short *pB, UINT cElements)
{
    UINT cElements_endOfEight = cElements&~7;
    int iRet = dotProduct_Ansi(pA + cElements_endOfEight, pB + cElements_endOfEight, cElements & 7);
    UINT cElementsRemaining = cElements_endOfEight;
    if (cElementsRemaining > 0)
    {
        __m128i mSummedAB = _mm_setzero_si128();
        do
        {
            cElementsRemaining -= 8;
            __m128i mA = _mm_loadu_si128((__m128i const*)&pA[cElementsRemaining]);
            __m128i mB = _mm_loadu_si128((__m128i const*)&pB[cElementsRemaining]);
            mSummedAB = _mm_add_epi32(mSummedAB, _mm_madd_epi16(mA, mB));
        } while (cElementsRemaining > 0);
        iRet += horizontalSum_SSE2(mSummedAB);
    }
    return iRet;
}

16-bit dot products come up frequently with artificial neural networks.

No comments:

Post a Comment