Skip to content

Commit

Permalink
remove cvtepi8232 out of set (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccw1996 authored Dec 6, 2020
1 parent d04481f commit 3fb5dd6
Showing 1 changed file with 75 additions and 75 deletions.
150 changes: 75 additions & 75 deletions src/dev/cpu/op/conv/x86/conv_kernel_x86.c
Original file line number Diff line number Diff line change
Expand Up @@ -1036,10 +1036,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
int k = 0;
for (; k + 3 < K; k = k + 4) {
// k0
__m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
__m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
__m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
__m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
__m256i _va0 = _mm256_set1_epi32(*va);
__m256i _va1 = _mm256_set1_epi32(*(va + 1));
__m256i _va2 = _mm256_set1_epi32(*(va + 2));
__m256i _va3 = _mm256_set1_epi32(*(va + 3));
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb));
__m256i _vb1 =
_mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(vb + 8)));
Expand All @@ -1051,10 +1051,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va1), _sum1);
_sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va2), _sum2);
_sum3 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va3), _sum3);
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4)));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7)));
_va0 = _mm256_set1_epi32(*(va + 4));
_va1 = _mm256_set1_epi32(*(va + 5));
_va2 = _mm256_set1_epi32(*(va + 6));
_va3 = _mm256_set1_epi32(*(va + 7));
_sum4 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va0), _sum4);
_sum5 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va1), _sum5);
_sum6 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va2), _sum6);
Expand All @@ -1063,18 +1063,18 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
va += 8;

// k1
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
_va0 = _mm256_set1_epi32(*va);
_va1 = _mm256_set1_epi32(*(va + 1));
_va2 = _mm256_set1_epi32(*(va + 2));
_va3 = _mm256_set1_epi32(*(va + 3));
_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va0), _sum0);
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va1), _sum1);
_sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va2), _sum2);
_sum3 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va3), _sum3);
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4)));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7)));
_va0 = _mm256_set1_epi32(*(va + 4));
_va1 = _mm256_set1_epi32(*(va + 5));
_va2 = _mm256_set1_epi32(*(va + 6));
_va3 = _mm256_set1_epi32(*(va + 7));
_sum4 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va0), _sum4);
_sum5 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va1), _sum5);
_sum6 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va2), _sum6);
Expand All @@ -1083,18 +1083,18 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
va += 8;

// k2
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
_va0 = _mm256_set1_epi32(*va);
_va1 = _mm256_set1_epi32(*(va + 1));
_va2 = _mm256_set1_epi32(*(va + 2));
_va3 = _mm256_set1_epi32(*(va + 3));
_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va0), _sum0);
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va1), _sum1);
_sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va2), _sum2);
_sum3 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va3), _sum3);
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4)));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7)));
_va0 = _mm256_set1_epi32(*(va + 4));
_va1 = _mm256_set1_epi32(*(va + 5));
_va2 = _mm256_set1_epi32(*(va + 6));
_va3 = _mm256_set1_epi32(*(va + 7));
_sum4 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va0), _sum4);
_sum5 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va1), _sum5);
_sum6 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va2), _sum6);
Expand All @@ -1103,18 +1103,18 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
va += 8;

// k3
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
_va0 = _mm256_set1_epi32(*va);
_va1 = _mm256_set1_epi32(*(va + 1));
_va2 = _mm256_set1_epi32(*(va + 2));
_va3 = _mm256_set1_epi32(*(va + 3));
_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va0), _sum0);
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va1), _sum1);
_sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va2), _sum2);
_sum3 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va3), _sum3);
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4)));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7)));
_va0 = _mm256_set1_epi32(*(va + 4));
_va1 = _mm256_set1_epi32(*(va + 5));
_va2 = _mm256_set1_epi32(*(va + 6));
_va3 = _mm256_set1_epi32(*(va + 7));
_sum4 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va0), _sum4);
_sum5 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va1), _sum5);
_sum6 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va2), _sum6);
Expand All @@ -1124,14 +1124,14 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
vb += 32;
}
for (; k < K; k++) {
__m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
__m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
__m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
__m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
__m256i _va4 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 4)));
__m256i _va5 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 5)));
__m256i _va6 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 6)));
__m256i _va7 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 7)));
__m256i _va0 = _mm256_set1_epi32(*va);
__m256i _va1 = _mm256_set1_epi32(*(va + 1));
__m256i _va2 = _mm256_set1_epi32(*(va + 2));
__m256i _va3 = _mm256_set1_epi32(*(va + 3));
__m256i _va4 = _mm256_set1_epi32(*(va + 4));
__m256i _va5 = _mm256_set1_epi32(*(va + 5));
__m256i _va6 = _mm256_set1_epi32(*(va + 6));
__m256i _va7 = _mm256_set1_epi32(*(va + 7));
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb));
_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va0), _sum0);
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va1), _sum1);
Expand Down Expand Up @@ -1219,10 +1219,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p

int k = 0;
for (; k + 3 < K; k = k + 4) {
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*vb));
__m256i _vb1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 1)));
__m256i _vb2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 2)));
__m256i _vb3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 3)));
__m256i _vb0 = _mm256_set1_epi32(*vb);
__m256i _vb1 = _mm256_set1_epi32(*(vb + 1));
__m256i _vb2 = _mm256_set1_epi32(*(vb + 2));
__m256i _vb3 = _mm256_set1_epi32(*(vb + 3));
__m256i _va0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)va));
__m256i _va1 =
_mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(va + 8)));
Expand All @@ -1246,7 +1246,7 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
_sum0_7 = _mm256_add_epi32(_sum0_7, _sum2);

for (; k < K; k++) {
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*vb));
__m256i _vb0 = _mm256_set1_epi32(*vb);
__m256i _va = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)va));

_sum0_7 = _mm256_add_epi32(_mm256_mullo_epi32(_va, _vb0), _sum0_7);
Expand Down Expand Up @@ -1335,10 +1335,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
int k = 0;
for (; k + 3 < K; k = K + 4) {
// k0
__m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
__m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
__m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
__m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
__m256i _va0 = _mm256_set1_epi32(*va);
__m256i _va1 = _mm256_set1_epi32(*(va + 1));
__m256i _va2 = _mm256_set1_epi32(*(va + 2));
__m256i _va3 = _mm256_set1_epi32(*(va + 3));
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb));
__m256i _vb1 =
_mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(vb + 8)));
Expand All @@ -1354,10 +1354,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
va += 4;

// k1
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
_va0 = _mm256_set1_epi32(*va);
_va1 = _mm256_set1_epi32(*(va + 1));
_va2 = _mm256_set1_epi32(*(va + 2));
_va3 = _mm256_set1_epi32(*(va + 3));
_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va0), _sum0);
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va1), _sum1);
_sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb1, _va2), _sum2);
Expand All @@ -1366,10 +1366,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
va += 4;

// k2
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
_va0 = _mm256_set1_epi32(*va);
_va1 = _mm256_set1_epi32(*(va + 1));
_va2 = _mm256_set1_epi32(*(va + 2));
_va3 = _mm256_set1_epi32(*(va + 3));
_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va0), _sum0);
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va1), _sum1);
_sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb2, _va2), _sum2);
Expand All @@ -1378,10 +1378,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
va += 4;

// k3
_va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
_va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
_va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
_va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
_va0 = _mm256_set1_epi32(*va);
_va1 = _mm256_set1_epi32(*(va + 1));
_va2 = _mm256_set1_epi32(*(va + 2));
_va3 = _mm256_set1_epi32(*(va + 3));
_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va0), _sum0);
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va1), _sum1);
_sum2 = _mm256_add_epi32(_mm256_mullo_epi32(_vb3, _va2), _sum2);
Expand All @@ -1392,10 +1392,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
}

for (; k < K; k++) {
__m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
__m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
__m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
__m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
__m256i _va0 = _mm256_set1_epi32(*va);
__m256i _va1 = _mm256_set1_epi32(*(va + 1));
__m256i _va2 = _mm256_set1_epi32(*(va + 2));
__m256i _va3 = _mm256_set1_epi32(*(va + 3));
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb));
_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va0), _sum0);
_sum1 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va1), _sum1);
Expand Down Expand Up @@ -1458,10 +1458,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
int k=0;
for (; k + 3 < K; k = k + 4)
{
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*vb));
__m256i _vb1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 1)));
__m256i _vb2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 2)));
__m256i _vb3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(vb + 3)));
__m256i _vb0 = _mm256_set1_epi32(*vb);
__m256i _vb1 = _mm256_set1_epi32(*(vb + 1));
__m256i _vb2 = _mm256_set1_epi32(*(vb + 2));
__m256i _vb3 = _mm256_set1_epi32(*(vb + 3));
__m256i _va0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)va));
__m256i _va1 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(va + 4)));
__m256i _va2 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(va + 8)));
Expand All @@ -1483,7 +1483,7 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p

for (; k < K; k++)
{
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*vb));
__m256i _vb0 = _mm256_set1_epi32(*vb);
__m256i _va = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)va));

_sum0_3 = _mm256_add_epi32(_mm256_mullo_epi32(_va, _vb0), _sum0_3);
Expand Down Expand Up @@ -1543,10 +1543,10 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p

int k = 0;
for (; k + 3 < K; k = k + 4) {
__m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
__m256i _va1 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 1)));
__m256i _va2 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 2)));
__m256i _va3 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*(va + 3)));
__m256i _va0 = _mm256_set1_epi32(*va);
__m256i _va1 = _mm256_set1_epi32(*(va + 1));
__m256i _va2 = _mm256_set1_epi32(*(va + 2));
__m256i _va3 = _mm256_set1_epi32(*(va + 3));
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb));
__m256i _vb1 =
_mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)(vb + 8)));
Expand All @@ -1565,7 +1565,7 @@ static void sgemm_i8(int M, int N, int K, int8_t* pA_t, int8_t* pB_t, int32_t* p
}

for (; k < K; k++) {
__m256i _va0 = _mm256_cvtepi8_epi32(_mm_set1_epi8(*va));
__m256i _va0 = _mm256_set1_epi32(*va);
__m256i _vb0 = _mm256_cvtepi8_epi32(_mm_loadu_si128((__m128i*)vb));

_sum0 = _mm256_add_epi32(_mm256_mullo_epi32(_vb0, _va0), _sum0);
Expand Down

0 comments on commit 3fb5dd6

Please sign in to comment.