Pixie
Loading...
Searching...
No Matches
bitvector.h
1#pragma once
2
3#include <pixie/bits.h>
4#include <pixie/cache_line.h>
5
6#include <algorithm>
7#include <bit>
8#include <cstdint>
9#include <span>
10#include <string>
11#include <vector>
12
13#ifdef PIXIE_DIAGNOSTICS
14#include <spdlog/spdlog.h>
15#endif
16
17namespace pixie {
18
63class BitVector {
64 private:
65 constexpr static size_t kWordSize = 64;
66 constexpr static size_t kSuperBlockRankIntSize = 64;
67 constexpr static size_t kBasicBlockRankIntSize = 16;
68 constexpr static size_t kBasicBlockSize = 512;
69 constexpr static size_t kWordsPerBlock = 8;
70 constexpr static size_t kSuperBlockSize = 65536;
71 constexpr static size_t kBlocksPerSuperBlock = 128;
72 constexpr static size_t kSelectSampleFrequency = 16384;
73
74 alignas(64) uint64_t delta_super[8];
75 alignas(64) uint16_t delta_basic[32];
76
77 AlignedStorage super_block_rank_; // 64-bit global prefix sums
78 AlignedStorage basic_block_rank_; // 16-bit local prefix sums
79 AlignedStorage select1_samples_; // 64-bit global positions
80 AlignedStorage select0_samples_; // 64-bit global positions
81 const size_t num_bits_;
82 const size_t padded_size_;
83 size_t max_rank_;
84
85 std::span<const uint64_t> bits_;
86
90 void build_rank() {
91 size_t num_superblocks = 8 + (padded_size_ - 1) / kSuperBlockSize;
92 // Add more blocks to ease SIMD processing
93 // num_basicblocks to fully cover superblock, i.e. 128
94 // This reduces branching in select
95 num_superblocks = ((num_superblocks + 7) / 8) * 8;
96 size_t num_basicblocks = num_superblocks * kBlocksPerSuperBlock;
97 super_block_rank_.resize(num_superblocks * 64);
98 basic_block_rank_.resize(num_basicblocks * 16);
99
100 auto super_block_rank = super_block_rank_.As64BitInts();
101 auto basic_block_rank = basic_block_rank_.As16BitInts();
102
103 uint64_t super_block_sum = 0;
104 uint16_t basic_block_sum = 0;
105
106 for (size_t i = 0; i / kBasicBlockSize < basic_block_rank.size();
107 i += kWordSize) {
108 if (i % kSuperBlockSize == 0) {
109 super_block_sum += basic_block_sum;
110 super_block_rank[i / kSuperBlockSize] = super_block_sum;
111 basic_block_sum = 0;
112 }
113 if (i % kBasicBlockSize == 0) {
114 basic_block_rank[i / kBasicBlockSize] = basic_block_sum;
115 }
116 if (i / kWordSize < bits_.size()) {
117 basic_block_sum += std::popcount(bits_[i / kWordSize]);
118 }
119 }
120 max_rank_ = super_block_sum + basic_block_sum;
121 }
122
126 void build_select() {
127 uint64_t milestone = kSelectSampleFrequency;
128 uint64_t milestone0 = kSelectSampleFrequency;
129 uint64_t rank = 0;
130 uint64_t rank0 = 0;
131
132 size_t num_one_samples =
133 1 + (max_rank_ + kSelectSampleFrequency - 1) / kSelectSampleFrequency;
134 size_t num_zero_samples =
135 1 + (num_bits_ - max_rank_ + kSelectSampleFrequency - 1) /
136 kSelectSampleFrequency;
137
138 select1_samples_.resize(num_one_samples * 64);
139 select0_samples_.resize(num_zero_samples * 64);
140 auto select1_samples = select1_samples_.As64BitInts();
141 auto select0_samples = select0_samples_.As64BitInts();
142
143 select1_samples[0] = 0;
144 select0_samples[0] = 0;
145
146 size_t num_zeros = 1, num_ones = 1;
147
148 for (size_t i = 0; i < bits_.size(); ++i) {
149 auto ones = std::popcount(bits_[i]);
150 auto zeros = 64 - ones;
151 if (rank + ones >= milestone) {
152 auto pos = select_64(bits_[i], milestone - rank - 1);
153 // TODO: try including global rank into select samples to save
154 // a cache miss on global rank scan
155 select1_samples[num_ones++] = (64 * i + pos) / kSuperBlockSize;
156 milestone += kSelectSampleFrequency;
157 }
158 if (rank0 + zeros >= milestone0) {
159 auto pos = select_64(~bits_[i], milestone0 - rank0 - 1);
160 select0_samples[num_zeros++] = (64 * i + pos) / kSuperBlockSize;
161 milestone0 += kSelectSampleFrequency;
162 }
163 rank += ones;
164 rank0 += zeros;
165 }
166
167 for (size_t i = 0; i < 8; ++i) {
168 delta_super[i] = i * kSuperBlockSize;
169 }
170 for (size_t i = 0; i < 32; ++i) {
171 delta_basic[i] = i * kBasicBlockSize;
172 }
173 }
174
180 uint64_t find_superblock(uint64_t rank) const {
181 auto select1_samples = select1_samples_.AsConst64BitInts();
182 auto super_block_rank = super_block_rank_.AsConst64BitInts();
183
184 uint64_t left = select1_samples[rank / kSelectSampleFrequency];
185
186 while (left + 7 < super_block_rank.size()) {
187 auto len = lower_bound_8x64(&super_block_rank[left], rank);
188 if (len < 8) {
189 return left + len - 1;
190 }
191 left += 8;
192 }
193 if (left + 3 < super_block_rank.size()) {
194 auto len = lower_bound_4x64(&super_block_rank[left], rank);
195 if (len < 4) {
196 return left + len - 1;
197 }
198 left += 4;
199 }
200 while (left < super_block_rank.size() && super_block_rank[left] < rank) {
201 left++;
202 }
203 return left - 1;
204 }
205
211 uint64_t find_superblock_zeros(uint64_t rank0) const {
212 auto select0_samples = select0_samples_.AsConst64BitInts();
213 auto super_block_rank = super_block_rank_.AsConst64BitInts();
214
215 uint64_t left = select0_samples[rank0 / kSelectSampleFrequency];
216
217 while (left + 7 < super_block_rank.size()) {
218 auto len = lower_bound_delta_8x64(&super_block_rank[left], rank0,
219 delta_super, kSuperBlockSize * left);
220 if (len < 8) {
221 return left + len - 1;
222 }
223 left += 8;
224 }
225 if (left + 3 < super_block_rank.size()) {
226 auto len = lower_bound_delta_4x64(&super_block_rank[left], rank0,
227 delta_super, kSuperBlockSize * left);
228 if (len < 4) {
229 return left + len - 1;
230 }
231 left += 4;
232 }
233 while (left < super_block_rank.size() &&
234 kSuperBlockSize * left - super_block_rank[left] < rank0) {
235 left++;
236 }
237 return left - 1;
238 }
239
251 uint64_t find_basicblock(uint16_t local_rank, uint64_t s_block) const {
252 auto basic_block_rank = basic_block_rank_.AsConst16BitInts();
253
254 for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) {
255 auto count = lower_bound_32x16(
256 &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank);
257 if (count < 32) {
258 return kBlocksPerSuperBlock * s_block + pos + count - 1;
259 }
260 }
261 return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1;
262 }
263
275 uint64_t find_basicblock_zeros(uint16_t local_rank0, uint64_t s_block) const {
276 auto basic_block_rank = basic_block_rank_.AsConst16BitInts();
277 for (size_t pos = 0; pos < kBlocksPerSuperBlock; pos += 32) {
278 auto count = lower_bound_delta_32x16(
279 &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0,
280 delta_basic, kBasicBlockSize * pos);
281 if (count < 32) {
282 return kBlocksPerSuperBlock * s_block + pos + count - 1;
283 }
284 }
285 return kBlocksPerSuperBlock * s_block + kBlocksPerSuperBlock - 1;
286 }
287
305 uint64_t find_basicblock_is(uint16_t local_rank, uint64_t s_block) const {
306 auto super_block_rank = super_block_rank_.AsConst64BitInts();
307 auto basic_block_rank = basic_block_rank_.AsConst16BitInts();
308
309 auto lower = super_block_rank[s_block];
310 auto upper = super_block_rank[s_block + 1];
311
312 uint64_t pos = kBlocksPerSuperBlock * local_rank / (upper - lower);
313 pos = pos + 16 < 32 ? 0 : (pos - 16);
314 pos = pos > 96 ? 96 : pos;
315 while (pos < 96) {
316 auto count = lower_bound_32x16(
317 &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank);
318 if (count == 0) {
319 return find_basicblock(local_rank, s_block);
320 }
321 if (count < 32) {
322 return kBlocksPerSuperBlock * s_block + pos + count - 1;
323 }
324 pos += 32;
325 }
326 pos = 96;
327 auto count = lower_bound_32x16(
328 &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank);
329 if (count == 0) {
330 return find_basicblock(local_rank, s_block);
331 }
332 return kBlocksPerSuperBlock * s_block + pos + count - 1;
333 }
334
353 uint64_t find_basicblock_is_zeros(uint16_t local_rank0,
354 uint64_t s_block) const {
355 auto super_block_rank = super_block_rank_.AsConst64BitInts();
356 auto basic_block_rank = basic_block_rank_.AsConst16BitInts();
357
358 auto lower = kSuperBlockSize * s_block - super_block_rank[s_block];
359 auto upper =
360 kSuperBlockSize * (s_block + 1) - super_block_rank[s_block + 1];
361
362 uint64_t pos = kBlocksPerSuperBlock * local_rank0 / (upper - lower);
363 pos = pos + 16 < 32 ? 0 : (pos - 16);
364 pos = pos > 96 ? 96 : pos;
365 while (pos < 96) {
366 auto count = lower_bound_delta_32x16(
367 &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0,
368 delta_basic, kBasicBlockSize * pos);
369 if (count == 0) {
370 return find_basicblock_zeros(local_rank0, s_block);
371 }
372 if (count < 32) {
373 return kBlocksPerSuperBlock * s_block + pos + count - 1;
374 }
375 pos += 32;
376 }
377 pos = 96;
378 auto count = lower_bound_delta_32x16(
379 &basic_block_rank[kBlocksPerSuperBlock * s_block + pos], local_rank0,
380 delta_basic, kBasicBlockSize * pos);
381 if (count == 0) {
382 return find_basicblock_zeros(local_rank0, s_block);
383 }
384 return kBlocksPerSuperBlock * s_block + pos + count - 1;
385 }
386
387 public:
388#ifdef PIXIE_DIAGNOSTICS
389 struct DiagnosticsBytes {
390 size_t source_bitvector_bytes = 0;
391 size_t super_block_rank_bytes = 0;
392 size_t basic_block_rank_bytes = 0;
393 size_t select1_samples_bytes = 0;
394 size_t select0_samples_bytes = 0;
395 size_t total_bytes = 0;
396 };
397
401 DiagnosticsBytes diagnostics_bytes() const {
402 DiagnosticsBytes result;
403 result.source_bitvector_bytes = (num_bits_ + 7) / 8;
404 result.super_block_rank_bytes = super_block_rank_.AsConstBytes().size();
405 result.basic_block_rank_bytes = basic_block_rank_.AsConstBytes().size();
406 result.select1_samples_bytes = select1_samples_.AsConstBytes().size();
407 result.select0_samples_bytes = select0_samples_.AsConstBytes().size();
408 result.total_bytes =
409 result.super_block_rank_bytes + result.basic_block_rank_bytes +
410 result.select1_samples_bytes + result.select0_samples_bytes;
411 return result;
412 }
413
417 void memory_report() const {
418 const auto diagnostics = diagnostics_bytes();
419 const double source_bytes =
420 static_cast<double>(diagnostics.source_bitvector_bytes);
421 const auto log_bytes = [&](std::string_view label, size_t bytes) {
422 const double percentage =
423 source_bytes > 0.0 ? 100.0 * static_cast<double>(bytes) / source_bytes
424 : 0.0;
425 spdlog::info("BitVector {}: {} bytes ({:.2f}% of source)", label, bytes,
426 percentage);
427 };
428 log_bytes("source_bitvector", diagnostics.source_bitvector_bytes);
429 log_bytes("super_block_rank", diagnostics.super_block_rank_bytes);
430 log_bytes("basic_block_rank", diagnostics.basic_block_rank_bytes);
431 log_bytes("select1_samples", diagnostics.select1_samples_bytes);
432 log_bytes("select0_samples", diagnostics.select0_samples_bytes);
433 log_bytes("total", diagnostics.total_bytes);
434 }
435#endif
443 explicit BitVector(std::span<const uint64_t> bit_vector, size_t num_bits)
444 : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)),
445 padded_size_(((num_bits_ + kWordSize - 1) / kWordSize) * kWordSize),
446 bits_(bit_vector) {
447 build_rank();
448 build_select();
449 }
450
454 size_t size() const { return num_bits_; }
455
461 int operator[](size_t pos) const {
462 size_t word_idx = pos / kWordSize;
463 size_t bit_off = pos % kWordSize;
464
465 return (bits_[word_idx] >> bit_off) & 1;
466 }
467
474 uint64_t rank(size_t pos) const {
475 if (pos >= bits_.size() * kWordSize) [[unlikely]] {
476 return max_rank_;
477 }
478
479 auto super_block_rank = super_block_rank_.AsConst64BitInts();
480 auto basic_block_rank = basic_block_rank_.AsConst16BitInts();
481
482 uint64_t b_block = pos / kBasicBlockSize;
483 uint64_t s_block = pos / kSuperBlockSize;
484 // Precomputed rank
485 uint64_t result = super_block_rank[s_block] + basic_block_rank[b_block];
486 // Basic block tail
487 result += rank_512(&bits_[b_block * kWordsPerBlock],
488 pos - (b_block * kBasicBlockSize));
489 return result;
490 }
491
498 uint64_t rank0(size_t pos) const {
499 if (pos >= bits_.size() * kWordSize) [[unlikely]] {
500 return bits_.size() * kWordSize - max_rank_;
501 }
502 return pos - rank(pos);
503 }
504
512 uint64_t select(size_t rank) const {
513 if (rank > max_rank_) [[unlikely]] {
514 return num_bits_;
515 }
516 if (rank == 0) [[unlikely]] {
517 return 0;
518 }
519 auto super_block_rank = super_block_rank_.AsConst64BitInts();
520 auto basic_block_rank = basic_block_rank_.AsConst16BitInts();
521
522 uint64_t s_block = find_superblock(rank);
523 rank -= super_block_rank[s_block];
524 auto pos = find_basicblock_is(rank, s_block);
525 rank -= basic_block_rank[pos];
526 pos *= kWordsPerBlock;
527
528 // Final search
529 if (pos + kWordsPerBlock - 1 < kWordsPerBlock) [[unlikely]] {
530 size_t ones = std::popcount(bits_[pos]);
531 while (pos < bits_.size() && ones < rank) {
532 rank -= ones;
533 ones = std::popcount(bits_[++pos]);
534 }
535 return kWordSize * pos + select_64(bits_[pos], rank - 1);
536 }
537 return kWordSize * pos + select_512(&bits_[pos], rank - 1);
538 }
539
547 uint64_t select0(size_t rank0) const {
548 if (rank0 > num_bits_ - max_rank_) [[unlikely]] {
549 return num_bits_;
550 }
551 if (rank0 == 0) [[unlikely]] {
552 return 0;
553 }
554 auto super_block_rank = super_block_rank_.AsConst64BitInts();
555 auto basic_block_rank = basic_block_rank_.AsConst16BitInts();
556
557 uint64_t s_block = find_superblock_zeros(rank0);
558 rank0 -= kSuperBlockSize * s_block - super_block_rank[s_block];
559 auto pos = find_basicblock_is_zeros(rank0, s_block);
560 auto pos_in_super_block = pos & (kBlocksPerSuperBlock - 1);
561 rank0 -= kBasicBlockSize * pos_in_super_block - basic_block_rank[pos];
562 pos *= kWordsPerBlock;
563
564 // Final search
565 if (pos + kWordsPerBlock - 1 < kWordsPerBlock) [[unlikely]] {
566 size_t zeros = std::popcount(~bits_[pos]);
567 while (pos < bits_.size() && zeros < rank0) {
568 rank0 -= zeros;
569 zeros = std::popcount(~bits_[++pos]);
570 }
571 return kWordSize * pos + select_64(~bits_[pos], rank0 - 1);
572 }
573 return kWordSize * pos + select0_512(&bits_[pos], rank0 - 1);
574 }
575
579 std::string to_string() const {
580 std::string result;
581 result.reserve(num_bits_);
582
583 for (size_t i = 0; i < num_bits_; i++) {
584 result.push_back(operator[](i) ? '1' : '0');
585 }
586
587 return result;
588 }
589};
590
609 private:
610 constexpr static size_t kWordSize = 64;
611 constexpr static size_t kSuperBlockRankIntSize = 64;
612 constexpr static size_t kBasicBlockRankIntSize = 16;
616 constexpr static size_t kBasicBlockSize = 496;
623 constexpr static size_t kSuperBlockSize = 63488;
624 constexpr static size_t kBlocksPerSuperBlock = 128;
625 constexpr static size_t kWordsPerBlock = 8;
626
627 const size_t num_bits_;
628 std::vector<uint64_t> bits_interleaved;
629 std::vector<uint64_t> super_block_rank_;
630
631 class BitReader {
632 size_t iterator_64_ = 0;
633 size_t offset_size_ = 0;
634 size_t offset_bits_ = 0;
635 std::span<const uint64_t> bits_;
636
637 public:
638 BitReader(std::span<const uint64_t> bits) : bits_(bits) {}
639 uint64_t ReadBits64(size_t num_bits) {
640 if (num_bits > 64) {
641 num_bits = 64;
642 }
643 uint64_t result = offset_bits_ & first_bits_mask(num_bits);
644 if (offset_size_ >= num_bits) {
645 offset_bits_ >>= num_bits;
646 offset_size_ -= num_bits;
647 return result;
648 }
649 uint64_t next = iterator_64_ < bits_.size() ? bits_[iterator_64_++] : 0;
650 result ^= (next & first_bits_mask(num_bits - offset_size_))
651 << offset_size_;
652 offset_bits_ = (num_bits - offset_size_ == 64)
653 ? 0
654 : next >> (num_bits - offset_size_);
655 offset_size_ = 64 - (num_bits - offset_size_);
656 return result;
657 }
658 };
659
660 public:
668 explicit BitVectorInterleaved(std::span<const uint64_t> bit_vector,
669 size_t num_bits)
670 : num_bits_(std::min(num_bits, bit_vector.size() * kWordSize)) {
671 build_rank_interleaved(bit_vector, num_bits);
672 }
673
677 static inline uint64_t first_bits_mask(size_t num) {
678 return num >= 64 ? UINT64_MAX : ((1llu << num) - 1);
679 }
680
684 size_t size() const { return num_bits_; }
685
691 int operator[](size_t pos) const {
692 size_t block_id = pos / kBasicBlockSize;
693 size_t block_bit = pos - block_id * kBasicBlockSize;
694 size_t word_id = block_id * kWordsPerBlock + block_bit / kWordSize;
695 size_t word_bit = block_bit % kWordSize;
696 kWordSize;
697
698 return (bits_interleaved[word_id] >> word_bit) & 1;
699 }
700
708 void build_rank_interleaved(std::span<const uint64_t> bits, size_t num_bits) {
709 size_t num_superblocks = 1 + (num_bits_ - 1) / kSuperBlockSize;
710 super_block_rank_.resize(num_superblocks);
711 size_t num_basicblocks = 1 + (num_bits_ - 1) / kBasicBlockSize;
712 bits_interleaved.resize(num_basicblocks * (512 / kWordSize));
713
714 uint64_t super_block_sum = 0;
715 uint16_t basic_block_sum = 0;
716 auto bit_reader = BitReader(bits);
717
718 for (size_t i = 0; i * kBasicBlockSize < num_bits; ++i) {
719 if (i % (kSuperBlockSize / kBasicBlockSize) == 0) {
720 super_block_sum += basic_block_sum;
721 super_block_rank_[i / (kSuperBlockSize / kBasicBlockSize)] =
722 super_block_sum;
723 basic_block_sum = 0;
724 }
725 bits_interleaved[i * (kWordsPerBlock) + 7] =
726 static_cast<uint64_t>(basic_block_sum) << 48;
727
728 for (size_t j = 0; j < 7 && kWordSize * (i + j) < num_bits; ++j) {
729 bits_interleaved[i * (kWordsPerBlock) + j] =
730 bit_reader.ReadBits64(std::min<uint64_t>(
731 64ull, num_bits - i * kBasicBlockSize + j * kWordSize));
732 basic_block_sum +=
733 std::popcount(bits_interleaved[i * (kWordsPerBlock) + j]);
734 }
735 if ((i + 7) * kWordSize < num_bits) {
736 auto v = bit_reader.ReadBits64(std::min<uint64_t>(
737 48ull, num_bits - (i * kBasicBlockSize + 7 * kWordSize)));
738 bits_interleaved[i * (kWordsPerBlock) + 7] ^= v;
739 basic_block_sum += std::popcount(v);
740 }
741 }
742 }
743
750 uint64_t rank(size_t pos) const {
751 // Multiplication/devisions
752 uint64_t b_block = pos / kBasicBlockSize;
753 uint64_t s_block = b_block / kBlocksPerSuperBlock;
754 uint64_t b_block_pos = b_block * kWordsPerBlock;
755 // Super block rank
756 uint64_t result = super_block_rank_[s_block];
763 // __builtin_prefetch(&bits_interleaved[b_block_pos]);
764 result += rank_512(&bits_interleaved[b_block_pos],
765 pos - (b_block * kBasicBlockSize));
766 result += bits_interleaved[b_block_pos + 7] >> 48;
767 return result;
768 }
769
773 std::string to_string() const {
774 std::string result;
775 result.reserve(num_bits_);
776
777 for (size_t i = 0; i < num_bits_; i++) {
778 result.push_back(operator[](i) ? '1' : '0');
779 }
780
781 return result;
782 }
783};
784
785} // namespace pixie
A simple aligned storage for cache-line sized blocks.
Definition cache_line.h:23
void resize(size_t bits)
Resize storage to hold at least bits bits, rounded up to 512.
Definition cache_line.h:38
int operator[](size_t pos) const
Returns the bit at the given position.
Definition bitvector.h:691
size_t size() const
Returns the number of valid bits.
Definition bitvector.h:684
std::string to_string() const
Convert to a binary string (debug helper).
Definition bitvector.h:773
BitVectorInterleaved(std::span< const uint64_t > bit_vector, size_t num_bits)
Construct from an external array of 64-bit words.
Definition bitvector.h:668
uint64_t rank(size_t pos) const
Rank of 1s up to position pos (exclusive).
Definition bitvector.h:750
static uint64_t first_bits_mask(size_t num)
Mask with the lowest num bits set.
Definition bitvector.h:677
void build_rank_interleaved(std::span< const uint64_t > bits, size_t num_bits)
Build the interleaved layout and rank index.
Definition bitvector.h:708
BitVector(std::span< const uint64_t > bit_vector, size_t num_bits)
Construct from an external array of 64-bit words.
Definition bitvector.h:443
size_t size() const
Returns the number of valid bits.
Definition bitvector.h:454
uint64_t rank(size_t pos) const
Rank of 1s up to position pos (exclusive).
Definition bitvector.h:474
int operator[](size_t pos) const
Returns the bit at the given position.
Definition bitvector.h:461
uint64_t select0(size_t rank0) const
Select the position of the rank0-th 0-bit (1-indexed).
Definition bitvector.h:547
uint64_t select(size_t rank) const
Select the position of the rank-th 1-bit (1-indexed).
Definition bitvector.h:512
std::string to_string() const
Convert to a binary string (debug helper).
Definition bitvector.h:579
uint64_t rank0(size_t pos) const
Rank of 0s up to position pos (exclusive).
Definition bitvector.h:498