Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loongarch 256-bit LASX SIMD optimization #1458

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ if (APPLE)
option(OPJ_USE_DSYMUTIL "Call dsymutil on binaries after build." OFF)
endif()

if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "loongarch64")
# Add -mlsx -mlasx flags to use simd optimizations on loongarch platform.
# Add -fno-expensive-optimizations option to avoid below three unit tests failure on Release version:
# (NR-C1P1-p1_05.j2k-compare2base (Failed), NR-JP2-file2.jp2-compare2base (Failed), NR-DEC-issue205.jp2-253-decode-md5 (Failed))
list (APPEND CMAKE_C_FLAGS ${CMAKE_C_FLAGS} " -mlsx -mlasx -fno-expensive-optimizations" )
endif()

#-----------------------------------------------------------------------------
# Big endian test:
if (NOT EMSCRIPTEN)
Expand Down
162 changes: 147 additions & 15 deletions src/lib/openjp2/dwt.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
#ifdef __AVX2__
#include <immintrin.h>
#endif
#ifdef __loongarch_sx
#include <lsxintrin.h>
#endif
#ifdef __loongarch_asx
#include <lasxintrin.h>
#endif

#if defined(__GNUC__)
#pragma GCC poison malloc calloc realloc free
Expand All @@ -66,7 +72,7 @@
#define OPJ_WS(i) v->mem[(i)*2]
#define OPJ_WD(i) v->mem[(1+(i)*2)]

#ifdef __AVX2__
#if defined(__AVX2__) || defined(__loongarch_asx)
/** Number of int32 values in a AVX2 register */
#define VREG_INT_COUNT 8
#else
Expand Down Expand Up @@ -511,7 +517,7 @@ static void opj_idwt53_h(const opj_dwt_t *dwt,
#endif
}

#if (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION)
#if (defined(__SSE2__) || defined(__AVX2__) || defined(__loongarch_asx)) && !defined(STANDARD_SLOW_VERSION)

/* Conveniency macros to improve the readability of the formulas */
#if __AVX2__
Expand All @@ -524,6 +530,16 @@ static void opj_idwt53_h(const opj_dwt_t *dwt,
#define ADD(x,y) _mm256_add_epi32((x),(y))
#define SUB(x,y) _mm256_sub_epi32((x),(y))
#define SAR(x,y) _mm256_srai_epi32((x),(y))
#elif __loongarch_asx
#define VREG __m256i
#define LOAD_CST(x) __lasx_xvreplgr2vr_w(x)
#define LOAD(x) __lasx_xvld((const VREG*)(x), 0)
#define LOADU(x) __lasx_xvld((const VREG*)(x), 0)
#define STORE(x,y) __lasx_xvst((y), (VREG*)(x), 0)
#define STOREU(x,y) __lasx_xvst((y), (VREG*)(x), 0)
#define ADD(x,y) __lasx_xvadd_w((x),(y))
#define SUB(x,y) __lasx_xvsub_w((x),(y))
#define SAR(x,y) __lasx_xvsrai_w((x),(y))
#else
#define VREG __m128i
#define LOAD_CST(x) _mm_set1_epi32(x)
Expand Down Expand Up @@ -558,8 +574,8 @@ void opj_idwt53_v_final_memcpy(OPJ_INT32* tiledp_col,
}

/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or
* 16 in AVX2, when top-most pixel is on even coordinate */
static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(
* 16 in AVX2, or 16 in LASX, when top-most pixel is on even coordinate */
static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2_OR_LASX(
OPJ_INT32* tmp,
const OPJ_INT32 sn,
const OPJ_INT32 len,
Expand All @@ -576,7 +592,7 @@ static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(
const VREG two = LOAD_CST(2);

assert(len > 1);
#if __AVX2__
#if defined(__AVX2__) || defined(__loongarch_asx)
assert(PARALLEL_COLS_53 == 16);
assert(VREG_INT_COUNT == 8);
#else
Expand Down Expand Up @@ -659,8 +675,8 @@ static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(


/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or
* 16 in AVX2, when top-most pixel is on odd coordinate */
static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(
* 16 in AVX2, or 16 in LASX, when top-most pixel is on odd coordinate */
static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2_OR_LASX(
OPJ_INT32* tmp,
const OPJ_INT32 sn,
const OPJ_INT32 len,
Expand All @@ -678,7 +694,7 @@ static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(
const OPJ_INT32* in_odd = &tiledp_col[0];

assert(len > 2);
#if __AVX2__
#if defined(__AVX2__) || defined(__loongarch_asx)
assert(PARALLEL_COLS_53 == 16);
assert(VREG_INT_COUNT == 8);
#else
Expand Down Expand Up @@ -767,7 +783,7 @@ static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(
#undef SUB
#undef SAR

#endif /* (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION) */
#endif /* (defined(__SSE2__) || defined(__AVX2__) || defined(__loongarch_asx)) && !defined(STANDARD_SLOW_VERSION) */

#if !defined(STANDARD_SLOW_VERSION)
/** Vertical inverse 5x3 wavelet transform for one column, when top-most
Expand Down Expand Up @@ -894,11 +910,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
if (dwt->cas == 0) {
/* If len == 1, unmodified value */

#if (defined(__SSE2__) || defined(__AVX2__))
#if (defined(__SSE2__) || defined(__AVX2__) || defined(__loongarch_asx))
if (len > 1 && nb_cols == PARALLEL_COLS_53) {
/* Same as below general case, except that thanks to SSE2/AVX2 */
/* we can efficiently process 8/16 columns in parallel */
opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride);
opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2_OR_LASX(dwt->mem, sn, len, tiledp_col, stride);
return;
}
#endif
Expand Down Expand Up @@ -937,11 +953,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
return;
}

#if (defined(__SSE2__) || defined(__AVX2__))
#if (defined(__SSE2__) || defined(__AVX2__) || defined(__loongarch_asx))
if (len > 2 && nb_cols == PARALLEL_COLS_53) {
/* Same as below general case, except that thanks to SSE2/AVX2 */
/* we can efficiently process 8/16 columns in parallel */
opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride);
opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2_OR_LASX(dwt->mem, sn, len, tiledp_col, stride);
return;
}
#endif
Expand Down Expand Up @@ -2428,6 +2444,24 @@ static void opj_dwt_decode_partial_1_parallel(OPJ_INT32 *a,
}
}
#endif
#ifdef __loongarch_sx
if (i + 1 < i_max) {
const __m128i two = __lsx_vreplgr2vr_w(2);
__m128i Dm1 = __lsx_vld((__m128i * const)(a + 4 + (i - 1) * 8), 0);
for (; i + 1 < i_max; i += 2) {
/* No bound checking */
__m128i S = __lsx_vld((__m128i * const)(a + i * 8), 0);
__m128i D = __lsx_vld((__m128i * const)(a + 4 + i * 8), 0);
__m128i S1 = __lsx_vld((__m128i * const)(a + (i + 1) * 8), 0);
__m128i D1 = __lsx_vld((__m128i * const)(a + 4 + (i + 1) * 8), 0);
S = __lsx_vsub_w(S, __lsx_vsrai_w(__lsx_vadd_w(__lsx_vadd_w(Dm1, D), two), 2));
S1 = __lsx_vsub_w(S1, __lsx_vsrai_w(__lsx_vadd_w(__lsx_vadd_w(D, D1), two), 2));
__lsx_vst(S, (__m128i*)(a + i * 8), 0);
__lsx_vst(S1, (__m128i*)(a + (i + 1) * 8), 0);
Dm1 = D1;
}
}
#endif

for (; i < i_max; i++) {
/* No bound checking */
Expand Down Expand Up @@ -2467,6 +2501,23 @@ static void opj_dwt_decode_partial_1_parallel(OPJ_INT32 *a,
}
}
#endif
#ifdef __loongarch_sx
if (i + 1 < i_max) {
__m128i S = __lsx_vld((__m128i * const)(a + i * 8), 0);
for (; i + 1 < i_max; i += 2) {
/* No bound checking */
__m128i D = __lsx_vld((__m128i * const)(a + 4 + i * 8), 0);
__m128i S1 = __lsx_vld((__m128i * const)(a + (i + 1) * 8), 0);
__m128i D1 = __lsx_vld((__m128i * const)(a + 4 + (i + 1) * 8), 0);
__m128i S2 = __lsx_vld((__m128i * const)(a + (i + 2) * 8), 0);
D = __lsx_vadd_w(D, __lsx_vsrai_w(__lsx_vadd_w(S, S1), 1));
D1 = __lsx_vadd_w(D1, __lsx_vsrai_w(__lsx_vadd_w(S1, S2), 1));
__lsx_vst(D, (__m128i*)(a + 4 + i * 8), 0);
__lsx_vst(D1, (__m128i*)(a + 4 + (i + 1) * 8), 0);
S = S2;
}
}
#endif

for (; i < i_max; i++) {
/* No bound checking */
Expand Down Expand Up @@ -3015,7 +3066,7 @@ static void opj_v8dwt_interleave_partial_v(opj_v8dwt_t* OPJ_RESTRICT dwt,
OPJ_UNUSED(ret);
}

#ifdef __SSE__
#if defined(__SSE__)

static void opj_v8dwt_decode_step1_sse(opj_v8_t* w,
OPJ_UINT32 start,
Expand Down Expand Up @@ -3070,6 +3121,66 @@ static void opj_v8dwt_decode_step2_sse(opj_v8_t* l, opj_v8_t* w,
}
}

#elif defined(__loongarch_asx)

static void opj_v8dwt_decode_step1_lasx(opj_v8_t* w,
OPJ_UINT32 start,
OPJ_UINT32 end,
const OPJ_FLOAT32 c)
{
OPJ_UINT32 i;
OPJ_FLOAT32* OPJ_RESTRICT fw = (OPJ_FLOAT32*) w;
__m256 vfw, vmul, vc;
/* To be adapted if NB_ELTS_V8 changes */
fw += start * 16;
vc = (__m256)__lasx_xvldrepl_w(&c, 0);
for (i = start; i < end; ++i) {
vfw = (__m256)__lasx_xvld(fw, 0);
vmul = __lasx_xvfmul_s(vfw, vc);
__lasx_xvst(vmul, fw, 0);
fw += 16;
}
}

static void opj_v8dwt_decode_step2_lasx(opj_v8_t* l, opj_v8_t* w,
OPJ_UINT32 start,
OPJ_UINT32 end,
OPJ_UINT32 m,
OPJ_FLOAT32 c)
{
OPJ_UINT32 i;
OPJ_FLOAT32* fl = (OPJ_FLOAT32*) l;
OPJ_FLOAT32* fw = (OPJ_FLOAT32*) w;
OPJ_UINT32 imax = opj_uint_min(end, m);
__m256 vfl0, vfw0, vfw1, vret;
__m256 vc = (__m256)__lasx_xvldrepl_w(&c, 0);
if (start > 0) {
fw += 2 * NB_ELTS_V8 * start;
fl = fw - 2 * NB_ELTS_V8;
}
/* To be adapted if NB_ELTS_V8 changes */
for (i = start; i < imax; ++i) {
vfl0 = (__m256)__lasx_xvld(fl, 0);
vfw0 = (__m256)__lasx_xvld(fw, 0);
vfw1 = (__m256)__lasx_xvld(fw, -32);
vret = __lasx_xvfadd_s(vfl0, vfw0);
vret = __lasx_xvfmul_s(vret, vc);
vret = __lasx_xvfadd_s(vret, vfw1);
__lasx_xvst(vret, fw, -32);
fl = fw;
fw += 2 * NB_ELTS_V8;
}
if (m < end) {
assert(m + 1 == end);
vc = __lasx_xvfadd_s(vc, vc);
vfl0 = (__m256)__lasx_xvld(fl, 0);
vfw1 = (__m256)__lasx_xvld(fw, -32);
vret = __lasx_xvfmul_s(vfl0, vc);
vret = __lasx_xvfadd_s(vret, vfw1);
__lasx_xvst(vret, fw, -32);
}
}

#else

static void opj_v8dwt_decode_step1(opj_v8_t* w,
Expand Down Expand Up @@ -3162,7 +3273,7 @@ static void opj_v8dwt_decode(opj_v8dwt_t* OPJ_RESTRICT dwt)
a = 1;
b = 0;
}
#ifdef __SSE__
#if defined(__SSE__)
opj_v8dwt_decode_step1_sse(dwt->wavelet + a, dwt->win_l_x0, dwt->win_l_x1,
_mm_set1_ps(opj_K));
opj_v8dwt_decode_step1_sse(dwt->wavelet + b, dwt->win_h_x0, dwt->win_h_x1,
Expand All @@ -3183,6 +3294,27 @@ static void opj_v8dwt_decode(opj_v8dwt_t* OPJ_RESTRICT dwt)
dwt->win_h_x0, dwt->win_h_x1,
(OPJ_UINT32)opj_int_min(dwt->dn, dwt->sn - b),
_mm_set1_ps(-opj_dwt_alpha));
#elif defined(__loongarch_asx)
opj_v8dwt_decode_step1_lasx(dwt->wavelet + a, dwt->win_l_x0, dwt->win_l_x1,
opj_K);
opj_v8dwt_decode_step1_lasx(dwt->wavelet + b, dwt->win_h_x0, dwt->win_h_x1,
two_invK);
opj_v8dwt_decode_step2_lasx(dwt->wavelet + b, dwt->wavelet + a + 1,
dwt->win_l_x0, dwt->win_l_x1,
(OPJ_UINT32)opj_int_min(dwt->sn, dwt->dn - a),
-opj_dwt_delta);
opj_v8dwt_decode_step2_lasx(dwt->wavelet + a, dwt->wavelet + b + 1,
dwt->win_h_x0, dwt->win_h_x1,
(OPJ_UINT32)opj_int_min(dwt->dn, dwt->sn - b),
-opj_dwt_gamma);
opj_v8dwt_decode_step2_lasx(dwt->wavelet + b, dwt->wavelet + a + 1,
dwt->win_l_x0, dwt->win_l_x1,
(OPJ_UINT32)opj_int_min(dwt->sn, dwt->dn - a),
-opj_dwt_beta);
opj_v8dwt_decode_step2_lasx(dwt->wavelet + a, dwt->wavelet + b + 1,
dwt->win_h_x0, dwt->win_h_x1,
(OPJ_UINT32)opj_int_min(dwt->dn, dwt->sn - b),
-opj_dwt_alpha);
#else
opj_v8dwt_decode_step1(dwt->wavelet + a, dwt->win_l_x0, dwt->win_l_x1,
opj_K);
Expand Down
57 changes: 57 additions & 0 deletions src/lib/openjp2/t1.c
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,34 @@ static void opj_t1_clbl_decode_processor(void* user_data, opj_tls_t* tls)
}
}
#endif

#ifdef __loongarch_asx
{
asm volatile
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using "volatile" is generally a bad idea (unless there is a side-effect not explainable by the input/output/clobber). If it does not work w/o "volatile", you've likely missed some register input/output/clobber.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

affect performance or something else ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because "if it's a chicken, don't model it as a duck".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please , provide helpful suggestions.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it broken after if we remove volatile? If it's not broken, just remove volatile. Otherwise try to figure out why the compiler breaks it and fix the input/output/clobber list.

volatile means there is some unexplainable side effect, for example writing/reading some CSRs, invoking hardware memory barriers, etc. These just do not apply for a SIMD optimization.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to tell compiler not to optimize or modify the asm block.

Copy link

@xry111 xry111 Jul 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to tell compiler not to optimize or modify the asm block.

It won't modify the asm block because the compiler does not parse asm at all.

The only possible optimizations are:

  1. Reorder the asm block (as a whole block: the compiler cannot insert something in the middle of the asm block because it does not know how to parse asm) with the instructions generated by the compiler itself.
  2. Remove the asm block completely (not "removing a part of the asm", again because the compiler does not know how to parse asm).

If the input/output/clobber lists are correct, 2 should be disabled, and 1 should be performed in a constrained way not to break the code.

Again volatile means "there is unexplainable side effect", not "don't modify the asm block".

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, and maybe

  1. Duplicating the asm block due to a loop unrolling etc. Again if the input/output/clobber lists are correct this should happen in a safe way.

(
"srli.w $t0, %[cblk], 4 \n\t"
"xvldrepl.w $xr0, %[step], 0 \n\t"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason not to use xvreplgr2vr.w here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, will update.

"beqz $t0, 999f \n\t"
"16: \n\t"
"xvld $xr1, %[data], 0 \n\t"
"xvld $xr2, %[data], 32 \n\t"
"xvffint.s.w $xr1, $xr1 \n\t"
"xvffint.s.w $xr2, $xr2 \n\t"
"xvfmul.s $xr1, $xr1, $xr0 \n\t"
"xvfmul.s $xr2, $xr2, $xr0 \n\t"
"xvst $xr1, %[data], 0 \n\t"
"xvst $xr2, %[data], 32 \n\t"
"addi.d %[data], %[data], 64 \n\t"
"addi.w $t0, $t0, -1 \n\t"
"bnez $t0, 16b \n\t"
"999: \n\t"
: [data]"+r"(datap)
: [step]"r"(&stepsize), [cblk]"r"(cblk_size)
: "$t0", "$xr0", "$xr1", "$xr2", "memory"
);
i = cblk_size & ~15U;
}
#endif
for (; i < cblk_size; ++i) {
OPJ_FLOAT32 tmp = ((OPJ_FLOAT32)(*datap)) * stepsize;
memcpy(datap, &tmp, sizeof(tmp));
Expand All @@ -1797,6 +1825,34 @@ static void opj_t1_clbl_decode_processor(void* user_data, opj_tls_t* tls)
(OPJ_SIZE_T)x];
for (j = 0; j < cblk_h; ++j) {
i = 0;
#ifdef __loongarch_asx
{
OPJ_INT32* ptr1 = datap + j * cblk_w;
OPJ_INT32* ptr2 = tiledp + j * (OPJ_SIZE_T)tile_w;
OPJ_UINT32 step_size = 0;
asm volatile
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, try to avoid volatile.

(
"srli.w $t0, %[cblk], 3 \n\t"
"xvxor.v $xr1, $xr1, $xr1 \n\t"
"xvaddi.wu $xr1, $xr1, 2 \n\t"
"beqz $t0, 999f \n\t"
"8: \n\t"
"xvldx $xr0, %[ptr1], %[step] \n\t"
"xvdiv.w $xr0, $xr0, $xr1 \n\t"
"xvstx $xr0, %[ptr2], %[step] \n\t"
"addi.w %[step], %[step], 32 \n\t"
"addi.w $t0, $t0, -1 \n\t"
"bnez $t0, 8b \n\t"
"999: \n\t"
: [step]"+r"(step_size)
: [ptr1]"r"(ptr1),
[ptr2]"r"(ptr2),
[cblk]"r"(cblk_w)
: "$t0", "$xr0", "memory"
);
i = cblk_w & ~7U;
}
#endif
for (; i < (cblk_w & ~(OPJ_UINT32)3U); i += 4U) {
OPJ_INT32 tmp0 = datap[(j * cblk_w) + i + 0U];
OPJ_INT32 tmp1 = datap[(j * cblk_w) + i + 1U];
Expand All @@ -1807,6 +1863,7 @@ static void opj_t1_clbl_decode_processor(void* user_data, opj_tls_t* tls)
((OPJ_INT32*)tiledp)[(j * (OPJ_SIZE_T)tile_w) + i + 2U] = tmp2 / 2;
((OPJ_INT32*)tiledp)[(j * (OPJ_SIZE_T)tile_w) + i + 3U] = tmp3 / 2;
}

for (; i < cblk_w; ++i) {
OPJ_INT32 tmp = datap[(j * cblk_w) + i];
((OPJ_INT32*)tiledp)[(j * (OPJ_SIZE_T)tile_w) + i] = tmp / 2;
Expand Down