defuse_core/nonce/
salted.rs1use core::mem;
2use hex::FromHex;
3use near_sdk::{
4 IntoStorageKey,
5 borsh::{BorshDeserialize, BorshSerialize},
6 env::{self, sha256_array},
7 near,
8 store::{IterableMap, key::Identity},
9};
10use serde_with::{DeserializeFromStr, SerializeDisplay};
11use std::{
12 fmt::{self, Debug},
13 str::FromStr,
14};
15
16use crate::{DefuseError, Result};
17
18#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
19#[derive(PartialEq, PartialOrd, Ord, Eq, Copy, Clone, SerializeDisplay, DeserializeFromStr)]
20#[near(serializers = [borsh])]
21pub struct Salt([u8; 4]);
22
23impl Salt {
24 pub fn derive(num: u8) -> Self {
25 const SIZE: usize = size_of::<Salt>();
26
27 let seed = env::random_seed_array();
28 let mut input = [0u8; 33];
29 input[..32].copy_from_slice(&seed);
30 input[32] = num;
31
32 Self(
33 sha256_array(input)[..SIZE]
34 .try_into()
35 .unwrap_or_else(|_| unreachable!()),
36 )
37 }
38}
39
40impl fmt::Debug for Salt {
41 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42 write!(f, "{}", hex::encode(self.0))
43 }
44}
45
46impl fmt::Display for Salt {
47 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48 Debug::fmt(self, f)
49 }
50}
51
52impl FromStr for Salt {
53 type Err = hex::FromHexError;
54
55 fn from_str(s: &str) -> Result<Self, Self::Err> {
56 FromHex::from_hex(s).map(Self)
57 }
58}
59
60#[cfg(all(feature = "abi", not(target_arch = "wasm32")))]
61const _: () = {
62 use near_sdk::schemars::{
63 JsonSchema,
64 r#gen::SchemaGenerator,
65 schema::{InstanceType, Schema, SchemaObject},
66 };
67
68 impl JsonSchema for Salt {
69 fn schema_name() -> String {
70 String::schema_name()
71 }
72
73 fn is_referenceable() -> bool {
74 false
75 }
76
77 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
78 SchemaObject {
79 instance_type: Some(InstanceType::String.into()),
80 extensions: [("contentEncoding", "hex".into())]
81 .into_iter()
82 .map(|(k, v)| (k.to_string(), v))
83 .collect(),
84 ..Default::default()
85 }
86 .into()
87 }
88 }
89};
90
91#[near(serializers = [borsh])]
94#[derive(Debug)]
95pub struct SaltRegistry {
96 previous: IterableMap<Salt, bool, Identity>,
97 current: Salt,
98}
99
100impl SaltRegistry {
101 #[inline]
103 pub fn new<S>(prefix: S) -> Self
104 where
105 S: IntoStorageKey,
106 {
107 Self {
108 previous: IterableMap::with_hasher(prefix),
109 current: Salt::derive(0),
110 }
111 }
112
113 fn derive_next_salt(&self) -> Result<Salt> {
114 (0..=u8::MAX)
115 .map(Salt::derive)
116 .find(|s| !self.is_used(*s))
117 .ok_or(DefuseError::SaltGenerationFailed)
118 }
119
120 #[inline]
122 pub fn set_new(&mut self) -> Result<Salt> {
123 let salt = self.derive_next_salt()?;
124
125 let previous = mem::replace(&mut self.current, salt);
126 self.previous.insert(previous, true);
127
128 Ok(previous)
129 }
130
131 #[inline]
133 pub fn invalidate(&mut self, salt: Salt) -> Result<()> {
134 if salt == self.current {
135 self.set_new()?;
136 }
137
138 self.previous
139 .get_mut(&salt)
140 .map(|v| *v = false)
141 .ok_or(DefuseError::InvalidSalt)
142 }
143
144 #[inline]
145 pub fn is_valid(&self, salt: Salt) -> bool {
146 salt == self.current || self.previous.get(&salt).is_some_and(|v| *v)
147 }
148
149 #[inline]
150 fn is_used(&self, salt: Salt) -> bool {
151 salt == self.current || self.previous.contains_key(&salt)
152 }
153
154 #[inline]
155 pub const fn current(&self) -> Salt {
156 self.current
157 }
158}
159
160#[derive(Clone, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
161#[borsh(crate = "::near_sdk::borsh")]
162pub struct SaltedNonce<T>
163where
164 T: BorshSerialize + BorshDeserialize,
165{
166 pub salt: Salt,
167 pub nonce: T,
168}
169
170impl<T> SaltedNonce<T>
171where
172 T: BorshSerialize + BorshDeserialize,
173{
174 pub const fn new(salt: Salt, nonce: T) -> Self {
175 Self { salt, nonce }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 use arbitrary::Unstructured;
184 use defuse_test_utils::random::{Rng, random_bytes, rng};
185 use near_sdk::{test_utils::VMContextBuilder, testing_env};
186
187 use rstest::rstest;
188
189 impl From<&[u8]> for Salt {
190 fn from(value: &[u8]) -> Self {
191 let mut result = [0u8; 4];
192 result.copy_from_slice(&value[..4]);
193 Self(result)
194 }
195 }
196
197 fn seed_to_salt(seed: &[u8; 32], attempts: u8) -> Salt {
198 let seed = [seed, attempts.to_be_bytes().as_ref()].concat();
199 let hash = sha256_array(&seed);
200
201 hash[..4].into()
202 }
203
204 fn set_random_seed(rng: &mut impl Rng) -> [u8; 32] {
205 let seed = rng.random();
206 let context = VMContextBuilder::new().random_seed(seed).build();
207 testing_env!(context);
208
209 seed
210 }
211
212 #[rstest]
213 fn contains_salt_test(random_bytes: Vec<u8>) {
214 let random_salt: Salt = Unstructured::new(&random_bytes).arbitrary().unwrap();
215 let salts = SaltRegistry::new(random_bytes);
216
217 assert!(salts.is_valid(salts.current));
218 assert!(!salts.is_valid(random_salt));
219 }
220
221 #[rstest]
222 fn update_current_salt_test(random_bytes: Vec<u8>, mut rng: impl Rng) {
223 let mut salts = SaltRegistry::new(random_bytes);
224
225 let seed = set_random_seed(&mut rng);
226 let previous_salt = salts.set_new().expect("should set new salt");
227
228 assert!(salts.is_valid(seed_to_salt(&seed, 0)));
229 assert!(salts.is_valid(previous_salt));
230
231 let previous_salt = salts.set_new().expect("should set new salt");
232 assert!(salts.is_valid(seed_to_salt(&seed, 1)));
233 assert!(salts.is_valid(previous_salt));
234 }
235
236 #[rstest]
237 fn reset_salt_test(random_bytes: Vec<u8>, mut rng: impl Rng) {
238 let mut salts = SaltRegistry::new(random_bytes);
239 let random_salt = rng.random::<[u8; 4]>().as_slice().into();
240
241 let seed = set_random_seed(&mut rng);
242 let current = seed_to_salt(&seed, 0);
243 let previous_salt = salts.set_new().expect("should set new salt");
244
245 assert!(salts.invalidate(previous_salt).is_ok());
246 assert!(!salts.is_valid(previous_salt));
247 assert!(matches!(
248 salts.invalidate(random_salt).unwrap_err(),
249 DefuseError::InvalidSalt
250 ));
251
252 let seed = set_random_seed(&mut rng);
253 let new_salt = seed_to_salt(&seed, 0);
254
255 assert!(salts.invalidate(current).is_ok());
256 assert!(!salts.is_valid(current));
257 assert_eq!(salts.current(), new_salt);
258 }
259
260 #[rstest]
261 fn derive_next_test(random_bytes: Vec<u8>) {
262 let mut salt_registry = SaltRegistry::new(random_bytes);
263
264 let prev = salt_registry.set_new().unwrap();
265
266 salt_registry.invalidate(prev).unwrap();
267 salt_registry.set_new().unwrap();
268
269 assert!(!salt_registry.is_valid(prev));
270 assert!(salt_registry.is_used(prev));
271 }
272}