diff --git a/src/lib/openjp2/dwt.c b/src/lib/openjp2/dwt.c index 11aae472d..a7837a07f 100644 --- a/src/lib/openjp2/dwt.c +++ b/src/lib/openjp2/dwt.c @@ -55,6 +55,9 @@ #if (defined(__AVX2__) || defined(__AVX512F__)) #include #endif +#ifdef __ARM_NEON +#include +#endif #if defined(__GNUC__) #pragma GCC poison malloc calloc realloc free @@ -73,7 +76,7 @@ /** Number of int32 values in a AVX2 register */ #define VREG_INT_COUNT 8 #else -/** Number of int32 values in a SSE2 register */ +/** Number of int32 values in a SSE2 or NEON register */ #define VREG_INT_COUNT 4 #endif @@ -699,7 +702,7 @@ static void opj_idwt53_h(const opj_dwt_t *dwt, #endif } -#if (defined(__SSE2__) || defined(__AVX2__) || defined(__AVX512F__)) && !defined(STANDARD_SLOW_VERSION) +#if (defined(__ARM_NEON) || defined(__SSE2__) || defined(__AVX2__) || defined(__AVX512F__)) && !defined(STANDARD_SLOW_VERSION) /* Conveniency macros to improve the readability of the formulas */ #if defined(__AVX512F__) @@ -722,6 +725,16 @@ static void opj_idwt53_h(const opj_dwt_t *dwt, #define ADD(x,y) _mm256_add_epi32((x),(y)) #define SUB(x,y) _mm256_sub_epi32((x),(y)) #define SAR(x,y) _mm256_srai_epi32((x),(y)) +#elif defined(__ARM_NEON) +#define VREG int32x4_t +#define LOAD_CST(x) vdupq_n_s32(x) +#define LOAD(x) vld1q_s32((const int32_t*)(x)) +#define LOADU(x) vld1q_s32((const int32_t*)(x)) +#define STORE(x,y) vst1q_s32((int32_t*)(x),(y)) +#define STOREU(x,y) vst1q_s32((int32_t*)(x),(y)) +#define ADD(x,y) vaddq_s32((x),(y)) +#define SUB(x,y) vsubq_s32((x),(y)) +#define SAR(x,y) vshrq_n_s32((x),(y)) #else #define VREG __m128i #define LOAD_CST(x) _mm_set1_epi32(x) @@ -755,9 +768,9 @@ void opj_idwt53_v_final_memcpy(OPJ_INT32* tiledp_col, } } -/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or - * 16 in AVX2, when top-most pixel is on even coordinate */ -static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2( +/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2 and NEON, + * or 16 in AVX2, when top-most pixel is on even coordinate */ +static void opj_idwt53_v_cas0_mcols_SIMD( OPJ_INT32* tmp, const OPJ_INT32 sn, const OPJ_INT32 len, @@ -862,9 +875,9 @@ static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2( } -/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or - * 16 in AVX2, when top-most pixel is on odd coordinate */ -static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2( +/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2 and NEON, + * or 16 in AVX2, when top-most pixel is on odd coordinate */ +static void opj_idwt53_v_cas1_mcols_SIMD( OPJ_INT32* tmp, const OPJ_INT32 sn, const OPJ_INT32 len, @@ -1104,11 +1117,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt, if (dwt->cas == 0) { /* If len == 1, unmodified value */ -#if (defined(__SSE2__) || defined(__AVX2__)) +#if (defined(__ARM_NEON) || defined(__SSE2__) || defined(__AVX2__)) if (len > 1 && nb_cols == PARALLEL_COLS_53) { - /* Same as below general case, except that thanks to SSE2/AVX2 */ + /* Same as below general case, except that thanks to SIMD */ /* we can efficiently process 8/16 columns in parallel */ - opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride); + opj_idwt53_v_cas0_mcols_SIMD(dwt->mem, sn, len, tiledp_col, stride); return; } #endif @@ -1147,11 +1160,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt, return; } -#if (defined(__SSE2__) || defined(__AVX2__)) +#if (defined(__ARM_NEON) || defined(__SSE2__) || defined(__AVX2__)) if (len > 2 && nb_cols == PARALLEL_COLS_53) { - /* Same as below general case, except that thanks to SSE2/AVX2 */ + /* Same as below general case, except that thanks to SIMD */ /* we can efficiently process 8/16 columns in parallel */ - opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride); + opj_idwt53_v_cas1_mcols_SIMD(dwt->mem, sn, len, tiledp_col, stride); return; } #endif