Pixie
Loading...
Searching...
No Matches
rmm_tree_sdsl.h
1#pragma once
2
3#ifndef SDSL_SUPPORT
4#error "pixie/rmm_tree_sdsl.h requires SDSL_SUPPORT"
5#endif
6
7#include <pixie/rmm_base.h>
8
9#include <algorithm>
10#include <bit>
11#include <cstddef>
12#include <cstdint>
13#include <sdsl/bit_vectors.hpp>
14
15// SDSL keeps the generic excess-search primitives private and exposes only
16// navigation wrappers such as find_close/find_open. The Pixie comparison
17// backend needs direct fwdsearch/bwdsearch, so expose them in this optional
18// adapter instead of benchmarking a naive fallback.
19#define private public
20#include <sdsl/bp_support_sada.hpp>
21#undef private
22
23#include <span>
24#include <utility>
25#include <vector>
26
27namespace pixie {
28
29class SdslRmMTree : public RmMBase<SdslRmMTree> {
30 public:
31 using BpSupport = sdsl::bp_support_sada<>;
32 static constexpr std::size_t npos = RmMBase<SdslRmMTree>::npos;
33
34 SdslRmMTree() = default;
35
36 SdslRmMTree(std::span<const std::uint64_t> words,
37 std::size_t bit_count,
38 std::size_t _ = 0) {
39 size_ = bit_count;
40 const std::size_t valid_words = (bit_count + 63) / 64;
41 for (std::size_t i = 0; i < valid_words; ++i) {
42 std::uint64_t word = words[i];
43 if (i + 1 == valid_words && (bit_count & 63) != 0) {
44 word &= (std::uint64_t{1} << (bit_count & 63)) - 1;
45 }
46 ones_ += std::popcount(word);
47 }
48 zeros_ = size_ - ones_;
49
50 bits_ = sdsl::bit_vector(size_);
51 prefix_excess_.assign(size_ + 1, 0);
52 int current_excess = 0;
53 for (std::size_t i = 0; i < size_; ++i) {
54 const bool bit = (words[i >> 6] >> (i & 63)) & 1u;
55 bits_[i] = bit;
56 current_excess += bit ? 1 : -1;
57 prefix_excess_[i + 1] = current_excess;
58 max_excess_ = std::max(max_excess_, current_excess);
59 }
60 build_excess_bounds();
61 tree_ = BpSupport(&bits_);
62 }
63
64 SdslRmMTree(const SdslRmMTree& other)
65 : size_(other.size_),
66 ones_(other.ones_),
67 zeros_(other.zeros_),
68 max_excess_(other.max_excess_),
69 prefix_excess_(other.prefix_excess_),
70 prefix_min_excess_(other.prefix_min_excess_),
71 prefix_max_excess_(other.prefix_max_excess_),
72 suffix_min_excess_(other.suffix_min_excess_),
73 suffix_max_excess_(other.suffix_max_excess_),
74 bits_(other.bits_) {
75 reset_support();
76 }
77
78 SdslRmMTree& operator=(const SdslRmMTree& other) {
79 if (this == &other) {
80 return *this;
81 }
82 size_ = other.size_;
83 ones_ = other.ones_;
84 zeros_ = other.zeros_;
85 max_excess_ = other.max_excess_;
86 prefix_excess_ = other.prefix_excess_;
87 prefix_min_excess_ = other.prefix_min_excess_;
88 prefix_max_excess_ = other.prefix_max_excess_;
89 suffix_min_excess_ = other.suffix_min_excess_;
90 suffix_max_excess_ = other.suffix_max_excess_;
91 bits_ = other.bits_;
92 reset_support();
93 return *this;
94 }
95
96 SdslRmMTree(SdslRmMTree&& other) noexcept
97 : size_(other.size_),
98 ones_(other.ones_),
99 zeros_(other.zeros_),
100 max_excess_(other.max_excess_),
101 prefix_excess_(std::move(other.prefix_excess_)),
102 prefix_min_excess_(std::move(other.prefix_min_excess_)),
103 prefix_max_excess_(std::move(other.prefix_max_excess_)),
104 suffix_min_excess_(std::move(other.suffix_min_excess_)),
105 suffix_max_excess_(std::move(other.suffix_max_excess_)),
106 bits_(std::move(other.bits_)) {
107 reset_support();
108 }
109
110 SdslRmMTree& operator=(SdslRmMTree&& other) noexcept {
111 if (this == &other) {
112 return *this;
113 }
114 size_ = other.size_;
115 ones_ = other.ones_;
116 zeros_ = other.zeros_;
117 max_excess_ = other.max_excess_;
118 prefix_excess_ = std::move(other.prefix_excess_);
119 prefix_min_excess_ = std::move(other.prefix_min_excess_);
120 prefix_max_excess_ = std::move(other.prefix_max_excess_);
121 suffix_min_excess_ = std::move(other.suffix_min_excess_);
122 suffix_max_excess_ = std::move(other.suffix_max_excess_);
123 bits_ = std::move(other.bits_);
124 reset_support();
125 return *this;
126 }
127
128 std::size_t size_impl() const { return size_; }
129
130 std::size_t rank1_impl(std::size_t end_position) const {
131 if (size_ == 0 || end_position == 0) {
132 return 0;
133 }
134 return tree_.rank(std::min(end_position, size_) - 1);
135 }
136
137 std::size_t rank0_impl(std::size_t end_position) const {
138 if (size_ == 0 || end_position == 0) {
139 return 0;
140 }
141 if (end_position >= size_) {
142 return zeros_;
143 }
144 return tree_.preceding_closing_parentheses(end_position);
145 }
146
147 std::size_t select1_impl(std::size_t rank) const {
148 if (rank == 0 || rank > ones_) {
149 return npos;
150 }
151 return tree_.select(rank);
152 }
153
154 std::size_t select0_impl(std::size_t) const { return npos; }
155 std::size_t rank10_impl(std::size_t) const { return 0; }
156 std::size_t select10_impl(std::size_t) const { return npos; }
157
158 int excess_impl(std::size_t end_position) const {
159 if (size_ == 0 || end_position == 0) {
160 return 0;
161 }
162 return tree_.excess(std::min(end_position, size_) - 1);
163 }
164
165 std::size_t fwdsearch_impl(std::size_t start_position, int delta) const {
166 if (start_position >= size_) {
167 return npos;
168 }
169 const int target = prefix_excess_[start_position] + delta;
170 if (target > max_excess_) {
171 return npos;
172 }
173
174 if (start_position == 0) {
175 const int first_excess = bits_[0] ? 1 : -1;
176 if (first_excess == delta) {
177 return 0;
178 }
179 if (!suffix_contains(2, target)) {
180 return npos;
181 }
182 const std::size_t position =
183 tree_.fwd_excess(0, static_cast<typename BpSupport::difference_type>(
184 delta - first_excess));
185 return position < size_ ? position : npos;
186 }
187
188 if (!suffix_contains(start_position + 1, target)) {
189 return npos;
190 }
191 const std::size_t position = tree_.fwd_excess(
192 start_position - 1,
193 static_cast<typename BpSupport::difference_type>(delta));
194 return position < size_ ? position : npos;
195 }
196
197 std::size_t bwdsearch_impl(std::size_t start_position, int delta) const {
198 if (start_position == 0 || start_position > size_) {
199 return npos;
200 }
201
202 const std::size_t anchor = start_position - 1;
203 const int target = prefix_excess_[start_position] + delta;
204 if (target > max_excess_) {
205 return npos;
206 }
207 if (prefix_excess_[anchor] == target) {
208 return anchor;
209 }
210 if (anchor == 0) {
211 return npos;
212 }
213 if (!prefix_contains(anchor - 1, target)) {
214 return npos;
215 }
216
217 const std::size_t position = tree_.bwd_excess(
218 anchor, static_cast<typename BpSupport::difference_type>(delta));
219 if (position == static_cast<std::size_t>(-1)) {
220 return target == 0 ? 0 : npos;
221 }
222 return position < size_ ? position + 1 : npos;
223 }
224
225 std::size_t range_min_query_pos_impl(std::size_t range_begin,
226 std::size_t range_end) const {
227 if (size_ == 0) {
228 return 0;
229 }
230 return tree_.rmq(range_begin, range_end);
231 }
232
233 int range_min_query_val_impl(std::size_t range_begin,
234 std::size_t range_end) const {
235 if (size_ == 0) {
236 return 0;
237 }
238 const auto min_position = tree_.rmq(range_begin, range_end);
239 if (min_position >= size_) {
240 return 0;
241 }
242 const auto base_excess =
243 (range_begin == 0 ? 0 : tree_.excess(range_begin - 1));
244 return tree_.excess(min_position) - base_excess;
245 }
246
247 std::size_t range_max_query_pos_impl(std::size_t, std::size_t) const {
248 return npos;
249 }
250 int range_max_query_val_impl(std::size_t, std::size_t) const { return 0; }
251 std::size_t mincount_impl(std::size_t, std::size_t) const { return 0; }
252 std::size_t minselect_impl(std::size_t, std::size_t, std::size_t) const {
253 return npos;
254 }
255
256 std::size_t close_impl(std::size_t open_position) const {
257 if (size_ == 0) {
258 return 0;
259 }
260 const std::size_t position = tree_.find_close(open_position);
261 return position < size_ ? position : npos;
262 }
263
264 std::size_t open_impl(std::size_t close_position) const {
265 if (size_ == 0) {
266 return 0;
267 }
268 const std::size_t position = tree_.find_open(close_position);
269 return position < size_ ? position : npos;
270 }
271
272 std::size_t enclose_impl(std::size_t open_position) const {
273 if (size_ == 0) {
274 return 0;
275 }
276 const std::size_t position = tree_.enclose(open_position);
277 return position < size_ ? position : npos;
278 }
279
280 int bit_impl(const size_t& position) const noexcept {
281 return (bits_[position >> 6] >> (position & 63)) & 1u;
282 }
283
284 private:
285 void reset_support() { tree_ = BpSupport(&bits_); }
286
287 void build_excess_bounds() {
288 prefix_min_excess_.resize(size_ + 1);
289 prefix_max_excess_.resize(size_ + 1);
290 suffix_min_excess_.resize(size_ + 1);
291 suffix_max_excess_.resize(size_ + 1);
292
293 prefix_min_excess_[0] = prefix_excess_[0];
294 prefix_max_excess_[0] = prefix_excess_[0];
295 for (std::size_t i = 1; i <= size_; ++i) {
296 prefix_min_excess_[i] =
297 std::min(prefix_min_excess_[i - 1], prefix_excess_[i]);
298 prefix_max_excess_[i] =
299 std::max(prefix_max_excess_[i - 1], prefix_excess_[i]);
300 }
301
302 suffix_min_excess_[size_] = prefix_excess_[size_];
303 suffix_max_excess_[size_] = prefix_excess_[size_];
304 for (std::size_t i = size_; i > 0;) {
305 --i;
306 suffix_min_excess_[i] =
307 std::min(prefix_excess_[i], suffix_min_excess_[i + 1]);
308 suffix_max_excess_[i] =
309 std::max(prefix_excess_[i], suffix_max_excess_[i + 1]);
310 }
311 }
312
313 bool suffix_contains(std::size_t boundary_begin, int target) const {
314 return boundary_begin <= size_ &&
315 suffix_min_excess_[boundary_begin] <= target &&
316 target <= suffix_max_excess_[boundary_begin];
317 }
318
319 bool prefix_contains(std::size_t boundary_end, int target) const {
320 return prefix_min_excess_[boundary_end] <= target &&
321 target <= prefix_max_excess_[boundary_end];
322 }
323
324 std::size_t size_{};
325 std::size_t ones_{};
326 std::size_t zeros_{};
327 int max_excess_{};
328 std::vector<int> prefix_excess_;
329 std::vector<int> prefix_min_excess_;
330 std::vector<int> prefix_max_excess_;
331 std::vector<int> suffix_min_excess_;
332 std::vector<int> suffix_max_excess_;
333 sdsl::bit_vector bits_;
334 BpSupport tree_;
335};
336
337} // namespace pixie
CRTP facade for rank/select and range min-max tree operations.
Definition rmm_base.h:18
int bit(const size_t &position) const noexcept
Definition rmm_base.h:199
static constexpr std::size_t npos
Sentinel returned by position queries when no valid answer exists.
Definition rmm_base.h:23