diff --git a/src/dev/cpu/op/conv/x86/conv_kernel_x86.c b/src/dev/cpu/op/conv/x86/conv_kernel_x86.c index 65fa80802..c4b6d9fb1 100644 --- a/src/dev/cpu/op/conv/x86/conv_kernel_x86.c +++ b/src/dev/cpu/op/conv/x86/conv_kernel_x86.c @@ -1036,10 +1036,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p int k = 0; for (; k + 3 < K; k = k + 4) { // k0 - __m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - __m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - __m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - __m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + __m256i _va0 = _mm256_set1_epi32(*va); + __m256i _va1 = _mm256_set1_epi32(*(va + 1)); + __m256i _va2 = _mm256_set1_epi32(*(va + 2)); + __m256i _va3 = _mm256_set1_epi32(*(va + 3)); __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb)); __m256i _vb1 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(vb + 8))); @@ -1051,10 +1051,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va1), _sum1); _sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va2), _sum2); _sum3 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va3), _sum3); - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4))); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7))); + _va0 = _mm256_set1_epi32(*(va + 4)); + _va1 = _mm256_set1_epi32(*(va + 5)); + _va2 = _mm256_set1_epi32(*(va + 6)); + _va3 = _mm256_set1_epi32(*(va + 7)); _sum4 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va0), _sum4); _sum5 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va1), _sum5); _sum6 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va2), _sum6); @@ -1063,18 +1063,18 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p va += 8; // k1 - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + _va0 = _mm256_set1_epi32(*va); + _va1 = _mm256_set1_epi32(*(va + 1)); + _va2 = _mm256_set1_epi32(*(va + 2)); + _va3 = _mm256_set1_epi32(*(va + 3)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va0), _sum0); _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va1), _sum1); _sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va2), _sum2); _sum3 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va3), _sum3); - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4))); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7))); + _va0 = _mm256_set1_epi32(*(va + 4)); + _va1 = _mm256_set1_epi32(*(va + 5)); + _va2 = _mm256_set1_epi32(*(va + 6)); + _va3 = _mm256_set1_epi32(*(va + 7)); _sum4 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va0), _sum4); _sum5 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va1), _sum5); _sum6 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va2), _sum6); @@ -1083,18 +1083,18 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p va += 8; // k2 - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + _va0 = _mm256_set1_epi32(*va); + _va1 = _mm256_set1_epi32(*(va + 1)); + _va2 = _mm256_set1_epi32(*(va + 2)); + _va3 = _mm256_set1_epi32(*(va + 3)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va0), _sum0); _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va1), _sum1); _sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va2), _sum2); _sum3 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va3), _sum3); - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4))); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7))); + _va0 = _mm256_set1_epi32(*(va + 4)); + _va1 = _mm256_set1_epi32(*(va + 5)); + _va2 = _mm256_set1_epi32(*(va + 6)); + _va3 = _mm256_set1_epi32(*(va + 7)); _sum4 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va0), _sum4); _sum5 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va1), _sum5); _sum6 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va2), _sum6); @@ -1103,18 +1103,18 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p va += 8; // k3 - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + _va0 = _mm256_set1_epi32(*va); + _va1 = _mm256_set1_epi32(*(va + 1)); + _va2 = _mm256_set1_epi32(*(va + 2)); + _va3 = _mm256_set1_epi32(*(va + 3)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va0), _sum0); _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va1), _sum1); _sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va2), _sum2); _sum3 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va3), _sum3); - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4))); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7))); + _va0 = _mm256_set1_epi32(*(va + 4)); + _va1 = _mm256_set1_epi32(*(va + 5)); + _va2 = _mm256_set1_epi32(*(va + 6)); + _va3 = _mm256_set1_epi32(*(va + 7)); _sum4 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va0), _sum4); _sum5 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va1), _sum5); _sum6 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va2), _sum6); @@ -1124,14 +1124,14 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p vb += 32; } for (; k < K; k++) { - __m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - __m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - __m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - __m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); - __m256i _va4 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4))); - __m256i _va5 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5))); - __m256i _va6 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6))); - __m256i _va7 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7))); + __m256i _va0 = _mm256_set1_epi32(*va); + __m256i _va1 = _mm256_set1_epi32(*(va + 1)); + __m256i _va2 = _mm256_set1_epi32(*(va + 2)); + __m256i _va3 = _mm256_set1_epi32(*(va + 3)); + __m256i _va4 = _mm256_set1_epi32(*(va + 4)); + __m256i _va5 = _mm256_set1_epi32(*(va + 5)); + __m256i _va6 = _mm256_set1_epi32(*(va + 6)); + __m256i _va7 = _mm256_set1_epi32(*(va + 7)); __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va0), _sum0); _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va1), _sum1); @@ -1219,10 +1219,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p int k = 0; for (; k + 3 < K; k = k + 4) { - __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*vb)); - __m256i _vb1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 1))); - __m256i _vb2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 2))); - __m256i _vb3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 3))); + __m256i _vb0 = _mm256_set1_epi32(*vb); + __m256i _vb1 = _mm256_set1_epi32(*(vb + 1)); + __m256i _vb2 = _mm256_set1_epi32(*(vb + 2)); + __m256i _vb3 = _mm256_set1_epi32(*(vb + 3)); __m256i _va0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)va)); __m256i _va1 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(va + 8))); @@ -1246,7 +1246,7 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p _sum0_7 = _mm256_add_epi32(_sum0_7, _sum2); for (; k < K; k++) { - __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*vb)); + __m256i _vb0 = _mm256_set1_epi32(*vb); __m256i _va = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)va)); _sum0_7 = _mm256_add_epi32(_mm256_mullo_epi32(_va, _vb0), _sum0_7); @@ -1335,10 +1335,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p int k = 0; for (; k + 3 < K; k = K + 4) { // k0 - __m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - __m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - __m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - __m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + __m256i _va0 = _mm256_set1_epi32(*va); + __m256i _va1 = _mm256_set1_epi32(*(va + 1)); + __m256i _va2 = _mm256_set1_epi32(*(va + 2)); + __m256i _va3 = _mm256_set1_epi32(*(va + 3)); __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb)); __m256i _vb1 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(vb + 8))); @@ -1354,10 +1354,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p va += 4; // k1 - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + _va0 = _mm256_set1_epi32(*va); + _va1 = _mm256_set1_epi32(*(va + 1)); + _va2 = _mm256_set1_epi32(*(va + 2)); + _va3 = _mm256_set1_epi32(*(va + 3)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va0), _sum0); _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va1), _sum1); _sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va2), _sum2); @@ -1366,10 +1366,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p va += 4; // k2 - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + _va0 = _mm256_set1_epi32(*va); + _va1 = _mm256_set1_epi32(*(va + 1)); + _va2 = _mm256_set1_epi32(*(va + 2)); + _va3 = _mm256_set1_epi32(*(va + 3)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va0), _sum0); _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va1), _sum1); _sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va2), _sum2); @@ -1378,10 +1378,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p va += 4; // k3 - _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + _va0 = _mm256_set1_epi32(*va); + _va1 = _mm256_set1_epi32(*(va + 1)); + _va2 = _mm256_set1_epi32(*(va + 2)); + _va3 = _mm256_set1_epi32(*(va + 3)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va0), _sum0); _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va1), _sum1); _sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va2), _sum2); @@ -1392,10 +1392,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p } for (; k < K; k++) { - __m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - __m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - __m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - __m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + __m256i _va0 = _mm256_set1_epi32(*va); + __m256i _va1 = _mm256_set1_epi32(*(va + 1)); + __m256i _va2 = _mm256_set1_epi32(*(va + 2)); + __m256i _va3 = _mm256_set1_epi32(*(va + 3)); __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va0), _sum0); _sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va1), _sum1); @@ -1458,10 +1458,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p int k=0; for (; k + 3 < K; k = k + 4) { - __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*vb)); - __m256i _vb1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 1))); - __m256i _vb2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 2))); - __m256i _vb3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 3))); + __m256i _vb0 = _mm256_set1_epi32(*vb); + __m256i _vb1 = _mm256_set1_epi32(*(vb + 1)); + __m256i _vb2 = _mm256_set1_epi32(*(vb + 2)); + __m256i _vb3 = _mm256_set1_epi32(*(vb + 3)); __m256i _va0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)va)); __m256i _va1 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(va + 4))); __m256i _va2 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(va + 8))); @@ -1483,7 +1483,7 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p for (; k < K; k++) { - __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*vb)); + __m256i _vb0 = _mm256_set1_epi32(*vb); __m256i _va = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)va)); _sum0_3 = _mm256_add_epi32(_mm256_mullo_epi32(_va, _vb0), _sum0_3); @@ -1543,10 +1543,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p int k = 0; for (; k + 3 < K; k = k + 4) { - __m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); - __m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1))); - __m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2))); - __m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3))); + __m256i _va0 = _mm256_set1_epi32(*va); + __m256i _va1 = _mm256_set1_epi32(*(va + 1)); + __m256i _va2 = _mm256_set1_epi32(*(va + 2)); + __m256i _va3 = _mm256_set1_epi32(*(va + 3)); __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb)); __m256i _vb1 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(vb + 8))); @@ -1565,7 +1565,7 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p } for (; k < K; k++) { - __m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va)); + __m256i _va0 = _mm256_set1_epi32(*va); __m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb)); _sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va0), _sum0);