Pixie
Loading...
Searching...
No Matches
bits.h
1#pragma once
2
3#include <immintrin.h>
4
5#include <array>
6#include <bit>
7#include <cstddef>
8#include <cstdint>
9#include <limits>
10#include <numeric>
11
12#if defined(__AVX512VPOPCNTDQ__) && defined(__AVX512F__) && \
13 defined(__AVX512BW__)
14#define PIXIE_AVX512_SUPPORT
15#endif
16
17#ifdef __AVX2__
18#define PIXIE_AVX2_SUPPORT
19// Lookup table for 4-bit popcount
20// This table maps each 4-bit value (0-15) to its population count
21// clang-format off
22static inline const __m256i lookup_popcount_4 = _mm256_setr_epi8(
23 0, 1, 1, 2, // 0000, 0001, 0010, 0011
24 1, 2, 2, 3, // 0100, 0101, 0110, 0111
25 1, 2, 2, 3, // 1000, 1001, 1010, 1011
26 2, 3, 3, 4, // 1100, 1101, 1110, 1111
27
28 // Same table repeated for high 128 bits
29 0, 1, 1, 2, // 0000, 0001, 0010, 0011
30 1, 2, 2, 3, // 0100, 0101, 0110, 0111
31 1, 2, 2, 3, // 1000, 1001, 1010, 1011
32 2, 3, 3, 4 // 1100, 1101, 1110, 1111
33);
34
35static inline const __m256i mask_first_half = _mm256_setr_epi8(
36 0xFF, 0xFF, 0xFF, 0xFF,
37 0xFF, 0xFF, 0xFF, 0xFF,
38 0xFF, 0xFF, 0xFF, 0xFF,
39 0xFF, 0xFF, 0xFF, 0xFF,
40 0, 0, 0, 0,
41 0, 0, 0, 0,
42 0, 0, 0, 0,
43 0, 0, 0, 0
44);
45
46// clang-format on
47#endif
48
49static inline uint64_t first_bits_mask(size_t num) {
50 return num >= 64 ? UINT64_MAX : ((1llu << num) - 1);
51}
52
73static inline uint64_t rank_512(const uint64_t* x, uint64_t count) {
74#ifdef PIXIE_AVX512_SUPPORT
75
76 __m512i a = _mm512_maskz_set1_epi64((1ull << ((count >> 6))) - 1,
77 std::numeric_limits<uint64_t>::max());
78 __m512i b = _mm512_maskz_set1_epi64((1ull << ((count >> 6) + 1)) - 1,
79 std::numeric_limits<uint64_t>::max());
80 __m512i mask = _mm512_shldv_epi64(a, b, _mm512_set1_epi64(count % 64));
81
82 __m512i res = _mm512_loadu_epi64(x);
83 res = _mm512_and_epi64(res, mask);
84 __m512i cnt = _mm512_popcnt_epi64(res);
85 return _mm512_reduce_add_epi64(cnt);
86
87#else
88
89 uint64_t last_uint = count < 512 ? count >> 6 : 8;
90
91 uint64_t pop_val = 0;
92
93 for (int i = 0; i < last_uint; i++) {
94 pop_val += std::popcount(x[i]);
95 }
96
97 pop_val += count < 512
98 ? std::popcount(x[last_uint] & first_bits_mask(count & 63))
99 : 0;
100 return pop_val;
101
102#endif
103}
104
108static inline uint64_t select_64(uint64_t x, uint64_t rank) {
109 return _tzcnt_u64(_pdep_u64(1ull << rank, x));
110}
111
129static inline uint64_t select_512(const uint64_t* x, uint64_t rank) {
130#ifdef PIXIE_AVX512_SUPPORT
131
132 __m512i res = _mm512_loadu_epi64(x);
133 __m512i counts = _mm512_popcnt_epi64(res);
134 __m512i prefix = counts;
135
136 const __m512i idx_shift1 = _mm512_set_epi64(6, 5, 4, 3, 2, 1, 0, 0);
137 const __m512i idx_shift2 = _mm512_set_epi64(5, 4, 3, 2, 1, 0, 0, 0);
138 const __m512i idx_shift4 = _mm512_set_epi64(3, 2, 1, 0, 0, 0, 0, 0);
139
140 __m512i tmp = _mm512_maskz_permutexvar_epi64(0xFE, idx_shift1, prefix);
141 prefix = _mm512_add_epi64(prefix, tmp);
142 tmp = _mm512_maskz_permutexvar_epi64(0xFC, idx_shift2, prefix);
143 prefix = _mm512_add_epi64(prefix, tmp);
144 tmp = _mm512_maskz_permutexvar_epi64(0xF0, idx_shift4, prefix);
145 prefix = _mm512_add_epi64(prefix, tmp);
146
147 __mmask8 mask = _mm512_cmpgt_epu64_mask(prefix, _mm512_set1_epi64(rank));
148 uint32_t i = _tzcnt_u32(static_cast<uint32_t>(mask));
149 uint64_t prev = 0;
150 if (i != 0) {
151 __m512i idx_prev = _mm512_set1_epi64(static_cast<int64_t>(i - 1));
152 __m512i prev_vec = _mm512_permutexvar_epi64(idx_prev, prefix);
153 prev = static_cast<uint64_t>(
154 _mm_cvtsi128_si64(_mm512_castsi512_si128(prev_vec)));
155 }
156 return i * 64 + select_64(x[i], rank - prev);
157
158#else
159
160 size_t i = 0;
161 int popcount = std::popcount(x[0]);
162 while (i < 7 && popcount <= rank) {
163 rank -= popcount;
164 popcount = std::popcount(x[++i]);
165 }
166 return i * 64 + select_64(x[i], rank);
167
168#endif
169}
170
175static inline uint64_t select0_512(const uint64_t* x, uint64_t rank0) {
176#ifdef PIXIE_AVX512_SUPPORT
177
178 __m512i res = _mm512_loadu_epi64(x);
179 res = _mm512_xor_epi64(res, _mm512_set1_epi64(-1));
180 __m512i counts = _mm512_popcnt_epi64(res);
181 __m512i prefix = counts;
182
183 const __m512i idx_shift1 = _mm512_set_epi64(6, 5, 4, 3, 2, 1, 0, 0);
184 const __m512i idx_shift2 = _mm512_set_epi64(5, 4, 3, 2, 1, 0, 0, 0);
185 const __m512i idx_shift4 = _mm512_set_epi64(3, 2, 1, 0, 0, 0, 0, 0);
186
187 __m512i tmp = _mm512_maskz_permutexvar_epi64(0xFE, idx_shift1, prefix);
188 prefix = _mm512_add_epi64(prefix, tmp);
189 tmp = _mm512_maskz_permutexvar_epi64(0xFC, idx_shift2, prefix);
190 prefix = _mm512_add_epi64(prefix, tmp);
191 tmp = _mm512_maskz_permutexvar_epi64(0xF0, idx_shift4, prefix);
192 prefix = _mm512_add_epi64(prefix, tmp);
193
194 __mmask8 mask = _mm512_cmpgt_epu64_mask(prefix, _mm512_set1_epi64(rank0));
195 uint32_t i = _tzcnt_u32(static_cast<uint32_t>(mask));
196 uint64_t prev = 0;
197 if (i != 0) {
198 __m512i idx_prev = _mm512_set1_epi64(static_cast<int64_t>(i - 1));
199 __m512i prev_vec = _mm512_permutexvar_epi64(idx_prev, prefix);
200 prev = static_cast<uint64_t>(
201 _mm_cvtsi128_si64(_mm512_castsi512_si128(prev_vec)));
202 }
203 return i * 64 + select_64(~x[i], rank0 - prev);
204
205#else
206
207 size_t i = 0;
208 int popcount = std::popcount(~x[0]);
209 while (i < 7 && popcount <= rank0) {
210 rank0 -= popcount;
211 popcount = std::popcount(~x[++i]);
212 }
213 return i * 64 + select_64(~x[i], rank0);
214
215#endif
216}
217
222static inline uint16_t lower_bound_4x64(const uint64_t* x, uint64_t y) {
223#ifdef PIXIE_AVX512_SUPPORT
224
225 auto y_4 = _mm256_set1_epi64x(y);
226 auto reg_256 = _mm256_loadu_epi64(x);
227 auto cmp = _mm256_cmpge_epu64_mask(reg_256, y_4);
228
229 return _tzcnt_u16(cmp);
230
231#else
232#ifdef PIXIE_AVX2_SUPPORT
233
234 auto y_4 = _mm256_set1_epi64x(y);
235 __m256i reg_256 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x));
236
237 const __m256i offset = _mm256_set1_epi64x(0x8000000000000000ULL);
238 __m256i x_offset = _mm256_xor_si256(reg_256, offset);
239 __m256i y_offset = _mm256_xor_si256(y_4, offset);
240 auto mask = _mm256_movemask_epi8(_mm256_cmpgt_epi64(
241 x_offset, _mm256_sub_epi64(y_offset, _mm256_set1_epi64x(1))));
242
243 return _tzcnt_u32(mask) >> 3;
244
245#else
246
247 for (uint16_t i = 0; i < 4; ++i) {
248 if (x[i] >= y) {
249 return i;
250 }
251 }
252 return 4;
253
254#endif
255#endif
256}
257
271static inline uint16_t lower_bound_delta_4x64(const uint64_t* x,
272 uint64_t y,
273 const uint64_t* delta_array,
274 uint64_t delta_scalar) {
275#ifdef PIXIE_AVX512_SUPPORT
276
277 const __m256i dlt_256 = _mm256_loadu_epi64(delta_array);
278 auto x_256 = _mm256_loadu_epi64(x);
279 auto dlt_4 = _mm256_set1_epi64x(delta_scalar);
280 auto y_4 = _mm256_set1_epi64x(y);
281
282 auto tmp = _mm256_add_epi64(dlt_4, dlt_256);
283 auto reg_256 = _mm256_sub_epi64(tmp, x_256);
284 auto cmp = _mm256_cmpge_epu64_mask(reg_256, y_4);
285
286 return _tzcnt_u16(cmp);
287
288#else
289#ifdef PIXIE_AVX2_SUPPORT
290
291 const __m256i dlt_256 =
292 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(delta_array));
293 auto x_256 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x));
294 auto dlt_4 = _mm256_set1_epi64x(delta_scalar);
295 auto y_4 = _mm256_set1_epi64x(y);
296
297 auto tmp = _mm256_add_epi64(dlt_4, dlt_256);
298 auto reg_256 = _mm256_sub_epi64(tmp, x_256);
299
300 const __m256i offset = _mm256_set1_epi64x(0x8000000000000000ULL);
301 __m256i x_offset = _mm256_xor_si256(reg_256, offset);
302 __m256i y_offset = _mm256_xor_si256(y_4, offset);
303 auto mask = _mm256_movemask_epi8(_mm256_cmpgt_epi64(
304 x_offset, _mm256_sub_epi64(y_offset, _mm256_set1_epi64x(1))));
305
306 return _tzcnt_u32(mask) >> 3;
307
308#else
309
310 for (uint16_t i = 0; i < 4; ++i) {
311 if (delta_array[i] + delta_scalar - x[i] >= y) {
312 return i;
313 }
314 }
315 return 4;
316
317#endif
318#endif
319}
320
325static inline uint16_t lower_bound_8x64(const uint64_t* x, uint64_t y) {
326#ifdef PIXIE_AVX512_SUPPORT
327
328 auto y_8 = _mm512_set1_epi64(y);
329 auto reg_512 = _mm512_loadu_epi64(x);
330 auto cmp = _mm512_cmpge_epu64_mask(reg_512, y_8);
331
332 return _tzcnt_u16(cmp);
333
334#else
335#ifdef PIXIE_AVX2_SUPPORT
336
337 uint16_t len = lower_bound_4x64(x, y);
338
339 if (len < 4) {
340 return len;
341 }
342
343 return len + lower_bound_4x64(x + 4, y);
344
345#else
346
347 for (uint16_t i = 0; i < 8; ++i) {
348 if (x[i] >= y) {
349 return i;
350 }
351 }
352 return 8;
353
354#endif
355#endif
356}
357
371static inline uint16_t lower_bound_delta_8x64(const uint64_t* x,
372 uint64_t y,
373 const uint64_t* delta_array,
374 uint64_t delta_scalar) {
375#ifdef PIXIE_AVX512_SUPPORT
376
377 const __m512i dlt_512 = _mm512_loadu_epi64(delta_array);
378 auto x_512 = _mm512_loadu_epi64(x);
379 auto dlt_8 = _mm512_set1_epi64(delta_scalar);
380 auto y_8 = _mm512_set1_epi64(y);
381
382 auto tmp = _mm512_add_epi64(dlt_8, dlt_512);
383 auto reg_512 = _mm512_sub_epi64(tmp, x_512);
384 auto cmp = _mm512_cmpge_epu64_mask(reg_512, y_8);
385
386 return _tzcnt_u16(cmp);
387
388#else
389#ifdef PIXIE_AVX2_SUPPORT
390
391 uint16_t len = lower_bound_delta_4x64(x, y, delta_array, delta_scalar);
392
393 if (len < 4) {
394 return len;
395 }
396
397 return len + lower_bound_delta_4x64(x + 4, y, delta_array + 4, delta_scalar);
398
399#else
400
401 for (uint16_t i = 0; i < 8; ++i) {
402 if (delta_array[i] + delta_scalar - x[i] >= y) {
403 return i;
404 }
405 }
406 return 8;
407
408#endif
409#endif
410}
411
416uint16_t lower_bound_32x16(const uint16_t* x, uint16_t y) {
417#ifdef PIXIE_AVX512_SUPPORT
418
419 auto y_32 = _mm512_set1_epi16(y);
420 auto reg_512 = _mm512_loadu_epi16(x);
421 auto cmp = _mm512_cmplt_epu16_mask(reg_512, y_32);
422 return std::popcount(cmp);
423
424#else
425#ifdef PIXIE_AVX2_SUPPORT
426
427 auto y_16 = _mm256_set1_epi16(y);
428 __m256i reg_256 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x));
429
430 const __m256i offset = _mm256_set1_epi16(0x8000);
431 __m256i x_offset = _mm256_xor_si256(reg_256, offset);
432 __m256i y_offset = _mm256_xor_si256(y_16, offset);
433 uint32_t mask = _mm256_movemask_epi8(_mm256_cmpgt_epi16(y_offset, x_offset));
434
435 uint16_t count = std::popcount(mask) >> 1;
436
437 reg_256 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x + 16));
438
439 x_offset = _mm256_xor_si256(reg_256, offset);
440 mask = _mm256_movemask_epi8(_mm256_cmpgt_epi16(y_offset, x_offset));
441
442 return count + (std::popcount(mask) >> 1);
443
444#else
445
446 uint16_t cnt = 0;
447 for (uint16_t i = 0; i < 32; ++i) {
448 if (x[i] < y) {
449 cnt++;
450 }
451 }
452 return cnt;
453
454#endif
455#endif
456}
457
471uint16_t lower_bound_delta_32x16(const uint16_t* x,
472 uint16_t y,
473 const uint16_t* delta_array,
474 uint16_t delta_scalar) {
475#ifdef PIXIE_AVX512_SUPPORT
476
477 const __m512i dlt_512 = _mm512_loadu_epi64(delta_array);
478 auto x_512 = _mm512_loadu_epi64(x);
479 auto dlt_32 = _mm512_set1_epi16(delta_scalar);
480 auto y_32 = _mm512_set1_epi16(y);
481
482 auto tmp = _mm512_add_epi16(dlt_32, dlt_512);
483 auto reg_512 = _mm512_sub_epi16(tmp, x_512);
484 auto cmp = _mm512_cmplt_epu16_mask(reg_512, y_32);
485 return std::popcount(cmp);
486
487#else
488#ifdef PIXIE_AVX2_SUPPORT
489
490 auto dlt_256 =
491 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(delta_array));
492 auto x_256 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x));
493 auto dlt_16 = _mm256_set1_epi16(delta_scalar);
494 auto y_16 = _mm256_set1_epi16(y);
495
496 auto tmp = _mm256_add_epi16(dlt_16, dlt_256);
497 auto reg_256 = _mm256_sub_epi16(tmp, x_256);
498
499 const __m256i offset = _mm256_set1_epi16(0x8000);
500 __m256i x_offset = _mm256_xor_si256(reg_256, offset);
501 __m256i y_offset = _mm256_xor_si256(y_16, offset);
502 uint32_t mask = _mm256_movemask_epi8(_mm256_cmpgt_epi16(y_offset, x_offset));
503
504 uint16_t count = std::popcount(mask) >> 1;
505
506 dlt_256 =
507 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(delta_array + 16));
508 x_256 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(x + 16));
509
510 tmp = _mm256_add_epi16(dlt_16, dlt_256);
511 reg_256 = _mm256_sub_epi16(tmp, x_256);
512
513 x_offset = _mm256_xor_si256(reg_256, offset);
514 mask = _mm256_movemask_epi8(_mm256_cmpgt_epi16(y_offset, x_offset));
515
516 return count + (std::popcount(mask) >> 1);
517
518#else
519
520 uint16_t cnt = 0;
521 for (uint16_t i = 0; i < 32; ++i) {
522 if (delta_array[i] + delta_scalar - x[i] < y) {
523 cnt++;
524 }
525 }
526 return cnt;
527
528#endif
529#endif
530}
531
543void popcount_64x4(const uint8_t* x, uint8_t* result) {
544#ifdef PIXIE_AVX512_SUPPORT
545 __m256i data = _mm256_loadu_si256((__m256i const*)x);
546
547 // Masks for extracting the lower and upper nibbles
548 const __m256i low_bits_mask = _mm256_set1_epi8(0x0F);
549
550 // Count bits in the lower half
551 __m256i low_bits = _mm256_and_si256(data, low_bits_mask);
552 __m256i low_count = _mm256_shuffle_epi8(lookup_popcount_4, low_bits);
553
554 // Count bits in the upper half
555 __m256i high_bits = _mm256_srli_epi16(data, 4);
556 high_bits = _mm256_and_si256(high_bits, low_bits_mask);
557 __m256i high_count = _mm256_shuffle_epi8(lookup_popcount_4, high_bits);
558
559 // Pack the results into a single output vector
560 __m256i result_vec =
561 _mm256_or_si256(low_count, _mm256_slli_epi16(high_count, 4));
562 _mm256_storeu_epi8(result, result_vec);
563#else
564 // Fallback implementation for non-AVX2 platforms
565 for (size_t i = 0; i < 32; i++) {
566 // Count bits in the lower half
567 uint8_t a = x[i] & 0x0F;
568 uint8_t low_count = std::popcount(a);
569 // Count bits in the upper half
570 a = (x[i] >> 4) & 0x0F;
571 uint8_t high_count = std::popcount(a);
572
573 // Pack the counts into the output byte
574 result[i] = low_count | (high_count << 4);
575 }
576#endif
577}
578
590void popcount_32x8(const uint8_t* x, uint8_t* result) {
591#ifdef PIXIE_AVX512_SUPPORT
592 // Load 64 4-bit integers (256 bits total)
593 __m256i data = _mm256_loadu_si256((__m256i const*)x);
594 auto popcount_8 = _mm256_popcnt_epi8(data);
595 _mm256_storeu_si256((__m256i*)result, popcount_8);
596#else
597#ifdef PIXIE_AVX2_SUPPORT
598 // Load 64 4-bit integers (256 bits total)
599 __m256i data = _mm256_loadu_si256((__m256i const*)x);
600
601 // Masks for extracting the lower and upper nibbles
602 const __m256i low_bits_mask = _mm256_set1_epi8(0x0F);
603
604 // Count bits in lower half
605 __m256i low_bits = _mm256_and_si256(data, low_bits_mask);
606 __m256i low_count = _mm256_shuffle_epi8(lookup_popcount_4, low_bits);
607
608 // Count bits upper half
609 __m256i high_bits = _mm256_srli_epi16(data, 4);
610 high_bits = _mm256_and_si256(high_bits, low_bits_mask);
611 __m256i high_count = _mm256_shuffle_epi8(lookup_popcount_4, high_bits);
612
613 __m256i result_vec = _mm256_add_epi8(low_count, high_count);
614 _mm256_storeu_si256((__m256i*)result, result_vec);
615#else
616 // Fallback implementation for non-AVX2 platforms
617 for (size_t i = 0; i < 32; i++) {
618 result[i] = std::popcount(x[i]);
619 }
620#endif
621#endif
622}
623
635void rank_32x8(const uint8_t* x, uint8_t* result) {
636#ifdef PIXIE_AVX512_SUPPORT
637 // Step 1: Calculate popcount of each byte
638 popcount_32x8(x, result);
639 __m256i prefix_sums = _mm256_loadu_si256((__m256i const*)result);
640 const __m256i zero = _mm256_setzero_si256();
641
642 prefix_sums = _mm256_add_epi8(prefix_sums,
643 _mm256_alignr_epi8(prefix_sums, zero, 16 - 1));
644 prefix_sums = _mm256_add_epi8(prefix_sums,
645 _mm256_alignr_epi8(prefix_sums, zero, 16 - 2));
646 prefix_sums = _mm256_add_epi8(prefix_sums,
647 _mm256_alignr_epi8(prefix_sums, zero, 16 - 4));
648 prefix_sums = _mm256_add_epi8(prefix_sums,
649 _mm256_alignr_epi8(prefix_sums, zero, 16 - 8));
650
651 // At this point we have prefix sums for two halfs, the last step is to
652 // extract 16-th value and add it to the whole second half
653 __m128i low_lane = _mm256_extracti128_si256(prefix_sums, 0);
654 __m128i high_lane = _mm256_extracti128_si256(prefix_sums, 1);
655 auto last_val_low = _mm_extract_epi8(low_lane, 15);
656 __m128i add_to_high = _mm_set1_epi8(last_val_low);
657 high_lane = _mm_add_epi8(high_lane, add_to_high);
658 prefix_sums = _mm256_set_m128i(high_lane, low_lane);
659 _mm256_storeu_epi8(result, prefix_sums);
660#else
661 // Scalar fallback implementation
662 result[0] = std::popcount(x[0]);
663 for (size_t i = 1; i < 32; ++i) {
664 result[i] = std::popcount(x[i]) + result[i - 1];
665 }
666#endif
667}