defuse_bitmap/
b256.rs

1use defuse_map_utils::{IterableMap, Map};
2use near_sdk::near;
3
4pub type U256 = [u8; 32];
5pub type U248 = [u8; 31];
6
7/// 256-bit map.  
8/// See [permit2 nonce schema](https://docs.uniswap.org/contracts/permit2/reference/signature-transfer#nonce-schema)
9#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
10#[near(serializers = [borsh, json])]
11#[derive(Debug, Clone, Default)]
12#[repr(transparent)]
13pub struct BitMap256<T: Map<K = U248, V = U256>>(T);
14
15impl<T> BitMap256<T>
16where
17    T: Map<K = U248, V = U256>,
18{
19    #[inline]
20    pub const fn new(map: T) -> Self {
21        Self(map)
22    }
23
24    /// Get the bit `n`
25    #[inline]
26    pub fn get_bit(&self, n: U256) -> bool {
27        let [word_pos @ .., bit_pos] = n;
28        let Some(bitmap) = self.0.get(&word_pos) else {
29            return false;
30        };
31        let byte = bitmap[usize::from(bit_pos / 8)];
32        let byte_mask = 1 << (bit_pos % 8);
33        byte & byte_mask != 0
34    }
35
36    #[inline]
37    fn get_mut_byte_with_mask(&mut self, n: U256) -> (&mut u8, u8) {
38        let [word_pos @ .., bit_pos] = n;
39        let bitmap = self.0.entry(word_pos).or_default();
40        let byte = &mut bitmap[usize::from(bit_pos / 8)];
41        let byte_mask = 1 << (bit_pos % 8);
42        (byte, byte_mask)
43    }
44
45    #[inline]
46    pub fn cleanup_by_prefix(&mut self, prefix: U248) -> bool {
47        self.0.remove(&prefix).is_some()
48    }
49
50    /// Set the bit `n` and return old value
51    #[inline]
52    pub fn set_bit(&mut self, n: U256) -> bool {
53        let (byte, mask) = self.get_mut_byte_with_mask(n);
54        let old = *byte & mask != 0;
55        *byte |= mask;
56        old
57    }
58
59    /// Clear the bit `n` and return old value
60    #[inline]
61    pub fn clear_bit(&mut self, n: U256) -> bool {
62        let (byte, mask) = self.get_mut_byte_with_mask(n);
63        let old = *byte & mask != 0;
64        *byte &= !mask;
65        old
66    }
67
68    /// Toggle the bit `n` and return old value
69    #[inline]
70    pub fn toggle_bit(&mut self, n: U256) -> bool {
71        let (byte, mask) = self.get_mut_byte_with_mask(n);
72        let old = *byte & mask != 0;
73        *byte ^= mask;
74        old
75    }
76
77    /// Set bit `n` to given value
78    #[inline]
79    pub fn set_bit_to(&mut self, n: U256, v: bool) -> bool {
80        if v {
81            self.set_bit(n)
82        } else {
83            self.clear_bit(n)
84        }
85    }
86
87    /// Iterate over set U256
88    #[inline]
89    pub fn as_iter(&self) -> impl Iterator<Item = U256> + '_
90    where
91        T: IterableMap,
92    {
93        self.0.iter().flat_map(|(prefix, bitmap)| {
94            (0..=u8::MAX)
95                .filter(|&bit_pos| {
96                    let byte = bitmap[usize::from(bit_pos / 8)];
97                    let byte_mask = 1 << (bit_pos % 8);
98                    byte & byte_mask != 0
99                })
100                .map(|bit_pos| {
101                    let mut nonce: U256 = [0; 32];
102                    nonce[..prefix.len()].copy_from_slice(prefix);
103                    nonce[prefix.len()..].copy_from_slice(&[bit_pos]);
104                    nonce
105                })
106        })
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use std::collections::{HashMap, HashSet};
113
114    use bnum::BUintD8;
115    use hex_literal::hex;
116    use rstest::rstest;
117
118    use super::*;
119
120    #[test]
121    fn test() {
122        type N = BUintD8<32>;
123
124        let mut m = BitMap256::<HashMap<U248, U256>>::default();
125
126        for n in [N::ZERO, N::ONE, N::MAX - N::ONE, N::MAX].map(Into::into) {
127            assert!(!m.get_bit(n));
128
129            assert!(!m.set_bit(n));
130            assert!(m.get_bit(n));
131            assert!(m.set_bit(n));
132            assert!(m.get_bit(n));
133
134            assert!(m.clear_bit(n));
135            assert!(!m.get_bit(n));
136            assert!(!m.clear_bit(n));
137            assert!(!m.get_bit(n));
138        }
139    }
140
141    #[rstest]
142    #[case(&[])]
143    #[case(&[hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
144    #[case(&[hex!("0000000000000000000000000000000000000000000000000000000000000000"), hex!("0000000000000000000000000000000000000000000000000000000000000001")])]
145    #[case(&[hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00"), hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
146    #[case(&[hex!("0000000000000000000000000000000000000000000000000000000000000000"), hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")])]
147    fn iter(#[case] nonces: &[U256]) {
148        let mut m = BitMap256::<HashMap<U248, U256>>::default();
149        for n in nonces {
150            assert!(!m.set_bit(*n));
151        }
152
153        let all: HashSet<_> = m.as_iter().collect();
154        assert_eq!(all, nonces.iter().copied().collect());
155    }
156}