82class RmMBTree :
public RmMBase<RmMBTree<HighCacheLines, LowFanout>> {
84 static_assert(HighCacheLines > 0);
85 static_assert(LowFanout > 0);
87 static constexpr std::size_t npos =
89 static constexpr std::size_t kBlockBits = 512;
90 static constexpr std::size_t kBlockWords = kBlockBits / 64;
91 static constexpr std::size_t kCacheLineBytes = 64;
92 static constexpr std::size_t kLowFanout = LowFanout;
93 static constexpr std::size_t kHighFanout =
94 std::max<std::size_t>(2, (512 * HighCacheLines) / (4 * 64));
95 static constexpr std::size_t kMaxFanout = std::max(kLowFanout, kHighFanout);
96 static_assert(kMaxFanout <= 64);
98 kBlockBits * kLowFanout <=
99 static_cast<std::size_t
>(std::numeric_limits<std::int16_t>::max()));
101 RmMBTree() =
default;
102 RmMBTree(
const RmMBTree&) =
default;
103 RmMBTree(RmMBTree&&)
noexcept =
default;
104 RmMBTree& operator=(
const RmMBTree&) =
default;
105 RmMBTree& operator=(RmMBTree&&)
noexcept =
default;
117 explicit RmMBTree(std::span<const std::uint64_t> words,
118 std::size_t bit_count,
119 std::size_t = kBlockBits) {
120 build(words, bit_count);
123 std::size_t size_impl()
const {
return bit_count_; }
125 std::size_t rank1_impl(std::size_t end_position)
const {
126 return rank_index_ ? rank_index_->rank(end_position) : 0;
129 std::size_t rank0_impl(std::size_t end_position)
const {
130 return rank_index_ ? rank_index_->rank0(end_position) : 0;
133 std::size_t select1_impl(std::size_t rank)
const {
134 if (!rank_index_ || rank == 0) {
137 const std::size_t position = rank_index_->select(rank);
138 return position < bit_count_ ? position : npos;
141 std::size_t select0_impl(std::size_t rank)
const {
142 if (!rank_index_ || rank == 0) {
145 const std::size_t position = rank_index_->select0(rank);
146 return position < bit_count_ ? position : npos;
149 std::size_t rank10_impl(std::size_t end_position)
const {
150 if (end_position <= 1 || bit_count_ == 0) {
153 end_position = std::min(end_position, bit_count_);
154 std::size_t count = 0;
155 for (std::size_t position = 0; position + 1 < end_position; ++position) {
156 count += bit(position) == 1 && bit(position + 1) == 0;
161 std::size_t select10_impl(std::size_t rank)
const {
165 for (std::size_t position = 0; position + 1 < bit_count_; ++position) {
166 if (bit(position) == 1 && bit(position + 1) == 0 && --rank == 0) {
173 int excess_impl(std::size_t end_position)
const {
174 end_position = std::min(end_position, bit_count_);
175 return static_cast<int>(
176 2 *
static_cast<std::int64_t
>(rank1_impl(end_position)) -
177 static_cast<std::int64_t
>(end_position));
180 std::size_t fwdsearch_impl(std::size_t start_position,
int delta)
const {
181 if (start_position >= bit_count_) {
185 const std::size_t block_index = start_position / kBlockBits;
186 const std::size_t block_begin = block_index * kBlockBits;
187 const std::size_t start_offset = start_position - block_begin;
188 const std::size_t block_result =
189 find_fwd_in_block(block_index, start_offset, delta);
190 if (block_result != npos) {
194 const std::int64_t relative_target =
195 block_excess_at_local(block_index, start_offset) + delta;
196 const std::int64_t block_start_excess = prefix_excess_impl(block_begin);
197 const std::int64_t target = block_start_excess + relative_target;
198 std::size_t level = 0;
199 std::size_t index = block_index;
200 while (has_parent_level(level)) {
201 const std::size_t fanout = fanout_to_parent(level);
202 const std::size_t parent = index / fanout;
203 const std::size_t sibling_end =
204 std::min(level_count(level), parent * fanout + fanout);
205 const NodeScanResult scan = scan_children_fwd(
206 level, parent, index % fanout + 1, sibling_end - parent * fanout,
207 target, prefix_excess_impl(node_end_bit(level, index)));
209 const std::size_t result =
210 descend_fwd(level, scan.index, target, scan.node_start_excess);
211 if (result != npos) {
221 std::size_t bwdsearch_impl(std::size_t start_position,
int delta)
const {
222 if (start_position == 0 || start_position > bit_count_) {
226 const std::size_t block_index = (start_position - 1) / kBlockBits;
227 const std::size_t block_begin = block_index * kBlockBits;
228 const std::size_t end_offset = start_position - block_begin;
229 const std::size_t block_result =
230 find_bwd_in_block(block_index, end_offset, delta);
231 if (block_result != npos) {
235 const std::int64_t relative_target =
236 block_excess_at_local(block_index, end_offset) + delta;
237 const std::int64_t block_start_excess = prefix_excess_impl(block_begin);
238 const std::int64_t target = block_start_excess + relative_target;
239 std::size_t level = 0;
240 std::size_t index = block_index;
241 while (has_parent_level(level)) {
242 const std::size_t fanout = fanout_to_parent(level);
243 const std::size_t parent = index / fanout;
244 const NodeScanResult scan =
245 scan_children_bwd(level, parent, 0, index % fanout, target,
246 prefix_excess_impl(node_start_bit(level, index)));
248 if (scan.boundary_only) {
249 return node_start_bit(level, scan.index);
251 const std::size_t result =
252 descend_bwd(level, scan.index, target, scan.node_start_excess);
253 if (result != npos) {
263 std::size_t range_min_query_pos_impl(std::size_t range_begin,
264 std::size_t range_end)
const {
265 if (range_begin > range_end || range_end >= bit_count_) {
268 return range_extreme_query_pos(range_begin, range_end,
true);
271 int range_min_query_val_impl(std::size_t range_begin,
272 std::size_t range_end)
const {
273 if (range_begin > range_end || range_end >= bit_count_) {
276 return range_extreme_query_val(range_begin, range_end,
true);
279 std::size_t range_max_query_pos_impl(std::size_t range_begin,
280 std::size_t range_end)
const {
281 if (range_begin > range_end || range_end >= bit_count_) {
284 return range_extreme_query_pos(range_begin, range_end,
false);
287 int range_max_query_val_impl(std::size_t range_begin,
288 std::size_t range_end)
const {
289 if (range_begin > range_end || range_end >= bit_count_) {
292 return range_extreme_query_val(range_begin, range_end,
false);
295 std::size_t mincount_impl(std::size_t range_begin,
296 std::size_t range_end)
const {
297 if (range_begin > range_end || range_end >= bit_count_) {
300 return range_min_stats(range_begin, range_end).count;
303 std::size_t minselect_impl(std::size_t range_begin,
304 std::size_t range_end,
305 std::size_t rank)
const {
306 if (range_begin > range_end || range_end >= bit_count_ || rank == 0) {
309 const RangeMinStats stats = range_min_stats(range_begin, range_end);
310 if (rank > stats.count) {
313 return range_min_select(range_begin, range_end, stats.value, rank);
316 std::size_t close_impl(std::size_t open_position)
const {
317 if (open_position >= bit_count_) {
320 if (!bit(open_position)) {
321 return open_position;
323 return fwd_excess_at(open_position, -1);
326 std::size_t open_impl(std::size_t close_position)
const {
327 if (close_position >= bit_count_) {
330 if (bit(close_position)) {
331 return close_position;
333 return bwdsearch_impl(close_position + 1, 0);
336 std::size_t enclose_impl(std::size_t position)
const {
337 if (position >= bit_count_) {
340 if (!bit(position)) {
341 return open_impl(position);
343 return bwdsearch_impl(position + 1, -2);
358 std::size_t fwd_excess_at(std::size_t position,
int delta)
const {
359 if (position >= bit_count_) {
362 if (position + 1 >= bit_count_) {
365 return fwdsearch_impl(position + 1, delta);
369 std::uint64_t size_bits = 0;
370 std::uint64_t ones = 0;
371 std::int64_t block_excess = 0;
372 std::int64_t min_excess = 0;
373 std::int64_t max_excess = 0;
374 std::uint64_t min_count = 0;
377 template <
class Excess,
class Count, std::
size_t Fanout>
378 struct alignas(kCacheLineBytes) SummaryNode {
379 using ExcessType = Excess;
380 using CountType = Count;
381 static constexpr std::size_t kFanout = Fanout;
382 std::array<Excess, Fanout> prefix_excess{};
383 std::array<Excess, Fanout> min_excess{};
384 std::array<Excess, Fanout> max_excess{};
385 std::array<Count, Fanout> min_count{};
388 using LowNode = SummaryNode<std::int16_t, std::uint16_t, kLowFanout>;
389 using HighNode = SummaryNode<std::int64_t, std::uint64_t, kHighFanout>;
390 static_assert(
alignof(LowNode) == kCacheLineBytes);
391 static_assert(
alignof(HighNode) == kCacheLineBytes);
392 static_assert(
sizeof(LowNode) % kCacheLineBytes == 0);
393 static_assert(
sizeof(HighNode) % kCacheLineBytes == 0);
396 std::int8_t block_excess = 0;
397 std::int8_t min_excess = 0;
398 std::int8_t max_excess = 0;
399 std::uint8_t min_count = 0;
400 std::uint8_t pos_first_min = 0;
401 std::uint8_t pos_first_max = 0;
404 static constexpr std::size_t kSearchChunkBits = 128;
405 static constexpr std::size_t kSearchChunkWords = kSearchChunkBits / 64;
406 static constexpr std::size_t kSearchChunkCount =
407 kBlockBits / kSearchChunkBits;
416 static const std::array<ByteAgg, 256>& byte_lut() {
417 static const std::array<ByteAgg, 256> table = [] {
418 std::array<ByteAgg, 256> result{};
419 for (
int byte_value = 0; byte_value < 256; ++byte_value) {
422 int minimum = std::numeric_limits<int>::max();
423 int maximum = std::numeric_limits<int>::min();
424 const auto bit_at = [&](
int bit_index) {
425 return (byte_value >> bit_index) & 1;
427 for (
int bit_index = 0; bit_index < 8; ++bit_index) {
428 const int value = bit_at(bit_index);
429 current += value ? 1 : -1;
430 if (current < minimum) {
433 agg.pos_first_min =
static_cast<std::uint8_t
>(bit_index);
434 }
else if (current == minimum) {
437 if (current > maximum) {
439 agg.pos_first_max =
static_cast<std::uint8_t
>(bit_index);
442 agg.block_excess =
static_cast<std::int8_t
>(current);
443 agg.min_excess =
static_cast<std::int8_t
>(minimum);
444 agg.max_excess =
static_cast<std::int8_t
>(maximum);
445 result[
static_cast<std::size_t
>(byte_value)] = agg;
462 void build(std::span<const std::uint64_t> words, std::size_t bit_count) {
463 const std::size_t required_words = (bit_count + 63) / 64;
464 if (words.size() < required_words) {
465 throw std::invalid_argument(
466 "RmMBTree input span is shorter than bit_count");
470 bit_count_ = bit_count;
471 rank_index_.emplace(words, bit_count);
472 block_count_ = (bit_count_ + kBlockBits - 1) / kBlockBits;
473 std::vector<Summary> block_summaries(block_count_);
475 for (std::size_t block_index = 0; block_index < block_count_;
477 const std::size_t block_begin = block_index * kBlockBits;
478 const std::size_t block_end =
479 std::min(bit_count_, block_begin + kBlockBits);
480 block_summaries[block_index] =
481 summarize_bits(block_begin, block_end - block_begin);
484 build_levels(block_summaries);
494 void build_levels(
const std::vector<Summary>& block_summaries) {
495 level_counts_.clear();
497 high_levels_.clear();
498 top_summary_ = Summary{};
499 level_counts_.push_back(block_summaries.size());
500 if (block_summaries.empty()) {
504 std::vector<Summary> current = block_summaries;
505 current = build_parent_level<LowNode>(current, low_levels_.emplace_back());
506 level_counts_.push_back(current.size());
508 build_parent_level<HighNode>(current, high_levels_.emplace_back());
509 level_counts_.push_back(current.size());
510 while (current.size() > 1) {
512 build_parent_level<HighNode>(current, high_levels_.emplace_back());
513 level_counts_.push_back(current.size());
515 top_summary_ = current.front();
528 template <
class Node>
529 static std::vector<Summary> build_parent_level(
const std::vector<Summary>& in,
530 std::vector<Node>& nodes) {
531 constexpr std::size_t fanout = Node::kFanout;
532 std::vector<Summary> out((in.size() + fanout - 1) / fanout);
533 nodes.resize(out.size());
534 for (std::size_t parent = 0; parent < out.size(); ++parent) {
535 const std::size_t begin = parent * fanout;
536 const std::size_t end = std::min(in.size(), begin + fanout);
538 for (std::size_t i = begin; i < end; ++i) {
539 store_child_summary(nodes[parent], i - begin, combined.block_excess,
541 combined = append(combined, in[i]);
543 out[parent] = combined;
560 template <
class Node>
561 static void store_child_summary(Node& node,
563 std::int64_t prefix_excess,
564 const Summary& summary) {
565 node.prefix_excess[slot] =
566 static_cast<typename decltype(node.prefix_excess)::value_type
>(
567 prefix_excess + summary.block_excess);
568 node.min_excess[slot] =
569 static_cast<typename decltype(node.min_excess)::value_type
>(
571 node.max_excess[slot] =
572 static_cast<typename decltype(node.max_excess)::value_type
>(
574 node.min_count[slot] =
575 static_cast<typename decltype(node.min_count)::value_type
>(
588 Summary summarize_bits(std::size_t begin, std::size_t length)
const {
590 summary.size_bits = length;
595 int minimum = std::numeric_limits<int>::max();
596 int maximum = std::numeric_limits<int>::min();
597 for (std::size_t offset = 0; offset < length; ++offset) {
598 const std::uint8_t value = bit(begin + offset);
599 summary.ones += value;
600 current += value ? 1 : -1;
601 if (current < minimum) {
603 summary.min_count = 1;
604 }
else if (current == minimum) {
607 if (current > maximum) {
611 summary.block_excess = current;
612 summary.min_excess = minimum;
613 summary.max_excess = maximum;
626 static Summary append(Summary left,
const Summary& right) {
627 if (left.size_bits == 0) {
630 if (right.size_bits == 0) {
634 result.size_bits = left.size_bits + right.size_bits;
635 result.ones = left.ones + right.ones;
636 result.block_excess = left.block_excess + right.block_excess;
638 std::min(left.min_excess, left.block_excess + right.min_excess);
640 std::max(left.max_excess, left.block_excess + right.max_excess);
641 result.min_count = 0;
642 if (left.min_excess == result.min_excess) {
643 result.min_count += left.min_count;
645 if (left.block_excess + right.min_excess == result.min_excess) {
646 result.min_count += right.min_count;
658 std::size_t block_size(std::size_t block_index)
const {
659 const std::size_t begin = block_index * kBlockBits;
660 return std::min(bit_count_ - begin, kBlockBits);
670 bool full_block_has_words(std::size_t block_index)
const {
671 return block_size(block_index) == kBlockBits &&
672 (block_index + 1) * kBlockWords <= bits_.size();
684 std::int64_t block_excess_at_local(std::size_t block_index,
685 std::size_t offset)
const {
686 const std::size_t length = block_size(block_index);
687 offset = std::min(offset, length);
692 const std::size_t first_word = block_index * kBlockWords;
693 std::size_t remaining = offset;
694 std::int64_t ones = 0;
695 std::size_t word_offset = 0;
696 while (remaining >= 64) {
697 ones += std::popcount(bits_[first_word + word_offset]);
701 if (remaining != 0) {
702 ones += std::popcount(bits_[first_word + word_offset] &
703 first_bits_mask(remaining));
705 return 2 * ones -
static_cast<std::int64_t
>(offset);
714 static int chunk_excess_128(
const std::uint64_t* chunk) {
715 return 2 *
static_cast<int>(std::popcount(chunk[0]) +
716 std::popcount(chunk[1])) -
717 static_cast<int>(kSearchChunkBits);
731 std::size_t find_fwd_in_block(std::size_t block_index,
732 std::size_t start_offset,
733 std::int64_t delta)
const {
734 const std::size_t length = block_size(block_index);
735 if (start_offset >= length) {
739 if (full_block_has_words(block_index)) {
740 const std::uint64_t* block = bits_.data() + block_index * kBlockWords;
741 const std::size_t first_chunk = start_offset / kSearchChunkBits;
742 std::int64_t target = delta;
743 for (std::size_t chunk = first_chunk; chunk < kSearchChunkCount;
745 const std::uint64_t* chunk_words = block + chunk * kSearchChunkWords;
746 const std::size_t local_start =
747 chunk == first_chunk ? start_offset - chunk * kSearchChunkBits : 0;
748 if (chunk == first_chunk) {
749 target += prefix_excess_128(chunk_words, local_start);
751 int block_excess = 0;
752 const std::size_t offset = forward_search_128(
753 chunk_words,
static_cast<int>(target), local_start, &block_excess);
754 if (offset != kSearchChunkBits) {
755 return block_index * kBlockBits + chunk * kSearchChunkBits + offset;
757 target -= block_excess;
762 std::int64_t current = block_excess_at_local(block_index, start_offset);
763 const std::int64_t relative_target = current + delta;
764 const std::size_t block_begin = block_index * kBlockBits;
765 for (std::size_t offset = start_offset; offset < length; ++offset) {
766 current += bit(block_begin + offset) ? 1 : -1;
767 if (current == relative_target) {
768 return block_index * kBlockBits + offset;
785 std::size_t find_bwd_in_block(std::size_t block_index,
786 std::size_t end_offset,
787 std::int64_t delta)
const {
788 if (end_offset == 0) {
791 const std::size_t max_prefix_length = end_offset - 1;
793 if (full_block_has_words(block_index)) {
794 const std::uint64_t* block = bits_.data() + block_index * kBlockWords;
795 const std::size_t last_chunk = max_prefix_length / kSearchChunkBits;
796 std::int64_t target = delta;
797 for (std::size_t chunk = last_chunk + 1; chunk > 0;) {
799 const std::uint64_t* chunk_words = block + chunk * kSearchChunkWords;
800 const std::size_t local_end =
801 chunk == last_chunk ? end_offset - chunk * kSearchChunkBits
803 if (chunk == last_chunk) {
804 target += prefix_excess_128(chunk_words, local_end);
806 int block_excess = 0;
807 const std::size_t offset = backward_search_128(
808 chunk_words,
static_cast<int>(target), local_end, &block_excess);
809 if (offset != kSearchChunkBits) {
810 return block_index * kBlockBits + chunk * kSearchChunkBits + offset;
813 target += chunk_excess_128(block + (chunk - 1) * kSearchChunkWords);
819 const std::int64_t relative_target =
820 block_excess_at_local(block_index, end_offset) + delta;
821 std::int64_t current =
822 block_excess_at_local(block_index, max_prefix_length);
823 const std::size_t block_begin = block_index * kBlockBits;
824 for (std::size_t prefix_length = max_prefix_length; prefix_length > 0;
826 if (current == relative_target) {
827 return block_index * kBlockBits + prefix_length;
829 current -= bit(block_begin + prefix_length - 1) ? 1 : -1;
831 return relative_target == 0 ? block_index * kBlockBits : npos;
845 std::size_t descend_fwd(std::size_t level,
848 std::int64_t node_start_excess)
const {
850 const std::size_t child_level = level - 1;
851 const std::size_t fanout = fanout_to_parent(child_level);
852 const std::size_t child_begin = index * fanout;
853 const std::size_t child_end =
854 std::min(level_count(child_level), child_begin + fanout);
855 const NodeScanResult scan =
856 scan_children_fwd(child_level, index, 0, child_end - child_begin,
857 target, node_start_excess);
863 node_start_excess = scan.node_start_excess;
865 return find_fwd_in_block(index, 0, target - node_start_excess);
879 std::size_t descend_bwd(std::size_t level,
882 std::int64_t node_start_excess)
const {
884 const std::size_t child_level = level - 1;
885 const std::size_t fanout = fanout_to_parent(child_level);
886 const std::size_t child_begin = index * fanout;
887 const std::size_t child_end =
888 std::min(level_count(child_level), child_begin + fanout);
889 std::int64_t child_start_excess =
890 node_start_excess + summary_at(level, index).block_excess;
891 const NodeScanResult scan =
892 scan_children_bwd(child_level, index, 0, child_end - child_begin,
893 target, child_start_excess);
897 if (scan.boundary_only) {
898 return node_start_bit(child_level, scan.index);
902 node_start_excess = scan.node_start_excess;
904 const std::int64_t block_excess = summary_at(0, index).block_excess;
905 return find_bwd_in_block(index, block_size(index),
906 target - node_start_excess - block_excess);
909 struct NodeScanResult {
911 bool boundary_only =
false;
912 std::size_t index = 0;
913 std::int64_t node_start_excess = 0;
928 NodeScanResult scan_children_fwd(std::size_t child_level,
930 std::size_t begin_slot,
931 std::size_t end_slot,
933 std::int64_t begin_excess)
const {
934 if (begin_slot >= end_slot) {
937 if (child_level == 0) {
938 return scan_node_fwd(low_levels_[0][parent], child_level, parent,
939 begin_slot, end_slot, target, begin_excess);
941 return scan_node_fwd(high_levels_[child_level - 1][parent], child_level,
942 parent, begin_slot, end_slot, target, begin_excess);
958 NodeScanResult scan_children_bwd(std::size_t child_level,
960 std::size_t begin_slot,
961 std::size_t end_slot,
963 std::int64_t end_excess)
const {
964 if (begin_slot >= end_slot) {
967 if (child_level == 0) {
968 return scan_node_bwd(low_levels_[0][parent], child_level, parent,
969 begin_slot, end_slot, target, end_excess);
971 return scan_node_bwd(high_levels_[child_level - 1][parent], child_level,
972 parent, begin_slot, end_slot, target, end_excess);
989 template <
class Node>
990 NodeScanResult scan_node_fwd(
const Node& node,
991 std::size_t child_level,
993 std::size_t begin_slot,
994 std::size_t end_slot,
996 std::int64_t begin_excess)
const {
997 const std::int64_t node_base_excess =
998 begin_excess - prefix_excess_at(node, begin_slot);
999 for (std::size_t slot = begin_slot; slot < end_slot;) {
1000 const std::size_t lane_count =
1001 std::min(vector_lane_count<Node>(), end_slot - slot);
1002 const std::uint32_t mask = matching_chunk_mask(
1003 node, slot, lane_count, target - node_base_excess,
false);
1005 const std::size_t lane = std::countr_zero(mask);
1006 const std::size_t matched_slot = slot + lane;
1007 return {
true,
false,
1008 parent * fanout_to_parent(child_level) + matched_slot,
1009 child_start_excess(node, node_base_excess, matched_slot)};
1030 template <
class Node>
1031 NodeScanResult scan_node_bwd(
const Node& node,
1032 std::size_t child_level,
1034 std::size_t begin_slot,
1035 std::size_t end_slot,
1036 std::int64_t target,
1037 std::int64_t end_excess)
const {
1038 const std::int64_t node_base_excess =
1039 end_excess - prefix_excess_at(node, end_slot);
1040 for (std::size_t slot_end = end_slot; slot_end > begin_slot;) {
1041 const std::size_t lane_count =
1042 std::min(vector_lane_count<Node>(), slot_end - begin_slot);
1043 const std::size_t slot = slot_end - lane_count;
1044 const std::uint32_t mask = matching_chunk_mask(
1045 node, slot, lane_count, target - node_base_excess,
true);
1047 const std::size_t lane =
1048 static_cast<std::size_t
>(std::bit_width(mask) - 1);
1049 const std::size_t matched_slot = slot + lane;
1050 const std::int64_t relative_target =
1051 target - child_start_excess(node, node_base_excess, matched_slot);
1052 const bool interior_match =
1053 node.min_excess[matched_slot] <= relative_target &&
1054 relative_target <= node.max_excess[matched_slot];
1055 return {
true, !interior_match,
1056 parent * fanout_to_parent(child_level) + matched_slot,
1057 child_start_excess(node, node_base_excess, matched_slot)};
1074 template <
class Node>
1075 static std::int64_t prefix_excess_at(
const Node& node, std::size_t slot) {
1079 return prefix_through(node, slot - 1);
1092 template <
class Node>
1093 static std::int64_t child_start_excess(
const Node& node,
1094 std::int64_t node_base_excess,
1096 return node_base_excess + prefix_excess_at(node, slot);
1107 template <
class Node>
1108 static std::int64_t prefix_through(
const Node& node, std::size_t slot) {
1109 return static_cast<std::int64_t
>(node.prefix_excess[slot]);
1121 template <
class Node>
1122 static std::int64_t child_excess(
const Node& node, std::size_t slot) {
1123 return prefix_through(node, slot) - prefix_excess_at(node, slot);
1132 template <
class Node>
1133 static constexpr std::size_t vector_lane_count() {
1134 if constexpr (std::is_same_v<typename Node::ExcessType, std::int16_t>) {
1154 template <
class Node>
1155 static std::uint32_t matching_chunk_mask(
const Node& node,
1157 std::size_t lane_count,
1158 std::int64_t target_in_node,
1159 bool include_zero_boundary) {
1160#ifdef PIXIE_AVX2_SUPPORT
1161 if constexpr (std::is_same_v<typename Node::ExcessType, std::int16_t>) {
1162 if (lane_count == 16 &&
1163 target_in_node >= std::numeric_limits<std::int16_t>::min() &&
1164 target_in_node <= std::numeric_limits<std::int16_t>::max()) {
1165 alignas(32) std::int16_t prefix_before[16]{};
1166 fill_prefix_before(node, slot, prefix_before);
1167 return rmm_btree_match_mask_i16x16(
1168 prefix_before, node.min_excess.data() + slot,
1169 node.max_excess.data() + slot,
1170 static_cast<std::int16_t
>(target_in_node), include_zero_boundary);
1172 }
else if constexpr (std::is_same_v<
typename Node::ExcessType,
1174 if (lane_count == 4) {
1175 alignas(32) std::int64_t prefix_before[4]{};
1176 fill_prefix_before(node, slot, prefix_before);
1177 return rmm_btree_match_mask_i64x4(
1178 prefix_before, node.min_excess.data() + slot,
1179 node.max_excess.data() + slot, target_in_node,
1180 include_zero_boundary);
1184 std::uint32_t result = 0;
1185 for (std::size_t lane = 0; lane < lane_count; ++lane) {
1186 const std::int64_t rel =
1187 target_in_node - prefix_excess_at(node, slot + lane);
1188 const bool found = (include_zero_boundary && rel == 0) ||
1189 (node.min_excess[slot + lane] <= rel &&
1190 rel <= node.max_excess[slot + lane]);
1192 result |= std::uint32_t{1} << lane;
1210 template <
class Node>
1211 static void fill_prefix_before(
const Node& node,
1213 typename Node::ExcessType* out) {
1216 for (std::size_t lane = 1; lane < vector_lane_count<Node>(); ++lane) {
1217 out[lane] = node.prefix_excess[lane - 1];
1221 for (std::size_t lane = 0; lane < vector_lane_count<Node>(); ++lane) {
1222 out[lane] = node.prefix_excess[slot + lane - 1];
1234 static bool contains_fwd(
const Summary& summary,
1235 std::int64_t relative_target) {
1236 return summary.min_excess <= relative_target &&
1237 relative_target <= summary.max_excess;
1249 static bool contains_bwd(
const Summary& summary,
1250 std::int64_t relative_target) {
1251 return relative_target == 0 || contains_fwd(summary, relative_target);
1254 std::int64_t prefix_excess_impl(std::size_t end_position)
const {
1255 return 2 *
static_cast<std::int64_t
>(rank1_impl(end_position)) -
1256 static_cast<std::int64_t
>(end_position);
1265 bool has_parent_level(std::size_t level)
const {
1266 return level + 1 < total_levels() && level_count(level + 1) != 0;
1275 std::size_t total_levels()
const {
return level_counts_.size(); }
1283 std::size_t level_count(std::size_t level)
const {
1284 return level < level_counts_.size() ? level_counts_[level] : 0;
1295 Summary summary_at(std::size_t level, std::size_t index)
const {
1296 if (level + 1 >= total_levels()) {
1297 return top_summary_;
1299 const std::size_t parent_level = level + 1;
1300 const std::size_t fanout = fanout_to_parent(level);
1301 const std::size_t parent = index / fanout;
1302 const std::size_t slot = index % fanout;
1304 if (parent_level == 1) {
1305 const LowNode& node = low_levels_[0][parent];
1306 summary.block_excess = child_excess(node, slot);
1307 summary.min_excess = node.min_excess[slot];
1308 summary.max_excess = node.max_excess[slot];
1309 summary.min_count = node.min_count[slot];
1311 const HighNode& node = high_levels_[parent_level - 2][parent];
1312 summary.block_excess = child_excess(node, slot);
1313 summary.min_excess = node.min_excess[slot];
1314 summary.max_excess = node.max_excess[slot];
1315 summary.min_count = node.min_count[slot];
1327 static std::size_t fanout_to_parent(std::size_t level) {
1328 return level == 0 ? kLowFanout : kHighFanout;
1339 static std::size_t mul_clamped(std::size_t lhs, std::size_t rhs) {
1340 if (lhs != 0 && rhs > std::numeric_limits<std::size_t>::max() / lhs) {
1341 return std::numeric_limits<std::size_t>::max();
1353 static std::size_t level_span_bits(std::size_t level) {
1354 std::size_t span = kBlockBits;
1356 span = mul_clamped(span, kLowFanout);
1359 for (std::size_t i = 2; i <= level; ++i) {
1360 span = mul_clamped(span, kHighFanout);
1374 std::size_t node_start_bit(std::size_t level, std::size_t index)
const {
1375 const std::size_t span = level_span_bits(level);
1376 if (span != 0 && index > std::numeric_limits<std::size_t>::max() / span) {
1379 return std::min(bit_count_, index * span);
1390 std::size_t node_size_bits(std::size_t level, std::size_t index)
const {
1391 const std::size_t start = node_start_bit(level, index);
1392 if (start >= bit_count_) {
1395 return std::min(level_span_bits(level), bit_count_ - start);
1406 std::size_t node_end_bit(std::size_t level, std::size_t index)
const {
1407 const std::size_t start = node_start_bit(level, index);
1408 const std::uint64_t size = node_size_bits(level, index);
1409 if (size > std::numeric_limits<std::size_t>::max() - start) {
1412 return std::min(bit_count_, start +
static_cast<std::size_t
>(size));
1416 std::size_t level = 0;
1417 std::size_t index = 0;
1420 static constexpr std::size_t kMaxCoverNodes = 512;
1423 std::int64_t block_excess = 0;
1424 std::int64_t min_value = std::numeric_limits<std::int64_t>::max();
1425 std::int64_t max_value = std::numeric_limits<std::int64_t>::min();
1426 std::uint64_t min_count = 0;
1427 std::size_t min_position = npos;
1428 std::size_t max_position = npos;
1431 struct RangeExtremeResult {
1432 std::size_t position = npos;
1433 std::int64_t value = 0;
1436 struct RangeMinStats {
1437 std::int64_t value = 0;
1438 std::uint64_t count = 0;
1442 std::array<NodeRef, kMaxCoverNodes> nodes{};
1443 std::size_t size = 0;
1451 void push(NodeRef node) {
1452 if (size < nodes.size()) {
1453 nodes[size++] = node;
1467 std::size_t range_extreme_query_pos(std::size_t range_begin,
1468 std::size_t range_end,
1469 bool find_min)
const {
1470 return range_extreme_query(range_begin, range_end, find_min).position;
1482 int range_extreme_query_val(std::size_t range_begin,
1483 std::size_t range_end,
1484 bool find_min)
const {
1485 return static_cast<int>(
1486 range_extreme_query(range_begin, range_end, find_min).value);
1499 RangeExtremeResult range_extreme_query(std::size_t range_begin,
1500 std::size_t range_end,
1501 bool find_min)
const {
1502 std::int64_t value = 0;
1503 std::int64_t best = find_min ? std::numeric_limits<std::int64_t>::max()
1504 : std::numeric_limits<std::int64_t>::min();
1505 std::size_t best_position = npos;
1507 std::int64_t prefix_at_best_node = 0;
1508 bool best_is_node =
false;
1510 auto consider_point = [&](std::int64_t candidate, std::size_t position) {
1511 if ((find_min && candidate < best) || (!find_min && candidate > best)) {
1513 best_position = position;
1514 best_is_node =
false;
1518 const std::size_t range_end_exclusive = range_end + 1;
1519 const std::size_t first_full_block =
1520 (range_begin + kBlockBits - 1) / kBlockBits;
1521 const std::size_t full_begin =
1522 std::min(range_end_exclusive, first_full_block * kBlockBits);
1523 if (range_begin < full_begin) {
1524 const ScanResult scan = scan_range(range_begin, full_begin);
1525 consider_point(find_min ? scan.min_value : scan.max_value,
1526 find_min ? scan.min_position : scan.max_position);
1527 value += scan.block_excess;
1530 const std::size_t last_full_block_exclusive =
1531 range_end_exclusive / kBlockBits;
1532 const std::size_t middle_begin = full_begin;
1533 const std::size_t middle_end =
1534 std::max(middle_begin, last_full_block_exclusive * kBlockBits);
1535 if (middle_begin < middle_end) {
1537 collect_cover(middle_begin, middle_end, cover);
1538 for (std::size_t i = 0; i < cover.size; ++i) {
1539 const NodeRef& node = cover.nodes[i];
1540 Summary summary = summary_at(node.level, node.index);
1541 const std::int64_t candidate =
1542 value + (find_min ? summary.min_excess : summary.max_excess);
1543 if ((find_min && candidate < best) || (!find_min && candidate > best)) {
1546 prefix_at_best_node = value;
1547 best_is_node =
true;
1549 value += summary.block_excess;
1553 if (middle_end < range_end_exclusive) {
1554 const ScanResult scan = scan_range(middle_end, range_end_exclusive);
1555 const std::int64_t candidate =
1556 value + (find_min ? scan.min_value : scan.max_value);
1557 consider_point(candidate,
1558 find_min ? scan.min_position : scan.max_position);
1563 descend_first_extreme(best_node.level, best_node.index,
1564 best - prefix_at_best_node, find_min);
1566 return {best_position,
1567 best == std::numeric_limits<std::int64_t>::max() ||
1568 best == std::numeric_limits<std::int64_t>::min()
1581 RangeMinStats range_min_stats(std::size_t range_begin,
1582 std::size_t range_end)
const {
1583 std::int64_t value = 0;
1584 std::int64_t best = std::numeric_limits<std::int64_t>::max();
1585 std::uint64_t count = 0;
1587 auto consider = [&](std::int64_t candidate, std::uint64_t candidate_count) {
1588 if (candidate < best) {
1590 count = candidate_count;
1591 }
else if (candidate == best) {
1592 count += candidate_count;
1596 const std::size_t range_end_exclusive = range_end + 1;
1597 const std::size_t first_full_block =
1598 (range_begin + kBlockBits - 1) / kBlockBits;
1599 const std::size_t full_begin =
1600 std::min(range_end_exclusive, first_full_block * kBlockBits);
1601 if (range_begin < full_begin) {
1602 const ScanResult scan = scan_range(range_begin, full_begin);
1603 consider(scan.min_value, scan.min_count);
1604 value += scan.block_excess;
1607 const std::size_t last_full_block_exclusive =
1608 range_end_exclusive / kBlockBits;
1609 const std::size_t middle_begin = full_begin;
1610 const std::size_t middle_end =
1611 std::max(middle_begin, last_full_block_exclusive * kBlockBits);
1612 if (middle_begin < middle_end) {
1614 collect_cover(middle_begin, middle_end, cover);
1615 for (std::size_t i = 0; i < cover.size; ++i) {
1616 const NodeRef& node = cover.nodes[i];
1617 Summary summary = summary_at(node.level, node.index);
1618 consider(value + summary.min_excess, summary.min_count);
1619 value += summary.block_excess;
1623 if (middle_end < range_end_exclusive) {
1624 const ScanResult scan = scan_range(middle_end, range_end_exclusive);
1625 consider(value + scan.min_value, scan.min_count);
1627 return {best == std::numeric_limits<std::int64_t>::max() ? 0 : best, count};
1641 std::size_t range_min_select(std::size_t range_begin,
1642 std::size_t range_end,
1643 std::int64_t target,
1644 std::uint64_t rank)
const {
1645 std::int64_t value = 0;
1646 const std::size_t range_end_exclusive = range_end + 1;
1647 const std::size_t first_full_block =
1648 (range_begin + kBlockBits - 1) / kBlockBits;
1649 const std::size_t full_begin =
1650 std::min(range_end_exclusive, first_full_block * kBlockBits);
1651 if (range_begin < full_begin) {
1652 const ScanResult scan = scan_range(range_begin, full_begin);
1653 if (scan.min_value == target) {
1654 if (rank <= scan.min_count) {
1655 return qth_min_in_range(range_begin, full_begin, target, rank);
1657 rank -= scan.min_count;
1659 value += scan.block_excess;
1662 const std::size_t last_full_block_exclusive =
1663 range_end_exclusive / kBlockBits;
1664 const std::size_t middle_begin = full_begin;
1665 const std::size_t middle_end =
1666 std::max(middle_begin, last_full_block_exclusive * kBlockBits);
1667 if (middle_begin < middle_end) {
1669 collect_cover(middle_begin, middle_end, cover);
1670 for (std::size_t i = 0; i < cover.size; ++i) {
1671 const NodeRef& node = cover.nodes[i];
1672 Summary summary = summary_at(node.level, node.index);
1673 const std::int64_t candidate = value + summary.min_excess;
1674 if (candidate == target) {
1675 if (rank <= summary.min_count) {
1676 return descend_qth_min(node.level, node.index, target - value,
1679 rank -= summary.min_count;
1681 value += summary.block_excess;
1685 if (middle_end < range_end_exclusive) {
1686 const ScanResult scan = scan_range(middle_end, range_end_exclusive);
1687 if (value + scan.min_value == target) {
1688 return qth_min_in_range(middle_end, range_end_exclusive, target - value,
1703 void collect_cover(std::size_t begin, std::size_t end, Cover& out)
const {
1704 if (begin >= end || total_levels() == 0 || (begin % kBlockBits) != 0 ||
1705 (end % kBlockBits) != 0) {
1710 std::size_t level = 0;
1711 std::size_t left = begin / kBlockBits;
1712 std::size_t right = end / kBlockBits;
1714 while (left < right) {
1715 if (!has_parent_level(level)) {
1716 for (std::size_t index = left; index < right; ++index) {
1717 out.push({level, index});
1722 const std::size_t fanout = fanout_to_parent(level);
1723 while (left < right && (left % fanout) != 0) {
1724 out.push({level, left});
1727 while (left < right && (right % fanout) != 0) {
1729 right_cover.push({level, right});
1736 while (right_cover.size > 0) {
1737 out.push(right_cover.nodes[--right_cover.size]);
1751 std::size_t descend_first_extreme(std::size_t level,
1753 std::int64_t target,
1754 bool find_min)
const {
1756 const std::size_t child_level = level - 1;
1757 const std::size_t fanout = fanout_to_parent(child_level);
1758 const std::size_t child_begin = index * fanout;
1759 const std::size_t child_end =
1760 std::min(level_count(child_level), child_begin + fanout);
1761 std::int64_t prefix = 0;
1763 for (std::size_t child = child_begin; child < child_end; ++child) {
1764 Summary summary = summary_at(child_level, child);
1765 const std::int64_t candidate =
1766 prefix + (find_min ? summary.min_excess : summary.max_excess);
1767 if (candidate == target) {
1769 level = child_level;
1774 prefix += summary.block_excess;
1780 return first_prefix_in_block(index, target);
1792 std::size_t first_prefix_in_block(std::size_t block_index,
1793 std::int64_t target)
const {
1794 if (full_block_has_words(block_index) && target >= -512 && target <= 512) {
1795 std::uint64_t out[kBlockWords];
1796 excess_positions_512(bits_.data() + block_index * kBlockWords,
1797 static_cast<int>(target), out);
1798 for (std::size_t word = 0; word < kBlockWords; ++word) {
1799 const std::uint64_t mask = out[word];
1801 return block_index * kBlockBits + word * 64 + std::countr_zero(mask);
1807 const std::size_t begin = block_index * kBlockBits;
1808 const std::size_t length = block_size(block_index);
1809 std::int64_t current = 0;
1810 for (std::size_t offset = 0; offset < length; ++offset) {
1811 current += bit(begin + offset) ? 1 : -1;
1812 if (current == target) {
1813 return begin + offset;
1829 std::size_t descend_qth_min(std::size_t level,
1831 std::int64_t target,
1832 std::uint64_t rank)
const {
1834 const std::size_t child_level = level - 1;
1835 const std::size_t fanout = fanout_to_parent(child_level);
1836 const std::size_t child_begin = index * fanout;
1837 const std::size_t child_end =
1838 std::min(level_count(child_level), child_begin + fanout);
1839 std::int64_t prefix = 0;
1841 for (std::size_t child = child_begin; child < child_end; ++child) {
1842 Summary summary = summary_at(child_level, child);
1843 const std::int64_t candidate = prefix + summary.min_excess;
1844 if (candidate == target) {
1845 if (rank <= summary.min_count) {
1847 level = child_level;
1852 rank -= summary.min_count;
1854 prefix += summary.block_excess;
1860 return qth_min_in_range(index * kBlockBits,
1861 index * kBlockBits + block_size(index), target,
1875 std::size_t qth_min_in_range(std::size_t begin,
1877 std::int64_t target,
1878 std::uint64_t rank)
const {
1879 std::int64_t current = 0;
1880 for (std::size_t position = begin; position < end; ++position) {
1881 current += bit(position) ? 1 : -1;
1882 if (current == target && --rank == 0) {
1898 ScanResult scan_range(std::size_t begin, std::size_t end)
const {
1900 const auto& lut = byte_lut();
1901 while (begin < end && (begin & 7) != 0) {
1902 append_scanned_bit(result, begin);
1905 while (begin + 8 <= end) {
1906 const ByteAgg&
byte = lut[get_byte(begin)];
1907 const std::int64_t min_candidate = result.block_excess +
byte.min_excess;
1908 if (min_candidate < result.min_value) {
1909 result.min_value = min_candidate;
1910 result.min_count =
byte.min_count;
1911 result.min_position = begin +
byte.pos_first_min;
1912 }
else if (min_candidate == result.min_value) {
1913 result.min_count +=
byte.min_count;
1915 const std::int64_t max_candidate = result.block_excess +
byte.max_excess;
1916 if (max_candidate > result.max_value) {
1917 result.max_value = max_candidate;
1918 result.max_position = begin +
byte.pos_first_max;
1920 result.block_excess +=
byte.block_excess;
1923 while (begin < end) {
1924 append_scanned_bit(result, begin);
1937 void append_scanned_bit(ScanResult& result, std::size_t position)
const {
1938 result.block_excess += bit(position) ? 1 : -1;
1939 if (result.block_excess < result.min_value) {
1940 result.min_value = result.block_excess;
1941 result.min_count = 1;
1942 result.min_position = position;
1943 }
else if (result.block_excess == result.min_value) {
1946 if (result.block_excess > result.max_value) {
1947 result.max_value = result.block_excess;
1948 result.max_position = position;
1960 std::uint8_t get_byte(std::size_t bit_position)
const {
1961 return static_cast<std::uint8_t
>(
1962 (bits_[bit_position >> 6] >> (bit_position & 63)) & 0xffu);
1971 std::uint8_t bit(std::size_t position)
const {
1972 return static_cast<std::uint8_t
>((bits_[position >> 6] >> (position & 63)) &
1976 std::span<const std::uint64_t> bits_;
1977 std::optional<BitVector> rank_index_;
1978 std::size_t bit_count_ = 0;
1979 std::size_t block_count_ = 0;
1980 Summary top_summary_;
1981 std::vector<std::size_t> level_counts_;
1982 std::vector<std::vector<LowNode>> low_levels_;
1983 std::vector<std::vector<HighNode>> high_levels_;