This article presents common patterns that can be used to workaround the limitation of some AVX-512 instructions that can only operate on 32-bit and 64-bit elements. For example, logical instructions including VPANDx
, VPANDNx
, VPORx
, VPXORx
, and VPTERNLOGx
only work with 32-bit or 64-bit elements. Even AVX512_BW
extension doesn't add complementary versions of these instructions that would operate on 8-bit and 16-bit elements. However, it's possible to use different instructions that operate on 8-bit and 16-bit elements to perform common operations like predicated zeroing or filling with ones.
The following instructions can only be used with 32-bit and 64-bit elements:
VPANDD
, VPANDQ
VPANDND
, VPANDNQ
VPORD
, VPORQ
VPXORD
, VPXORQ
VPTERNLOGD
, VPTERNLOGQ
What we are interested in this article is implementing the following predicated operations:
It's possible to replace the VPXORx instruction with instruction that would do the exact same thing when called with the same inputs. The simplest and the most straightforward approach is to subtract them as x - x == 0
. The following code implements an efficient element zeroing for all element sizes:
// AVX-512: Clear 8-bit elements depending on the predicate `k`
static inline __m512i avx512_mask_clear_epi8(__m512i x, __mmask64 k) {
return _mm512_mask_sub_epi8(x, k, x, x);
}
// AVX-512: Clear 16-bit elements depending on the predicate `k`
static inline __m512i avx512_mask_clear_epi16(__m512i x, __mmask32 k) {
return _mm512_mask_sub_epi16(x, k, x, x);
}
// AVX-512: Clear 32-bit elements depending on the predicate `k`
static inline __m512i avx512_mask_clear_epi32(__m512i x, __mmask16 k) {
return _mm512_mask_xor_epi32(x, k, x, x);
}
// AVX-512: Clear 64-bit elements depending on the predicate `k`
static inline __m512i avx512_mask_clear_epi64(__m512i x, __mmask8 k) {
return _mm512_mask_xor_epi64(x, k, x, x);
}
Filling all bits to ones is trivial with VPTERNLOGx
instruction with0xFF
predicate, which means to ignore all inputs and set all outputs to 1
(bitwise). However, since there are no ternary logic instruction for 8-bit and 16-bit elements we have to get a little more creative.
To help with our approach we would need an additional vector register, which would have all bits set to ones (of all elements). Such vector register can be used as an input to implement both bitwise fill and negation operations. Bitwise fills can be implemented by using saturated additions (which are supported for 8-bit and 16-bit elements) and bitwise negation can be implemented as subtraction from ~0
.
First, let's introduce the all ones constant:
static inline __m512i avx512_ones() {
__m512i u = _mm512_undefined_epi32();
return _mm512_ternarylogic_epi32(u, u, u, 0xFF);
}
This constant would be used in some places, but to make sure that the C/C++ compiler is not emitting this sequence multiple times it could be better to just pass it to all the functions, which is not done here to keep things simple.
The bitwise 8-bit and 16-bit element fill implemented via unsigned saturating addition is below:
// AVX-512: Fill 8-bit elements depending on the predicate `k`
static inline __m512i avx512_mask_fill_epi8(__m512i x, __mmask64 k) {
return _mm512_mask_adds_epu8(x, k, x, avx512_ones());
}
// AVX-512: Fill 16-bit elements depending on the predicate `k`
static inline __m512i avx512_mask_fill_epi16(__m512i x, __mmask32 k) {
return _mm512_mask_adds_epu16(x, k, x, avx512_ones());
}
Alternatively, instead of saturated addition it's possible to use unsigned maximum as well, which could be cheaper on some hardware:
// AVX-512: Fill 8-bit elements depending on the predicate `k`
static inline __m512i avx512_mask_fill_epi8(__m512i x, __mmask64 k) {
return _mm512_mask_max_epu8(x, k, x, avx512_ones());
}
// AVX-512: Fill 16-bit elements depending on the predicate `k`
static inline __m512i avx512_mask_fill_epi16(__m512i x, __mmask32 k) {
return _mm512_mask_max_epu16(x, k, x, avx512_ones());
}
The same trick can be used to implement bit negation, as ~x == ~0 - x
:
// AVX-512: Complement 8-bit elements with ones depending on the predicate `k`
static inline __m512i avx512_mask_not_epi8(__m512i x, __mmask64 k) {
return _mm512_mask_sub_epi8(x, k, avx512_ones(), x);
}
// AVX-512: Complement 16-bit elements with ones depending on the predicate `k`
static inline __m512i avx512_mask_not_epi16(__m512i x, __mmask32 k) {
return _mm512_mask_sub_epi16(x, k, avx512_ones(), x);
}
With the implementations presented above and a static mask it's possible to go further and to implement a logic to set particular bytes in a vector to zero, ones, or keep them untouched, and to do this with a single instruction. To make this possible we would need a register that would contain a mask of all bytes to be filled, and a predicate k
that would be used together with zeroing {k}{z}
to zero bytes, see below:
// AVX-512: Example of an approach to set bytes to all ones, all zeros, or to keep them untouched.
__m512i example(__m512i x) {
// Vector of bytes to be filled with ones, unprocessed bytes must be zero.
__m512i ones = _mm512_set1_epi32(0xFF000000);
// Mask of bytes to be processed - zero bits indicate bytes in the input to be zeroed.
__mmask64 k = 0b10101010101010101010101010101010;
// A single instruction to fill, clear, or keep bytes untouched depending on `ones` and k`.
return _mm512_mask_max_epu8(x, k, x, ones);
}
In the example above every MSB byte in all 32-bit elements of x
is set to all ones and fhe first and third LSB bytes are cleared. The operation can be written as (x | 0xFF000000) & 0xFF00FF00
and it's performed with a single instruction. Of course it can do any 64-byte pattern, bits can be complemented instead of filled, it all depends on use-case.
If the constant -1 is already in a register, it can be used to implement the following operations (works for integers of all sizes):
x - (-1) == x + 1
x + (-1) == x - 1
x >= 0 == x > -1
x < 0 == x <= -1
(x + (0x1 << ELEMENT_SIZE)) >> 1
with VPAVGx
instruction, could be also rewritten as (x >> 1) | 0x80
with VPAVGB
and (x >> 1) | 0x8000
with VPAVGW
Additionally, with zero instead of -1 VPAVGx
can be used as a rounding shift right by 1 bit, although I have never used this instruction for such purpose.
This article only aimed at using byte/word instructions to implement some special cases of bitwise operations that are used often in SIMD programming. If anyone needs regular and predicated bitwise operations that work with 8-bit and 16-bit element sizes the best way is to use 32-bit or 64-bit versions without a predicate and to merge the result with predicated VMOVDQU8
or VMOVDQU16
later.