defuse_bitmap/
b256.rs

1use defuse_map_utils::{IterableMap, Map};
2
3pub type U256 = [u8; 32];
4pub type U248 = [u8; 31];
5
6/// 256-bit map.
7/// See [permit2 nonce schema](https://docs.uniswap.org/contracts/permit2/reference/signature-transfer#nonce-schema)
8#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
9#[cfg_attr(
10    feature = "borsh",
11    derive(::borsh::BorshSerialize, ::borsh::BorshDeserialize),
12    cfg_attr(feature = "abi", derive(::borsh::BorshSchema))
13)]
14#[derive(Debug, Clone, Default)]
15#[repr(transparent)]
16pub struct BitMap256<T: Map<K = U248, V = U256>>(T);
17
18impl<T> BitMap256<T>
19where
20    T: Map<K = U248, V = U256>,
21{
22    #[inline]
23    pub const fn new(map: T) -> Self {
24        Self(map)
25    }
26
27    /// Get the bit `n`
28    #[inline]
29    pub fn get_bit(&self, n: U256) -> bool {
30        let [word_pos @ .., bit_pos] = n;
31        let Some(bitmap) = self.0.get(&word_pos) else {
32            return false;
33        };
34        let byte = bitmap[usize::from(bit_pos / 8)];
35        let byte_mask = 1 << (bit_pos % 8);
36        byte & byte_mask != 0
37    }
38
39    #[inline]
40    fn get_mut_byte_with_mask(&mut self, n: U256) -> (&mut u8, u8) {
41        let [word_pos @ .., bit_pos] = n;
42        let bitmap = self.0.entry(word_pos).or_default();
43        let byte = &mut bitmap[usize::from(bit_pos / 8)];
44        let byte_mask = 1 << (bit_pos % 8);
45        (byte, byte_mask)
46    }
47
48    #[inline]
49    pub fn cleanup_by_prefix(&mut self, prefix: U248) -> bool {
50        self.0.remove(&prefix).is_some()
51    }
52
53    /// Set the bit `n` and return old value
54    #[inline]
55    pub fn set_bit(&mut self, n: U256) -> bool {
56        let (byte, mask) = self.get_mut_byte_with_mask(n);
57        let old = *byte & mask != 0;
58        *byte |= mask;
59        old
60    }
61
62    /// Clear the bit `n` and return old value
63    #[inline]
64    pub fn clear_bit(&mut self, n: U256) -> bool {
65        let (byte, mask) = self.get_mut_byte_with_mask(n);
66        let old = *byte & mask != 0;
67        *byte &= !mask;
68        old
69    }
70
71    /// Toggle the bit `n` and return old value
72    #[inline]
73    pub fn toggle_bit(&mut self, n: U256) -> bool {
74        let (byte, mask) = self.get_mut_byte_with_mask(n);
75        let old = *byte & mask != 0;
76        *byte ^= mask;
77        old
78    }
79
80    /// Set bit `n` to given value
81    #[inline]
82    pub fn set_bit_to(&mut self, n: U256, v: bool) -> bool {
83        if v {
84            self.set_bit(n)
85        } else {
86            self.clear_bit(n)
87        }
88    }
89
90    /// Iterate over set U256
91    #[inline]
92    pub fn as_iter(&self) -> impl Iterator<Item = U256> + '_
93    where
94        T: IterableMap,
95    {
96        self.0.iter().flat_map(|(prefix, bitmap)| {
97            (0..=u8::MAX)
98                .filter(|&bit_pos| {
99                    let byte = bitmap[usize::from(bit_pos / 8)];
100                    let byte_mask = 1 << (bit_pos % 8);
101                    byte & byte_mask != 0
102                })
103                .map(|bit_pos| {
104                    let mut nonce: U256 = [0; 32];
105                    nonce[..prefix.len()].copy_from_slice(prefix);
106                    nonce[prefix.len()..].copy_from_slice(&[bit_pos]);
107                    nonce
108                })
109        })
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::collections::{HashMap, HashSet};
116
117    use bnum::BUintD8;
118    use hex_literal::hex;
119    use rstest::rstest;
120
121    use super::*;
122
123    #[test]
124    fn test() {
125        type N = BUintD8<32>;
126
127        let mut m = BitMap256::<HashMap<U248, U256>>::default();
128
129        for n in [N::ZERO, N::ONE, N::MAX - N::ONE, N::MAX].map(Into::into) {
130            assert!(!m.get_bit(n));
131
132            assert!(!m.set_bit(n));
133            assert!(m.get_bit(n));
134            assert!(m.set_bit(n));
135            assert!(m.get_bit(n));
136
137            assert!(m.clear_bit(n));
138            assert!(!m.get_bit(n));
139            assert!(!m.clear_bit(n));
140            assert!(!m.get_bit(n));
141        }
142    }
143
144    #[rstest]
145    #[case(&[])]
146    #[case(&[hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
147    #[case(&[hex!("0000000000000000000000000000000000000000000000000000000000000000"), hex!("0000000000000000000000000000000000000000000000000000000000000001")])]
148    #[case(&[hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00"), hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
149    #[case(&[hex!("0000000000000000000000000000000000000000000000000000000000000000"), hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
150    fn iter(#[case] nonces: &[U256]) {
151        let mut m = BitMap256::<HashMap<U248, U256>>::default();
152        for n in nonces {
153            assert!(!m.set_bit(*n));
154        }
155
156        let all: HashSet<_> = m.as_iter().collect();
157        assert_eq!(all, nonces.iter().copied().collect());
158    }
159}