defuse_core/nonce/
salted.rs

1use core::mem;
2use hex::FromHex;
3use near_sdk::{
4    IntoStorageKey,
5    borsh::{BorshDeserialize, BorshSerialize},
6    env::{self, sha256_array},
7    near,
8    store::{IterableMap, key::Identity},
9};
10use serde_with::{DeserializeFromStr, SerializeDisplay};
11use std::{
12    fmt::{self, Debug},
13    str::FromStr,
14};
15
16use crate::{DefuseError, Result};
17
18#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
19#[derive(PartialEq, PartialOrd, Ord, Eq, Copy, Clone, SerializeDisplay, DeserializeFromStr)]
20#[near(serializers = [borsh])]
21pub struct Salt([u8; 4]);
22
23impl Salt {
24    pub fn derive(num: u8) -> Self {
25        const SIZE: usize = size_of::<Salt>();
26
27        let seed = env::random_seed_array();
28        let mut input = [0u8; 33];
29        input[..32].copy_from_slice(&seed);
30        input[32] = num;
31
32        Self(
33            sha256_array(input)[..SIZE]
34                .try_into()
35                .unwrap_or_else(|_| unreachable!()),
36        )
37    }
38}
39
40impl fmt::Debug for Salt {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        write!(f, "{}", hex::encode(self.0))
43    }
44}
45
46impl fmt::Display for Salt {
47    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48        Debug::fmt(self, f)
49    }
50}
51
52impl FromStr for Salt {
53    type Err = hex::FromHexError;
54
55    fn from_str(s: &str) -> Result<Self, Self::Err> {
56        FromHex::from_hex(s).map(Self)
57    }
58}
59
60#[cfg(all(feature = "abi", not(target_arch = "wasm32")))]
61const _: () = {
62    use near_sdk::{
63        schemars::{
64            JsonSchema,
65            r#gen::SchemaGenerator,
66            schema::{InstanceType, Metadata, Schema, SchemaObject},
67        },
68        serde_json,
69    };
70
71    impl JsonSchema for Salt {
72        fn schema_name() -> String {
73            String::schema_name()
74        }
75
76        fn is_referenceable() -> bool {
77            false
78        }
79
80        fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
81            SchemaObject {
82                instance_type: Some(InstanceType::String.into()),
83                extensions: [("contentEncoding", "hex".into())]
84                    .into_iter()
85                    .map(|(k, v)| (k.to_string(), v))
86                    .collect(),
87                ..Default::default()
88            }
89            .into()
90        }
91    }
92};
93
94/// Contains current valid salt and set of previous
95/// salts that can be valid or invalid.
96#[near(serializers = [borsh])]
97#[derive(Debug)]
98pub struct SaltRegistry {
99    previous: IterableMap<Salt, bool, Identity>,
100    current: Salt,
101}
102
103impl SaltRegistry {
104    /// There can be only one valid salt at the beginning
105    #[inline]
106    pub fn new<S>(prefix: S) -> Self
107    where
108        S: IntoStorageKey,
109    {
110        Self {
111            previous: IterableMap::with_hasher(prefix),
112            current: Salt::derive(0),
113        }
114    }
115
116    fn derive_next_salt(&self) -> Result<Salt> {
117        (0..=u8::MAX)
118            .map(Salt::derive)
119            .find(|s| !self.is_used(*s))
120            .ok_or(DefuseError::SaltGenerationFailed)
121    }
122
123    /// Rotates the current salt, making it previous and keeping it valid.
124    #[inline]
125    pub fn set_new(&mut self) -> Result<Salt> {
126        let salt = self.derive_next_salt()?;
127
128        let previous = mem::replace(&mut self.current, salt);
129        self.previous.insert(previous, true);
130
131        Ok(previous)
132    }
133
134    /// Deactivates the previous salt, making it invalid.
135    #[inline]
136    pub fn invalidate(&mut self, salt: Salt) -> Result<()> {
137        if salt == self.current {
138            self.set_new()?;
139        }
140
141        self.previous
142            .get_mut(&salt)
143            .map(|v| *v = false)
144            .ok_or(DefuseError::InvalidSalt)
145    }
146
147    #[inline]
148    pub fn is_valid(&self, salt: Salt) -> bool {
149        salt == self.current || self.previous.get(&salt).is_some_and(|v| *v)
150    }
151
152    #[inline]
153    fn is_used(&self, salt: Salt) -> bool {
154        salt == self.current || self.previous.contains_key(&salt)
155    }
156
157    #[inline]
158    pub const fn current(&self) -> Salt {
159        self.current
160    }
161}
162
163#[derive(Clone, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
164#[borsh(crate = "::near_sdk::borsh")]
165pub struct SaltedNonce<T>
166where
167    T: BorshSerialize + BorshDeserialize,
168{
169    pub salt: Salt,
170    pub nonce: T,
171}
172
173impl<T> SaltedNonce<T>
174where
175    T: BorshSerialize + BorshDeserialize,
176{
177    pub const fn new(salt: Salt, nonce: T) -> Self {
178        Self { salt, nonce }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    use arbitrary::Unstructured;
187    use defuse_test_utils::random::{Rng, random_bytes, rng};
188    use near_sdk::{test_utils::VMContextBuilder, testing_env};
189
190    use rstest::rstest;
191
192    impl From<&[u8]> for Salt {
193        fn from(value: &[u8]) -> Self {
194            let mut result = [0u8; 4];
195            result.copy_from_slice(&value[..4]);
196            Self(result)
197        }
198    }
199
200    fn seed_to_salt(seed: &[u8; 32], attempts: u8) -> Salt {
201        let seed = [seed, attempts.to_be_bytes().as_ref()].concat();
202        let hash = sha256_array(&seed);
203
204        hash[..4].into()
205    }
206
207    fn set_random_seed(rng: &mut impl Rng) -> [u8; 32] {
208        let seed = rng.random();
209        let context = VMContextBuilder::new().random_seed(seed).build();
210        testing_env!(context);
211
212        seed
213    }
214
215    #[rstest]
216    fn contains_salt_test(random_bytes: Vec<u8>) {
217        let random_salt: Salt = Unstructured::new(&random_bytes).arbitrary().unwrap();
218        let salts = SaltRegistry::new(random_bytes);
219
220        assert!(salts.is_valid(salts.current));
221        assert!(!salts.is_valid(random_salt));
222    }
223
224    #[rstest]
225    fn update_current_salt_test(random_bytes: Vec<u8>, mut rng: impl Rng) {
226        let mut salts = SaltRegistry::new(random_bytes);
227
228        let seed = set_random_seed(&mut rng);
229        let previous_salt = salts.set_new().expect("should set new salt");
230
231        assert!(salts.is_valid(seed_to_salt(&seed, 0)));
232        assert!(salts.is_valid(previous_salt));
233
234        let previous_salt = salts.set_new().expect("should set new salt");
235        assert!(salts.is_valid(seed_to_salt(&seed, 1)));
236        assert!(salts.is_valid(previous_salt));
237    }
238
239    #[rstest]
240    fn reset_salt_test(random_bytes: Vec<u8>, mut rng: impl Rng) {
241        let mut salts = SaltRegistry::new(random_bytes);
242        let random_salt = rng.random::<[u8; 4]>().as_slice().into();
243
244        let seed = set_random_seed(&mut rng);
245        let current = seed_to_salt(&seed, 0);
246        let previous_salt = salts.set_new().expect("should set new salt");
247
248        assert!(salts.invalidate(previous_salt).is_ok());
249        assert!(!salts.is_valid(previous_salt));
250        assert!(matches!(
251            salts.invalidate(random_salt).unwrap_err(),
252            DefuseError::InvalidSalt
253        ));
254
255        let seed = set_random_seed(&mut rng);
256        let new_salt = seed_to_salt(&seed, 0);
257
258        assert!(salts.invalidate(current).is_ok());
259        assert!(!salts.is_valid(current));
260        assert_eq!(salts.current(), new_salt);
261    }
262
263    #[rstest]
264    fn derive_next_test(random_bytes: Vec<u8>) {
265        let mut salt_registry = SaltRegistry::new(random_bytes);
266
267        let prev = salt_registry.set_new().unwrap();
268
269        salt_registry.invalidate(prev).unwrap();
270        salt_registry.set_new().unwrap();
271
272        assert!(!salt_registry.is_valid(prev));
273        assert!(salt_registry.is_used(prev));
274    }
275}