Pixie
Loading...
Searching...
No Matches
rmm_btree.h
1#pragma once
2
3#include <pixie/bits.h>
4#include <pixie/bitvector.h>
5#include <pixie/rmm_base.h>
6
7#include <algorithm>
8#include <array>
9#include <bit>
10#include <cstddef>
11#include <cstdint>
12#include <limits>
13#include <optional>
14#include <span>
15#include <stdexcept>
16#include <type_traits>
17#include <utility>
18#include <vector>
19
20namespace pixie::experimental {
21
81template <std::size_t HighCacheLines = 4, std::size_t LowFanout = 32>
82class RmMBTree : public RmMBase<RmMBTree<HighCacheLines, LowFanout>> {
83 public:
84 static_assert(HighCacheLines > 0);
85 static_assert(LowFanout > 0);
86
87 static constexpr std::size_t npos =
89 static constexpr std::size_t kBlockBits = 512;
90 static constexpr std::size_t kBlockWords = kBlockBits / 64;
91 static constexpr std::size_t kCacheLineBytes = 64;
92 static constexpr std::size_t kLowFanout = LowFanout;
93 static constexpr std::size_t kHighFanout =
94 std::max<std::size_t>(2, (512 * HighCacheLines) / (4 * 64));
95 static constexpr std::size_t kMaxFanout = std::max(kLowFanout, kHighFanout);
96 static_assert(kMaxFanout <= 64);
97 static_assert(
98 kBlockBits * kLowFanout <=
99 static_cast<std::size_t>(std::numeric_limits<std::int16_t>::max()));
100
101 RmMBTree() = default;
102 RmMBTree(const RmMBTree&) = default;
103 RmMBTree(RmMBTree&&) noexcept = default;
104 RmMBTree& operator=(const RmMBTree&) = default;
105 RmMBTree& operator=(RmMBTree&&) noexcept = default;
106
117 explicit RmMBTree(std::span<const std::uint64_t> words,
118 std::size_t bit_count,
119 std::size_t = kBlockBits) {
120 build(words, bit_count);
121 }
122
123 std::size_t size_impl() const { return bit_count_; }
124
125 std::size_t rank1_impl(std::size_t end_position) const {
126 return rank_index_ ? rank_index_->rank(end_position) : 0;
127 }
128
129 std::size_t rank0_impl(std::size_t end_position) const {
130 return rank_index_ ? rank_index_->rank0(end_position) : 0;
131 }
132
133 std::size_t select1_impl(std::size_t rank) const {
134 if (!rank_index_ || rank == 0) {
135 return npos;
136 }
137 const std::size_t position = rank_index_->select(rank);
138 return position < bit_count_ ? position : npos;
139 }
140
141 std::size_t select0_impl(std::size_t rank) const {
142 if (!rank_index_ || rank == 0) {
143 return npos;
144 }
145 const std::size_t position = rank_index_->select0(rank);
146 return position < bit_count_ ? position : npos;
147 }
148
149 std::size_t rank10_impl(std::size_t end_position) const {
150 if (end_position <= 1 || bit_count_ == 0) {
151 return 0;
152 }
153 end_position = std::min(end_position, bit_count_);
154 std::size_t count = 0;
155 for (std::size_t position = 0; position + 1 < end_position; ++position) {
156 count += bit(position) == 1 && bit(position + 1) == 0;
157 }
158 return count;
159 }
160
161 std::size_t select10_impl(std::size_t rank) const {
162 if (rank == 0) {
163 return npos;
164 }
165 for (std::size_t position = 0; position + 1 < bit_count_; ++position) {
166 if (bit(position) == 1 && bit(position + 1) == 0 && --rank == 0) {
167 return position;
168 }
169 }
170 return npos;
171 }
172
173 int excess_impl(std::size_t end_position) const {
174 end_position = std::min(end_position, bit_count_);
175 return static_cast<int>(
176 2 * static_cast<std::int64_t>(rank1_impl(end_position)) -
177 static_cast<std::int64_t>(end_position));
178 }
179
180 std::size_t fwdsearch_impl(std::size_t start_position, int delta) const {
181 if (start_position >= bit_count_) {
182 return npos;
183 }
184
185 const std::size_t block_index = start_position / kBlockBits;
186 const std::size_t block_begin = block_index * kBlockBits;
187 const std::size_t start_offset = start_position - block_begin;
188 const std::size_t block_result =
189 find_fwd_in_block(block_index, start_offset, delta);
190 if (block_result != npos) {
191 return block_result;
192 }
193
194 const std::int64_t relative_target =
195 block_excess_at_local(block_index, start_offset) + delta;
196 const std::int64_t block_start_excess = prefix_excess_impl(block_begin);
197 const std::int64_t target = block_start_excess + relative_target;
198 std::size_t level = 0;
199 std::size_t index = block_index;
200 while (has_parent_level(level)) {
201 const std::size_t fanout = fanout_to_parent(level);
202 const std::size_t parent = index / fanout;
203 const std::size_t sibling_end =
204 std::min(level_count(level), parent * fanout + fanout);
205 const NodeScanResult scan = scan_children_fwd(
206 level, parent, index % fanout + 1, sibling_end - parent * fanout,
207 target, prefix_excess_impl(node_end_bit(level, index)));
208 if (scan.found) {
209 const std::size_t result =
210 descend_fwd(level, scan.index, target, scan.node_start_excess);
211 if (result != npos) {
212 return result;
213 }
214 }
215 level += 1;
216 index = parent;
217 }
218 return npos;
219 }
220
221 std::size_t bwdsearch_impl(std::size_t start_position, int delta) const {
222 if (start_position == 0 || start_position > bit_count_) {
223 return npos;
224 }
225
226 const std::size_t block_index = (start_position - 1) / kBlockBits;
227 const std::size_t block_begin = block_index * kBlockBits;
228 const std::size_t end_offset = start_position - block_begin;
229 const std::size_t block_result =
230 find_bwd_in_block(block_index, end_offset, delta);
231 if (block_result != npos) {
232 return block_result;
233 }
234
235 const std::int64_t relative_target =
236 block_excess_at_local(block_index, end_offset) + delta;
237 const std::int64_t block_start_excess = prefix_excess_impl(block_begin);
238 const std::int64_t target = block_start_excess + relative_target;
239 std::size_t level = 0;
240 std::size_t index = block_index;
241 while (has_parent_level(level)) {
242 const std::size_t fanout = fanout_to_parent(level);
243 const std::size_t parent = index / fanout;
244 const NodeScanResult scan =
245 scan_children_bwd(level, parent, 0, index % fanout, target,
246 prefix_excess_impl(node_start_bit(level, index)));
247 if (scan.found) {
248 if (scan.boundary_only) {
249 return node_start_bit(level, scan.index);
250 }
251 const std::size_t result =
252 descend_bwd(level, scan.index, target, scan.node_start_excess);
253 if (result != npos) {
254 return result;
255 }
256 }
257 level += 1;
258 index = parent;
259 }
260 return npos;
261 }
262
263 std::size_t range_min_query_pos_impl(std::size_t range_begin,
264 std::size_t range_end) const {
265 if (range_begin > range_end || range_end >= bit_count_) {
266 return npos;
267 }
268 return range_extreme_query_pos(range_begin, range_end, true);
269 }
270
271 int range_min_query_val_impl(std::size_t range_begin,
272 std::size_t range_end) const {
273 if (range_begin > range_end || range_end >= bit_count_) {
274 return 0;
275 }
276 return range_extreme_query_val(range_begin, range_end, true);
277 }
278
279 std::size_t range_max_query_pos_impl(std::size_t range_begin,
280 std::size_t range_end) const {
281 if (range_begin > range_end || range_end >= bit_count_) {
282 return npos;
283 }
284 return range_extreme_query_pos(range_begin, range_end, false);
285 }
286
287 int range_max_query_val_impl(std::size_t range_begin,
288 std::size_t range_end) const {
289 if (range_begin > range_end || range_end >= bit_count_) {
290 return 0;
291 }
292 return range_extreme_query_val(range_begin, range_end, false);
293 }
294
295 std::size_t mincount_impl(std::size_t range_begin,
296 std::size_t range_end) const {
297 if (range_begin > range_end || range_end >= bit_count_) {
298 return 0;
299 }
300 return range_min_stats(range_begin, range_end).count;
301 }
302
303 std::size_t minselect_impl(std::size_t range_begin,
304 std::size_t range_end,
305 std::size_t rank) const {
306 if (range_begin > range_end || range_end >= bit_count_ || rank == 0) {
307 return npos;
308 }
309 const RangeMinStats stats = range_min_stats(range_begin, range_end);
310 if (rank > stats.count) {
311 return npos;
312 }
313 return range_min_select(range_begin, range_end, stats.value, rank);
314 }
315
316 std::size_t close_impl(std::size_t open_position) const {
317 if (open_position >= bit_count_) {
318 return npos;
319 }
320 if (!bit(open_position)) {
321 return open_position;
322 }
323 return fwd_excess_at(open_position, -1);
324 }
325
326 std::size_t open_impl(std::size_t close_position) const {
327 if (close_position >= bit_count_) {
328 return npos;
329 }
330 if (bit(close_position)) {
331 return close_position;
332 }
333 return bwdsearch_impl(close_position + 1, 0);
334 }
335
336 std::size_t enclose_impl(std::size_t position) const {
337 if (position >= bit_count_) {
338 return npos;
339 }
340 if (!bit(position)) {
341 return open_impl(position);
342 }
343 return bwdsearch_impl(position + 1, -2);
344 }
345
346 private:
358 std::size_t fwd_excess_at(std::size_t position, int delta) const {
359 if (position >= bit_count_) {
360 return npos;
361 }
362 if (position + 1 >= bit_count_) {
363 return npos;
364 }
365 return fwdsearch_impl(position + 1, delta);
366 }
367
368 struct Summary {
369 std::uint64_t size_bits = 0;
370 std::uint64_t ones = 0;
371 std::int64_t block_excess = 0;
372 std::int64_t min_excess = 0;
373 std::int64_t max_excess = 0;
374 std::uint64_t min_count = 0;
375 };
376
377 template <class Excess, class Count, std::size_t Fanout>
378 struct alignas(kCacheLineBytes) SummaryNode {
379 using ExcessType = Excess;
380 using CountType = Count;
381 static constexpr std::size_t kFanout = Fanout;
382 std::array<Excess, Fanout> prefix_excess{};
383 std::array<Excess, Fanout> min_excess{};
384 std::array<Excess, Fanout> max_excess{};
385 std::array<Count, Fanout> min_count{};
386 };
387
388 using LowNode = SummaryNode<std::int16_t, std::uint16_t, kLowFanout>;
389 using HighNode = SummaryNode<std::int64_t, std::uint64_t, kHighFanout>;
390 static_assert(alignof(LowNode) == kCacheLineBytes);
391 static_assert(alignof(HighNode) == kCacheLineBytes);
392 static_assert(sizeof(LowNode) % kCacheLineBytes == 0);
393 static_assert(sizeof(HighNode) % kCacheLineBytes == 0);
394
395 struct ByteAgg {
396 std::int8_t block_excess = 0;
397 std::int8_t min_excess = 0;
398 std::int8_t max_excess = 0;
399 std::uint8_t min_count = 0;
400 std::uint8_t pos_first_min = 0;
401 std::uint8_t pos_first_max = 0;
402 };
403
404 static constexpr std::size_t kSearchChunkBits = 128;
405 static constexpr std::size_t kSearchChunkWords = kSearchChunkBits / 64;
406 static constexpr std::size_t kSearchChunkCount =
407 kBlockBits / kSearchChunkBits;
408
416 static const std::array<ByteAgg, 256>& byte_lut() {
417 static const std::array<ByteAgg, 256> table = [] {
418 std::array<ByteAgg, 256> result{};
419 for (int byte_value = 0; byte_value < 256; ++byte_value) {
420 ByteAgg agg;
421 int current = 0;
422 int minimum = std::numeric_limits<int>::max();
423 int maximum = std::numeric_limits<int>::min();
424 const auto bit_at = [&](int bit_index) {
425 return (byte_value >> bit_index) & 1;
426 };
427 for (int bit_index = 0; bit_index < 8; ++bit_index) {
428 const int value = bit_at(bit_index);
429 current += value ? 1 : -1;
430 if (current < minimum) {
431 minimum = current;
432 agg.min_count = 1;
433 agg.pos_first_min = static_cast<std::uint8_t>(bit_index);
434 } else if (current == minimum) {
435 ++agg.min_count;
436 }
437 if (current > maximum) {
438 maximum = current;
439 agg.pos_first_max = static_cast<std::uint8_t>(bit_index);
440 }
441 }
442 agg.block_excess = static_cast<std::int8_t>(current);
443 agg.min_excess = static_cast<std::int8_t>(minimum);
444 agg.max_excess = static_cast<std::int8_t>(maximum);
445 result[static_cast<std::size_t>(byte_value)] = agg;
446 }
447 return result;
448 }();
449 return table;
450 }
451
462 void build(std::span<const std::uint64_t> words, std::size_t bit_count) {
463 const std::size_t required_words = (bit_count + 63) / 64;
464 if (words.size() < required_words) {
465 throw std::invalid_argument(
466 "RmMBTree input span is shorter than bit_count");
467 }
468
469 bits_ = words;
470 bit_count_ = bit_count;
471 rank_index_.emplace(words, bit_count);
472 block_count_ = (bit_count_ + kBlockBits - 1) / kBlockBits;
473 std::vector<Summary> block_summaries(block_count_);
474
475 for (std::size_t block_index = 0; block_index < block_count_;
476 ++block_index) {
477 const std::size_t block_begin = block_index * kBlockBits;
478 const std::size_t block_end =
479 std::min(bit_count_, block_begin + kBlockBits);
480 block_summaries[block_index] =
481 summarize_bits(block_begin, block_end - block_begin);
482 }
483
484 build_levels(block_summaries);
485 }
486
494 void build_levels(const std::vector<Summary>& block_summaries) {
495 level_counts_.clear();
496 low_levels_.clear();
497 high_levels_.clear();
498 top_summary_ = Summary{};
499 level_counts_.push_back(block_summaries.size());
500 if (block_summaries.empty()) {
501 return;
502 }
503
504 std::vector<Summary> current = block_summaries;
505 current = build_parent_level<LowNode>(current, low_levels_.emplace_back());
506 level_counts_.push_back(current.size());
507 current =
508 build_parent_level<HighNode>(current, high_levels_.emplace_back());
509 level_counts_.push_back(current.size());
510 while (current.size() > 1) {
511 current =
512 build_parent_level<HighNode>(current, high_levels_.emplace_back());
513 level_counts_.push_back(current.size());
514 }
515 top_summary_ = current.front();
516 }
517
528 template <class Node>
529 static std::vector<Summary> build_parent_level(const std::vector<Summary>& in,
530 std::vector<Node>& nodes) {
531 constexpr std::size_t fanout = Node::kFanout;
532 std::vector<Summary> out((in.size() + fanout - 1) / fanout);
533 nodes.resize(out.size());
534 for (std::size_t parent = 0; parent < out.size(); ++parent) {
535 const std::size_t begin = parent * fanout;
536 const std::size_t end = std::min(in.size(), begin + fanout);
537 Summary combined;
538 for (std::size_t i = begin; i < end; ++i) {
539 store_child_summary(nodes[parent], i - begin, combined.block_excess,
540 in[i]);
541 combined = append(combined, in[i]);
542 }
543 out[parent] = combined;
544 }
545 return out;
546 }
547
560 template <class Node>
561 static void store_child_summary(Node& node,
562 std::size_t slot,
563 std::int64_t prefix_excess,
564 const Summary& summary) {
565 node.prefix_excess[slot] =
566 static_cast<typename decltype(node.prefix_excess)::value_type>(
567 prefix_excess + summary.block_excess);
568 node.min_excess[slot] =
569 static_cast<typename decltype(node.min_excess)::value_type>(
570 summary.min_excess);
571 node.max_excess[slot] =
572 static_cast<typename decltype(node.max_excess)::value_type>(
573 summary.max_excess);
574 node.min_count[slot] =
575 static_cast<typename decltype(node.min_count)::value_type>(
576 summary.min_count);
577 }
578
588 Summary summarize_bits(std::size_t begin, std::size_t length) const {
589 Summary summary;
590 summary.size_bits = length;
591 if (length == 0) {
592 return summary;
593 }
594 int current = 0;
595 int minimum = std::numeric_limits<int>::max();
596 int maximum = std::numeric_limits<int>::min();
597 for (std::size_t offset = 0; offset < length; ++offset) {
598 const std::uint8_t value = bit(begin + offset);
599 summary.ones += value;
600 current += value ? 1 : -1;
601 if (current < minimum) {
602 minimum = current;
603 summary.min_count = 1;
604 } else if (current == minimum) {
605 ++summary.min_count;
606 }
607 if (current > maximum) {
608 maximum = current;
609 }
610 }
611 summary.block_excess = current;
612 summary.min_excess = minimum;
613 summary.max_excess = maximum;
614 return summary;
615 }
616
626 static Summary append(Summary left, const Summary& right) {
627 if (left.size_bits == 0) {
628 return right;
629 }
630 if (right.size_bits == 0) {
631 return left;
632 }
633 Summary result;
634 result.size_bits = left.size_bits + right.size_bits;
635 result.ones = left.ones + right.ones;
636 result.block_excess = left.block_excess + right.block_excess;
637 result.min_excess =
638 std::min(left.min_excess, left.block_excess + right.min_excess);
639 result.max_excess =
640 std::max(left.max_excess, left.block_excess + right.max_excess);
641 result.min_count = 0;
642 if (left.min_excess == result.min_excess) {
643 result.min_count += left.min_count;
644 }
645 if (left.block_excess + right.min_excess == result.min_excess) {
646 result.min_count += right.min_count;
647 }
648 return result;
649 }
650
658 std::size_t block_size(std::size_t block_index) const {
659 const std::size_t begin = block_index * kBlockBits;
660 return std::min(bit_count_ - begin, kBlockBits);
661 }
662
670 bool full_block_has_words(std::size_t block_index) const {
671 return block_size(block_index) == kBlockBits &&
672 (block_index + 1) * kBlockWords <= bits_.size();
673 }
674
684 std::int64_t block_excess_at_local(std::size_t block_index,
685 std::size_t offset) const {
686 const std::size_t length = block_size(block_index);
687 offset = std::min(offset, length);
688 if (offset == 0) {
689 return 0;
690 }
691
692 const std::size_t first_word = block_index * kBlockWords;
693 std::size_t remaining = offset;
694 std::int64_t ones = 0;
695 std::size_t word_offset = 0;
696 while (remaining >= 64) {
697 ones += std::popcount(bits_[first_word + word_offset]);
698 remaining -= 64;
699 ++word_offset;
700 }
701 if (remaining != 0) {
702 ones += std::popcount(bits_[first_word + word_offset] &
703 first_bits_mask(remaining));
704 }
705 return 2 * ones - static_cast<std::int64_t>(offset);
706 }
707
714 static int chunk_excess_128(const std::uint64_t* chunk) {
715 return 2 * static_cast<int>(std::popcount(chunk[0]) +
716 std::popcount(chunk[1])) -
717 static_cast<int>(kSearchChunkBits);
718 }
719
731 std::size_t find_fwd_in_block(std::size_t block_index,
732 std::size_t start_offset,
733 std::int64_t delta) const {
734 const std::size_t length = block_size(block_index);
735 if (start_offset >= length) {
736 return npos;
737 }
738
739 if (full_block_has_words(block_index)) {
740 const std::uint64_t* block = bits_.data() + block_index * kBlockWords;
741 const std::size_t first_chunk = start_offset / kSearchChunkBits;
742 std::int64_t target = delta;
743 for (std::size_t chunk = first_chunk; chunk < kSearchChunkCount;
744 ++chunk) {
745 const std::uint64_t* chunk_words = block + chunk * kSearchChunkWords;
746 const std::size_t local_start =
747 chunk == first_chunk ? start_offset - chunk * kSearchChunkBits : 0;
748 if (chunk == first_chunk) {
749 target += prefix_excess_128(chunk_words, local_start);
750 }
751 int block_excess = 0;
752 const std::size_t offset = forward_search_128(
753 chunk_words, static_cast<int>(target), local_start, &block_excess);
754 if (offset != kSearchChunkBits) {
755 return block_index * kBlockBits + chunk * kSearchChunkBits + offset;
756 }
757 target -= block_excess;
758 }
759 return npos;
760 }
761
762 std::int64_t current = block_excess_at_local(block_index, start_offset);
763 const std::int64_t relative_target = current + delta;
764 const std::size_t block_begin = block_index * kBlockBits;
765 for (std::size_t offset = start_offset; offset < length; ++offset) {
766 current += bit(block_begin + offset) ? 1 : -1;
767 if (current == relative_target) {
768 return block_index * kBlockBits + offset;
769 }
770 }
771 return npos;
772 }
773
785 std::size_t find_bwd_in_block(std::size_t block_index,
786 std::size_t end_offset,
787 std::int64_t delta) const {
788 if (end_offset == 0) {
789 return npos;
790 }
791 const std::size_t max_prefix_length = end_offset - 1;
792
793 if (full_block_has_words(block_index)) {
794 const std::uint64_t* block = bits_.data() + block_index * kBlockWords;
795 const std::size_t last_chunk = max_prefix_length / kSearchChunkBits;
796 std::int64_t target = delta;
797 for (std::size_t chunk = last_chunk + 1; chunk > 0;) {
798 --chunk;
799 const std::uint64_t* chunk_words = block + chunk * kSearchChunkWords;
800 const std::size_t local_end =
801 chunk == last_chunk ? end_offset - chunk * kSearchChunkBits
802 : kSearchChunkBits;
803 if (chunk == last_chunk) {
804 target += prefix_excess_128(chunk_words, local_end);
805 }
806 int block_excess = 0;
807 const std::size_t offset = backward_search_128(
808 chunk_words, static_cast<int>(target), local_end, &block_excess);
809 if (offset != kSearchChunkBits) {
810 return block_index * kBlockBits + chunk * kSearchChunkBits + offset;
811 }
812 if (chunk > 0) {
813 target += chunk_excess_128(block + (chunk - 1) * kSearchChunkWords);
814 }
815 }
816 return npos;
817 }
818
819 const std::int64_t relative_target =
820 block_excess_at_local(block_index, end_offset) + delta;
821 std::int64_t current =
822 block_excess_at_local(block_index, max_prefix_length);
823 const std::size_t block_begin = block_index * kBlockBits;
824 for (std::size_t prefix_length = max_prefix_length; prefix_length > 0;
825 --prefix_length) {
826 if (current == relative_target) {
827 return block_index * kBlockBits + prefix_length;
828 }
829 current -= bit(block_begin + prefix_length - 1) ? 1 : -1;
830 }
831 return relative_target == 0 ? block_index * kBlockBits : npos;
832 }
833
845 std::size_t descend_fwd(std::size_t level,
846 std::size_t index,
847 std::int64_t target,
848 std::int64_t node_start_excess) const {
849 while (level > 0) {
850 const std::size_t child_level = level - 1;
851 const std::size_t fanout = fanout_to_parent(child_level);
852 const std::size_t child_begin = index * fanout;
853 const std::size_t child_end =
854 std::min(level_count(child_level), child_begin + fanout);
855 const NodeScanResult scan =
856 scan_children_fwd(child_level, index, 0, child_end - child_begin,
857 target, node_start_excess);
858 if (!scan.found) {
859 return npos;
860 }
861 index = scan.index;
862 level = child_level;
863 node_start_excess = scan.node_start_excess;
864 }
865 return find_fwd_in_block(index, 0, target - node_start_excess);
866 }
867
879 std::size_t descend_bwd(std::size_t level,
880 std::size_t index,
881 std::int64_t target,
882 std::int64_t node_start_excess) const {
883 while (level > 0) {
884 const std::size_t child_level = level - 1;
885 const std::size_t fanout = fanout_to_parent(child_level);
886 const std::size_t child_begin = index * fanout;
887 const std::size_t child_end =
888 std::min(level_count(child_level), child_begin + fanout);
889 std::int64_t child_start_excess =
890 node_start_excess + summary_at(level, index).block_excess;
891 const NodeScanResult scan =
892 scan_children_bwd(child_level, index, 0, child_end - child_begin,
893 target, child_start_excess);
894 if (!scan.found) {
895 return npos;
896 }
897 if (scan.boundary_only) {
898 return node_start_bit(child_level, scan.index);
899 }
900 index = scan.index;
901 level = child_level;
902 node_start_excess = scan.node_start_excess;
903 }
904 const std::int64_t block_excess = summary_at(0, index).block_excess;
905 return find_bwd_in_block(index, block_size(index),
906 target - node_start_excess - block_excess);
907 }
908
909 struct NodeScanResult {
910 bool found = false;
911 bool boundary_only = false;
912 std::size_t index = 0;
913 std::int64_t node_start_excess = 0;
914 };
915
928 NodeScanResult scan_children_fwd(std::size_t child_level,
929 std::size_t parent,
930 std::size_t begin_slot,
931 std::size_t end_slot,
932 std::int64_t target,
933 std::int64_t begin_excess) const {
934 if (begin_slot >= end_slot) {
935 return {};
936 }
937 if (child_level == 0) {
938 return scan_node_fwd(low_levels_[0][parent], child_level, parent,
939 begin_slot, end_slot, target, begin_excess);
940 }
941 return scan_node_fwd(high_levels_[child_level - 1][parent], child_level,
942 parent, begin_slot, end_slot, target, begin_excess);
943 }
944
958 NodeScanResult scan_children_bwd(std::size_t child_level,
959 std::size_t parent,
960 std::size_t begin_slot,
961 std::size_t end_slot,
962 std::int64_t target,
963 std::int64_t end_excess) const {
964 if (begin_slot >= end_slot) {
965 return {};
966 }
967 if (child_level == 0) {
968 return scan_node_bwd(low_levels_[0][parent], child_level, parent,
969 begin_slot, end_slot, target, end_excess);
970 }
971 return scan_node_bwd(high_levels_[child_level - 1][parent], child_level,
972 parent, begin_slot, end_slot, target, end_excess);
973 }
974
989 template <class Node>
990 NodeScanResult scan_node_fwd(const Node& node,
991 std::size_t child_level,
992 std::size_t parent,
993 std::size_t begin_slot,
994 std::size_t end_slot,
995 std::int64_t target,
996 std::int64_t begin_excess) const {
997 const std::int64_t node_base_excess =
998 begin_excess - prefix_excess_at(node, begin_slot);
999 for (std::size_t slot = begin_slot; slot < end_slot;) {
1000 const std::size_t lane_count =
1001 std::min(vector_lane_count<Node>(), end_slot - slot);
1002 const std::uint32_t mask = matching_chunk_mask(
1003 node, slot, lane_count, target - node_base_excess, false);
1004 if (mask != 0) {
1005 const std::size_t lane = std::countr_zero(mask);
1006 const std::size_t matched_slot = slot + lane;
1007 return {true, false,
1008 parent * fanout_to_parent(child_level) + matched_slot,
1009 child_start_excess(node, node_base_excess, matched_slot)};
1010 }
1011 slot += lane_count;
1012 }
1013 return {};
1014 }
1015
1030 template <class Node>
1031 NodeScanResult scan_node_bwd(const Node& node,
1032 std::size_t child_level,
1033 std::size_t parent,
1034 std::size_t begin_slot,
1035 std::size_t end_slot,
1036 std::int64_t target,
1037 std::int64_t end_excess) const {
1038 const std::int64_t node_base_excess =
1039 end_excess - prefix_excess_at(node, end_slot);
1040 for (std::size_t slot_end = end_slot; slot_end > begin_slot;) {
1041 const std::size_t lane_count =
1042 std::min(vector_lane_count<Node>(), slot_end - begin_slot);
1043 const std::size_t slot = slot_end - lane_count;
1044 const std::uint32_t mask = matching_chunk_mask(
1045 node, slot, lane_count, target - node_base_excess, true);
1046 if (mask != 0) {
1047 const std::size_t lane =
1048 static_cast<std::size_t>(std::bit_width(mask) - 1);
1049 const std::size_t matched_slot = slot + lane;
1050 const std::int64_t relative_target =
1051 target - child_start_excess(node, node_base_excess, matched_slot);
1052 const bool interior_match =
1053 node.min_excess[matched_slot] <= relative_target &&
1054 relative_target <= node.max_excess[matched_slot];
1055 return {true, !interior_match,
1056 parent * fanout_to_parent(child_level) + matched_slot,
1057 child_start_excess(node, node_base_excess, matched_slot)};
1058 }
1059 slot_end = slot;
1060 }
1061 return {};
1062 }
1063
1074 template <class Node>
1075 static std::int64_t prefix_excess_at(const Node& node, std::size_t slot) {
1076 if (slot == 0) {
1077 return 0;
1078 }
1079 return prefix_through(node, slot - 1);
1080 }
1081
1092 template <class Node>
1093 static std::int64_t child_start_excess(const Node& node,
1094 std::int64_t node_base_excess,
1095 std::size_t slot) {
1096 return node_base_excess + prefix_excess_at(node, slot);
1097 }
1098
1107 template <class Node>
1108 static std::int64_t prefix_through(const Node& node, std::size_t slot) {
1109 return static_cast<std::int64_t>(node.prefix_excess[slot]);
1110 }
1111
1121 template <class Node>
1122 static std::int64_t child_excess(const Node& node, std::size_t slot) {
1123 return prefix_through(node, slot) - prefix_excess_at(node, slot);
1124 }
1125
1132 template <class Node>
1133 static constexpr std::size_t vector_lane_count() {
1134 if constexpr (std::is_same_v<typename Node::ExcessType, std::int16_t>) {
1135 return 16;
1136 } else {
1137 return 4;
1138 }
1139 }
1140
1154 template <class Node>
1155 static std::uint32_t matching_chunk_mask(const Node& node,
1156 std::size_t slot,
1157 std::size_t lane_count,
1158 std::int64_t target_in_node,
1159 bool include_zero_boundary) {
1160#ifdef PIXIE_AVX2_SUPPORT
1161 if constexpr (std::is_same_v<typename Node::ExcessType, std::int16_t>) {
1162 if (lane_count == 16 &&
1163 target_in_node >= std::numeric_limits<std::int16_t>::min() &&
1164 target_in_node <= std::numeric_limits<std::int16_t>::max()) {
1165 alignas(32) std::int16_t prefix_before[16]{};
1166 fill_prefix_before(node, slot, prefix_before);
1167 return rmm_btree_match_mask_i16x16(
1168 prefix_before, node.min_excess.data() + slot,
1169 node.max_excess.data() + slot,
1170 static_cast<std::int16_t>(target_in_node), include_zero_boundary);
1171 }
1172 } else if constexpr (std::is_same_v<typename Node::ExcessType,
1173 std::int64_t>) {
1174 if (lane_count == 4) {
1175 alignas(32) std::int64_t prefix_before[4]{};
1176 fill_prefix_before(node, slot, prefix_before);
1177 return rmm_btree_match_mask_i64x4(
1178 prefix_before, node.min_excess.data() + slot,
1179 node.max_excess.data() + slot, target_in_node,
1180 include_zero_boundary);
1181 }
1182 }
1183#endif
1184 std::uint32_t result = 0;
1185 for (std::size_t lane = 0; lane < lane_count; ++lane) {
1186 const std::int64_t rel =
1187 target_in_node - prefix_excess_at(node, slot + lane);
1188 const bool found = (include_zero_boundary && rel == 0) ||
1189 (node.min_excess[slot + lane] <= rel &&
1190 rel <= node.max_excess[slot + lane]);
1191 if (found) {
1192 result |= std::uint32_t{1} << lane;
1193 }
1194 }
1195 return result;
1196 }
1197
1210 template <class Node>
1211 static void fill_prefix_before(const Node& node,
1212 std::size_t slot,
1213 typename Node::ExcessType* out) {
1214 if (slot == 0) {
1215 out[0] = 0;
1216 for (std::size_t lane = 1; lane < vector_lane_count<Node>(); ++lane) {
1217 out[lane] = node.prefix_excess[lane - 1];
1218 }
1219 return;
1220 }
1221 for (std::size_t lane = 0; lane < vector_lane_count<Node>(); ++lane) {
1222 out[lane] = node.prefix_excess[slot + lane - 1];
1223 }
1224 }
1225
1234 static bool contains_fwd(const Summary& summary,
1235 std::int64_t relative_target) {
1236 return summary.min_excess <= relative_target &&
1237 relative_target <= summary.max_excess;
1238 }
1239
1249 static bool contains_bwd(const Summary& summary,
1250 std::int64_t relative_target) {
1251 return relative_target == 0 || contains_fwd(summary, relative_target);
1252 }
1253
1254 std::int64_t prefix_excess_impl(std::size_t end_position) const {
1255 return 2 * static_cast<std::int64_t>(rank1_impl(end_position)) -
1256 static_cast<std::int64_t>(end_position);
1257 }
1258
1265 bool has_parent_level(std::size_t level) const {
1266 return level + 1 < total_levels() && level_count(level + 1) != 0;
1267 }
1268
1275 std::size_t total_levels() const { return level_counts_.size(); }
1276
1283 std::size_t level_count(std::size_t level) const {
1284 return level < level_counts_.size() ? level_counts_[level] : 0;
1285 }
1286
1295 Summary summary_at(std::size_t level, std::size_t index) const {
1296 if (level + 1 >= total_levels()) {
1297 return top_summary_;
1298 }
1299 const std::size_t parent_level = level + 1;
1300 const std::size_t fanout = fanout_to_parent(level);
1301 const std::size_t parent = index / fanout;
1302 const std::size_t slot = index % fanout;
1303 Summary summary;
1304 if (parent_level == 1) {
1305 const LowNode& node = low_levels_[0][parent];
1306 summary.block_excess = child_excess(node, slot);
1307 summary.min_excess = node.min_excess[slot];
1308 summary.max_excess = node.max_excess[slot];
1309 summary.min_count = node.min_count[slot];
1310 } else {
1311 const HighNode& node = high_levels_[parent_level - 2][parent];
1312 summary.block_excess = child_excess(node, slot);
1313 summary.min_excess = node.min_excess[slot];
1314 summary.max_excess = node.max_excess[slot];
1315 summary.min_count = node.min_count[slot];
1316 }
1317 return summary;
1318 }
1319
1327 static std::size_t fanout_to_parent(std::size_t level) {
1328 return level == 0 ? kLowFanout : kHighFanout;
1329 }
1330
1339 static std::size_t mul_clamped(std::size_t lhs, std::size_t rhs) {
1340 if (lhs != 0 && rhs > std::numeric_limits<std::size_t>::max() / lhs) {
1341 return std::numeric_limits<std::size_t>::max();
1342 }
1343 return lhs * rhs;
1344 }
1345
1353 static std::size_t level_span_bits(std::size_t level) {
1354 std::size_t span = kBlockBits;
1355 if (level >= 1) {
1356 span = mul_clamped(span, kLowFanout);
1357 }
1358 if (level >= 2) {
1359 for (std::size_t i = 2; i <= level; ++i) {
1360 span = mul_clamped(span, kHighFanout);
1361 }
1362 }
1363 return span;
1364 }
1365
1374 std::size_t node_start_bit(std::size_t level, std::size_t index) const {
1375 const std::size_t span = level_span_bits(level);
1376 if (span != 0 && index > std::numeric_limits<std::size_t>::max() / span) {
1377 return bit_count_;
1378 }
1379 return std::min(bit_count_, index * span);
1380 }
1381
1390 std::size_t node_size_bits(std::size_t level, std::size_t index) const {
1391 const std::size_t start = node_start_bit(level, index);
1392 if (start >= bit_count_) {
1393 return 0;
1394 }
1395 return std::min(level_span_bits(level), bit_count_ - start);
1396 }
1397
1406 std::size_t node_end_bit(std::size_t level, std::size_t index) const {
1407 const std::size_t start = node_start_bit(level, index);
1408 const std::uint64_t size = node_size_bits(level, index);
1409 if (size > std::numeric_limits<std::size_t>::max() - start) {
1410 return bit_count_;
1411 }
1412 return std::min(bit_count_, start + static_cast<std::size_t>(size));
1413 }
1414
1415 struct NodeRef {
1416 std::size_t level = 0;
1417 std::size_t index = 0;
1418 };
1419
1420 static constexpr std::size_t kMaxCoverNodes = 512;
1421
1422 struct ScanResult {
1423 std::int64_t block_excess = 0;
1424 std::int64_t min_value = std::numeric_limits<std::int64_t>::max();
1425 std::int64_t max_value = std::numeric_limits<std::int64_t>::min();
1426 std::uint64_t min_count = 0;
1427 std::size_t min_position = npos;
1428 std::size_t max_position = npos;
1429 };
1430
1431 struct RangeExtremeResult {
1432 std::size_t position = npos;
1433 std::int64_t value = 0;
1434 };
1435
1436 struct RangeMinStats {
1437 std::int64_t value = 0;
1438 std::uint64_t count = 0;
1439 };
1440
1441 struct Cover {
1442 std::array<NodeRef, kMaxCoverNodes> nodes{};
1443 std::size_t size = 0;
1444
1451 void push(NodeRef node) {
1452 if (size < nodes.size()) {
1453 nodes[size++] = node;
1454 }
1455 }
1456 };
1457
1467 std::size_t range_extreme_query_pos(std::size_t range_begin,
1468 std::size_t range_end,
1469 bool find_min) const {
1470 return range_extreme_query(range_begin, range_end, find_min).position;
1471 }
1472
1482 int range_extreme_query_val(std::size_t range_begin,
1483 std::size_t range_end,
1484 bool find_min) const {
1485 return static_cast<int>(
1486 range_extreme_query(range_begin, range_end, find_min).value);
1487 }
1488
1499 RangeExtremeResult range_extreme_query(std::size_t range_begin,
1500 std::size_t range_end,
1501 bool find_min) const {
1502 std::int64_t value = 0;
1503 std::int64_t best = find_min ? std::numeric_limits<std::int64_t>::max()
1504 : std::numeric_limits<std::int64_t>::min();
1505 std::size_t best_position = npos;
1506 NodeRef best_node;
1507 std::int64_t prefix_at_best_node = 0;
1508 bool best_is_node = false;
1509
1510 auto consider_point = [&](std::int64_t candidate, std::size_t position) {
1511 if ((find_min && candidate < best) || (!find_min && candidate > best)) {
1512 best = candidate;
1513 best_position = position;
1514 best_is_node = false;
1515 }
1516 };
1517
1518 const std::size_t range_end_exclusive = range_end + 1;
1519 const std::size_t first_full_block =
1520 (range_begin + kBlockBits - 1) / kBlockBits;
1521 const std::size_t full_begin =
1522 std::min(range_end_exclusive, first_full_block * kBlockBits);
1523 if (range_begin < full_begin) {
1524 const ScanResult scan = scan_range(range_begin, full_begin);
1525 consider_point(find_min ? scan.min_value : scan.max_value,
1526 find_min ? scan.min_position : scan.max_position);
1527 value += scan.block_excess;
1528 }
1529
1530 const std::size_t last_full_block_exclusive =
1531 range_end_exclusive / kBlockBits;
1532 const std::size_t middle_begin = full_begin;
1533 const std::size_t middle_end =
1534 std::max(middle_begin, last_full_block_exclusive * kBlockBits);
1535 if (middle_begin < middle_end) {
1536 Cover cover;
1537 collect_cover(middle_begin, middle_end, cover);
1538 for (std::size_t i = 0; i < cover.size; ++i) {
1539 const NodeRef& node = cover.nodes[i];
1540 Summary summary = summary_at(node.level, node.index);
1541 const std::int64_t candidate =
1542 value + (find_min ? summary.min_excess : summary.max_excess);
1543 if ((find_min && candidate < best) || (!find_min && candidate > best)) {
1544 best = candidate;
1545 best_node = node;
1546 prefix_at_best_node = value;
1547 best_is_node = true;
1548 }
1549 value += summary.block_excess;
1550 }
1551 }
1552
1553 if (middle_end < range_end_exclusive) {
1554 const ScanResult scan = scan_range(middle_end, range_end_exclusive);
1555 const std::int64_t candidate =
1556 value + (find_min ? scan.min_value : scan.max_value);
1557 consider_point(candidate,
1558 find_min ? scan.min_position : scan.max_position);
1559 }
1560
1561 if (best_is_node) {
1562 best_position =
1563 descend_first_extreme(best_node.level, best_node.index,
1564 best - prefix_at_best_node, find_min);
1565 }
1566 return {best_position,
1567 best == std::numeric_limits<std::int64_t>::max() ||
1568 best == std::numeric_limits<std::int64_t>::min()
1569 ? 0
1570 : best};
1571 }
1572
1581 RangeMinStats range_min_stats(std::size_t range_begin,
1582 std::size_t range_end) const {
1583 std::int64_t value = 0;
1584 std::int64_t best = std::numeric_limits<std::int64_t>::max();
1585 std::uint64_t count = 0;
1586
1587 auto consider = [&](std::int64_t candidate, std::uint64_t candidate_count) {
1588 if (candidate < best) {
1589 best = candidate;
1590 count = candidate_count;
1591 } else if (candidate == best) {
1592 count += candidate_count;
1593 }
1594 };
1595
1596 const std::size_t range_end_exclusive = range_end + 1;
1597 const std::size_t first_full_block =
1598 (range_begin + kBlockBits - 1) / kBlockBits;
1599 const std::size_t full_begin =
1600 std::min(range_end_exclusive, first_full_block * kBlockBits);
1601 if (range_begin < full_begin) {
1602 const ScanResult scan = scan_range(range_begin, full_begin);
1603 consider(scan.min_value, scan.min_count);
1604 value += scan.block_excess;
1605 }
1606
1607 const std::size_t last_full_block_exclusive =
1608 range_end_exclusive / kBlockBits;
1609 const std::size_t middle_begin = full_begin;
1610 const std::size_t middle_end =
1611 std::max(middle_begin, last_full_block_exclusive * kBlockBits);
1612 if (middle_begin < middle_end) {
1613 Cover cover;
1614 collect_cover(middle_begin, middle_end, cover);
1615 for (std::size_t i = 0; i < cover.size; ++i) {
1616 const NodeRef& node = cover.nodes[i];
1617 Summary summary = summary_at(node.level, node.index);
1618 consider(value + summary.min_excess, summary.min_count);
1619 value += summary.block_excess;
1620 }
1621 }
1622
1623 if (middle_end < range_end_exclusive) {
1624 const ScanResult scan = scan_range(middle_end, range_end_exclusive);
1625 consider(value + scan.min_value, scan.min_count);
1626 }
1627 return {best == std::numeric_limits<std::int64_t>::max() ? 0 : best, count};
1628 }
1629
1641 std::size_t range_min_select(std::size_t range_begin,
1642 std::size_t range_end,
1643 std::int64_t target,
1644 std::uint64_t rank) const {
1645 std::int64_t value = 0;
1646 const std::size_t range_end_exclusive = range_end + 1;
1647 const std::size_t first_full_block =
1648 (range_begin + kBlockBits - 1) / kBlockBits;
1649 const std::size_t full_begin =
1650 std::min(range_end_exclusive, first_full_block * kBlockBits);
1651 if (range_begin < full_begin) {
1652 const ScanResult scan = scan_range(range_begin, full_begin);
1653 if (scan.min_value == target) {
1654 if (rank <= scan.min_count) {
1655 return qth_min_in_range(range_begin, full_begin, target, rank);
1656 }
1657 rank -= scan.min_count;
1658 }
1659 value += scan.block_excess;
1660 }
1661
1662 const std::size_t last_full_block_exclusive =
1663 range_end_exclusive / kBlockBits;
1664 const std::size_t middle_begin = full_begin;
1665 const std::size_t middle_end =
1666 std::max(middle_begin, last_full_block_exclusive * kBlockBits);
1667 if (middle_begin < middle_end) {
1668 Cover cover;
1669 collect_cover(middle_begin, middle_end, cover);
1670 for (std::size_t i = 0; i < cover.size; ++i) {
1671 const NodeRef& node = cover.nodes[i];
1672 Summary summary = summary_at(node.level, node.index);
1673 const std::int64_t candidate = value + summary.min_excess;
1674 if (candidate == target) {
1675 if (rank <= summary.min_count) {
1676 return descend_qth_min(node.level, node.index, target - value,
1677 rank);
1678 }
1679 rank -= summary.min_count;
1680 }
1681 value += summary.block_excess;
1682 }
1683 }
1684
1685 if (middle_end < range_end_exclusive) {
1686 const ScanResult scan = scan_range(middle_end, range_end_exclusive);
1687 if (value + scan.min_value == target) {
1688 return qth_min_in_range(middle_end, range_end_exclusive, target - value,
1689 rank);
1690 }
1691 }
1692 return npos;
1693 }
1694
1703 void collect_cover(std::size_t begin, std::size_t end, Cover& out) const {
1704 if (begin >= end || total_levels() == 0 || (begin % kBlockBits) != 0 ||
1705 (end % kBlockBits) != 0) {
1706 return;
1707 }
1708
1709 Cover right_cover;
1710 std::size_t level = 0;
1711 std::size_t left = begin / kBlockBits;
1712 std::size_t right = end / kBlockBits;
1713
1714 while (left < right) {
1715 if (!has_parent_level(level)) {
1716 for (std::size_t index = left; index < right; ++index) {
1717 out.push({level, index});
1718 }
1719 break;
1720 }
1721
1722 const std::size_t fanout = fanout_to_parent(level);
1723 while (left < right && (left % fanout) != 0) {
1724 out.push({level, left});
1725 ++left;
1726 }
1727 while (left < right && (right % fanout) != 0) {
1728 --right;
1729 right_cover.push({level, right});
1730 }
1731 left /= fanout;
1732 right /= fanout;
1733 ++level;
1734 }
1735
1736 while (right_cover.size > 0) {
1737 out.push(right_cover.nodes[--right_cover.size]);
1738 }
1739 }
1740
1751 std::size_t descend_first_extreme(std::size_t level,
1752 std::size_t index,
1753 std::int64_t target,
1754 bool find_min) const {
1755 while (level > 0) {
1756 const std::size_t child_level = level - 1;
1757 const std::size_t fanout = fanout_to_parent(child_level);
1758 const std::size_t child_begin = index * fanout;
1759 const std::size_t child_end =
1760 std::min(level_count(child_level), child_begin + fanout);
1761 std::int64_t prefix = 0;
1762 bool found = false;
1763 for (std::size_t child = child_begin; child < child_end; ++child) {
1764 Summary summary = summary_at(child_level, child);
1765 const std::int64_t candidate =
1766 prefix + (find_min ? summary.min_excess : summary.max_excess);
1767 if (candidate == target) {
1768 index = child;
1769 level = child_level;
1770 target -= prefix;
1771 found = true;
1772 break;
1773 }
1774 prefix += summary.block_excess;
1775 }
1776 if (!found) {
1777 return npos;
1778 }
1779 }
1780 return first_prefix_in_block(index, target);
1781 }
1782
1792 std::size_t first_prefix_in_block(std::size_t block_index,
1793 std::int64_t target) const {
1794 if (full_block_has_words(block_index) && target >= -512 && target <= 512) {
1795 std::uint64_t out[kBlockWords];
1796 excess_positions_512(bits_.data() + block_index * kBlockWords,
1797 static_cast<int>(target), out);
1798 for (std::size_t word = 0; word < kBlockWords; ++word) {
1799 const std::uint64_t mask = out[word];
1800 if (mask != 0) {
1801 return block_index * kBlockBits + word * 64 + std::countr_zero(mask);
1802 }
1803 }
1804 return npos;
1805 }
1806
1807 const std::size_t begin = block_index * kBlockBits;
1808 const std::size_t length = block_size(block_index);
1809 std::int64_t current = 0;
1810 for (std::size_t offset = 0; offset < length; ++offset) {
1811 current += bit(begin + offset) ? 1 : -1;
1812 if (current == target) {
1813 return begin + offset;
1814 }
1815 }
1816 return npos;
1817 }
1818
1829 std::size_t descend_qth_min(std::size_t level,
1830 std::size_t index,
1831 std::int64_t target,
1832 std::uint64_t rank) const {
1833 while (level > 0) {
1834 const std::size_t child_level = level - 1;
1835 const std::size_t fanout = fanout_to_parent(child_level);
1836 const std::size_t child_begin = index * fanout;
1837 const std::size_t child_end =
1838 std::min(level_count(child_level), child_begin + fanout);
1839 std::int64_t prefix = 0;
1840 bool found = false;
1841 for (std::size_t child = child_begin; child < child_end; ++child) {
1842 Summary summary = summary_at(child_level, child);
1843 const std::int64_t candidate = prefix + summary.min_excess;
1844 if (candidate == target) {
1845 if (rank <= summary.min_count) {
1846 index = child;
1847 level = child_level;
1848 target -= prefix;
1849 found = true;
1850 break;
1851 }
1852 rank -= summary.min_count;
1853 }
1854 prefix += summary.block_excess;
1855 }
1856 if (!found) {
1857 return npos;
1858 }
1859 }
1860 return qth_min_in_range(index * kBlockBits,
1861 index * kBlockBits + block_size(index), target,
1862 rank);
1863 }
1864
1875 std::size_t qth_min_in_range(std::size_t begin,
1876 std::size_t end,
1877 std::int64_t target,
1878 std::uint64_t rank) const {
1879 std::int64_t current = 0;
1880 for (std::size_t position = begin; position < end; ++position) {
1881 current += bit(position) ? 1 : -1;
1882 if (current == target && --rank == 0) {
1883 return position;
1884 }
1885 }
1886 return npos;
1887 }
1888
1898 ScanResult scan_range(std::size_t begin, std::size_t end) const {
1899 ScanResult result;
1900 const auto& lut = byte_lut();
1901 while (begin < end && (begin & 7) != 0) {
1902 append_scanned_bit(result, begin);
1903 ++begin;
1904 }
1905 while (begin + 8 <= end) {
1906 const ByteAgg& byte = lut[get_byte(begin)];
1907 const std::int64_t min_candidate = result.block_excess + byte.min_excess;
1908 if (min_candidate < result.min_value) {
1909 result.min_value = min_candidate;
1910 result.min_count = byte.min_count;
1911 result.min_position = begin + byte.pos_first_min;
1912 } else if (min_candidate == result.min_value) {
1913 result.min_count += byte.min_count;
1914 }
1915 const std::int64_t max_candidate = result.block_excess + byte.max_excess;
1916 if (max_candidate > result.max_value) {
1917 result.max_value = max_candidate;
1918 result.max_position = begin + byte.pos_first_max;
1919 }
1920 result.block_excess += byte.block_excess;
1921 begin += 8;
1922 }
1923 while (begin < end) {
1924 append_scanned_bit(result, begin);
1925 ++begin;
1926 }
1927 return result;
1928 }
1929
1937 void append_scanned_bit(ScanResult& result, std::size_t position) const {
1938 result.block_excess += bit(position) ? 1 : -1;
1939 if (result.block_excess < result.min_value) {
1940 result.min_value = result.block_excess;
1941 result.min_count = 1;
1942 result.min_position = position;
1943 } else if (result.block_excess == result.min_value) {
1944 ++result.min_count;
1945 }
1946 if (result.block_excess > result.max_value) {
1947 result.max_value = result.block_excess;
1948 result.max_position = position;
1949 }
1950 }
1951
1960 std::uint8_t get_byte(std::size_t bit_position) const {
1961 return static_cast<std::uint8_t>(
1962 (bits_[bit_position >> 6] >> (bit_position & 63)) & 0xffu);
1963 }
1964
1971 std::uint8_t bit(std::size_t position) const {
1972 return static_cast<std::uint8_t>((bits_[position >> 6] >> (position & 63)) &
1973 1ull);
1974 }
1975
1976 std::span<const std::uint64_t> bits_;
1977 std::optional<BitVector> rank_index_;
1978 std::size_t bit_count_ = 0;
1979 std::size_t block_count_ = 0;
1980 Summary top_summary_;
1981 std::vector<std::size_t> level_counts_;
1982 std::vector<std::vector<LowNode>> low_levels_;
1983 std::vector<std::vector<HighNode>> high_levels_;
1984};
1985
1986} // namespace pixie::experimental
CRTP facade for rank/select and range min-max tree operations.
Definition rmm_base.h:18
RmMBTree(std::span< const std::uint64_t > words, std::size_t bit_count, std::size_t=kBlockBits)
Construct an RmM btree over an external bit-vector span.
Definition rmm_btree.h:117