dryoc/
bytes_serde.rs

1use serde::de::{Error, SeqAccess, Visitor};
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3
4use crate::types::*;
5
6impl<const LENGTH: usize> Serialize for StackByteArray<LENGTH> {
7    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
8    where
9        S: Serializer,
10    {
11        serializer.serialize_bytes(self.as_slice())
12    }
13}
14
15impl<'de, const LENGTH: usize> Deserialize<'de> for StackByteArray<LENGTH> {
16    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
17    where
18        D: Deserializer<'de>,
19    {
20        struct ByteArrayVisitor<const LENGTH: usize>;
21
22        impl<'de, const LENGTH: usize> Visitor<'de> for ByteArrayVisitor<LENGTH> {
23            type Value = StackByteArray<LENGTH>;
24
25            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
26                write!(formatter, "bytes")
27            }
28
29            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
30            where
31                A: SeqAccess<'de>,
32            {
33                let mut arr = StackByteArray::<LENGTH>::new();
34                let mut idx: usize = 0;
35
36                while let Some(elem) = seq.next_element()? {
37                    if idx < LENGTH {
38                        arr[idx] = elem;
39                        idx += 1;
40                    } else {
41                        break;
42                    }
43                }
44
45                Ok(arr)
46            }
47
48            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
49            where
50                E: Error,
51            {
52                if v.len() != LENGTH {
53                    return Err(Error::invalid_length(v.len(), &stringify!(LENGTH)));
54                }
55                let mut arr = StackByteArray::<LENGTH>::new();
56                arr.copy_from_slice(v);
57                Ok(arr)
58            }
59        }
60
61        deserializer.deserialize_bytes(ByteArrayVisitor::<LENGTH>)
62    }
63}
64
65#[cfg(any(feature = "nightly", all(doc, not(doctest))))]
66mod protected {
67    use super::*;
68    use crate::protected::*;
69
70    impl<const LENGTH: usize> Serialize for HeapByteArray<LENGTH> {
71        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
72        where
73            S: Serializer,
74        {
75            serializer.serialize_bytes(self.as_slice())
76        }
77    }
78
79    impl<const LENGTH: usize> Serialize for Locked<HeapByteArray<LENGTH>> {
80        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
81        where
82            S: Serializer,
83        {
84            serializer.serialize_bytes(self.as_slice())
85        }
86    }
87
88    impl Serialize for HeapBytes {
89        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
90        where
91            S: Serializer,
92        {
93            serializer.serialize_bytes(self.as_slice())
94        }
95    }
96
97    impl Serialize for LockedBytes {
98        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
99        where
100            S: Serializer,
101        {
102            serializer.serialize_bytes(self.as_slice())
103        }
104    }
105
106    impl Serialize for LockedRO<HeapBytes> {
107        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
108        where
109            S: Serializer,
110        {
111            serializer.serialize_bytes(self.as_slice())
112        }
113    }
114
115    impl<'de> Deserialize<'de> for HeapBytes {
116        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
117        where
118            D: Deserializer<'de>,
119        {
120            struct BytesVisitor;
121
122            impl<'de> Visitor<'de> for BytesVisitor {
123                type Value = HeapBytes;
124
125                fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
126                    write!(formatter, "bytes")
127                }
128
129                fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
130                where
131                    A: SeqAccess<'de>,
132                {
133                    let mut arr = HeapBytes::default();
134                    let mut idx: usize = 0;
135                    let size_hint = seq.size_hint().unwrap_or(1);
136                    arr.resize(size_hint, 0);
137
138                    while let Some(elem) = seq.next_element()? {
139                        if idx > arr.len() {
140                            arr.resize(idx, 0);
141                        }
142                        arr[idx] = elem;
143                        idx += 1;
144                    }
145
146                    Ok(arr)
147                }
148
149                fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
150                where
151                    E: Error,
152                {
153                    Ok(HeapBytes::from(v))
154                }
155            }
156
157            deserializer.deserialize_bytes(BytesVisitor)
158        }
159    }
160
161    impl<'de> Deserialize<'de> for LockedBytes {
162        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
163        where
164            D: Deserializer<'de>,
165        {
166            struct BytesVisitor;
167
168            impl<'de> Visitor<'de> for BytesVisitor {
169                type Value = LockedBytes;
170
171                fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
172                    write!(formatter, "bytes")
173                }
174
175                fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
176                where
177                    A: SeqAccess<'de>,
178                {
179                    let mut arr = HeapBytes::gen_locked().expect("couldn't create locked bytes");
180                    let mut idx: usize = 0;
181                    let size_hint = seq.size_hint().unwrap_or(1);
182                    arr.resize(size_hint, 0);
183
184                    while let Some(elem) = seq.next_element()? {
185                        if idx > arr.len() {
186                            arr.resize(idx, 0);
187                        }
188                        arr[idx] = elem;
189                        idx += 1;
190                    }
191
192                    Ok(arr)
193                }
194
195                fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
196                where
197                    E: Error,
198                {
199                    Ok(HeapBytes::from_slice_into_locked(v)
200                        .expect("couldn't copy slice into locked bytes"))
201                }
202            }
203
204            deserializer.deserialize_bytes(BytesVisitor)
205        }
206    }
207
208    impl<'de, const LENGTH: usize> Deserialize<'de> for Locked<HeapByteArray<LENGTH>> {
209        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
210        where
211            D: Deserializer<'de>,
212        {
213            struct BytesVisitor<const LENGTH: usize>;
214
215            impl<'de, const LENGTH: usize> Visitor<'de> for BytesVisitor<LENGTH> {
216                type Value = Locked<HeapByteArray<LENGTH>>;
217
218                fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
219                    write!(formatter, "bytes")
220                }
221
222                fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
223                where
224                    A: SeqAccess<'de>,
225                {
226                    let mut arr = HeapByteArray::<LENGTH>::gen_locked()
227                        .expect("couldn't create locked bytes");
228                    let mut idx: usize = 0;
229                    let size_hint = seq.size_hint().unwrap_or(0);
230                    if size_hint != LENGTH {
231                        Err(Error::invalid_length(size_hint, &stringify!(LENGTH)))
232                    } else {
233                        while let Some(elem) = seq.next_element()? {
234                            arr[idx] = elem;
235                            idx += 1;
236                        }
237
238                        Ok(arr)
239                    }
240                }
241
242                fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
243                where
244                    E: Error,
245                {
246                    if v.len() != LENGTH {
247                        Err(Error::invalid_length(v.len(), &stringify!(LENGTH)))
248                    } else {
249                        Ok(HeapByteArray::<LENGTH>::from_slice_into_locked(v)
250                            .expect("couldn't copy slice into locked bytes"))
251                    }
252                }
253            }
254
255            deserializer.deserialize_bytes(BytesVisitor)
256        }
257    }
258}