defuse_bitmap/
lib.rs

1mod b256;
2
3pub use self::b256::*;
4
5use core::ops::{DerefMut, RangeInclusive, Shl};
6
7use defuse_map_utils::{IterableMap, Map, cleanup::DefaultMap};
8use near_sdk::near;
9use num_traits::{One, PrimInt, Zero};
10
11/// Bitmap for primitive types
12#[near(serializers = [borsh, json])]
13#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
14#[repr(transparent)]
15pub struct BitMap<M>(M);
16
17impl<M> BitMap<M>
18where
19    M: DefaultMap<K = <M as Map>::V>,
20    M::K: PrimInt + Shl<M::K, Output = M::K>,
21{
22    #[allow(clippy::as_conversions)]
23    const BITS_FOR_BIT_POS: usize = (size_of::<M::K>() * 8).ilog2() as usize;
24
25    #[inline]
26    pub const fn new(map: M) -> Self {
27        Self(map)
28    }
29
30    /// Get the bit `n`
31    #[inline]
32    pub fn get_bit(&self, n: M::K) -> bool {
33        let (word, bit_mask) = Self::split_word_mask(n);
34        let Some(bitmap) = self.0.get(&word) else {
35            return false;
36        };
37        *bitmap & bit_mask != M::V::zero()
38    }
39
40    /// Set the bit `n` and return old value
41    #[inline]
42    pub fn set_bit(&mut self, n: M::K) -> bool {
43        let (mut bitmap, mask) = self.get_mut_with_mask(n);
44        let old = *bitmap & mask != M::V::zero();
45        *bitmap = *bitmap | mask;
46        old
47    }
48
49    /// Clear the bit `n` and return old value
50    #[inline]
51    pub fn clear_bit(&mut self, n: M::K) -> bool {
52        let (mut bitmap, mask) = self.get_mut_with_mask(n);
53        let old = *bitmap & mask != M::V::zero();
54        *bitmap = *bitmap & !mask;
55        old
56    }
57
58    /// Toggle the bit `n` and return old value
59    #[inline]
60    pub fn toggle_bit(&mut self, n: M::K) -> bool {
61        let (mut bitmap, mask) = self.get_mut_with_mask(n);
62        let old = *bitmap & mask != M::V::zero();
63        *bitmap = *bitmap ^ mask;
64        old
65    }
66
67    /// Set bit `n` to given value and return old value
68    #[inline]
69    pub fn set_bit_to(&mut self, n: M::K, v: bool) -> bool {
70        if v {
71            self.set_bit(n)
72        } else {
73            self.clear_bit(n)
74        }
75    }
76
77    /// Iterate over set bits
78    pub fn as_iter(&self) -> impl Iterator<Item = M::V> + '_
79    where
80        M: IterableMap,
81        RangeInclusive<M::V>: Iterator<Item = M::V>,
82    {
83        self.0.iter().flat_map(|(prefix, bitmap)| {
84            (M::V::zero()..=Self::bit_pos_mask())
85                .filter(|&bit_pos| {
86                    let bit_mask = M::V::one() << bit_pos;
87                    *bitmap & bit_mask != M::V::zero()
88                })
89                .map(|bit_pos| (*prefix << Self::BITS_FOR_BIT_POS) | bit_pos)
90        })
91    }
92
93    #[inline]
94    fn get_mut_with_mask(&mut self, n: M::K) -> (impl DerefMut<Target = M::V>, M::V) {
95        let (word, bit_mask) = Self::split_word_mask(n);
96        (self.0.entry_or_default(word), bit_mask)
97    }
98
99    /// Returns `(word, bit_pos_mask)`
100    #[inline]
101    fn split_word_mask(n: M::K) -> (M::K, M::V) {
102        let word = n >> Self::BITS_FOR_BIT_POS;
103        let bit_mask = M::V::one() << (n & Self::bit_pos_mask());
104        (word, bit_mask)
105    }
106
107    #[inline]
108    fn bit_pos_mask() -> M::V {
109        (M::V::one() << Self::BITS_FOR_BIT_POS) - M::V::one()
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::{collections::BTreeMap, fmt::Debug};
116
117    use rstest::rstest;
118
119    use super::*;
120
121    #[allow(clippy::used_underscore_binding)]
122    #[rstest]
123    fn test<T>(#[values(0u8, 0u16, 0u32, 0u64, 0u128)] _n: T)
124    where
125        T: PrimInt + Shl<T, Output = T> + Default,
126    {
127        let mut m = BitMap::<BTreeMap<T, T>>::default();
128
129        for n in [
130            T::zero(),
131            T::one(),
132            T::max_value() - T::one(),
133            T::max_value(),
134        ] {
135            assert!(!m.get_bit(n));
136
137            assert!(!m.set_bit(n));
138            assert!(m.get_bit(n));
139            assert!(m.set_bit(n));
140            assert!(m.get_bit(n));
141
142            assert!(m.clear_bit(n));
143            assert!(!m.get_bit(n));
144            assert!(!m.clear_bit(n));
145            assert!(!m.get_bit(n));
146        }
147    }
148
149    #[rstest]
150    fn as_iter<T>(
151        #[values(
152            Vec::<u8>::new(),
153            vec![0u8],
154            vec![3u16, 0, 2, 7, u16::MAX],
155            vec![1000u32, 15, 23, 717, 999, u32::MAX],
156        )]
157        mut ns: Vec<T>,
158    ) where
159        RangeInclusive<T>: Iterator<Item = T>,
160        T: PrimInt + Shl<T, Output = T> + Debug + Default,
161    {
162        let mut m = BitMap::<BTreeMap<T, T>>::default();
163        for n in &ns {
164            assert!(!m.set_bit(*n));
165        }
166        ns.sort();
167        assert_eq!(m.as_iter().collect::<Vec<_>>(), ns);
168    }
169}