Pixie
Loading...
Searching...
No Matches
rmm_tree.h
1#pragma once
2#include <immintrin.h>
3#include <pixie/bits.h>
4#include <pixie/rmm_base.h>
5
6#include <algorithm>
7#include <array>
8#include <bit>
9#include <climits>
10#include <cstddef>
11#include <cstdint>
12#include <limits>
13#include <span>
14#include <stdexcept>
15#include <vector>
16
17namespace pixie {
38class RmMTree : public RmMBase<RmMTree> {
39 // ------------ bitvector ------------
40 std::span<const std::uint64_t> bits; // LSB-first, externally owned
41 size_t num_bits = 0; // number of bits
42
43 // ------------ blocking ------------
44 size_t block_bits = 64; // block size (bits), leaf covers <= block_bits bits
45 size_t leaf_count = 0; // #leaves = ceil(num_bits/block_bits)
46
47 // ------------ tree arrays (heap order: 1 is root) ------------
48 // size of segment (in bits) covered by node
49 // needed for: rank1/rank0, select1/select0, select10,
50 // excess, fwdsearch/bwdsearch/close/open/enclose,
51 // range_min_query/range_max_query, minselect.
52 std::vector<uint32_t> segment_size_bits;
53
54 // node_total_excess = total excess (+1 for '1', -1 for '0') on the node
55 // needed for: rank1/rank0, select1/select0, excess,
56 // fwdsearch/bwdsearch/close/open/enclose,
57 // range_min_query/range_max_query, mincount/minselect.
58 std::vector<int32_t> node_total_excess;
59
60 // node_min_prefix_excess = minimum pref-excess on the node (from 0)
61 // needed for: fwdsearch/bwdsearch/close/open/enclose, range_min_query,
62 // mincount/minselect.
63 std::vector<int32_t> node_min_prefix_excess;
64
65 // node_max_prefix_excess = maximum pref-excess on the node (from 0)
66 // needed for: fwdsearch/bwdsearch/close/open/enclose, range_max_query.
67 std::vector<int32_t> node_max_prefix_excess;
68
69 // node_min_count = number of positions where the minimum is attained
70 // needed for: mincount/minselect.
71 std::vector<uint32_t> node_min_count;
72
73 // node_pattern10_count = # of "10" pattern occurrences inside the node
74 // needed for: rank10, select10.
75 std::vector<uint32_t> node_pattern10_count;
76
77 // node_first_bit = first bit (0/1), node_last_bit = last bit (0/1) of the
78 // segment (to handle "10" crossing)
79 // both needed for: rank10, select10.
80 std::vector<uint8_t> node_first_bit, node_last_bit;
81
82 public:
86 static constexpr size_t npos = std::numeric_limits<size_t>::max();
87
88#ifdef DEBUG
89 float built_overhead = 0.0;
90#endif
91
92 // --------- construction ----------
93
97 RmMTree() = default;
98
112 explicit RmMTree(std::span<const std::uint64_t> words,
113 size_t bit_count,
114 const size_t& leaf_block_bits /*0=auto*/ = 0,
115 const float& max_overhead /*<0=off*/ = -1.0) {
116 build_from_words(words, bit_count, leaf_block_bits, max_overhead);
117 }
118
119 size_t size_impl() const { return num_bits; }
120
121 // --------- queries: rank/select/excess ----------
122
127 size_t rank1_impl(const size_t& end_position) const {
128 if (end_position == 0) {
129 return 0;
130 }
131 const size_t block_index = block_of(end_position - 1);
132 size_t ones_count = 0;
133 if (block_index > 0) {
134 size_t nodes_buffer[64];
135 const size_t node_count =
136 cover_blocks_collect(0, block_index - 1, nodes_buffer);
137 for (size_t j = 0; j < node_count; ++j) {
138 ones_count += ones_in_node(nodes_buffer[j]);
139 }
140 }
141 const size_t block_begin = block_index * block_bits;
142 const size_t block_end = std::min(num_bits, block_begin + block_bits);
143 ones_count +=
144 rank1_in_block(block_begin, std::min(end_position, block_end));
145 return ones_count;
146 }
147
152 size_t rank0_impl(const size_t& end_position) const {
153 return end_position - rank1_impl(end_position);
154 }
155
160 size_t select1_impl(size_t target_one_rank) const {
161 if (target_one_rank == 0 || num_bits == 0) {
162 return npos;
163 }
164 size_t node_index = 1;
165 if (ones_in_node(node_index) < target_one_rank) {
166 return npos;
167 }
168 size_t segment_base = 0;
169 while (node_index < first_leaf_index) {
170 const size_t left_child = node_index << 1;
171 const size_t right_child = left_child | 1;
172 const uint32_t ones_in_left_child = ones_in_node(left_child);
173 if (ones_in_left_child >= target_one_rank) {
174 node_index = left_child;
175 } else {
176 target_one_rank -= ones_in_left_child;
177 segment_base += segment_size_bits[left_child];
178 node_index = right_child;
179 }
180 }
181 return select1_in_block(
182 segment_base,
183 std::min(segment_base + segment_size_bits[node_index], num_bits),
184 target_one_rank);
185 }
186
191 size_t select0_impl(size_t target_zero_rank) const {
192 if (target_zero_rank == 0 || num_bits == 0) {
193 return npos;
194 }
195 size_t node_index = 1;
196 const auto zeros_in_node = [&](const size_t& node) noexcept {
197 return segment_size_bits[node] - ones_in_node(node);
198 };
199 if (zeros_in_node(node_index) < target_zero_rank) {
200 return npos;
201 }
202 size_t segment_base = 0;
203 while (node_index < first_leaf_index) {
204 const size_t left_child = node_index << 1;
205 const size_t right_child = left_child | 1;
206 const size_t zeros_in_left_child = zeros_in_node(left_child);
207 if (zeros_in_left_child >= target_zero_rank) {
208 node_index = left_child;
209 } else {
210 target_zero_rank -= zeros_in_left_child;
211 segment_base += segment_size_bits[left_child];
212 node_index = right_child;
213 }
214 }
215 return select0_in_block(
216 segment_base,
217 std::min(segment_base + segment_size_bits[node_index], num_bits),
218 target_zero_rank);
219 }
220
226 size_t rank10_impl(const size_t& end_position) const {
227 if (end_position <= 1) {
228 return 0;
229 }
230 const size_t block_index = block_of(end_position - 1);
231 size_t pattern_count = 0;
232 int previous_last_bit = -1;
233
234 if (block_index > 0) {
235 const auto covered_nodes = cover_blocks(0, block_index - 1);
236 for (const size_t& node_index : covered_nodes) {
237 pattern_count += node_pattern10_count[node_index];
238 if (previous_last_bit != -1 && previous_last_bit == 1 &&
239 node_first_bit[node_index] == 0) {
240 ++pattern_count;
241 }
242 previous_last_bit = node_last_bit[node_index];
243 }
244 }
245 const size_t block_begin = block_index * block_bits;
246 pattern_count += rr_in_block(block_begin, end_position);
247 // boundary between the last full node and the leaf tail
248 if (block_index > 0 && end_position > block_begin &&
249 previous_last_bit == 1 && bit(block_begin) == 0) {
250 ++pattern_count;
251 }
252 return pattern_count;
253 }
254
259 size_t select10_impl(size_t target_pattern_rank) const {
260 if (target_pattern_rank == 0 || num_bits == 0) {
261 return npos;
262 }
263 size_t node_index = 1;
264 if (node_pattern10_count[node_index] < target_pattern_rank) {
265 return npos;
266 }
267 const size_t tree_size = segment_size_bits.size() - 1;
268 size_t segment_base = 0;
269 while (node_index < first_leaf_index) {
270 const size_t left_child = node_index << 1;
271 const size_t left_segment_size =
272 (left_child <= tree_size) ? segment_size_bits[left_child] : 0;
273 if (left_segment_size == 0) {
274 return npos;
275 }
276
277 const size_t left_count = node_pattern10_count[left_child];
278 if (left_count >= target_pattern_rank) {
279 node_index = left_child;
280 continue;
281 }
282
283 size_t remaining_rank = target_pattern_rank - left_count;
284 const size_t right_child = left_child | 1;
285 const bool has_right =
286 (right_child <= tree_size) && (segment_size_bits[right_child] != 0);
287 if (!has_right) {
288 return npos;
289 }
290
291 const size_t crossing_pattern =
292 (node_last_bit[left_child] == 1 && node_first_bit[right_child] == 0)
293 ? 1u
294 : 0u;
295 if (crossing_pattern) {
296 if (remaining_rank == 1) {
297 return segment_base + left_segment_size - 1;
298 }
299 --remaining_rank;
300 }
301 segment_base += left_segment_size;
302 node_index = right_child;
303 target_pattern_rank = remaining_rank;
304 }
305 return select10_in_block(
306 segment_base,
307 std::min(segment_base + segment_size_bits[node_index], num_bits),
308 target_pattern_rank);
309 }
310
314 inline int excess_impl(const size_t& end_position) const {
315 return int64_t(rank1_impl(end_position)) * 2 - int64_t(end_position);
316 }
317
324 size_t fwdsearch_impl(const size_t& start_position, const int& delta) const {
325 if (start_position >= num_bits) {
326 return npos;
327 }
328
329 // 1) scan the remainder of the current leaf
330 const size_t leaf_block_index = block_of(start_position);
331 const size_t block_begin = leaf_block_index * block_bits;
332 const size_t block_end = std::min(num_bits, block_begin + block_bits);
333 int leaf_delta = 0;
334 const size_t leaf_result = leaf_fwd_bp_simd(
335 leaf_block_index, block_begin, start_position, delta, leaf_delta);
336 if (leaf_result != npos) {
337 return leaf_result;
338 }
339
340 int remaining_delta = delta - leaf_delta;
341 size_t segment_base = block_end;
342 if (remaining_delta == 0) {
343 return segment_base;
344 }
345
346 // Tree-walk to the right:
347 // go up; whenever we come from a left child, try the right sibling subtree.
348 // If target is inside sibling -> descend; else skip it and continue up.
349 size_t node_index = leaf_index_of(block_begin);
350 const size_t tree_size = segment_size_bits.size() - 1;
351
352 // If we are already at/after the last leaf boundary, there's nothing to
353 // scan.
354 if (segment_base >= num_bits || leaf_block_index + 1 >= leaf_count) {
355 return npos;
356 }
357
358 while (node_index > 1) {
359 const bool is_left_child = ((node_index & 1u) == 0u);
360 if (is_left_child) {
361 const size_t sibling = node_index | 1u; // right sibling
362 if (sibling <= tree_size && segment_size_bits[sibling]) {
363 // Boundary at sibling start already handled above via
364 // remaining_delta==0.
365 if (node_min_prefix_excess[sibling] <= remaining_delta &&
366 remaining_delta <= node_max_prefix_excess[sibling]) {
367 return descend_fwd(sibling, remaining_delta, segment_base);
368 }
369 // Skip whole sibling subtree.
370 remaining_delta -= node_total_excess[sibling];
371 segment_base += segment_size_bits[sibling];
372 if (remaining_delta == 0) {
373 return segment_base;
374 }
375 }
376 }
377 node_index >>= 1;
378 }
379 return npos;
380 }
381
388 size_t bwdsearch_impl(const size_t& start_position, const int& delta) const {
389 if (start_position > num_bits || start_position == 0) {
390 return npos;
391 }
392
393 // 1) scan inside the block
394 const size_t leaf_block_index = block_of(start_position - 1);
395 const size_t block_begin = leaf_block_index * block_bits;
396 int leaf_delta =
397 0; // excess_impl(start_position) - excess_impl(block_begin)
398 const size_t leaf_result = leaf_bwd_bp_simd(
399 leaf_block_index, block_begin, start_position, delta, leaf_delta);
400 if (leaf_result != npos) {
401 return leaf_result;
402 }
403
404 // need = target - excess_impl(block_begin) = excess_impl(start_position) +
405 // delta - excess_impl(block_begin) = leaf_delta + delta
406 int remaining_delta = leaf_delta + delta;
407 size_t node_index = leaf_index_of(block_begin);
408 size_t segment_base = block_begin;
409 while (node_index > 1) {
410 if (node_index & 1) { // node_index is the right child
411 const size_t sibling_index = node_index ^ 1; // left sibling
412 const size_t sibling_border =
413 segment_base; // right border of the sibling (== start(node_index))
414 const int needed_inside_sibling =
415 remaining_delta +
416 node_total_excess[sibling_index]; // target in coordinates relative
417 // to the start of sibling
418 const bool allow_right_border =
419 (sibling_border != start_position); // j must be < start_position
420
421 // try inside the sibling, but return only if a position is found
422 if (needed_inside_sibling == 0 ||
423 (node_min_prefix_excess[sibling_index] <= needed_inside_sibling &&
424 needed_inside_sibling <= node_max_prefix_excess[sibling_index])) {
425 const size_t result = descend_bwd(
426 sibling_index, sibling_border - segment_size_bits[sibling_index],
427 needed_inside_sibling, sibling_border, allow_right_border);
428 if (result != npos) {
429 return result;
430 }
431 }
432 // junction between children is a separate branch (allowed only if < i)
433 if (needed_inside_sibling == node_total_excess[sibling_index] &&
434 sibling_border < start_position) {
435 return sibling_border;
436 }
437
438 // stepped over the sibling, shifted the zero point of the coordinates
439 remaining_delta += node_total_excess[sibling_index];
440 segment_base -= segment_size_bits[sibling_index];
441 }
442 node_index >>= 1;
443 }
444 return npos;
445 }
446
453 size_t range_min_query_pos_impl(const size_t& range_begin,
454 const size_t& range_end) const {
455 if (range_begin > range_end || range_end >= num_bits) {
456 return npos;
457 }
458
459 const size_t begin_block_index = block_of(range_begin);
460 const size_t begin_block_start = begin_block_index * block_bits;
461 const size_t begin_block_end =
462 std::min(num_bits, begin_block_start + block_bits);
463 const size_t end_block_index = block_of(range_end);
464 const size_t end_block_start = end_block_index * block_bits;
465
466 int best_value = INT_MAX;
467 size_t best_position = npos;
468 size_t chosen_node_index = 0;
469 int prefix_excess = 0, prefix_at_choice = 0;
470
471 // prefix
472 int min_prefix_first_chunk = INT_MAX;
473 size_t first_chunk_position = npos;
474 const size_t end_of_first_chunk = std::min(
475 range_end, (size_t)(begin_block_end ? begin_block_end - 1 : 0));
476 if (range_begin <= end_of_first_chunk) {
477 first_min_value_pos8(range_begin, end_of_first_chunk,
478 min_prefix_first_chunk, first_chunk_position);
479 prefix_excess =
480 (int64_t)rank1_in_block(range_begin, end_of_first_chunk + 1) * 2 -
481 int64_t(end_of_first_chunk + 1 - range_begin);
482 best_value = min_prefix_first_chunk;
483 best_position = first_chunk_position;
484 chosen_node_index = 0;
485 }
486
487 // middle
488 if (begin_block_index + 1 <= end_block_index - 1) {
489 size_t left_index = first_leaf_index + (begin_block_index + 1);
490 size_t right_index = first_leaf_index + (end_block_index - 1);
491 size_t right_nodes[64];
492 int right_nodes_count = 0;
493
494 while (left_index <= right_index) {
495 if (left_index & 1) {
496 const size_t node_index = left_index++;
497 const int candidate =
498 prefix_excess + node_min_prefix_excess[node_index];
499 if (candidate < best_value) {
500 best_value = candidate;
501 best_position = npos;
502 chosen_node_index = node_index;
503 prefix_at_choice = prefix_excess;
504 }
505 prefix_excess += node_total_excess[node_index];
506 }
507 if ((right_index & 1) == 0) {
508 right_nodes[right_nodes_count++] = right_index--;
509 }
510 left_index >>= 1;
511 right_index >>= 1;
512 }
513 while (right_nodes_count--) {
514 const size_t node_index = right_nodes[right_nodes_count];
515 const int candidate =
516 prefix_excess + node_min_prefix_excess[node_index];
517 if (candidate < best_value) {
518 best_value = candidate;
519 best_position = npos;
520 chosen_node_index = node_index;
521 prefix_at_choice = prefix_excess;
522 }
523 prefix_excess += node_total_excess[node_index];
524 }
525 }
526
527 // tail
528 if (end_block_index != begin_block_index) {
529 int min_prefix_last_chunk;
530 size_t last_chunk_position;
531 first_min_value_pos8(end_block_start, range_end, min_prefix_last_chunk,
532 last_chunk_position);
533 const int candidate = prefix_excess + min_prefix_last_chunk;
534 if (candidate < best_value) {
535 best_value = candidate;
536 best_position = last_chunk_position;
537 chosen_node_index = 0;
538 }
539 }
540
541 if (best_position != npos) {
542 return best_position;
543 }
544
545 return descend_first_min(chosen_node_index, best_value - prefix_at_choice,
546 node_base(chosen_node_index));
547 }
548
555 int range_min_query_val_impl(const size_t& range_begin,
556 const size_t& range_end) const {
557 if (range_begin > range_end || range_end >= num_bits) {
558 return 0;
559 }
560 size_t min_position = range_min_query_pos_impl(range_begin, range_end);
561 if (min_position == npos) {
562 return 0;
563 }
564 return excess_impl(min_position + 1) - excess_impl(range_begin);
565 }
566
573 size_t range_max_query_pos_impl(const size_t& range_begin,
574 const size_t& range_end) const {
575 if (range_begin > range_end || range_end >= num_bits) {
576 return npos;
577 }
578
579 const size_t begin_block_index = block_of(range_begin);
580 const size_t begin_block_start = begin_block_index * block_bits;
581 const size_t begin_block_end =
582 std::min(num_bits, begin_block_start + block_bits);
583 const size_t end_block_index = block_of(range_end);
584 const size_t end_block_start = end_block_index * block_bits;
585
586 int best_value = INT_MIN;
587 size_t best_position = npos;
588 size_t chosen_node_index = 0;
589 int prefix_excess = 0, prefix_at_choice = 0;
590
591 // prefix
592 int max_prefix_first_chunk = INT_MIN;
593 size_t first_chunk_position = npos;
594 const size_t end_of_first_chunk = std::min(
595 range_end, (size_t)(begin_block_end ? begin_block_end - 1 : 0));
596 if (range_begin <= end_of_first_chunk) {
597 first_max_value_pos8(range_begin, end_of_first_chunk,
598 max_prefix_first_chunk, first_chunk_position);
599 prefix_excess =
600 (int64_t)rank1_in_block(range_begin, end_of_first_chunk + 1) * 2 -
601 int64_t(end_of_first_chunk + 1 - range_begin);
602 best_value = max_prefix_first_chunk;
603 best_position = first_chunk_position;
604 chosen_node_index = 0;
605 }
606
607 // middle
608 if (begin_block_index + 1 <= end_block_index - 1) {
609 size_t left_index = first_leaf_index + (begin_block_index + 1);
610 size_t right_index = first_leaf_index + (end_block_index - 1);
611 size_t right_nodes[64];
612 int right_nodes_count = 0;
613
614 while (left_index <= right_index) {
615 if (left_index & 1) {
616 const size_t node_index = left_index++;
617 const int candidate =
618 prefix_excess + node_max_prefix_excess[node_index];
619 if (candidate > best_value) {
620 best_value = candidate;
621 best_position = npos;
622 chosen_node_index = node_index;
623 prefix_at_choice = prefix_excess;
624 }
625 prefix_excess += node_total_excess[node_index];
626 }
627 if ((right_index & 1) == 0) {
628 right_nodes[right_nodes_count++] = right_index--;
629 }
630 left_index >>= 1;
631 right_index >>= 1;
632 }
633 while (right_nodes_count--) {
634 const size_t node_index = right_nodes[right_nodes_count];
635 const int candidate =
636 prefix_excess + node_max_prefix_excess[node_index];
637 if (candidate > best_value) {
638 best_value = candidate;
639 best_position = npos;
640 chosen_node_index = node_index;
641 prefix_at_choice = prefix_excess;
642 }
643 prefix_excess += node_total_excess[node_index];
644 }
645 }
646
647 // tail
648 if (end_block_index != begin_block_index) {
649 int max_prefix_last_chunk;
650 size_t last_chunk_position;
651 first_max_value_pos8(end_block_start, range_end, max_prefix_last_chunk,
652 last_chunk_position);
653 const int candidate = prefix_excess + max_prefix_last_chunk;
654 if (candidate > best_value) {
655 best_value = candidate;
656 best_position = last_chunk_position;
657 chosen_node_index = 0;
658 }
659 }
660
661 if (best_position != npos) {
662 return best_position;
663 }
664
665 return descend_first_max(chosen_node_index, best_value - prefix_at_choice,
666 node_base(chosen_node_index));
667 }
668
673 int range_max_query_val_impl(const size_t& range_begin,
674 const size_t& range_end) const {
675 if (range_begin > range_end || range_end >= num_bits) {
676 return 0;
677 }
678 size_t max_position = range_max_query_pos_impl(range_begin, range_end);
679 if (max_position == npos) {
680 return 0;
681 }
682 return excess_impl(max_position + 1) - excess_impl(range_begin);
683 }
684
689 size_t mincount_impl(const size_t& range_begin,
690 const size_t& range_end) const {
691 if (range_begin > range_end || range_end >= num_bits) {
692 return 0;
693 }
694
695 const size_t begin_block_index = block_of(range_begin);
696 const size_t begin_block_start = begin_block_index * block_bits;
697 const size_t begin_block_end =
698 std::min(num_bits, begin_block_start + block_bits);
699 const size_t end_block_index = block_of(range_end);
700 const size_t end_block_start = end_block_index * block_bits;
701
702 int best_value = INT_MAX;
703 size_t min_count = 0;
704 int prefix_excess = 0;
705
706 // first chunk
707 {
708 int current_excess = 0, min_value = INT_MAX, local_count = 0;
709 const size_t end_of_first_chunk =
710 std::min(range_end, begin_block_end - 1);
711 for (size_t position = range_begin; position <= end_of_first_chunk;
712 ++position) {
713 current_excess += bit(position) ? +1 : -1;
714 if (current_excess < min_value) {
715 min_value = current_excess;
716 local_count = 1;
717 } else if (current_excess == min_value) {
718 ++local_count;
719 }
720 }
721 best_value = min_value;
722 min_count = local_count;
723 prefix_excess = current_excess; // offset toward the middle
724 }
725
726 // middle
727 if (begin_block_index + 1 <= end_block_index - 1) {
728 const auto middle_nodes =
729 cover_blocks(begin_block_index + 1, end_block_index - 1);
730 for (const size_t& node_index : middle_nodes) {
731 const int candidate =
732 prefix_excess + node_min_prefix_excess[node_index];
733 if (candidate < best_value) {
734 best_value = candidate;
735 min_count = node_min_count[node_index];
736 } else if (candidate == best_value) {
737 min_count += node_min_count[node_index];
738 }
739 prefix_excess += node_total_excess[node_index];
740 }
741 }
742
743 // last chunk
744 if (end_block_index != begin_block_index) {
745 int current_excess = 0, min_value = INT_MAX, local_count = 0;
746 for (size_t position = end_block_start; position <= range_end;
747 ++position) {
748 current_excess += bit(position) ? +1 : -1;
749 if (current_excess < min_value) {
750 min_value = current_excess;
751 local_count = 1;
752 } else if (current_excess == min_value) {
753 ++local_count;
754 }
755 }
756 const int candidate = prefix_excess + min_value;
757 if (candidate < best_value) {
758 best_value = candidate;
759 min_count = local_count;
760 } else if (candidate == best_value) {
761 min_count += local_count;
762 }
763 }
764 return min_count;
765 }
766
773 size_t minselect_impl(const size_t& range_begin,
774 const size_t& range_end,
775 size_t target_min_rank) const {
776 if (range_begin > range_end || range_end >= num_bits ||
777 target_min_rank == 0) {
778 return npos;
779 }
780
781 const size_t begin_block_index = block_of(range_begin);
782 const size_t begin_block_start = begin_block_index * block_bits;
783 const size_t begin_block_end =
784 std::min(num_bits, begin_block_start + block_bits);
785 const size_t end_block_index = block_of(range_end);
786 const size_t end_block_start = end_block_index * block_bits;
787
788 // prefix
789 const size_t end_of_first_chunk = std::min(range_end, begin_block_end - 1);
790 int current_first_chunk_excess = 0, min_first_chunk = 0;
791 uint32_t count_first_chunk = 0;
792
793 if (range_begin <= end_of_first_chunk) {
794 scan_range_min_count8(range_begin, end_of_first_chunk,
795 current_first_chunk_excess, min_first_chunk,
796 count_first_chunk);
797 } else {
798 current_first_chunk_excess = 0;
799 min_first_chunk = INT_MAX;
800 count_first_chunk = 0;
801 }
802
803 int best_value = (min_first_chunk == INT_MAX ? INT_MAX : min_first_chunk);
804 size_t total_count =
805 (min_first_chunk == INT_MAX ? 0u : (size_t)count_first_chunk);
806 int prefix_excess = current_first_chunk_excess; // offset for middle
807
808 size_t left_index = first_leaf_index + begin_block_index + 1;
809 size_t right_index = first_leaf_index + end_block_index - 1;
810 size_t right_nodes[64];
811 int right_nodes_count = 0;
812
813 // middle
814 if (begin_block_index + 1 <= end_block_index - 1) {
815 while (left_index <= right_index) {
816 if (left_index & 1) {
817 const int candidate =
818 prefix_excess + node_min_prefix_excess[left_index];
819 if (candidate < best_value) {
820 best_value = candidate;
821 total_count = node_min_count[left_index];
822 } else if (candidate == best_value) {
823 total_count += node_min_count[left_index];
824 }
825 prefix_excess += node_total_excess[left_index++];
826 }
827 if ((right_index & 1) == 0) {
828 right_nodes[right_nodes_count++] = right_index--;
829 }
830 left_index >>= 1;
831 right_index >>= 1;
832 }
833 while (right_nodes_count--) {
834 const size_t node_index = right_nodes[right_nodes_count];
835 const int candidate =
836 prefix_excess + node_min_prefix_excess[node_index];
837 if (candidate < best_value) {
838 best_value = candidate;
839 total_count = node_min_count[node_index];
840 } else if (candidate == best_value) {
841 total_count += node_min_count[node_index];
842 }
843 prefix_excess += node_total_excess[node_index];
844 }
845 }
846
847 // tail
848 int current_last_chunk_excess = 0, min_last_chunk = INT_MAX;
849 uint32_t count_last_chunk = 0;
850 if (end_block_index != begin_block_index) {
851 scan_range_min_count8(end_block_start, range_end,
852 current_last_chunk_excess, min_last_chunk,
853 count_last_chunk);
854 const int candidate = prefix_excess + min_last_chunk;
855 if (candidate < best_value) {
856 best_value = candidate;
857 total_count = count_last_chunk;
858 } else if (candidate == best_value) {
859 total_count += count_last_chunk;
860 }
861 }
862
863 if (target_min_rank > total_count) {
864 return npos;
865 }
866
867 // prefix
868 if (min_first_chunk == best_value && count_first_chunk) {
869 if (target_min_rank <= count_first_chunk) {
870 return qth_min_in_block(range_begin, end_of_first_chunk,
871 target_min_rank);
872 }
873 target_min_rank -= count_first_chunk;
874 }
875
876 // middle
877 prefix_excess = current_first_chunk_excess;
878 if (begin_block_index + 1 <= end_block_index - 1) {
879 left_index = first_leaf_index + (begin_block_index + 1);
880 right_index = first_leaf_index + (end_block_index - 1);
881 right_nodes_count = 0;
882 while (left_index <= right_index) {
883 if (left_index & 1) {
884 const size_t node_index = left_index++;
885 const int candidate =
886 prefix_excess + node_min_prefix_excess[node_index];
887 if (candidate == best_value) {
888 if (target_min_rank <= node_min_count[node_index]) {
889 return descend_qth_min(node_index, best_value - prefix_excess,
890 target_min_rank, node_base(node_index));
891 }
892 target_min_rank -= node_min_count[node_index];
893 }
894 prefix_excess += node_total_excess[node_index];
895 }
896 if (!(right_index & 1)) {
897 right_nodes[right_nodes_count++] = right_index--;
898 }
899 left_index >>= 1;
900 right_index >>= 1;
901 }
902 while (right_nodes_count--) {
903 const size_t node_index = right_nodes[right_nodes_count];
904 const int candidate =
905 prefix_excess + node_min_prefix_excess[node_index];
906 if (candidate == best_value) {
907 if (target_min_rank <= node_min_count[node_index]) {
908 return descend_qth_min(node_index, best_value - prefix_excess,
909 target_min_rank, node_base(node_index));
910 }
911 target_min_rank -= node_min_count[node_index];
912 }
913 prefix_excess += node_total_excess[node_index];
914 }
915 }
916
917 // tail
918 if (end_block_index != begin_block_index &&
919 (prefix_excess + min_last_chunk) == best_value) {
920 return qth_min_in_block(end_block_start, range_end, target_min_rank);
921 }
922
923 return npos;
924 }
925
926 // ----- parentheses navigation (BP) -----
927
933 inline size_t close_impl(const size_t& open_position) const {
934 if (open_position >= num_bits) {
935 return npos;
936 }
937 if (!bit(open_position)) {
938 return open_position;
939 }
940 return fwdsearch_impl(open_position, 0);
941 }
942
948 inline size_t open_impl(const size_t& close_position) const {
949 if (close_position >= num_bits) {
950 return npos;
951 }
952 if (bit(close_position)) {
953 return close_position;
954 }
955 return bwdsearch_impl(close_position + 1, 0);
956 }
957
963 inline size_t enclose_impl(const size_t& position) const {
964 if (position >= num_bits) {
965 return npos;
966 }
967 if (!bit(position)) {
968 return open_impl(position);
969 }
970 return bwdsearch_impl(position + 1, -2);
971 }
972
976 inline int bit(const size_t& position) const noexcept {
977 return (bits[position >> 6] >> (position & 63)) & 1u;
978 }
979
980 private:
986 static inline size_t pop10_in_slice64(const std::uint64_t& slice,
987 const int& length) noexcept {
988 if (length <= 1) {
989 return 0;
990 }
991 std::uint64_t pattern_mask = slice & ~(slice >> 1); // candidates for "10"
992 if (length < 64) {
993 pattern_mask &= ((std::uint64_t(1) << (length - 1)) - 1);
994 } else {
995 pattern_mask &= 0x7FFFFFFFFFFFFFFFull;
996 }
997 return (size_t)std::popcount(pattern_mask);
998 }
999
1004 size_t rank1_in_block(const size_t& block_begin,
1005 const size_t& block_end) const noexcept {
1006 if (block_end <= block_begin) {
1007 return 0;
1008 }
1009 size_t left_word_index = block_begin >> 6;
1010 const size_t right_word_index = block_end >> 6;
1011 size_t left_offset = block_begin & 63;
1012 const size_t right_offset = block_end & 63;
1013 size_t count = 0;
1014 if (left_word_index == right_word_index) {
1015 const std::uint64_t mask =
1016 ((right_offset == 0) ? 0 : ((std::uint64_t(1) << right_offset) - 1)) &
1017 (~std::uint64_t(0) << left_offset);
1018 return (size_t)std::popcount(bits[left_word_index] & mask);
1019 }
1020 if (left_offset) {
1021 count += (size_t)std::popcount(bits[left_word_index] &
1022 (~std::uint64_t(0) << left_offset));
1023 ++left_word_index;
1024 }
1025 while (left_word_index < right_word_index) {
1026 count += (size_t)std::popcount(bits[left_word_index]);
1027 ++left_word_index;
1028 }
1029 if (right_offset) {
1030 count += (size_t)std::popcount(bits[right_word_index] &
1031 ((std::uint64_t(1) << right_offset) - 1));
1032 }
1033 return count;
1034 }
1035
1040 size_t rr_in_block(const size_t& block_begin,
1041 const size_t& block_end) const noexcept {
1042 if (block_end <= block_begin + 1) {
1043 return 0;
1044 }
1045 size_t left_word_index = block_begin >> 6;
1046 const size_t right_word_index = (block_end - 1) >> 6;
1047 const int left_offset = block_begin & 63;
1048 const int right_offset = (block_end - 1) & 63;
1049 size_t count = 0;
1050
1051 if (left_word_index == right_word_index) {
1052 const int length = right_offset - left_offset + 1;
1053 const std::uint64_t slice = bits[left_word_index] >> left_offset;
1054 return pop10_in_slice64(slice, length);
1055 }
1056
1057 // prefix word
1058 {
1059 const int length = 64 - left_offset;
1060 const std::uint64_t slice = bits[left_word_index] >> left_offset;
1061 count += pop10_in_slice64(slice, length);
1062 }
1063 // full interior words
1064 for (size_t word_index = left_word_index + 1; word_index < right_word_index;
1065 ++word_index) {
1066 const std::uint64_t word = bits[word_index];
1067 count += pop10_in_slice64(word, 64);
1068 }
1069 // suffix word
1070 {
1071 const int length = right_offset + 1;
1072 const std::uint64_t mask = (length == 64)
1073 ? ~std::uint64_t(0)
1074 : ((std::uint64_t(1) << length) - 1);
1075 const std::uint64_t slice = bits[right_word_index] & mask;
1076 count += pop10_in_slice64(slice, length);
1077 }
1078 // cross-word boundaries (bit 63 of w and bit 0 of w+1)
1079 for (size_t word_index = left_word_index; word_index < right_word_index;
1080 ++word_index) {
1081 if (((bits[word_index] >> 63) & 1u) &&
1082 ((bits[word_index + 1] & 1u) == 0)) {
1083 ++count;
1084 }
1085 }
1086 return count;
1087 }
1088
1094 size_t select10_in_block(const size_t& block_begin,
1095 const size_t& block_end,
1096 size_t target_pattern_rank) const noexcept {
1097 if (block_end <= block_begin + 1) {
1098 return npos;
1099 }
1100 size_t left_word_index = block_begin >> 6;
1101 const size_t right_word_index = (block_end - 1) >> 6;
1102 const int left_offset = block_begin & 63;
1103 const int right_offset = (block_end - 1) & 63;
1104
1105 const auto select_in_masked_slice =
1106 [&](const std::uint64_t& slice, const int& length,
1107 const size_t& target_index) noexcept -> int {
1108 if (length <= 1) {
1109 return -1;
1110 }
1111 std::uint64_t pattern_mask = slice & ~(slice >> 1);
1112 if (length < 64) {
1113 pattern_mask &= ((std::uint64_t(1) << (length - 1)) - 1);
1114 } else {
1115 pattern_mask &= 0x7FFFFFFFFFFFFFFFull;
1116 }
1117 return select_in_word(pattern_mask, target_index);
1118 };
1119
1120 if (left_word_index == right_word_index) {
1121 const int length = right_offset - left_offset + 1;
1122 const std::uint64_t slice = bits[left_word_index] >> left_offset;
1123 const int offset =
1124 select_in_masked_slice(slice, length, target_pattern_rank);
1125 return offset >= 0 ? (block_begin + (size_t)offset) : npos;
1126 }
1127
1128 // prefix word
1129 {
1130 const int length = 64 - left_offset;
1131 const std::uint64_t slice = bits[left_word_index] >> left_offset;
1132 std::uint64_t pattern_mask = slice & ~(slice >> 1);
1133 pattern_mask &= ((std::uint64_t(1) << (length - 1)) - 1);
1134 const int count = std::popcount(pattern_mask);
1135 if (target_pattern_rank <= (size_t)count) {
1136 const int offset =
1137 select_in_masked_slice(slice, length, target_pattern_rank);
1138 return block_begin + (size_t)offset;
1139 }
1140 target_pattern_rank -= count;
1141 }
1142
1143 // walk interior boundaries and words
1144 for (size_t word_index = left_word_index; word_index + 1 < right_word_index;
1145 ++word_index) {
1146 // boundary between w and w+1
1147 if (((bits[word_index] >> 63) & 1u) &&
1148 ((bits[word_index + 1] & 1u) == 0)) {
1149 if (--target_pattern_rank == 0) {
1150 return (word_index << 6) + 63;
1151 }
1152 }
1153 // full word w+1 (positions 0..62)
1154 const std::uint64_t next_word = bits[word_index + 1];
1155 const std::uint64_t pattern_mask =
1156 (next_word & ~(next_word >> 1)) & 0x7FFFFFFFFFFFFFFFull;
1157 const int count = std::popcount(pattern_mask);
1158 if (target_pattern_rank <= (size_t)count) {
1159 const int offset = select_in_word(pattern_mask, target_pattern_rank);
1160 if (offset == -1) {
1161 return npos;
1162 }
1163 return ((word_index + 1) << 6) + (size_t)offset;
1164 }
1165 target_pattern_rank -= count;
1166 }
1167
1168 // boundary (w_r-1, w_r)
1169 if (((bits[right_word_index - 1] >> 63) & 1u) &&
1170 ((bits[right_word_index] & 1u) == 0)) {
1171 if (--target_pattern_rank == 0) {
1172 return ((right_word_index - 1) << 6) + 63;
1173 }
1174 }
1175
1176 // suffix word w_r: [0..off_r]
1177 {
1178 const int length = right_offset + 1;
1179 const std::uint64_t mask = (length == 64)
1180 ? ~std::uint64_t(0)
1181 : ((std::uint64_t(1) << length) - 1);
1182 const std::uint64_t slice = bits[right_word_index] & mask;
1183 const int offset =
1184 select_in_masked_slice(slice, length, target_pattern_rank);
1185 if (offset >= 0) {
1186 return (right_word_index << 6) + (size_t)offset;
1187 }
1188 }
1189 return npos;
1190 }
1191
1192 struct ByteAgg {
1193 int8_t excess_total; // total excess for the byte
1194 int8_t min_prefix; // minimum prefix within the byte (from 0)
1195 int8_t max_prefix; // maximum prefix within the byte (from 0)
1196 uint8_t min_count; // number of positions attaining the minimum in the byte
1197 uint8_t pattern10_count; // number of "10" patterns inside the byte
1198 uint8_t first_bit; // first bit (LSB)
1199 uint8_t last_bit; // last bit (MSB)
1200 uint8_t pos_first_min; // pos of first minimum in this byte
1201 uint8_t pos_first_max; // pos of first maximum in this byte
1202 };
1203
1204 struct LUT8Tables {
1205 std::array<ByteAgg, 256> agg;
1210 std::array<std::array<int8_t, 17>, 256> fwd_pos;
1215 std::array<std::array<int8_t, 17>, 256> bwd_pos;
1216 };
1217
1221 static inline const LUT8Tables& LUT8_ALL() noexcept {
1222 static const LUT8Tables tables = [] {
1223 LUT8Tables lookup_tables{};
1224 for (int byte_value = 0; byte_value < 256; ++byte_value) {
1225 int current_excess = 0, min_prefix = INT_MAX, max_prefix = INT_MIN,
1226 min_count = 0, pattern10_count = 0;
1227 int first_min_position = 0, first_max_position = 0;
1228 int prefixes[8];
1229 const auto bit_at = [&](const int& bit_index) {
1230 return (byte_value >> bit_index) & 1;
1231 }; // LSB-first
1232 for (int bit_index = 0; bit_index < 8; ++bit_index) {
1233 int bit_value = bit_at(bit_index);
1234 if (bit_index + 1 < 8 && bit_value && bit_at(bit_index + 1) == 0) {
1235 ++pattern10_count;
1236 }
1237 current_excess += bit_value ? +1 : -1;
1238 prefixes[bit_index] = current_excess;
1239 if (current_excess < min_prefix) {
1240 min_prefix = current_excess;
1241 min_count = 1;
1242 first_min_position = bit_index;
1243 } else if (current_excess == min_prefix) {
1244 ++min_count;
1245 }
1246 if (current_excess > max_prefix) {
1247 max_prefix = current_excess;
1248 first_max_position = bit_index;
1249 }
1250 }
1251 ByteAgg aggregates{};
1252 aggregates.excess_total = current_excess;
1253 aggregates.min_prefix = (min_prefix == INT_MAX ? 0 : min_prefix);
1254 aggregates.max_prefix = (max_prefix == INT_MIN ? 0 : max_prefix);
1255 aggregates.min_count = min_count;
1256 aggregates.pattern10_count = pattern10_count;
1257 aggregates.first_bit = bit_at(0);
1258 aggregates.last_bit = bit_at(7);
1259 aggregates.pos_first_min = first_min_position;
1260 aggregates.pos_first_max = first_max_position;
1261 lookup_tables.agg[byte_value] = aggregates;
1262 auto& forward_positions = lookup_tables.fwd_pos[byte_value];
1263 auto& backward_positions = lookup_tables.bwd_pos[byte_value];
1264 forward_positions.fill(-1);
1265 backward_positions.fill(-1);
1266 for (int delta = -8; delta <= 8; ++delta) {
1267 for (int bit_index = 0; bit_index < 8; ++bit_index) {
1268 if (prefixes[bit_index] == delta) {
1269 forward_positions[delta + 8] = bit_index;
1270 break;
1271 }
1272 }
1273 for (int bit_index = 7; bit_index >= 0; --bit_index) {
1274 if (prefixes[bit_index] == delta) {
1275 backward_positions[delta + 8] = bit_index;
1276 break;
1277 }
1278 }
1279 }
1280 }
1281 return lookup_tables;
1282 }();
1283 return tables;
1284 }
1285
1289 static inline const std::array<ByteAgg, 256>& LUT8() noexcept {
1290 return LUT8_ALL().agg;
1291 }
1292
1296 static inline const std::array<std::array<int8_t, 17>, 256>&
1297 LUT8_FWD_POS() noexcept {
1298 return LUT8_ALL().fwd_pos;
1299 }
1300
1304 static inline const std::array<std::array<int8_t, 17>, 256>&
1305 LUT8_BWD_POS() noexcept {
1306 return LUT8_ALL().bwd_pos;
1307 }
1308
1312 inline uint16_t get_u16(const size_t& position) const noexcept {
1313 const size_t word_index = position >> 6;
1314 const unsigned offset = unsigned(position & 63);
1315 const std::uint64_t w0 =
1316 (word_index < bits.size()) ? bits[word_index] : 0ULL;
1317 if (offset == 0) {
1318 return uint16_t(w0 & 0xFFFFu);
1319 }
1320 const std::uint64_t w1 =
1321 (word_index + 1 < bits.size()) ? bits[word_index + 1] : 0ULL;
1322 const std::uint64_t v = (w0 >> offset) | (w1 << (64u - offset));
1323 return uint16_t(v & 0xFFFFu);
1324 }
1325
1326#if defined(PIXIE_AVX2_SUPPORT)
1327 static inline __m256i bit_masks_16x() noexcept {
1328 // 16 lanes: (1<<0), (1<<1), ... (1<<15)
1329 return _mm256_setr_epi16(0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020,
1330 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800,
1331 0x1000, 0x2000, 0x4000, (int16_t)0x8000);
1332 }
1333
1334 static inline __m256i prefix_sum_16x_i16(__m256i v) noexcept {
1335 // Inclusive prefix sum within 128-bit lanes, then fix carry into the high
1336 // lane.
1337 __m256i x = v;
1338 __m256i t = _mm256_slli_si256(x, 2);
1339 x = _mm256_add_epi16(x, t);
1340 t = _mm256_slli_si256(x, 4);
1341 x = _mm256_add_epi16(x, t);
1342 t = _mm256_slli_si256(x, 8);
1343 x = _mm256_add_epi16(x, t);
1344
1345 __m128i lo = _mm256_extracti128_si256(x, 0);
1346 __m128i hi = _mm256_extracti128_si256(x, 1);
1347 const int16_t carry =
1348 (int16_t)_mm_extract_epi16(lo, 7); // sum of first 8 elems
1349 hi = _mm_add_epi16(hi, _mm_set1_epi16(carry));
1350
1351 __m256i out = _mm256_castsi128_si256(lo);
1352 out = _mm256_inserti128_si256(out, hi, 1);
1353 return out;
1354 }
1355
1356 static inline int16_t last_prefix_16x_i16(__m256i pref) noexcept {
1357 __m128i hi = _mm256_extracti128_si256(pref, 1);
1358 return (int16_t)_mm_extract_epi16(hi, 7); // lane 15
1359 }
1360
1367 inline size_t scan_leaf_fwd_simd(const size_t& start,
1368 const size_t& end,
1369 const int& required_delta,
1370 int* out_total) const noexcept {
1371 if (start >= end) {
1372 if (out_total) {
1373 *out_total = 0;
1374 }
1375 return npos;
1376 }
1377 if (required_delta < -32768 || required_delta > 32767) {
1378 if (out_total) {
1379 const int len = int(end - start);
1380 const int ones = int(rank1_in_block(start, end));
1381 *out_total = ones * 2 - len;
1382 }
1383 return npos;
1384 }
1385
1386 static const __m256i masks = bit_masks_16x();
1387 static const __m256i vzero = _mm256_setzero_si256();
1388 static const __m256i vallones = _mm256_cmpeq_epi16(vzero, vzero);
1389 static const __m256i vminus1 = _mm256_set1_epi16(-1);
1390 static const __m256i vtwo = _mm256_set1_epi16(2);
1391 const __m256i vtarget = _mm256_set1_epi16((int16_t)required_delta);
1392
1393 int cur = 0;
1394 size_t pos = start;
1395 while (pos + 16 <= end) {
1396 const uint16_t bits16 = get_u16(pos);
1397 const __m256i vb = _mm256_set1_epi16((int16_t)bits16);
1398 const __m256i m = _mm256_and_si256(vb, masks);
1399 const __m256i is_zero = _mm256_cmpeq_epi16(m, vzero);
1400 const __m256i is_set = _mm256_andnot_si256(is_zero, vallones);
1401 const __m256i steps =
1402 _mm256_add_epi16(vminus1, _mm256_and_si256(is_set, vtwo));
1403
1404 const __m256i pref_rel = prefix_sum_16x_i16(steps);
1405 const __m256i base = _mm256_set1_epi16((int16_t)cur);
1406 const __m256i pref = _mm256_add_epi16(pref_rel, base);
1407 const __m256i cmp = _mm256_cmpeq_epi16(pref, vtarget);
1408 const uint32_t mask = (uint32_t)_mm256_movemask_epi8(cmp);
1409 if (mask) {
1410 const int lane = int(std::countr_zero(mask)) >> 1;
1411 return pos + (size_t)lane;
1412 }
1413 cur += (int)last_prefix_16x_i16(pref_rel);
1414 pos += 16;
1415 }
1416 while (pos < end) {
1417 cur += bit(pos) ? +1 : -1;
1418 if (cur == required_delta) {
1419 return pos;
1420 }
1421 ++pos;
1422 }
1423 if (out_total) {
1424 *out_total = cur;
1425 }
1426 return npos;
1427 }
1428#endif // PIXIE_AVX2_SUPPORT
1429
1436 inline size_t scan_leaf_fwd_lut8_fast(const size_t& start,
1437 const size_t& end,
1438 const int& required_delta,
1439 int* out_total) const noexcept {
1440 if (start >= end) {
1441 if (out_total) {
1442 *out_total = 0;
1443 }
1444 return npos;
1445 }
1446 int cur = 0;
1447
1448 size_t pos = start;
1449 while (pos < end && (pos & 7)) {
1450 cur += bit(pos) ? +1 : -1;
1451 if (cur == required_delta) {
1452 if (out_total) {
1453 *out_total = cur;
1454 }
1455 return pos;
1456 }
1457 ++pos;
1458 }
1459
1460 // Byte-aligned fast path: read bytes directly.
1461 // bits are LSB-first; on little-endian x86 the in-memory byte order matches
1462 // bit groups [pos..pos+7].
1463 const uint8_t* bytep =
1464 reinterpret_cast<const uint8_t*>(bits.data()) + (pos >> 3);
1465 const auto& agg = LUT8();
1466 const auto& fwd = LUT8_FWD_POS();
1467
1468 while (pos + 8 <= end) {
1469 const uint8_t bv = *bytep++;
1470 const auto& a = agg[bv];
1471 const int need = required_delta - cur;
1472 // need must be in [-8..8] for a match inside one byte.
1473 if ((unsigned)(need + 8) <= 16u) {
1474 // min/max pruning first, then position lookup.
1475 if (need >= a.min_prefix && need <= a.max_prefix) {
1476 const int8_t off = fwd[bv][need + 8];
1477 if (off >= 0) {
1478 if (out_total) {
1479 *out_total = cur + a.excess_total; // not exact end, but caller
1480 // only uses when not found
1481 }
1482 return pos + (size_t)off;
1483 }
1484 }
1485 }
1486 cur += a.excess_total;
1487 pos += 8;
1488 }
1489
1490 while (pos < end) {
1491 cur += bit(pos) ? +1 : -1;
1492 if (cur == required_delta) {
1493 if (out_total) {
1494 *out_total = cur;
1495 }
1496 return pos;
1497 }
1498 ++pos;
1499 }
1500 if (out_total) {
1501 *out_total = cur;
1502 }
1503 return npos;
1504 }
1505
1513 inline size_t scan_leaf_fwd(const size_t& search_start,
1514 const size_t& search_end,
1515 const int& required_delta) const noexcept {
1516 if (search_start >= search_end) {
1517 return npos;
1518 }
1519 const auto& aggregates_table = LUT8();
1520 const auto& forward_lookup = LUT8_FWD_POS();
1521 int current_excess = 0;
1522 size_t position = search_start;
1523 while (position + 8 <= search_end) {
1524 const uint8_t byte_value = get_byte(position);
1525 const auto& byte_aggregate = aggregates_table[byte_value];
1526 const int local_need = required_delta - current_excess;
1527 if (local_need >= byte_aggregate.min_prefix &&
1528 local_need <= byte_aggregate.max_prefix && local_need >= -8 &&
1529 local_need <= 8) {
1530 const int8_t offset = forward_lookup[byte_value][local_need + 8];
1531 if (offset >= 0) {
1532 return position + size_t(offset);
1533 }
1534 }
1535 current_excess += byte_aggregate.excess_total;
1536 position += 8;
1537 }
1538
1539 while (position < search_end) {
1540 current_excess += bit(position) ? 1 : -1;
1541 if (current_excess == required_delta) {
1542 return position;
1543 }
1544 ++position;
1545 }
1546
1547 return npos;
1548 }
1549
1558 inline size_t scan_leaf_bwd(
1559 const size_t& block_begin,
1560 const size_t& block_end,
1561 const int& required_delta,
1562 const bool& allow_right_boundary,
1563 const size_t& global_right_border,
1564 const int& prefix_at_boundary_max /*=kNoPrefixOverride*/) const noexcept {
1565 // We scan bits in [block_begin, block_end) and look for the LAST boundary
1566 // where prefix == required_delta, with optional exclusion of the right
1567 // boundary.
1568 if (block_begin > block_end) {
1569 return npos;
1570 }
1571
1572 // Maximum allowed boundary (inclusive) inside this scan.
1573 // If right boundary is forbidden, we forbid boundary == block_end.
1574 size_t boundary_max = block_end;
1575 if (!allow_right_boundary && boundary_max > block_begin) {
1576 --boundary_max;
1577 }
1578
1579 // No bits to scan -> only possible answer is the left boundary.
1580 if (block_begin >= boundary_max) {
1581 if ((block_begin < global_right_border || allow_right_boundary) &&
1582 required_delta == 0) {
1583 return block_begin;
1584 }
1585 return npos;
1586 }
1587
1588 if (required_delta < -32768 || required_delta > 32767) {
1589 // Just in case. Should be impossible.
1590 if ((block_begin < global_right_border || allow_right_boundary) &&
1591 required_delta == 0) {
1592 return block_begin;
1593 }
1594 return npos;
1595 }
1596
1597#if defined(PIXIE_AVX2_SUPPORT)
1598 // Fast reverse scan with early exit:
1599 // find the RIGHTMOST boundary j in (block_begin..boundary_max] such that
1600 // prefix(j) == required_delta.
1601
1602 // prefix_end = prefix(boundary_max) relative to block_begin
1603 int prefix_end = prefix_at_boundary_max;
1604 if (prefix_end == kNoPrefixOverride) {
1605 const int len = int(boundary_max - block_begin);
1606 const int ones = int(rank1_in_block(block_begin, boundary_max));
1607 prefix_end = ones * 2 - len;
1608 }
1609
1610 const __m256i masks = bit_masks_16x();
1611 const __m256i vzero = _mm256_setzero_si256();
1612 const __m256i vallones = _mm256_cmpeq_epi16(vzero, vzero);
1613 const __m256i vminus1 = _mm256_set1_epi16(-1);
1614 const __m256i vtwo = _mm256_set1_epi16(2);
1615 const __m256i vtarget = _mm256_set1_epi16((int16_t)required_delta);
1616
1617 size_t pos_end = boundary_max; // boundary (not bit index)
1618 int cur_end = prefix_end; // prefix(pos_end)
1619
1620 // Vector chunks: process 16 bits ending at pos_end.
1621 while (pos_end >= block_begin + 16) {
1622 const size_t pos = pos_end - 16; // bit index of the chunk start
1623 const uint16_t bits16 = get_u16(pos);
1624 const __m256i vb = _mm256_set1_epi16((int16_t)bits16);
1625 const __m256i m = _mm256_and_si256(vb, masks);
1626 const __m256i is_zero = _mm256_cmpeq_epi16(m, vzero);
1627 const __m256i is_set = _mm256_andnot_si256(is_zero, vallones);
1628 const __m256i steps =
1629 _mm256_add_epi16(vminus1, _mm256_and_si256(is_set, vtwo));
1630
1631 const __m256i pref_rel =
1632 prefix_sum_16x_i16(steps); // prefix after each bit (relative)
1633 const int16_t sum16 =
1634 last_prefix_16x_i16(pref_rel); // total sum on this 16-bit chunk
1635 const int cur_start =
1636 cur_end - (int)sum16; // prefix at boundary pos (chunk start)
1637
1638 const __m256i base = _mm256_set1_epi16((int16_t)cur_start);
1639 const __m256i pref = _mm256_add_epi16(
1640 pref_rel, base); // prefix at boundaries (pos+1..pos+16)
1641 const __m256i cmp = _mm256_cmpeq_epi16(pref, vtarget);
1642 const uint32_t mask = (uint32_t)_mm256_movemask_epi8(cmp);
1643 if (mask) {
1644 const int bit_i = 31 - int(std::countl_zero(mask));
1645 const int lane = bit_i >> 1;
1646 const size_t boundary = pos + (size_t)lane + 1;
1647 if (boundary < global_right_border || allow_right_boundary) {
1648 return boundary;
1649 }
1650 return npos;
1651 }
1652
1653 cur_end = cur_start;
1654 pos_end = pos;
1655 }
1656
1657 while (pos_end > block_begin) {
1658 // boundary pos_end corresponds to prefix cur_end
1659 if (cur_end == required_delta) {
1660 if (pos_end < global_right_border || allow_right_boundary) {
1661 return pos_end;
1662 }
1663 return npos;
1664 }
1665 const size_t bit_pos = pos_end - 1;
1666 cur_end -= bit(bit_pos) ? +1 : -1; // move one bit to the left
1667 pos_end = bit_pos;
1668 }
1669#else
1670 size_t last_boundary = npos;
1671 int cur = 0;
1672 for (size_t pos = block_begin; pos < boundary_max; ++pos) {
1673 cur += bit(pos) ? +1 : -1;
1674 if (cur == required_delta) {
1675 last_boundary = pos + 1;
1676 }
1677 }
1678 if (last_boundary != npos) {
1679 if (last_boundary < global_right_border || allow_right_boundary) {
1680 return last_boundary;
1681 }
1682 return npos;
1683 }
1684#endif
1685
1686 // Left boundary (prefix == 0) is always the final candidate.
1687 if ((block_begin < global_right_border || allow_right_boundary) &&
1688 required_delta == 0) {
1689 return block_begin;
1690 }
1691 return npos;
1692 }
1693
1697 inline uint8_t get_byte(const size_t& position) const noexcept {
1698 const size_t word_index = position >> 6;
1699 const size_t offset = position & 63;
1700 const std::uint64_t lower_word = bits[word_index] >> offset;
1701 if (offset <= 56) {
1702 return uint8_t(lower_word & 0xFFu);
1703 }
1704 const std::uint64_t higher_word =
1705 (word_index + 1 < bits.size()) ? bits[word_index + 1] : 0;
1706 const std::uint64_t byte_value =
1707 (lower_word | (higher_word << (64 - offset))) & 0xFFu;
1708 return uint8_t(byte_value);
1709 }
1710
1719 size_t descend_first_max(size_t node_index,
1720 int target_prefix,
1721 size_t segment_base) const noexcept {
1722 while (node_index < first_leaf_index) {
1723 const size_t left_child = node_index << 1, right_child = left_child | 1;
1724 const int left_max = node_max_prefix_excess[left_child];
1725 const int right_max =
1726 node_total_excess[left_child] + node_max_prefix_excess[right_child];
1727 if (left_max >= right_max && left_max == target_prefix) {
1728 node_index = left_child;
1729 } else if (right_max == target_prefix) {
1730 segment_base += segment_size_bits[left_child];
1731 target_prefix -= node_total_excess[left_child];
1732 node_index = right_child;
1733 } else {
1734 return npos;
1735 }
1736 }
1737
1738 const size_t segment_begin = segment_base;
1739 const size_t segment_end =
1740 std::min(segment_base + segment_size_bits[node_index], num_bits);
1741 int max_value;
1742 size_t position;
1743
1744 first_max_value_pos8(segment_begin,
1745 segment_end ? (segment_end - 1) : segment_begin,
1746 max_value, position);
1747 return (max_value == target_prefix ? position : npos);
1748 }
1749
1753 size_t first_leaf_index = 1;
1754
1759 static constexpr int kNoPrefixOverride = std::numeric_limits<int>::min();
1760
1764 size_t block_of(const size_t& position) const noexcept {
1765 return position / block_bits;
1766 }
1767
1771 size_t leaf_index_of(const size_t& block_start) const noexcept {
1772 return first_leaf_index + block_of(block_start);
1773 }
1774
1779 size_t node_base(size_t node_index) const noexcept {
1780 if (node_index >= first_leaf_index) {
1781 return (node_index - first_leaf_index) * block_bits;
1782 }
1783
1784 size_t base = 0;
1785 for (; node_index > 1; node_index >>= 1) {
1786 if (node_index & 1) {
1787 base += segment_size_bits[node_index - 1];
1788 }
1789 }
1790 return base;
1791 }
1792
1798 std::vector<size_t> cover_blocks(const size_t& block_begin_index,
1799 const size_t& block_end_index) const {
1800 size_t left_index = first_leaf_index + block_begin_index;
1801 size_t right_index = first_leaf_index + block_end_index;
1802 std::vector<size_t> left_nodes, right_nodes;
1803 while (left_index <= right_index) {
1804 if ((left_index & 1) == 1) {
1805 left_nodes.push_back(left_index++);
1806 }
1807 if ((right_index & 1) == 0) {
1808 right_nodes.push_back(right_index--);
1809 }
1810 left_index >>= 1;
1811 right_index >>= 1;
1812 }
1813 std::reverse(right_nodes.begin(), right_nodes.end());
1814 left_nodes.insert(left_nodes.end(), right_nodes.begin(), right_nodes.end());
1815 return left_nodes;
1816 }
1817
1822 size_t descend_fwd(size_t node_index,
1823 int required_delta,
1824 size_t segment_base) const noexcept {
1825 while (node_index < first_leaf_index) {
1826 const size_t left_child = node_index << 1;
1827 const size_t right_child = left_child | 1;
1828 if (node_min_prefix_excess[left_child] <= required_delta &&
1829 required_delta <= node_max_prefix_excess[left_child]) {
1830 node_index = left_child;
1831 } else {
1832 required_delta -= node_total_excess[left_child];
1833 segment_base += segment_size_bits[left_child];
1834 node_index = right_child;
1835 }
1836 }
1837 const size_t seg_end =
1838 std::min(segment_base + segment_size_bits[node_index], num_bits);
1839#if defined(PIXIE_AVX2_SUPPORT)
1840 return scan_leaf_fwd_simd(segment_base, seg_end, required_delta, nullptr);
1841#else
1842 return scan_leaf_fwd(segment_base, seg_end, required_delta);
1843#endif
1844 }
1845
1855 size_t descend_bwd(size_t node_index,
1856 const size_t& segment_base,
1857 const int& required_delta,
1858 const size_t& global_right_border,
1859 const bool& allow_right_boundary) const noexcept {
1860 while (node_index < first_leaf_index) {
1861 const size_t left_child = node_index << 1;
1862 const size_t right_child = left_child | 1;
1863 const int required_in_right =
1864 required_delta - node_total_excess[left_child];
1865
1866 // 1) try the right child first (to capture the rightmost j)
1867 if (node_min_prefix_excess[right_child] <= required_in_right &&
1868 required_in_right <= node_max_prefix_excess[right_child]) {
1869 const size_t result = descend_bwd(
1870 right_child, segment_base + segment_size_bits[left_child],
1871 required_in_right, global_right_border, allow_right_boundary);
1872 if (result != npos) {
1873 return result;
1874 }
1875 }
1876
1877 // 2) junction between children (end of the left child)
1878 const size_t junction = segment_base + segment_size_bits[left_child];
1879 if (required_delta == node_total_excess[left_child] &&
1880 (junction < global_right_border || allow_right_boundary)) {
1881 return junction;
1882 }
1883
1884 // 3) can we move left within the range?
1885 if (node_min_prefix_excess[left_child] <= required_delta &&
1886 required_delta <= node_max_prefix_excess[left_child]) {
1887 node_index = left_child;
1888 continue;
1889 }
1890
1891 // None of (1)-(3) worked. The only possible point is the left border of
1892 // the node.
1893 if (required_delta == 0 &&
1894 (segment_base < global_right_border || allow_right_boundary)) {
1895 return segment_base;
1896 }
1897
1898 return npos;
1899 }
1900
1901 const size_t seg_end =
1902 std::min(segment_base + segment_size_bits[node_index], num_bits);
1903 const size_t block_end = std::min(global_right_border, seg_end);
1904
1905 int prefix_override = kNoPrefixOverride;
1906 // If we scan the full leaf up to its end boundary, we know
1907 // prefix(block_end) from node_total_excess[leaf]. If the right boundary is
1908 // forbidden, we can still derive prefix(block_end-1) cheaply.
1909 if (block_end == seg_end) {
1910 const int total = node_total_excess[node_index];
1911 if (!allow_right_boundary && block_end > segment_base) {
1912 prefix_override = total - (bit(block_end - 1) ? +1 : -1);
1913 } else {
1914 prefix_override = total;
1915 }
1916 }
1917
1918 return scan_leaf_bwd(segment_base, block_end, required_delta,
1919 allow_right_boundary, global_right_border,
1920 prefix_override);
1921 }
1922
1927 size_t descend_first_min(size_t node_index,
1928 int target_prefix,
1929 size_t segment_base) const noexcept {
1930 while (node_index < first_leaf_index) {
1931 const size_t left_child = node_index << 1, right_child = left_child | 1;
1932 const int left_min = node_min_prefix_excess[left_child];
1933 const int right_min =
1934 node_total_excess[left_child] + node_min_prefix_excess[right_child];
1935 if (left_min <= right_min && left_min == target_prefix) {
1936 node_index = left_child;
1937 } else if (right_min == target_prefix) {
1938 segment_base += segment_size_bits[left_child];
1939 target_prefix -= node_total_excess[left_child];
1940 node_index = right_child;
1941 } else {
1942 return npos;
1943 }
1944 }
1945
1946 const size_t segment_begin = segment_base;
1947 const size_t segment_end =
1948 std::min(segment_base + segment_size_bits[node_index], num_bits);
1949 int min_value;
1950 size_t position;
1951
1952 first_min_value_pos8(segment_begin,
1953 segment_end ? (segment_end - 1) : segment_begin,
1954 min_value, position);
1955 return (min_value == target_prefix ? position : npos);
1956 }
1957
1962 size_t descend_qth_min(size_t node_index,
1963 int target_prefix,
1964 size_t target_min_rank,
1965 size_t segment_base) const noexcept {
1966 while (node_index < first_leaf_index) {
1967 const size_t left_child = node_index << 1;
1968 const size_t right_child = left_child | 1;
1969 const int left_min = node_min_prefix_excess[left_child];
1970 const int right_min =
1971 node_total_excess[left_child] + node_min_prefix_excess[right_child];
1972 if (left_min == target_prefix) {
1973 if (node_min_count[left_child] >= target_min_rank) {
1974 node_index = left_child;
1975 continue;
1976 }
1977 target_min_rank -= node_min_count[left_child];
1978 }
1979 if (right_min == target_prefix) {
1980 segment_base += segment_size_bits[left_child];
1981 target_prefix -= node_total_excess[left_child];
1982 node_index = right_child;
1983 continue;
1984 }
1985 return npos;
1986 }
1987 return qth_min_in_block(
1988 segment_base,
1989 std::min(segment_base + segment_size_bits[node_index], num_bits) - 1,
1990 target_min_rank);
1991 }
1992
1997 size_t select1_in_block(const size_t& block_begin,
1998 const size_t& block_end,
1999 size_t target_one_rank) const noexcept {
2000 size_t left_word_index = block_begin >> 6;
2001 const size_t right_word_index = (block_end >> 6);
2002 const size_t left_offset = block_begin & 63;
2003 const std::uint64_t left_mask =
2004 (left_offset ? (~std::uint64_t(0) << left_offset) : ~std::uint64_t(0));
2005 if (left_word_index == right_word_index) {
2006 const std::uint64_t word =
2007 bits[left_word_index] & left_mask &
2008 ((block_end & 63) ? ((std::uint64_t(1) << (block_end & 63)) - 1)
2009 : ~std::uint64_t(0));
2010 return block_begin + select_in_word(word, target_one_rank);
2011 }
2012 // prefix
2013 if (left_offset) {
2014 const std::uint64_t word = bits[left_word_index] & left_mask;
2015 const int count = std::popcount(word);
2016 if (target_one_rank <= (size_t)count) {
2017 return block_begin + select_in_word(word, target_one_rank);
2018 }
2019 target_one_rank -= count;
2020 left_word_index++;
2021 }
2022 // full words
2023 while (left_word_index < right_word_index) {
2024 const std::uint64_t word = bits[left_word_index];
2025 const int count = std::popcount(word);
2026 if (target_one_rank <= (size_t)count) {
2027 return (left_word_index << 6) + select_in_word(word, target_one_rank);
2028 }
2029 target_one_rank -= count;
2030 ++left_word_index;
2031 }
2032 // tail
2033 const size_t right_offset = block_end & 63;
2034 if (right_offset) {
2035 const std::uint64_t word =
2036 bits[left_word_index] & ((std::uint64_t(1) << right_offset) - 1);
2037 const int count = std::popcount(word);
2038 if (target_one_rank <= (size_t)count) {
2039 return (left_word_index << 6) + select_in_word(word, target_one_rank);
2040 }
2041 }
2042 return npos;
2043 }
2044
2049 size_t select0_in_block(const size_t& block_begin,
2050 const size_t& block_end,
2051 size_t target_zero_rank) const noexcept {
2052 if (block_end <= block_begin) {
2053 return npos;
2054 }
2055
2056 size_t left_word_index = block_begin >> 6;
2057 const size_t right_word_index = block_end >> 6;
2058 const size_t left_offset = block_begin & 63;
2059
2060 if (left_word_index == right_word_index) {
2061 const std::uint64_t left_mask =
2062 (left_offset ? (~std::uint64_t(0) << left_offset)
2063 : ~std::uint64_t(0));
2064 const std::uint64_t right_mask =
2065 ((block_end & 63) ? ((std::uint64_t(1) << (block_end & 63)) - 1)
2066 : ~std::uint64_t(0));
2067 const std::uint64_t word =
2068 (~bits[left_word_index]) & left_mask & right_mask;
2069 const int offset = select_in_word(word, target_zero_rank);
2070 return (offset >= 0) ? (block_begin + (size_t)offset) : npos;
2071 }
2072
2073 // prefix
2074 if (left_offset) {
2075 const std::uint64_t word =
2076 (~bits[left_word_index]) & (~std::uint64_t(0) << left_offset);
2077 const int count = std::popcount(word);
2078 if (target_zero_rank <= (size_t)count) {
2079 const int offset = select_in_word(word, target_zero_rank);
2080 return (offset >= 0) ? (block_begin + (size_t)offset) : npos;
2081 }
2082 target_zero_rank -= count;
2083 ++left_word_index;
2084 }
2085
2086 // full words
2087 while (left_word_index < right_word_index) {
2088 const std::uint64_t word = ~bits[left_word_index];
2089 const int count = std::popcount(word);
2090 if (target_zero_rank <= (size_t)count) {
2091 const int offset = select_in_word(word, target_zero_rank);
2092 return (offset >= 0) ? ((left_word_index << 6) + (size_t)offset) : npos;
2093 }
2094 target_zero_rank -= count;
2095 ++left_word_index;
2096 }
2097
2098 // tail
2099 const size_t right_offset = block_end & 63;
2100 if (right_offset) {
2101 const std::uint64_t word =
2102 (~bits[left_word_index]) & ((std::uint64_t(1) << right_offset) - 1);
2103 const int count = std::popcount(word);
2104 if (target_zero_rank <= (size_t)count) {
2105 const int offset = select_in_word(word, target_zero_rank);
2106 return (offset >= 0) ? ((left_word_index << 6) + (size_t)offset) : npos;
2107 }
2108 }
2109 return npos;
2110 }
2111
2116 static inline int select_in_word(std::uint64_t word,
2117 size_t target_rank) noexcept {
2118 while (word) {
2119 if (--target_rank == 0) {
2120 return std::countr_zero(word);
2121 }
2122 word &= (word - 1);
2123 }
2124 return -1;
2125 }
2126
2130 static inline size_t ceil_div(const size_t& numerator,
2131 const size_t& denominator) noexcept {
2132 return (numerator + denominator - 1) / denominator;
2133 }
2134
2140 static inline size_t nodeslots_for(const size_t& bit_count,
2141 const size_t& block_size_pow2) noexcept {
2142 if (bit_count == 0) {
2143 return 0;
2144 }
2145 size_t leaf_node_count = ceil_div(bit_count, block_size_pow2);
2146 return std::bit_ceil(std::max<size_t>(1, leaf_node_count)) +
2147 leaf_node_count;
2148 }
2149
2153 static inline float overhead_for(const size_t& bit_count,
2154 const size_t& block_size_pow2) noexcept {
2155 static constexpr size_t AUX_SLOT_BYTES =
2156 sizeof(uint32_t) + sizeof(int32_t) + sizeof(int32_t) + sizeof(int32_t) +
2157 sizeof(uint32_t) + sizeof(uint32_t) + sizeof(uint8_t) + sizeof(uint8_t);
2158
2159 size_t bitvector_bytes = ceil_div(bit_count, 64) * 8;
2160 if (bitvector_bytes == 0) {
2161 return 0;
2162 }
2163 size_t slot_count = nodeslots_for(bit_count, block_size_pow2);
2164 size_t aux_bytes = slot_count * AUX_SLOT_BYTES;
2165 return ((float)aux_bytes) / ((float)bitvector_bytes);
2166 }
2167
2174 static inline size_t choose_block_bits_for_overhead(
2175 const size_t& bit_count,
2176 const float& overhead_cap) noexcept {
2177 if (overhead_cap < 0.f) {
2178 return 64;
2179 }
2180
2181 const size_t max_block_bits = std::min<size_t>(bit_count, 16384);
2182 size_t candidate_block_bits = 64;
2183 while (candidate_block_bits < max_block_bits) {
2184 if (overhead_for(bit_count, candidate_block_bits) <= overhead_cap) {
2185 break;
2186 }
2187 candidate_block_bits <<= 1;
2188 }
2189 return candidate_block_bits;
2190 }
2191
2199 void build_from_words(std::span<const std::uint64_t> words,
2200 const size_t& bit_count,
2201 const size_t& leaf_block_bits = 0,
2202 const float& max_overhead = -1.0) {
2203 bits = words;
2204 num_bits = bit_count;
2205 if (bits.size() * 64 < num_bits) {
2206 throw std::invalid_argument(
2207 "RmMTree bit_count exceeds the provided word span");
2208 }
2209 build(leaf_block_bits, max_overhead);
2210 }
2211
2216 inline uint32_t ones_in_node(const size_t& node_index) const noexcept {
2217 return ((int64_t)segment_size_bits[node_index] +
2218 (int64_t)node_total_excess[node_index]) >>
2219 1;
2220 }
2221
2229 inline void scan_range_min_count8(size_t range_begin,
2230 const size_t& range_end,
2231 int& current_excess,
2232 int& min_value,
2233 uint32_t& count) const noexcept {
2234 current_excess = 0;
2235 min_value = INT_MAX;
2236 count = 0;
2237 if (range_end < range_begin) {
2238 min_value = 0;
2239 return;
2240 }
2241 // to byte alignment
2242 while (range_begin <= range_end && (range_begin & 7)) {
2243 current_excess += bit(range_begin) ? +1 : -1;
2244 if (current_excess < min_value) {
2245 min_value = current_excess;
2246 count = 1;
2247 } else if (current_excess == min_value) {
2248 ++count;
2249 }
2250 ++range_begin;
2251 }
2252 // full bytes
2253 const auto& aggregates_table = LUT8();
2254 while (range_begin + 7 <= range_end) {
2255 const auto& byte_aggregate = aggregates_table[get_byte(range_begin)];
2256 const int candidate = current_excess + byte_aggregate.min_prefix;
2257 if (candidate < min_value) {
2258 min_value = candidate;
2259 count = byte_aggregate.min_count;
2260 } else if (candidate == min_value) {
2261 count += byte_aggregate.min_count;
2262 }
2263 current_excess += byte_aggregate.excess_total;
2264 range_begin += 8;
2265 }
2266 // tail
2267 while (range_begin <= range_end) {
2268 current_excess += bit(range_begin) ? +1 : -1;
2269 if (current_excess < min_value) {
2270 min_value = current_excess;
2271 count = 1;
2272 } else if (current_excess == min_value) {
2273 ++count;
2274 }
2275 ++range_begin;
2276 }
2277 if (min_value == INT_MAX) {
2278 min_value = count = 0;
2279 }
2280 }
2281
2288 inline size_t cover_blocks_collect(const size_t& block_begin_index,
2289 const size_t& block_end_index,
2290 size_t (&out_nodes)[64]) const noexcept {
2291 if (leaf_count == 0 || block_begin_index > block_end_index) {
2292 return 0;
2293 }
2294 size_t left_index = first_leaf_index + block_begin_index;
2295 size_t right_index = first_leaf_index + block_end_index;
2296 size_t left_nodes[32];
2297 size_t right_nodes[32];
2298 size_t left_count = 0, right_count = 0;
2299 while (left_index <= right_index) {
2300 if (left_index & 1) {
2301 left_nodes[left_count++] = left_index++;
2302 }
2303 if ((right_index & 1) == 0) {
2304 right_nodes[right_count++] = right_index--;
2305 }
2306 left_index >>= 1;
2307 right_index >>= 1;
2308 }
2309 size_t out_count = 0;
2310 for (size_t i = 0; i < left_count; ++i) {
2311 out_nodes[out_count++] = left_nodes[i];
2312 }
2313 while (right_count > 0) {
2314 out_nodes[out_count++] = right_nodes[--right_count];
2315 }
2316 return out_count;
2317 }
2318
2325 inline size_t qth_min_in_block(const size_t& range_begin,
2326 const size_t& range_end,
2327 size_t target_min_rank) const noexcept {
2328 if (range_end < range_begin || target_min_rank == 0) {
2329 return npos;
2330 }
2331
2332 const auto& aggregates_table = LUT8();
2333
2334 int current_excess = 0, min_value = INT_MAX;
2335 size_t position = range_begin;
2336
2337 while (position <= range_end && (position & 7)) {
2338 current_excess += bit(position) ? +1 : -1;
2339 if (current_excess < min_value) {
2340 min_value = current_excess;
2341 }
2342 ++position;
2343 }
2344 while (position + 7 <= range_end) {
2345 const auto& byte_aggregate = aggregates_table[get_byte(position)];
2346 min_value =
2347 std::min(min_value, current_excess + byte_aggregate.min_prefix);
2348 current_excess += byte_aggregate.excess_total;
2349 position += 8;
2350 }
2351 while (position <= range_end) {
2352 current_excess += bit(position) ? +1 : -1;
2353 if (current_excess < min_value) {
2354 min_value = current_excess;
2355 }
2356 ++position;
2357 }
2358
2359 current_excess = 0;
2360 position = range_begin;
2361
2362 // to byte alignment
2363 while (position <= range_end && (position & 7)) {
2364 current_excess += bit(position) ? +1 : -1;
2365 if (current_excess == min_value) {
2366 if (--target_min_rank == 0) {
2367 return position;
2368 }
2369 }
2370 ++position;
2371 }
2372
2373 // full bytes
2374 while (position + 7 <= range_end) {
2375 const uint8_t byte_value = get_byte(position);
2376 const auto& byte_aggregate = aggregates_table[byte_value];
2377 const int candidate = current_excess + byte_aggregate.min_prefix;
2378 if (candidate == min_value) {
2379 int prefix_sum = 0;
2380 for (int k = 0; k < 8; ++k) {
2381 prefix_sum += ((byte_value >> k) & 1u) ? +1 : -1;
2382 if (prefix_sum == byte_aggregate.min_prefix) {
2383 if (--target_min_rank == 0) {
2384 return position + k;
2385 }
2386 }
2387 }
2388 }
2389 current_excess += byte_aggregate.excess_total;
2390 position += 8;
2391 }
2392
2393 // tail
2394 while (position <= range_end) {
2395 current_excess += bit(position) ? +1 : -1;
2396 if (current_excess == min_value) {
2397 if (--target_min_rank == 0) {
2398 return position;
2399 }
2400 }
2401 ++position;
2402 }
2403
2404 return npos;
2405 }
2406
2417 inline size_t leaf_fwd_bp_simd(const size_t& leaf_index,
2418 const size_t& leaf_block_begin,
2419 const size_t& start_position,
2420 const int& delta,
2421 int& leaf_delta) const noexcept {
2422 const size_t leaf_length = segment_size_bits[first_leaf_index + leaf_index];
2423 const size_t leaf_end = std::min(num_bits, leaf_block_begin + leaf_length);
2424 if (start_position >= leaf_end) {
2425 leaf_delta = 0;
2426 return npos;
2427 }
2428#if defined(PIXIE_AVX2_SUPPORT)
2429 int total = 0;
2430 // Heuristic: for tiny tails and tiny deltas, LUT8 wins because AVX2 setup
2431 // is heavy...
2432 // At least I think so...
2433 const size_t tail_len = leaf_end - start_position;
2434 size_t res = npos;
2435 if (delta >= -8 && delta <= 8 && tail_len <= 256) {
2436 res = scan_leaf_fwd_lut8_fast(start_position, leaf_end, delta, &total);
2437 } else {
2438 res = scan_leaf_fwd_simd(start_position, leaf_end, delta, &total);
2439 }
2440 if (res != npos) {
2441 return res;
2442 }
2443 leaf_delta = total;
2444 return npos;
2445#else
2446 const size_t res = scan_leaf_fwd(start_position, leaf_end, delta);
2447 if (res != npos) {
2448 return res;
2449 }
2450 const int len = int(leaf_end - start_position);
2451 const int ones = int(rank1_in_block(start_position, leaf_end));
2452 leaf_delta = ones * 2 - len;
2453 return npos;
2454#endif
2455 }
2456
2465 inline size_t leaf_bwd_bp_simd(const size_t& leaf_index,
2466 const size_t& leaf_block_begin,
2467 const size_t& start_position,
2468 const int& delta,
2469 int& leaf_delta) const noexcept {
2470 const size_t leaf_length = segment_size_bits[first_leaf_index + leaf_index];
2471 const size_t leaf_end = std::min(num_bits, leaf_block_begin + leaf_length);
2472 if (start_position < leaf_block_begin || start_position > leaf_end) {
2473 leaf_delta = 0;
2474 return npos;
2475 }
2476
2477 // leaf_delta = excess_impl(start_position) - excess_impl(leaf_block_begin)
2478 // = 2*rank1_impl([leaf_begin, start)) - (start - leaf_begin)
2479 const int len = int(start_position - leaf_block_begin);
2480 const int ones = int(rank1_in_block(leaf_block_begin, start_position));
2481 leaf_delta = ones * 2 - len;
2482
2483 // Must find a boundary strictly < start_position (so
2484 // boundary==start_position forbidden).
2485 if (start_position == leaf_block_begin) {
2486 return npos;
2487 }
2488 const int target_prefix = leaf_delta + delta;
2489 return scan_leaf_bwd(
2490 leaf_block_begin,
2491 start_position, // do not look to the right of start_position
2492 target_prefix,
2493 false, // right boundary (=start_position) forbidden
2494 start_position,
2495 // prefix at boundary_max = start_position-1:
2496 // leaf_delta is prefix at start_position, subtract last step
2497 (start_position > leaf_block_begin
2498 ? (leaf_delta - (bit(start_position - 1) ? +1 : -1))
2499 : 0));
2500 }
2501
2508 inline void first_min_value_pos8(size_t range_begin,
2509 const size_t& range_end,
2510 int& min_value_out,
2511 size_t& first_position) const noexcept {
2512 const auto& aggregates_table = LUT8();
2513 int current_excess = 0;
2514 int min_value = INT_MAX;
2515 first_position = npos;
2516
2517 // to byte allignment
2518 while (range_begin <= range_end && (range_begin & 7)) {
2519 current_excess += bit(range_begin) ? +1 : -1;
2520 if (current_excess < min_value) {
2521 min_value = current_excess;
2522 first_position = range_begin;
2523 }
2524 ++range_begin;
2525 }
2526
2527 // full bytes
2528 while (range_begin + 7 <= range_end) {
2529 const auto& byte_aggregate = aggregates_table[get_byte(range_begin)];
2530 const int candidate = current_excess + byte_aggregate.min_prefix;
2531 if (candidate < min_value) {
2532 min_value = candidate;
2533 first_position = range_begin + byte_aggregate.pos_first_min;
2534 }
2535 current_excess += byte_aggregate.excess_total;
2536 range_begin += 8;
2537 }
2538
2539 // tail
2540 while (range_begin <= range_end) {
2541 current_excess += bit(range_begin) ? +1 : -1;
2542 if (current_excess < min_value) {
2543 min_value = current_excess;
2544 first_position = range_begin;
2545 }
2546 ++range_begin;
2547 }
2548
2549 min_value_out = (min_value == INT_MAX ? 0 : min_value);
2550 }
2551
2558 inline void first_max_value_pos8(size_t range_begin,
2559 const size_t& range_end,
2560 int& max_value_out,
2561 size_t& first_position) const noexcept {
2562 const auto& aggregates_table = LUT8();
2563 int current_excess = 0;
2564 int max_value = INT_MIN;
2565 first_position = npos;
2566
2567 while (range_begin <= range_end && (range_begin & 7)) {
2568 current_excess += bit(range_begin) ? +1 : -1;
2569 if (current_excess > max_value) {
2570 max_value = current_excess;
2571 first_position = range_begin;
2572 }
2573 ++range_begin;
2574 }
2575
2576 while (range_begin + 7 <= range_end) {
2577 const auto& byte_aggregate = aggregates_table[get_byte(range_begin)];
2578 const int candidate = current_excess + byte_aggregate.max_prefix;
2579 if (candidate > max_value) {
2580 max_value = candidate;
2581 first_position = range_begin + byte_aggregate.pos_first_max;
2582 }
2583 current_excess += byte_aggregate.excess_total;
2584 range_begin += 8;
2585 }
2586
2587 while (range_begin <= range_end) {
2588 current_excess += bit(range_begin) ? +1 : -1;
2589 if (current_excess > max_value) {
2590 max_value = current_excess;
2591 first_position = range_begin;
2592 }
2593 ++range_begin;
2594 }
2595
2596 max_value_out = (max_value == INT_MIN ? 0 : max_value);
2597 }
2598
2605 void build(const size_t& leaf_block_bits, const float& max_overhead) {
2606 // the lower clamp depends on the desired overhead fraction; otherwise use
2607 // 64
2608 const size_t clamp_by_overhead =
2609 (max_overhead >= 0.0
2610 ? choose_block_bits_for_overhead(num_bits, max_overhead)
2611 : size_t(64));
2612
2613 // chosen block_bits: honor an explicit request, but not below
2614 // clamp_by_overhead
2615 if (leaf_block_bits == 0) {
2616 block_bits =
2617 std::max(clamp_by_overhead,
2618 std::bit_ceil<size_t>(
2619 (num_bits <= 1) ? 1 : std::bit_width(num_bits - 1)));
2620 } else {
2621 block_bits =
2622 std::max(clamp_by_overhead,
2623 std::bit_ceil(std::max<size_t>(1, leaf_block_bits)));
2624 }
2625
2626#ifdef DEBUG
2627 // finalizes the achieved overhead percentage
2628 built_overhead = overhead_for(num_bits, block_bits);
2629#endif
2630
2631 leaf_count = ceil_div(num_bits, block_bits);
2632 first_leaf_index = std::bit_ceil(std::max<size_t>(1, leaf_count));
2633 const size_t tree_size = first_leaf_index + leaf_count - 1;
2634 segment_size_bits.assign(tree_size + 1, 0);
2635 node_total_excess.assign(tree_size + 1, 0);
2636 node_min_prefix_excess.assign(tree_size + 1, 0);
2637 node_max_prefix_excess.assign(tree_size + 1, 0);
2638 node_min_count.assign(tree_size + 1, 0);
2639 node_pattern10_count.assign(tree_size + 1, 0);
2640 node_first_bit.assign(tree_size + 1, 0);
2641 node_last_bit.assign(tree_size + 1, 0);
2642
2643 // leaves
2644 for (size_t leaf_block_index = 0; leaf_block_index < leaf_count;
2645 ++leaf_block_index) {
2646 const size_t leaf_node_index = first_leaf_index + leaf_block_index;
2647 const size_t segment_begin = leaf_block_index * block_bits;
2648 const size_t segment_end = std::min(num_bits, segment_begin + block_bits);
2649 segment_size_bits[leaf_node_index] = segment_end - segment_begin;
2650
2651 if (segment_begin < segment_end) {
2652 node_first_bit[leaf_node_index] = bit(segment_begin);
2653 }
2654
2655 const auto& aggregates_table = LUT8();
2656
2657 int current_excess = 0, min_value = INT_MAX, max_value = INT_MIN;
2658 uint32_t min_count = 0;
2659 uint32_t pattern10_count = 0;
2660
2661 uint8_t previous_bit = 0;
2662
2663 size_t position = segment_begin;
2664
2665 // Full bytes
2666 while (position + 8 <= segment_end) {
2667 const uint8_t byte_value = get_byte(position);
2668 const auto& byte_aggregate = aggregates_table[byte_value];
2669
2670 // internal "10" inside the byte
2671 pattern10_count += byte_aggregate.pattern10_count;
2672 // stitching across the boundary between the previous and current byte
2673 // (within the segment)
2674 if (previous_bit == 1 && byte_aggregate.first_bit == 0) {
2675 pattern10_count++;
2676 }
2677
2678 // prefix min/max accounting for the current offset
2679 const int candidate_min = current_excess + byte_aggregate.min_prefix;
2680 if (candidate_min < min_value) {
2681 min_value = candidate_min;
2682 min_count = byte_aggregate.min_count;
2683 } else if (candidate_min == min_value) {
2684 min_count += byte_aggregate.min_count;
2685 }
2686
2687 max_value =
2688 std::max(max_value, current_excess + byte_aggregate.max_prefix);
2689 current_excess += byte_aggregate.excess_total;
2690 previous_bit = byte_aggregate.last_bit;
2691 position += 8;
2692 }
2693
2694 // Tail < 8 bits
2695 while (position < segment_end) {
2696 const uint8_t bit_value = bit(position);
2697 if (previous_bit == 1 && bit_value == 0) {
2698 pattern10_count++;
2699 }
2700 const int step = bit_value ? +1 : -1;
2701 current_excess += step;
2702 if (current_excess < min_value) {
2703 min_value = current_excess;
2704 min_count = 1;
2705 } else if (current_excess == min_value) {
2706 ++min_count;
2707 }
2708 if (current_excess > max_value) {
2709 max_value = current_excess;
2710 }
2711
2712 previous_bit = bit_value;
2713 ++position;
2714 }
2715
2716 if (segment_begin < segment_end) {
2717 node_last_bit[leaf_node_index] = previous_bit;
2718 }
2719
2720 node_total_excess[leaf_node_index] = current_excess;
2721 node_min_prefix_excess[leaf_node_index] =
2722 (segment_size_bits[leaf_node_index] == 0 ? 0 : min_value);
2723 node_max_prefix_excess[leaf_node_index] =
2724 (segment_size_bits[leaf_node_index] == 0 ? 0 : max_value);
2725 node_min_count[leaf_node_index] = min_count;
2726 node_pattern10_count[leaf_node_index] = (uint32_t)pattern10_count;
2727 }
2728 // internal nodes
2729 for (size_t node_index = first_leaf_index - 1; node_index >= 1;
2730 --node_index) {
2731 const size_t left_child = node_index << 1;
2732 const size_t right_child = left_child | 1;
2733 const bool has_left =
2734 (left_child <= tree_size) && segment_size_bits[left_child];
2735 const bool has_right =
2736 (right_child <= tree_size) && segment_size_bits[right_child];
2737 if (!has_left && !has_right) {
2738 segment_size_bits[node_index] = 0;
2739 continue;
2740 }
2741 if (has_left && !has_right) {
2742 segment_size_bits[node_index] = segment_size_bits[left_child];
2743 node_total_excess[node_index] = node_total_excess[left_child];
2744 node_min_prefix_excess[node_index] = node_min_prefix_excess[left_child];
2745 node_max_prefix_excess[node_index] = node_max_prefix_excess[left_child];
2746 node_min_count[node_index] = node_min_count[left_child];
2747 node_pattern10_count[node_index] = node_pattern10_count[left_child];
2748 node_first_bit[node_index] = node_first_bit[left_child];
2749 node_last_bit[node_index] = node_last_bit[left_child];
2750 } else if (!has_left && has_right) {
2751 segment_size_bits[node_index] = segment_size_bits[right_child];
2752 node_total_excess[node_index] = node_total_excess[right_child];
2753 node_min_prefix_excess[node_index] =
2754 node_min_prefix_excess[right_child];
2755 node_max_prefix_excess[node_index] =
2756 node_max_prefix_excess[right_child];
2757 node_min_count[node_index] = node_min_count[right_child];
2758 node_pattern10_count[node_index] = node_pattern10_count[right_child];
2759 node_first_bit[node_index] = node_first_bit[right_child];
2760 node_last_bit[node_index] = node_last_bit[right_child];
2761 } else {
2762 segment_size_bits[node_index] =
2763 segment_size_bits[left_child] + segment_size_bits[right_child];
2764 node_total_excess[node_index] =
2765 node_total_excess[left_child] + node_total_excess[right_child];
2766 const int right_min_candidate =
2767 node_total_excess[left_child] + node_min_prefix_excess[right_child];
2768 const int right_max_candidate =
2769 node_total_excess[left_child] + node_max_prefix_excess[right_child];
2770 node_min_prefix_excess[node_index] =
2771 std::min(node_min_prefix_excess[left_child], right_min_candidate);
2772 node_max_prefix_excess[node_index] =
2773 std::max(node_max_prefix_excess[left_child], right_max_candidate);
2774 node_min_count[node_index] =
2775 (node_min_prefix_excess[left_child] ==
2776 node_min_prefix_excess[node_index]
2777 ? node_min_count[left_child]
2778 : 0) +
2779 (right_min_candidate == node_min_prefix_excess[node_index]
2780 ? node_min_count[right_child]
2781 : 0);
2782 node_pattern10_count[node_index] = node_pattern10_count[left_child] +
2783 node_pattern10_count[right_child] +
2784 ((node_last_bit[left_child] == 1 &&
2785 node_first_bit[right_child] == 0)
2786 ? 1u
2787 : 0u);
2788 node_first_bit[node_index] = node_first_bit[left_child];
2789 node_last_bit[node_index] = node_last_bit[right_child];
2790 }
2791 if (node_index == 1) {
2792 break;
2793 }
2794 }
2795 }
2796};
2797
2798} // namespace pixie
CRTP facade for rank/select and range min-max tree operations.
Definition rmm_base.h:18
int range_max_query_val_impl(const size_t &range_begin, const size_t &range_end) const
Value of the maximum prefix excess on [range_begin, range_end] relative to range_begin.
Definition rmm_tree.h:673
size_t fwdsearch_impl(const size_t &start_position, const int &delta) const
Forward search: first position p ≥ start_position where excess_impl(p) = excess_impl(start_position) ...
Definition rmm_tree.h:324
size_t minselect_impl(const size_t &range_begin, const size_t &range_end, size_t target_min_rank) const
Position of the target_min_rank-th (1-based) occurrence of the minimum on [range_begin,...
Definition rmm_tree.h:773
size_t rank0_impl(const size_t &end_position) const
Number of zeros in prefix [0, end_position).
Definition rmm_tree.h:152
int bit(const size_t &position) const noexcept
Read bit at position position (LSB-first across words).
Definition rmm_tree.h:976
size_t close_impl(const size_t &open_position) const
close_impl(open_position): matching ')' for '(' at open_position.
Definition rmm_tree.h:933
int excess_impl(const size_t &end_position) const
Prefix excess on [0, end_position): +1 for '1', −1 for '0'.
Definition rmm_tree.h:314
size_t bwdsearch_impl(const size_t &start_position, const int &delta) const
Backward search: last position p ≤ start_position where excess_impl(p) = excess_impl(start_position) ...
Definition rmm_tree.h:388
RmMTree()=default
Construct empty structure.
size_t rank1_impl(const size_t &end_position) const
Number of ones in prefix [0, end_position).
Definition rmm_tree.h:127
RmMTree(std::span< const std::uint64_t > words, size_t bit_count, const size_t &leaf_block_bits=0, const float &max_overhead=-1.0)
Build from a non-owning view of 64-bit words (LSB-first).
Definition rmm_tree.h:112
size_t enclose_impl(const size_t &position) const
enclose_impl(position): opening '(' that strictly encloses position.
Definition rmm_tree.h:963
int range_min_query_val_impl(const size_t &range_begin, const size_t &range_end) const
Value of the minimum prefix excess on [range_begin, range_end] relative to range_begin.
Definition rmm_tree.h:555
size_t open_impl(const size_t &close_position) const
open_impl(close_position): matching '(' for ')' at close_position.
Definition rmm_tree.h:948
size_t rank10_impl(const size_t &end_position) const
Rank of the pattern "10" (starts) within [0, end_position).
Definition rmm_tree.h:226
size_t mincount_impl(const size_t &range_begin, const size_t &range_end) const
How many times the minimum prefix excess occurs on [range_begin, range_end].
Definition rmm_tree.h:689
size_t range_min_query_pos_impl(const size_t &range_begin, const size_t &range_end) const
Position of the first minimum of excess on [range_begin, range_end] (inclusive).
Definition rmm_tree.h:453
size_t select0_impl(size_t target_zero_rank) const
1-based select of the target_zero_rank-th zero.
Definition rmm_tree.h:191
size_t select10_impl(size_t target_pattern_rank) const
1-based select of the target_pattern_rank-th "10" start.
Definition rmm_tree.h:259
static constexpr size_t npos
Sentinel for "not found".
Definition rmm_tree.h:86
size_t select1_impl(size_t target_one_rank) const
1-based select of the target_one_rank-th one.
Definition rmm_tree.h:160
size_t range_max_query_pos_impl(const size_t &range_begin, const size_t &range_end) const
Position of the first maximum of excess on [range_begin, range_end] (inclusive).
Definition rmm_tree.h:573