defuse_bitmap/
lib.rs

1mod b256;
2
3pub use self::b256::*;
4
5use core::ops::{DerefMut, RangeInclusive, Shl};
6
7use defuse_map_utils::{IterableMap, Map, cleanup::DefaultMap};
8use num_traits::{One, PrimInt, Zero};
9
10/// Bitmap for primitive types
11#[cfg_attr(
12    feature = "borsh",
13    derive(::borsh::BorshSerialize, ::borsh::BorshDeserialize),
14    cfg_attr(feature = "abi", derive(::borsh::BorshSchema))
15)]
16#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
17#[repr(transparent)]
18pub struct BitMap<M>(M);
19
20impl<M> BitMap<M>
21where
22    M: DefaultMap<K = <M as Map>::V>,
23    M::K: PrimInt + Shl<M::K, Output = M::K>,
24{
25    #[allow(clippy::as_conversions)]
26    const BITS_FOR_BIT_POS: usize = (size_of::<M::K>() * 8).ilog2() as usize;
27
28    #[inline]
29    pub const fn new(map: M) -> Self {
30        Self(map)
31    }
32
33    /// Get the bit `n`
34    #[inline]
35    pub fn get_bit(&self, n: M::K) -> bool {
36        let (word, bit_mask) = Self::split_word_mask(n);
37        let Some(bitmap) = self.0.get(&word) else {
38            return false;
39        };
40        *bitmap & bit_mask != M::V::zero()
41    }
42
43    /// Set the bit `n` and return old value
44    #[inline]
45    pub fn set_bit(&mut self, n: M::K) -> bool {
46        let (mut bitmap, mask) = self.get_mut_with_mask(n);
47        let old = *bitmap & mask != M::V::zero();
48        *bitmap = *bitmap | mask;
49        old
50    }
51
52    /// Clear the bit `n` and return old value
53    #[inline]
54    pub fn clear_bit(&mut self, n: M::K) -> bool {
55        let (mut bitmap, mask) = self.get_mut_with_mask(n);
56        let old = *bitmap & mask != M::V::zero();
57        *bitmap = *bitmap & !mask;
58        old
59    }
60
61    /// Toggle the bit `n` and return old value
62    #[inline]
63    pub fn toggle_bit(&mut self, n: M::K) -> bool {
64        let (mut bitmap, mask) = self.get_mut_with_mask(n);
65        let old = *bitmap & mask != M::V::zero();
66        *bitmap = *bitmap ^ mask;
67        old
68    }
69
70    /// Set bit `n` to given value and return old value
71    #[inline]
72    pub fn set_bit_to(&mut self, n: M::K, v: bool) -> bool {
73        if v {
74            self.set_bit(n)
75        } else {
76            self.clear_bit(n)
77        }
78    }
79
80    /// Iterate over set bits
81    pub fn as_iter(&self) -> impl Iterator<Item = M::V> + '_
82    where
83        M: IterableMap,
84        RangeInclusive<M::V>: Iterator<Item = M::V>,
85    {
86        self.0.iter().flat_map(|(prefix, bitmap)| {
87            (M::V::zero()..=Self::bit_pos_mask())
88                .filter(|&bit_pos| {
89                    let bit_mask = M::V::one() << bit_pos;
90                    *bitmap & bit_mask != M::V::zero()
91                })
92                .map(|bit_pos| (*prefix << Self::BITS_FOR_BIT_POS) | bit_pos)
93        })
94    }
95
96    #[inline]
97    fn get_mut_with_mask(&mut self, n: M::K) -> (impl DerefMut<Target = M::V>, M::V) {
98        let (word, bit_mask) = Self::split_word_mask(n);
99        (self.0.entry_or_default(word), bit_mask)
100    }
101
102    /// Returns `(word, bit_pos_mask)`
103    #[inline]
104    fn split_word_mask(n: M::K) -> (M::K, M::V) {
105        let word = n >> Self::BITS_FOR_BIT_POS;
106        let bit_mask = M::V::one() << (n & Self::bit_pos_mask());
107        (word, bit_mask)
108    }
109
110    #[inline]
111    fn bit_pos_mask() -> M::V {
112        (M::V::one() << Self::BITS_FOR_BIT_POS) - M::V::one()
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use std::{collections::BTreeMap, fmt::Debug};
119
120    use rstest::rstest;
121
122    use super::*;
123
124    #[allow(clippy::used_underscore_binding)]
125    #[rstest]
126    fn test<T>(#[values(0u8, 0u16, 0u32, 0u64, 0u128)] _n: T)
127    where
128        T: PrimInt + Shl<T, Output = T> + Default,
129    {
130        let mut m = BitMap::<BTreeMap<T, T>>::default();
131
132        for n in [
133            T::zero(),
134            T::one(),
135            T::max_value() - T::one(),
136            T::max_value(),
137        ] {
138            assert!(!m.get_bit(n));
139
140            assert!(!m.set_bit(n));
141            assert!(m.get_bit(n));
142            assert!(m.set_bit(n));
143            assert!(m.get_bit(n));
144
145            assert!(m.clear_bit(n));
146            assert!(!m.get_bit(n));
147            assert!(!m.clear_bit(n));
148            assert!(!m.get_bit(n));
149        }
150    }
151
152    #[rstest]
153    fn as_iter<T>(
154        #[values(
155            Vec::<u8>::new(),
156            vec![0u8],
157            vec![3u16, 0, 2, 7, u16::MAX],
158            vec![1000u32, 15, 23, 717, 999, u32::MAX],
159        )]
160        mut ns: Vec<T>,
161    ) where
162        RangeInclusive<T>: Iterator<Item = T>,
163        T: PrimInt + Shl<T, Output = T> + Debug + Default,
164    {
165        let mut m = BitMap::<BTreeMap<T, T>>::default();
166        for n in &ns {
167            assert!(!m.set_bit(*n));
168        }
169        ns.sort();
170        assert_eq!(m.as_iter().collect::<Vec<_>>(), ns);
171    }
172}