Pixie
Loading...
Searching...
No Matches
excess.h
1#pragma once
2
3#include <pixie/bits.h>
4
5#include <bit>
6#include <cstddef>
7#include <cstdint>
8
9namespace pixie::experimental {
10
11#ifdef PIXIE_AVX2_SUPPORT
12// clang-format off
13static inline const __m256i excess_branch_lut_em4 = _mm256_setr_epi8(
14 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
15 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
16 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
17 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00);
18
19static inline const __m256i excess_branch_lut_em3 = _mm256_setr_epi8(
20 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
21 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
22 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
23 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00);
24
25static inline const __m256i excess_branch_lut_em2 = _mm256_setr_epi8(
26 0x02, 0x08, 0x08, 0x00, 0x0A, 0x00, 0x00, 0x00,
27 0x0A, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
28 0x02, 0x08, 0x08, 0x00, 0x0A, 0x00, 0x00, 0x00,
29 0x0A, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00);
30
31static inline const __m256i excess_branch_lut_em1 = _mm256_setr_epi8(
32 0x01, 0x04, 0x05, 0x00, 0x05, 0x00, 0x01, 0x00,
33 0x01, 0x04, 0x05, 0x00, 0x05, 0x00, 0x01, 0x00,
34 0x01, 0x04, 0x05, 0x00, 0x05, 0x00, 0x01, 0x00,
35 0x01, 0x04, 0x05, 0x00, 0x05, 0x00, 0x01, 0x00);
36
37static inline const __m256i excess_branch_lut_e0 = _mm256_setr_epi8(
38 0x00, 0x02, 0x02, 0x08, 0x00, 0x0A, 0x0A, 0x00,
39 0x00, 0x0A, 0x0A, 0x00, 0x08, 0x02, 0x02, 0x00,
40 0x00, 0x02, 0x02, 0x08, 0x00, 0x0A, 0x0A, 0x00,
41 0x00, 0x0A, 0x0A, 0x00, 0x08, 0x02, 0x02, 0x00);
42
43static inline const __m256i excess_branch_lut_e1 = _mm256_setr_epi8(
44 0x00, 0x01, 0x00, 0x05, 0x00, 0x05, 0x04, 0x01,
45 0x00, 0x01, 0x00, 0x05, 0x00, 0x05, 0x04, 0x01,
46 0x00, 0x01, 0x00, 0x05, 0x00, 0x05, 0x04, 0x01,
47 0x00, 0x01, 0x00, 0x05, 0x00, 0x05, 0x04, 0x01);
48
49static inline const __m256i excess_branch_lut_e2 = _mm256_setr_epi8(
50 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0A,
51 0x00, 0x00, 0x00, 0x0A, 0x00, 0x08, 0x08, 0x02,
52 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0A,
53 0x00, 0x00, 0x00, 0x0A, 0x00, 0x08, 0x08, 0x02);
54
55static inline const __m256i excess_branch_lut_e3 = _mm256_setr_epi8(
56 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04,
57 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04,
58 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04,
59 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04);
60
61static inline const __m256i excess_branch_lut_e4 = _mm256_setr_epi8(
62 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
63 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08,
64 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
65 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08);
66// clang-format on
67
68static inline __m256i excess_bit_masks_16x() noexcept {
69 return _mm256_setr_epi16(0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020,
70 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800,
71 0x1000, 0x2000, 0x4000, (int16_t)0x8000);
72}
73
74static inline __m256i excess_prefix_sum_16x_i16(__m256i v) noexcept {
75 __m256i x = v;
76 __m256i t = _mm256_slli_si256(x, 2);
77 x = _mm256_add_epi16(x, t);
78 t = _mm256_slli_si256(x, 4);
79 x = _mm256_add_epi16(x, t);
80 t = _mm256_slli_si256(x, 8);
81 x = _mm256_add_epi16(x, t);
82
83 __m128i lo = _mm256_extracti128_si256(x, 0);
84 __m128i hi = _mm256_extracti128_si256(x, 1);
85 const int16_t carry = (int16_t)_mm_extract_epi16(lo, 7);
86 hi = _mm_add_epi16(hi, _mm_set1_epi16(carry));
87
88 __m256i out = _mm256_castsi128_si256(lo);
89 out = _mm256_inserti128_si256(out, hi, 1);
90 return out;
91}
92
93static inline int16_t excess_last_prefix_16x_i16(__m256i pref) noexcept {
94 __m128i hi = _mm256_extracti128_si256(pref, 1);
95 return (int16_t)_mm_extract_epi16(hi, 7);
96}
97
98static inline __m256i excess_bit_masks_32x8() noexcept {
99 return _mm256_setr_epi8(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (char)0x80,
100 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (char)0x80,
101 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (char)0x80,
102 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (char)0x80);
103}
104
105static inline __m256i excess_byte_selectors_32x8() noexcept {
106 return _mm256_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
107 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3);
108}
109
110static inline __m256i excess_prefix_sum_32x_i8(__m256i v) noexcept {
111 __m256i x = v;
112 __m256i t = _mm256_slli_si256(x, 1);
113 x = _mm256_add_epi8(x, t);
114 t = _mm256_slli_si256(x, 2);
115 x = _mm256_add_epi8(x, t);
116 t = _mm256_slli_si256(x, 4);
117 x = _mm256_add_epi8(x, t);
118 t = _mm256_slli_si256(x, 8);
119 x = _mm256_add_epi8(x, t);
120
121 __m128i lo = _mm256_extracti128_si256(x, 0);
122 __m128i hi = _mm256_extracti128_si256(x, 1);
123 const int8_t carry = (int8_t)_mm_extract_epi8(lo, 15);
124 hi = _mm_add_epi8(hi, _mm_set1_epi8(carry));
125
126 __m256i out = _mm256_castsi128_si256(lo);
127 out = _mm256_inserti128_si256(out, hi, 1);
128 return out;
129}
130
131static inline int8_t excess_last_prefix_32x_i8(__m256i pref) noexcept {
132 __m128i hi = _mm256_extracti128_si256(pref, 1);
133 return (int8_t)_mm_extract_epi8(hi, 15);
134}
135
136static inline void excess_positions_512_branching_lut(const uint64_t* s,
137 int target_x,
138 uint64_t* out) noexcept {
139 out[0] = out[1] = out[2] = out[3] = 0;
140 out[4] = out[5] = out[6] = out[7] = 0;
141
142 if (target_x < -512 || target_x > 512) {
143 return;
144 }
145
146 int cur = 0;
147 const __m256i vdelta =
148 _mm256_setr_epi8(-4, -2, -2, 0, -2, 0, 0, 2, -2, 0, 0, 2, 0, 2, 2, 4, -4,
149 -2, -2, 0, -2, 0, 0, 2, -2, 0, 0, 2, 0, 2, 2, 4);
150 const __m256i vmult = _mm256_set1_epi16(0x1001);
151 const __m128i vnibble_mask = _mm_set1_epi8(0x0F);
152
153 for (int k = 0; k < 4; ++k) {
154 __m128i word_vec = _mm_loadu_si128((const __m128i*)&s[2 * k]);
155 __m128i lo_nibbles = _mm_and_si128(word_vec, vnibble_mask);
156 __m128i hi_nibbles =
157 _mm_and_si128(_mm_srli_epi16(word_vec, 4), vnibble_mask);
158
159 __m128i unpack_lo = _mm_unpacklo_epi8(lo_nibbles, hi_nibbles);
160 __m128i unpack_hi = _mm_unpackhi_epi8(lo_nibbles, hi_nibbles);
161 __m256i nibbles = _mm256_inserti128_si256(_mm256_castsi128_si256(unpack_lo),
162 unpack_hi, 1);
163
164 __m256i ps = _mm256_shuffle_epi8(vdelta, nibbles);
165 ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 1));
166 ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 2));
167 ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 4));
168 ps = _mm256_add_epi8(ps, _mm256_slli_si256(ps, 8));
169
170 __m128i ps_lo = _mm256_castsi256_si128(ps);
171 __m128i ps_hi = _mm256_extracti128_si256(ps, 1);
172 __m128i carry = _mm_set1_epi8((int8_t)_mm_extract_epi8(ps_lo, 15));
173 ps_hi = _mm_add_epi8(ps_hi, carry);
174 ps = _mm256_inserti128_si256(_mm256_castsi128_si256(ps_lo), ps_hi, 1);
175
176 __m256i b = _mm256_permute2x128_si256(ps, ps, 0x08);
177 __m256i excl_ps = _mm256_alignr_epi8(ps, b, 15);
178
179 int target_rel = target_x - cur;
180 int block_delta =
181 2 * (std::popcount(s[2 * k]) + std::popcount(s[2 * k + 1])) - 128;
182
183 const int d = 2 * target_rel - block_delta;
184 if (d < -128 || d > 128) {
185 cur += block_delta;
186 continue;
187 }
188
189 if (target_rel == 128 || target_rel == -128) {
190 out[2 * k + 1] |= (uint64_t{1} << 63);
191 cur += block_delta;
192 continue;
193 }
194
195 __m256i t = _mm256_sub_epi8(_mm256_set1_epi8((int8_t)target_rel), excl_ps);
196 __m256i total_match = _mm256_setzero_si256();
197 __m256i t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(-4));
198 total_match = _mm256_or_si256(
199 total_match,
200 _mm256_and_si256(t_eq,
201 _mm256_shuffle_epi8(excess_branch_lut_em4, nibbles)));
202 t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(-3));
203 total_match = _mm256_or_si256(
204 total_match,
205 _mm256_and_si256(t_eq,
206 _mm256_shuffle_epi8(excess_branch_lut_em3, nibbles)));
207 t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(-2));
208 total_match = _mm256_or_si256(
209 total_match,
210 _mm256_and_si256(t_eq,
211 _mm256_shuffle_epi8(excess_branch_lut_em2, nibbles)));
212 t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(-1));
213 total_match = _mm256_or_si256(
214 total_match,
215 _mm256_and_si256(t_eq,
216 _mm256_shuffle_epi8(excess_branch_lut_em1, nibbles)));
217 t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(0));
218 total_match = _mm256_or_si256(
219 total_match,
220 _mm256_and_si256(t_eq,
221 _mm256_shuffle_epi8(excess_branch_lut_e0, nibbles)));
222 t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(1));
223 total_match = _mm256_or_si256(
224 total_match,
225 _mm256_and_si256(t_eq,
226 _mm256_shuffle_epi8(excess_branch_lut_e1, nibbles)));
227 t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(2));
228 total_match = _mm256_or_si256(
229 total_match,
230 _mm256_and_si256(t_eq,
231 _mm256_shuffle_epi8(excess_branch_lut_e2, nibbles)));
232 t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(3));
233 total_match = _mm256_or_si256(
234 total_match,
235 _mm256_and_si256(t_eq,
236 _mm256_shuffle_epi8(excess_branch_lut_e3, nibbles)));
237 t_eq = _mm256_cmpeq_epi8(t, _mm256_set1_epi8(4));
238 total_match = _mm256_or_si256(
239 total_match,
240 _mm256_and_si256(t_eq,
241 _mm256_shuffle_epi8(excess_branch_lut_e4, nibbles)));
242
243 __m256i res = _mm256_maddubs_epi16(total_match, vmult);
244 __m128i packed = _mm_packus_epi16(_mm256_castsi256_si128(res),
245 _mm256_extracti128_si256(res, 1));
246 _mm_storeu_si128((__m128i*)&out[2 * k], packed);
247
248 cur += block_delta;
249 }
250}
251#else
252static inline void excess_positions_512_branching_lut(const uint64_t* s,
253 int target_x,
254 uint64_t* out) noexcept {
255 excess_positions_512(s, target_x, out);
256}
257#endif
258
259#ifdef PIXIE_AVX512_SUPPORT
260static inline __m512i excess_lut_delta_64x() noexcept {
261 return _mm512_broadcast_i32x4(
262 _mm_setr_epi8(-4, -2, -2, 0, -2, 0, 0, 2, -2, 0, 0, 2, 0, 2, 2, 4));
263}
264
265static inline __m512i excess_lut_pos0_64x() noexcept {
266 return _mm512_broadcast_i32x4(
267 _mm_setr_epi8(-1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1));
268}
269
270static inline __m512i excess_lut_pos1_64x() noexcept {
271 return _mm512_broadcast_i32x4(
272 _mm_setr_epi8(-2, 0, 0, 2, -2, 0, 0, 2, -2, 0, 0, 2, -2, 0, 0, 2));
273}
274
275static inline __m512i excess_lut_pos2_64x() noexcept {
276 return _mm512_broadcast_i32x4(
277 _mm_setr_epi8(-3, -1, -1, 1, -1, 1, 1, 3, -3, -1, -1, 1, -1, 1, 1, 3));
278}
279
280static inline __m512i excess_bit_masks_64x8() noexcept {
281 alignas(64) static constexpr int8_t masks[64] = {
282 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80,
283 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80,
284 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80,
285 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80,
286 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80,
287 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80,
288 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80,
289 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (int8_t)0x80};
290 return _mm512_load_si512((const void*)masks);
291}
292
293static inline __m512i excess_byte_selectors_64x8() noexcept {
294 alignas(64) static constexpr int8_t selectors[64] = {
295 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
296 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5,
297 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7};
298 return _mm512_load_si512((const void*)selectors);
299}
300
301static inline __m512i excess_prefix_sum_64x_i8(__m512i v) noexcept {
302 __m512i x = v;
303 __m512i t = _mm512_bslli_epi128(x, 1);
304 x = _mm512_add_epi8(x, t);
305 t = _mm512_bslli_epi128(x, 2);
306 x = _mm512_add_epi8(x, t);
307 t = _mm512_bslli_epi128(x, 4);
308 x = _mm512_add_epi8(x, t);
309 t = _mm512_bslli_epi128(x, 8);
310 x = _mm512_add_epi8(x, t);
311
312 const __m512i last_byte = _mm512_set1_epi8(15);
313 const __m512i lane_carry = _mm512_shuffle_epi8(x, last_byte);
314 const __m512i shift1_idx =
315 _mm512_setr_epi32(0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
316 const __m512i shift2_idx =
317 _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7);
318 const __m512i shift3_idx =
319 _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3);
320
321 __m512i lane_base =
322 _mm512_maskz_permutexvar_epi32(0xFFF0, shift1_idx, lane_carry);
323 lane_base = _mm512_add_epi8(lane_base, _mm512_maskz_permutexvar_epi32(
324 0xFF00, shift2_idx, lane_carry));
325 lane_base = _mm512_add_epi8(lane_base, _mm512_maskz_permutexvar_epi32(
326 0xF000, shift3_idx, lane_carry));
327 return _mm512_add_epi8(x, lane_base);
328}
329
330static inline __m512i excess_prefix_sum_2x32_i8(__m512i v) noexcept {
331 __m512i x = v;
332 __m512i t = _mm512_bslli_epi128(x, 1);
333 x = _mm512_add_epi8(x, t);
334 t = _mm512_bslli_epi128(x, 2);
335 x = _mm512_add_epi8(x, t);
336 t = _mm512_bslli_epi128(x, 4);
337 x = _mm512_add_epi8(x, t);
338 t = _mm512_bslli_epi128(x, 8);
339 x = _mm512_add_epi8(x, t);
340
341 const __m512i last_byte = _mm512_set1_epi8(15);
342 const __m512i lane_carry = _mm512_shuffle_epi8(x, last_byte);
343 const __m512i prev_lane_idx =
344 _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8);
345 const __mmask16 carry_to_second_lane_of_each_half = 0xF0F0;
346 const __m512i lane_base = _mm512_maskz_permutexvar_epi32(
347 carry_to_second_lane_of_each_half, prev_lane_idx, lane_carry);
348 return _mm512_add_epi8(x, lane_base);
349}
350
351static inline __m512i excess_nibbles_64x_from_256(__m256i words) noexcept {
352 const __m512i bytes16 = _mm512_cvtepu8_epi16(words);
353 const __m512i low = _mm512_and_si512(bytes16, _mm512_set1_epi16(0x000F));
354 const __m512i high = _mm512_and_si512(_mm512_srli_epi16(bytes16, 4),
355 _mm512_set1_epi16(0x000F));
356 return _mm512_or_si512(low, _mm512_slli_epi16(high, 8));
357}
358
359static inline __m512i excess_exclusive_prefix_2x32_i8(__m512i pref) noexcept {
360 const __m512i zero = _mm512_setzero_si512();
361 __m512i out = _mm512_alignr_epi8(pref, zero, 15);
362
363 const __m512i last_byte = _mm512_set1_epi8(15);
364 const __m512i lane_carry = _mm512_shuffle_epi8(pref, last_byte);
365 const __m512i prev_lane_idx =
366 _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8);
367 const __m512i carry_dwords =
368 _mm512_permutexvar_epi32(prev_lane_idx, lane_carry);
369 const __mmask64 first_byte_of_second_lane_in_each_half =
370 (uint64_t{1} << 16) | (uint64_t{1} << 48);
371 return _mm512_or_si512(
372 out, _mm512_maskz_mov_epi8(first_byte_of_second_lane_in_each_half,
373 carry_dwords));
374}
375
376static inline uint64_t excess_repeat_byte(int value) noexcept {
377 return uint64_t{0x0101010101010101} *
378 static_cast<uint8_t>(static_cast<int8_t>(value));
379}
380
381static inline void excess_positions_512_lut_avx512(const uint64_t* s,
382 int target_x,
383 uint64_t* out) noexcept {
384 out[0] = out[1] = out[2] = out[3] = 0;
385 out[4] = out[5] = out[6] = out[7] = 0;
386
387 if (target_x < -512 || target_x > 512) {
388 return;
389 }
390
391 static const __m512i vdelta = excess_lut_delta_64x();
392 static const __m512i vpos0 = excess_lut_pos0_64x();
393 static const __m512i vpos1 = excess_lut_pos1_64x();
394 static const __m512i vpos2 = excess_lut_pos2_64x();
395 static const __m512i vbit0 = _mm512_set1_epi8(1);
396 static const __m512i vbit1 = _mm512_set1_epi8(2);
397 static const __m512i vbit2 = _mm512_set1_epi8(4);
398 static const __m512i vbit3 = _mm512_set1_epi8(8);
399 static const __m512i vmult = _mm512_set1_epi16(0x1001);
400
401 for (int k = 0; k < 2; ++k) {
402 const int base_word = 4 * k;
403 const int delta0 =
404 2 * (std::popcount(s[base_word]) + std::popcount(s[base_word + 1])) -
405 128;
406 const int delta1 = 2 * (std::popcount(s[base_word + 2]) +
407 std::popcount(s[base_word + 3])) -
408 128;
409 const int target0 = target_x;
410 const int target1 = target_x - delta0;
411 const bool reachable0 = [&] {
412 const int d = 2 * target0 - delta0;
413 return -128 <= d && d <= 128;
414 }();
415 const bool reachable1 = [&] {
416 const int d = 2 * target1 - delta1;
417 return -128 <= d && d <= 128;
418 }();
419
420 if (!reachable0 && !reachable1) {
421 target_x -= delta0 + delta1;
422 continue;
423 }
424
425 const __m256i words =
426 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&s[base_word]));
427 const __m512i nibbles = excess_nibbles_64x_from_256(words);
428 const __m512i ps =
429 excess_prefix_sum_2x32_i8(_mm512_shuffle_epi8(vdelta, nibbles));
430 const __m512i excl_ps = excess_exclusive_prefix_2x32_i8(ps);
431 const uint64_t repeated0 = excess_repeat_byte(target0);
432 const uint64_t repeated1 = excess_repeat_byte(target1);
433 const __m512i vtgt =
434 _mm512_setr_epi64(repeated0, repeated0, repeated0, repeated0, repeated1,
435 repeated1, repeated1, repeated1);
436 const __m512i t = _mm512_sub_epi8(vtgt, excl_ps);
437
438 const __mmask64 cmp0 =
439 _mm512_cmpeq_epi8_mask(_mm512_shuffle_epi8(vpos0, nibbles), t);
440 const __mmask64 cmp1 =
441 _mm512_cmpeq_epi8_mask(_mm512_shuffle_epi8(vpos1, nibbles), t);
442 const __mmask64 cmp2 =
443 _mm512_cmpeq_epi8_mask(_mm512_shuffle_epi8(vpos2, nibbles), t);
444 const __mmask64 cmp3 = _mm512_cmpeq_epi8_mask(ps, vtgt);
445 __m512i total_match = _mm512_maskz_mov_epi8(cmp0, vbit0);
446 total_match =
447 _mm512_or_si512(total_match, _mm512_maskz_mov_epi8(cmp1, vbit1));
448 total_match =
449 _mm512_or_si512(total_match, _mm512_maskz_mov_epi8(cmp2, vbit2));
450 total_match =
451 _mm512_or_si512(total_match, _mm512_maskz_mov_epi8(cmp3, vbit3));
452
453 const __mmask64 active =
454 (reachable0 ? __mmask64{0x00000000FFFFFFFFull} : __mmask64{0}) |
455 (reachable1 ? __mmask64{0xFFFFFFFF00000000ull} : __mmask64{0});
456 total_match = _mm512_maskz_mov_epi8(active, total_match);
457
458 const __m512i res = _mm512_maddubs_epi16(total_match, vmult);
459 const __m256i packed = _mm512_cvtepi16_epi8(res);
460 _mm256_storeu_si256(reinterpret_cast<__m256i*>(&out[base_word]), packed);
461
462 target_x -= delta0 + delta1;
463 }
464}
465#else
466static inline void excess_positions_512_lut_avx512(const uint64_t* s,
467 int target_x,
468 uint64_t* out) noexcept {
469 excess_positions_512(s, target_x, out);
470}
471#endif
472
473static inline void excess_positions_512_expand(const uint64_t* s,
474 int target_x,
475 uint64_t* out) noexcept {
476 out[0] = out[1] = out[2] = out[3] = 0;
477 out[4] = out[5] = out[6] = out[7] = 0;
478
479 if (target_x < -512 || target_x > 512) {
480 return;
481 }
482
483#ifdef PIXIE_AVX2_SUPPORT
484 static const __m256i masks = excess_bit_masks_16x();
485 static const __m256i vzero = _mm256_setzero_si256();
486 static const __m256i vallones = _mm256_cmpeq_epi16(vzero, vzero);
487 static const __m256i vminus1 = _mm256_set1_epi16(-1);
488 static const __m256i vtwo = _mm256_set1_epi16(2);
489 const __m256i vtarget = _mm256_set1_epi16((int16_t)target_x);
490
491 int cur = 0;
492 for (int block = 0; block < 4; ++block) {
493 const int target_rel = target_x - cur;
494 if (target_rel <= -64 || target_rel >= 64) {
495 const int block_delta =
496 2 * (std::popcount(s[2 * block]) + std::popcount(s[2 * block + 1])) -
497 128;
498 const int reachability = 2 * target_rel - block_delta;
499 if (reachability < -128 || reachability > 128) {
500 cur += block_delta;
501 continue;
502 }
503 }
504
505 for (int j = 0; j < 8; ++j) {
506 const int k = 8 * block + j;
507 const size_t word_idx = size_t(k) >> 2;
508 const size_t shift = size_t(k & 3) * 16;
509 const uint16_t bits16 =
510 static_cast<uint16_t>((s[word_idx] >> shift) & 0xFFFFull);
511
512 const __m256i vb = _mm256_set1_epi16((int16_t)bits16);
513 const __m256i m = _mm256_and_si256(vb, masks);
514 const __m256i is_zero = _mm256_cmpeq_epi16(m, vzero);
515 const __m256i is_set = _mm256_andnot_si256(is_zero, vallones);
516 const __m256i steps =
517 _mm256_add_epi16(vminus1, _mm256_and_si256(is_set, vtwo));
518
519 const __m256i pref_rel = excess_prefix_sum_16x_i16(steps);
520 const __m256i base = _mm256_set1_epi16((int16_t)cur);
521 const __m256i pref_abs = _mm256_add_epi16(pref_rel, base);
522 const __m256i cmp = _mm256_cmpeq_epi16(pref_abs, vtarget);
523
524 const uint32_t m32 = (uint32_t)_mm256_movemask_epi8(cmp);
525 const uint16_t m16 = (uint16_t)_pext_u32(m32, 0xAAAAAAAAu);
526
527 out[word_idx] |= uint64_t(m16) << shift;
528 cur += (int)excess_last_prefix_16x_i16(pref_rel);
529 }
530 }
531#else
532 int cur = 0;
533 for (size_t i = 0; i < 512; ++i) {
534 const uint64_t w = s[i >> 6];
535 const int bit = int((w >> (i & 63)) & 1ull);
536 cur += bit ? +1 : -1;
537 if (cur == target_x) {
538 out[i >> 6] |= (uint64_t{1} << (i & 63));
539 }
540 }
541#endif
542}
543
544static inline void excess_positions_512_expand8(const uint64_t* s,
545 int target_x,
546 uint64_t* out) noexcept {
547 out[0] = out[1] = out[2] = out[3] = 0;
548 out[4] = out[5] = out[6] = out[7] = 0;
549
550 if (target_x < -512 || target_x > 512) {
551 return;
552 }
553
554#ifdef PIXIE_AVX2_SUPPORT
555 static const __m256i byte_selectors = excess_byte_selectors_32x8();
556 static const __m256i masks = excess_bit_masks_32x8();
557 static const __m256i vzero = _mm256_setzero_si256();
558 static const __m256i vallones = _mm256_cmpeq_epi8(vzero, vzero);
559 static const __m256i vminus1 = _mm256_set1_epi8(-1);
560 static const __m256i vtwo = _mm256_set1_epi8(2);
561
562 int cur = 0;
563 for (int k = 0; k < 16; ++k) {
564 const size_t word_idx = size_t(k) >> 1;
565 const size_t shift = size_t(k & 1) * 32;
566 const uint32_t bits32 =
567 static_cast<uint32_t>((s[word_idx] >> shift) & 0xFFFFFFFFull);
568
569 const int target_rel = target_x - cur;
570 if (target_rel < -32 || target_rel > 32) {
571 cur += 2 * static_cast<int>(std::popcount(bits32)) - 32;
572 continue;
573 }
574
575 const __m256i src = _mm256_set1_epi32((int)bits32);
576 const __m256i bytes = _mm256_shuffle_epi8(src, byte_selectors);
577 const __m256i m = _mm256_and_si256(bytes, masks);
578 const __m256i is_zero = _mm256_cmpeq_epi8(m, vzero);
579 const __m256i is_set = _mm256_andnot_si256(is_zero, vallones);
580 const __m256i steps =
581 _mm256_add_epi8(vminus1, _mm256_and_si256(is_set, vtwo));
582
583 const __m256i pref_rel = excess_prefix_sum_32x_i8(steps);
584 const __m256i vtarget = _mm256_set1_epi8((int8_t)target_rel);
585 const __m256i cmp = _mm256_cmpeq_epi8(pref_rel, vtarget);
586 const uint32_t mask = static_cast<uint32_t>(_mm256_movemask_epi8(cmp));
587
588 out[word_idx] |= uint64_t(mask) << shift;
589 cur += static_cast<int>(excess_last_prefix_32x_i8(pref_rel));
590 }
591#else
592 int cur = 0;
593 for (size_t i = 0; i < 512; ++i) {
594 const uint64_t w = s[i >> 6];
595 const int bit = int((w >> (i & 63)) & 1ull);
596 cur += bit ? +1 : -1;
597 if (cur == target_x) {
598 out[i >> 6] |= (uint64_t{1} << (i & 63));
599 }
600 }
601#endif
602}
603
604static inline void excess_positions_512_expand_avx512(const uint64_t* s,
605 int target_x,
606 uint64_t* out) noexcept {
607 out[0] = out[1] = out[2] = out[3] = 0;
608 out[4] = out[5] = out[6] = out[7] = 0;
609
610 if (target_x < -512 || target_x > 512) {
611 return;
612 }
613
614#ifdef PIXIE_AVX512_SUPPORT
615 static const __m512i byte_selectors = excess_byte_selectors_64x8();
616 static const __m512i masks = excess_bit_masks_64x8();
617 static const __m512i vzero = _mm512_setzero_si512();
618 static const __m512i vallones = _mm512_set1_epi8(-1);
619 static const __m512i vminus1 = _mm512_set1_epi8(-1);
620 static const __m512i vtwo = _mm512_set1_epi8(2);
621
622 int cur = 0;
623 for (int k = 0; k < 8; ++k) {
624 const uint64_t bits64 = s[k];
625 const int target_rel = target_x - cur;
626 if (target_rel < -64 || target_rel > 64) {
627 cur += 2 * static_cast<int>(std::popcount(bits64)) - 64;
628 continue;
629 }
630
631 const __m512i src = _mm512_set1_epi64(static_cast<int64_t>(bits64));
632 const __m512i bytes = _mm512_shuffle_epi8(src, byte_selectors);
633 const __m512i m = _mm512_and_si512(bytes, masks);
634 const __mmask64 is_zero = _mm512_cmpeq_epi8_mask(m, vzero);
635 const __m512i is_set = _mm512_maskz_mov_epi8(~is_zero, vallones);
636 const __m512i steps =
637 _mm512_add_epi8(vminus1, _mm512_and_si512(is_set, vtwo));
638
639 const __m512i pref_rel = excess_prefix_sum_64x_i8(steps);
640 const __mmask64 match =
641 _mm512_cmpeq_epi8_mask(pref_rel, _mm512_set1_epi8((int8_t)target_rel));
642 out[k] = static_cast<uint64_t>(match);
643 cur += 2 * static_cast<int>(std::popcount(bits64)) - 64;
644 }
645#else
646 excess_positions_512_expand8(s, target_x, out);
647#endif
648}
649
650struct ExcessByteLut {
651 uint8_t masks[256][17]; // target index: T + 8
652 int8_t deltas[256];
653
654 constexpr ExcessByteLut() : masks{}, deltas{} {
655 for (int b = 0; b < 256; ++b) {
656 int pop = 0;
657 for (int i = 0; i < 8; ++i) {
658 if ((b >> i) & 1) {
659 pop++;
660 }
661 }
662 deltas[b] = static_cast<int8_t>(2 * pop - 8);
663
664 for (int t = -8; t <= 8; ++t) {
665 uint8_t mask = 0;
666 int cur_pop = 0;
667 for (int i = 0; i < 8; ++i) {
668 if ((b >> i) & 1) {
669 cur_pop++;
670 }
671 int excess = 2 * cur_pop - (i + 1);
672 if (excess == t) {
673 mask |= (1 << i);
674 }
675 }
676 masks[b][t + 8] = mask;
677 }
678 }
679 }
680};
681
682inline constexpr ExcessByteLut kExcessByteLut;
683
684static inline void excess_positions_512_byte_lut(const uint64_t* s,
685 int target_x,
686 uint64_t* out) noexcept {
687 out[0] = out[1] = out[2] = out[3] = 0;
688 out[4] = out[5] = out[6] = out[7] = 0;
689
690 if (target_x < -512 || target_x > 512) {
691 return;
692 }
693
694 const uint8_t* bytes = reinterpret_cast<const uint8_t*>(s);
695 uint8_t* out_bytes = reinterpret_cast<uint8_t*>(out);
696
697 int cur = 0;
698 for (int i = 0; i < 64; ++i) {
699 const uint8_t b = bytes[i];
700 const int target_rel = target_x - cur;
701 if (target_rel >= -8 && target_rel <= 8) {
702 out_bytes[i] = kExcessByteLut.masks[b][target_rel + 8];
703 }
704 cur += kExcessByteLut.deltas[b];
705 }
706}
707
708} // namespace pixie::experimental
Definition excess.h:650