There's another dot product... yes, yet another dot product. It's almost as though the instruction set was designed around them.
The integer multiply instructions on SSE2 are a little sparse. A full four lane 32bit*32bit=32bit multiply can be only be constructed by multiple _mm_mul_epu32 instructions. It isn't until the SSE4.1 support of the instruction _mm_mullo_epi32 that this seemingly fundamental operation was directly supported.
A single 32bit * 32bit = 64bit multiply can be composed of four 16bit*16bit = 32bit multiplies. The _mm_mul_epu32 has two of these. The equivalent of these eight multiplies are exposed in a couple different ways.
The _mm_madd_epi16 instruction seems directly designed for dot product. It is documented as "Multiply packed signed 16-bit integers in a and b, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers, and pack the results in dst."
This requires internally eight 16bit*16bit = 32bit multiplies.
The code for a dot product using this instruction is looking familiar:
int dotProduct_Ansi(const signed short *pA, const signed short *pB, UINT cElements)
{
int iRet = 0;
UINT cElementsRemaining = cElements;
while (cElementsRemaining > 0)
{
cElementsRemaining--;
iRet += ((int)pA[cElementsRemaining]) * ((int)pB[cElementsRemaining]);
}
return iRet;
}
__forceinline int horizontalSum_SSE2(const __m128i &mABCD)
{
__m128i mCD = _mm_srli_si128(mABCD, 8);
__m128i mApCBpD = _mm_add_epi32(mABCD, mCD);
__m128i mBpD = _mm_srli_si128(mApCBpD, 4);
__m128i mApBpCpD = _mm_add_epi32(mApCBpD, mBpD);
return _mm_cvtsi128_si32(mApBpCpD);
}
int dotProduct_SSE2(const signed short *pA, const signed short *pB, UINT cElements)
{
UINT cElements_endOfEight = cElements&~7;
int iRet = dotProduct_Ansi(pA + cElements_endOfEight, pB + cElements_endOfEight, cElements & 7);
UINT cElementsRemaining = cElements_endOfEight;
if (cElementsRemaining > 0)
{
__m128i mSummedAB = _mm_setzero_si128();
do
{
cElementsRemaining -= 8;
__m128i mA = _mm_loadu_si128((__m128i const*)&pA[cElementsRemaining]);
__m128i mB = _mm_loadu_si128((__m128i const*)&pB[cElementsRemaining]);
mSummedAB = _mm_add_epi32(mSummedAB, _mm_madd_epi16(mA, mB));
} while (cElementsRemaining > 0);
iRet += horizontalSum_SSE2(mSummedAB);
}
return iRet;
}
16-bit dot products come up frequently with artificial neural networks.
No comments:
Post a Comment