defuse_bitmap/
lib.rs

1use defuse_map_utils::{IterableMap, Map};
2use near_sdk::near;
3
4pub type U256 = [u8; 32];
5pub type U248 = [u8; 31];
6
7/// 256-bit map.  
8/// See [permit2 nonce schema](https://docs.uniswap.org/contracts/permit2/reference/signature-transfer#nonce-schema)
9#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
10#[near(serializers = [borsh, json])]
11#[derive(Debug, Clone, Default)]
12pub struct BitMap256<T: Map<K = U248, V = U256>>(T);
13
14impl<T> BitMap256<T>
15where
16    T: Map<K = U248, V = U256>,
17{
18    #[inline]
19    pub const fn new(map: T) -> Self {
20        Self(map)
21    }
22
23    /// Get the bit `n`
24    #[inline]
25    pub fn get_bit(&self, n: U256) -> bool {
26        let [word_pos @ .., bit_pos] = n;
27        let Some(bitmap) = self.0.get(&word_pos) else {
28            return false;
29        };
30        let byte = bitmap[usize::from(bit_pos / 8)];
31        let byte_mask = 1 << (bit_pos % 8);
32        byte & byte_mask != 0
33    }
34
35    #[inline]
36    fn get_mut_byte_with_mask(&mut self, n: U256) -> (&mut u8, u8) {
37        let [word_pos @ .., bit_pos] = n;
38        let bitmap = self.0.entry(word_pos).or_default();
39        let byte = &mut bitmap[usize::from(bit_pos / 8)];
40        let byte_mask = 1 << (bit_pos % 8);
41        (byte, byte_mask)
42    }
43
44    #[inline]
45    pub fn cleanup_by_prefix(&mut self, prefix: U248) -> bool {
46        self.0.remove(&prefix).is_some()
47    }
48
49    /// Set the bit `n` and return old value
50    #[inline]
51    pub fn set_bit(&mut self, n: U256) -> bool {
52        let (byte, mask) = self.get_mut_byte_with_mask(n);
53        let old = *byte & mask != 0;
54        *byte |= mask;
55        old
56    }
57
58    /// Clear the bit `n` and return old value
59    #[inline]
60    pub fn clear_bit(&mut self, n: U256) -> bool {
61        let (byte, mask) = self.get_mut_byte_with_mask(n);
62        let old = *byte & mask != 0;
63        *byte &= !mask;
64        old
65    }
66
67    /// Toggle the bit `n` and return old value
68    #[inline]
69    pub fn toggle_bit(&mut self, n: U256) -> bool {
70        let (byte, mask) = self.get_mut_byte_with_mask(n);
71        let old = *byte & mask != 0;
72        *byte ^= mask;
73        old
74    }
75
76    /// Set bit `n` to given value
77    #[inline]
78    pub fn set_bit_to(&mut self, n: U256, v: bool) -> bool {
79        if v {
80            self.set_bit(n)
81        } else {
82            self.clear_bit(n)
83        }
84    }
85
86    /// Iterate over set U256
87    #[inline]
88    pub fn as_iter(&self) -> impl Iterator<Item = U256> + '_
89    where
90        T: IterableMap,
91    {
92        self.0.iter().flat_map(|(prefix, bitmap)| {
93            (0..=u8::MAX)
94                .filter(|&bit_pos| {
95                    let byte = bitmap[usize::from(bit_pos / 8)];
96                    let byte_mask = 1 << (bit_pos % 8);
97                    byte & byte_mask != 0
98                })
99                .map(|bit_pos| {
100                    let mut nonce: U256 = [0; 32];
101                    nonce[..prefix.len()].copy_from_slice(prefix);
102                    nonce[prefix.len()..].copy_from_slice(&[bit_pos]);
103                    nonce
104                })
105        })
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use std::collections::{HashMap, HashSet};
112
113    use bnum::BUintD8;
114    use hex_literal::hex;
115    use rstest::rstest;
116
117    use super::*;
118
119    #[test]
120    fn test() {
121        type N = BUintD8<32>;
122
123        let mut m = BitMap256::<HashMap<U248, U256>>::default();
124
125        for n in [N::ZERO, N::ONE, N::MAX - N::ONE, N::MAX].map(Into::into) {
126            assert!(!m.get_bit(n));
127
128            assert!(!m.set_bit(n));
129            assert!(m.get_bit(n));
130            assert!(m.set_bit(n));
131            assert!(m.get_bit(n));
132
133            assert!(m.clear_bit(n));
134            assert!(!m.get_bit(n));
135            assert!(!m.clear_bit(n));
136            assert!(!m.get_bit(n));
137        }
138    }
139
140    #[rstest]
141    #[case(&[])]
142    #[case(&[hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
143    #[case(&[hex!("0000000000000000000000000000000000000000000000000000000000000000"), hex!("0000000000000000000000000000000000000000000000000000000000000001")])]
144    #[case(&[hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00"), hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
145    #[case(&[hex!("0000000000000000000000000000000000000000000000000000000000000000"), hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
146    fn iter(#[case] nonces: &[U256]) {
147        let mut m = BitMap256::<HashMap<U248, U256>>::default();
148        for n in nonces {
149            assert!(!m.set_bit(*n));
150        }
151
152        let all: HashSet<_> = m.as_iter().collect();
153        assert_eq!(all, nonces.iter().copied().collect());
154    }
155}