Multiplying together two 32bit numbers gives a 64bit result. SSE2 supports an unsigned 32bit multiply yielding a 64bit unsigned result. But this isn't what we normally need for dot products.
The instruction _mm_mul_epi32 (a signed 32bit multiply yielding a 64bit signed result) was not supported until SSE41. While today SSE41 is pretty standard, this may not be the case on virtual machines. So there is value in understanding how to do this with only SSE2.
We can generate the signed 64bit multiply result from the unsigned 64bit multiply result. The procedure is detailed in Hacker's Delight 2nd edition in the chapter "Multiplication" in section "8-3 High-Order Product Signed from/to Unsigned".
Overall the four signed 32bit*32bit=64bit multiplies can be performed as follows:
__forceinline void _mm_mul_epi32_SSE2(const __m128i &x, const __m128i &y, __m128i &mXY02, __m128i &mXY13)
{
__m128i mxy02u = _mm_mul_epu32(x, y);
__m128i mxy13u = _mm_mul_epu32(_mm_srli_epi64(x,32), _mm_srli_epi64(y,32));
__m128i mt1 = _mm_and_si128(y, _mm_srai_epi32(x, 31));
__m128i mt2 = _mm_and_si128(x, _mm_srai_epi32(y, 31));
__m128i mt1Pt2 = _mm_add_epi32(mt1, mt2);
mXY02 = _mm_sub_epi32(mxy02u, _mm_unpacklo_epi32(_mm_setzero_si128(), mt1Pt2));
mXY13 = _mm_sub_epi32(mxy13u, _mm_unpackhi_epi32(_mm_setzero_si128(), mt1Pt2));
}
Please do note that the results end up in a non-standard order that doesn't matter for the dot product:
The first __m128i returns
__int64 xy00Int = ((__int64)x.m128i_i32[0])*((__int64)y.m128i_i32[0]);
__int64 xy22Int = ((__int64)x.m128i_i32[2])*((__int64)y.m128i_i32[2]);
while the second __m128i return
__int64 xy11Int = ((__int64)x.m128i_i32[1])*((__int64)y.m128i_i32[1]);
__int64 xy33Int = ((__int64)x.m128i_i32[3])*((__int64)y.m128i_i32[3]);
This is easy to correct, but since we are just going to be summing, we don't bother.
__forceinline __int64 dotProduct_int64result_Ansi(const int *pA, const int *pB, UINT cElements)
{
int iRet = 0;
UINT cElementsRemaining = cElements;
while (cElementsRemaining > 0)
{
cElementsRemaining--;
iRet += ((__int64)pA[cElementsRemaining]) * ((__int64)pB[cElementsRemaining]);
}
return iRet;
}
__int64 dotProduct_int64result_SSE2(const signed int *pA, const signed int *pB, UINT cElements)
{
UINT cElements_endOfFour = cElements&~3;
__int64 iRet = dotProduct_int64result_Ansi(pA + cElements_endOfFour, pB + cElements_endOfFour, cElements & 3);
UINT_PTR cElementsRemaining = cElements_endOfFour;
if (cElementsRemaining > 0)
{
__m128i mSummed = _mm_setzero_si128();
do
{
cElementsRemaining -= 4;
__m128i mA = _mm_loadu_si128((__m128i const*)&pA[cElementsRemaining]);
__m128i mB = _mm_loadu_si128((__m128i const*)&pB[cElementsRemaining]);
__m128i mulLeft, mulRight;
_mm_mul_epi32_SSE2(mA, mB, mulLeft, mulRight);
mSummed = _mm_add_epi64(mSummed, mulLeft);
mSummed = _mm_add_epi64(mSummed, mulRight);
} while (cElementsRemaining > 0);
iRet += horizontalSum_epi64_SSE2(mSummed);
}
return iRet;
}
It's ok. But on SSE2, it's only marginally better than the ANSI version. But it's instructive.
No comments:
Post a Comment