dryoc/
protected.rs

1//! # Memory protection utilities
2//!
3//! Provides access to the memory locking system calls, such as `mlock()` and
4//! `mprotect()` on UNIX-like systems, `VirtualLock()` and `VirtualProtect()` on
5//! Windows. Similar to libsodium's `sodium_mlock` and `sodium_mprotect_*`
6//! functions.
7//!
8//! On Linux, sets `MADV_DONTDUMP` with `madvise()` on locked regions.
9//!
10//! The protected memory features leverage Rust's [`Allocator`] API, which
11//! requires nightly Rust. This crate must be built with the `nightly` feature
12//! flag enabled to activate these features.
13//!
14//! For details on the [`Allocator`] API, see:
15//! <https://github.com/rust-lang/rust/issues/32838>
16//!
17//! If the `serde` feature is enabled, the [`serde::Deserialize`] and
18//! [`serde::Serialize`] traits will be implemented for [`HeapBytes`] and
19//! [`HeapByteArray`].
20//!
21//! ## Example
22//!
23//! ```
24//! use dryoc::protected::*;
25//!
26//! // Create a read-only, locked region of memory
27//! let readonly_locked = HeapBytes::from_slice_into_readonly_locked(b"some locked bytes")
28//!     .expect("failed to get locked bytes");
29//!
30//! // ... now do stuff with `readonly_locked` ...
31//! println!("{:?}", readonly_locked.as_slice());
32//! ```
33//!
34//! ## Protection features
35//!
36//! The type safe API uses traits to guard against misuse of protected memory.
37//! For example, memory that is set as read-only can be accessed with immutable
38//! accessors (such as `.as_slice()` or `.as_array()`), but not with mutable
39//! accessors like `.as_mut_slice()` or `.as_mut_array()`.
40//!
41//! ```compile_fail
42//! use dryoc::protected::*;
43//!
44//! // Create a read-only, locked region of memory
45//! let readonly_locked = HeapBytes::from_slice_into_readonly_locked(b"some locked bytes")
46//!     .expect("failed to get locked bytes");
47//!
48//! // Try to access the memory mutably
49//! println!("{:?}", readonly_locked.as_mut_slice()); // fails to compile, cannot access mutably
50//! ```
51//!
52//! Memory that has been protected as read-only or no-access will cause the
53//! process to crash if you attempt to access the memory improperly. To test
54//! this, try the following code (which requires an `unsafe` block):
55//!
56//! ```should_panic
57//! use dryoc::protected::*;
58//!
59//! // Create a read-only, locked region of memory
60//! let readonly_locked = HeapBytes::from_slice_into_readonly_locked(b"some locked bytes")
61//!     .expect("failed to get locked bytes");
62//!
63//! // Write to a protected region of memory, causing a crash.
64//! unsafe {
65//!     std::ptr::write(readonly_locked.as_slice().as_ptr() as *mut u8, 0) // <- crash happens here
66//! };
67//! ```
68//!
69//! Running the code above produces as `signal: 10, SIGBUS: access to undefined
70//! memory` panic.
71use std::alloc::{AllocError, Allocator, Layout};
72use std::marker::PhantomData;
73use std::ptr;
74
75use lazy_static::lazy_static;
76use zeroize::{Zeroize, ZeroizeOnDrop};
77
78use crate::error;
79use crate::rng::copy_randombytes;
80pub use crate::types::*;
81
82mod int {
83    #[derive(Clone, Debug, PartialEq, Eq)]
84    pub(super) enum LockMode {
85        Locked,
86        Unlocked,
87    }
88
89    #[derive(Clone, Debug, PartialEq, Eq)]
90    pub(super) enum ProtectMode {
91        ReadOnly,
92        ReadWrite,
93        NoAccess,
94    }
95
96    #[derive(Clone)]
97    pub(super) struct InternalData<A> {
98        pub(super) a: A,
99        pub(super) lm: LockMode,
100        pub(super) pm: ProtectMode,
101    }
102}
103
104#[doc(hidden)] // Edit this PR to remove doc(hidden) or add a doc comment.
105pub mod traits {
106    pub trait ProtectMode {}
107    pub struct ReadOnly {}
108    pub struct ReadWrite {}
109    pub struct NoAccess {}
110
111    impl ProtectMode for ReadOnly {}
112    impl ProtectMode for ReadWrite {}
113    impl ProtectMode for NoAccess {}
114
115    pub trait LockMode {}
116    pub struct Locked {}
117    pub struct Unlocked {}
118    impl LockMode for Locked {}
119    impl LockMode for Unlocked {}
120}
121
122/// A region of memory that can be locked, but is not yet protected. In order to
123/// lock the memory, it may require making a copy.
124pub trait Lockable<A: Zeroize + Bytes> {
125    /// Consumes `self`, creates a new protected region of memory, and returns
126    /// the result in a heap-allocated, page-aligned region of memory. The
127    /// memory is locked with `mlock()` on UNIX, or `VirtualLock()` on
128    /// Windows. By default, the protect mode is set to ReadWrite (i.e., no
129    /// exec) using `mprotect()` on UNIX, or `VirtualProtect()` on Windows.
130    /// On Linux, it will also set `MADV_DONTDUMP` using `madvise()`.
131    fn mlock(self) -> Result<Protected<A, traits::ReadWrite, traits::Locked>, std::io::Error>;
132}
133
134/// Protected region of memory that can be locked.
135pub trait Lock<A: Zeroize + Bytes, PM: traits::ProtectMode> {
136    /// Locks a region of memory, using `mlock()` on UNIX, or `VirtualLock()` on
137    /// Windows. By default, the protect mode is set to ReadWrite (i.e., no
138    /// exec) using `mprotect()` on UNIX, or `VirtualProtect()` on Windows.
139    /// On Linux, it will also set `MADV_DONTDUMP` using `madvise()`.
140    fn mlock(self) -> Result<Protected<A, PM, traits::Locked>, std::io::Error>;
141}
142
143/// Protected region of memory that can be locked (i.e., is already locked).
144pub trait Unlock<A: Zeroize + Bytes, PM: traits::ProtectMode> {
145    /// Unlocks a region of memory, using `munlock()` on UNIX, or
146    /// `VirtualLock()` on Windows.
147    fn munlock(self) -> Result<Protected<A, PM, traits::Unlocked>, std::io::Error>;
148}
149
150/// Protected region of memory that can be set as read-only.
151pub trait ProtectReadOnly<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> {
152    /// Protects a region of memory as read-only (and no exec), using
153    /// `mprotect()` on UNIX, or `VirtualProtect()` on Windows.
154    fn mprotect_readonly(self) -> Result<Protected<A, traits::ReadOnly, LM>, std::io::Error>;
155}
156
157/// Protected region of memory that can be set as read-write.
158pub trait ProtectReadWrite<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> {
159    /// Protects a region of memory as read-write (and no exec), using
160    /// `mprotect()` on UNIX, or `VirtualProtect()` on Windows.
161    fn mprotect_readwrite(self) -> Result<Protected<A, traits::ReadWrite, LM>, std::io::Error>;
162}
163
164/// Protected region of memory that can be set as no-access. Must be unlocked.
165pub trait ProtectNoAccess<A: Zeroize + Bytes, PM: traits::ProtectMode> {
166    /// Protects an unlocked region of memory as no-access (and no exec), using
167    /// `mprotect()` on UNIX, or `VirtualProtect()` on Windows.
168    fn mprotect_noaccess(
169        self,
170    ) -> Result<Protected<A, traits::NoAccess, traits::Unlocked>, std::io::Error>;
171}
172
173/// Bytes which can be allocated and protected.
174pub trait NewLocked<A: Zeroize + NewBytes + Lockable<A>> {
175    /// Returns a new locked byte array.
176    fn new_locked() -> Result<Protected<A, traits::ReadWrite, traits::Locked>, std::io::Error>;
177    /// Returns a new locked byte array.
178    fn new_readonly_locked()
179    -> Result<Protected<A, traits::ReadOnly, traits::Locked>, std::io::Error>;
180    /// Returns a new locked byte array, filled with random data.
181    fn gen_locked() -> Result<Protected<A, traits::ReadWrite, traits::Locked>, std::io::Error>;
182    /// Returns a new read-only, locked byte array, filled with random data.
183    fn gen_readonly_locked()
184    -> Result<Protected<A, traits::ReadOnly, traits::Locked>, std::io::Error>;
185}
186
187/// Create a new region of protected memory from a slice.
188pub trait NewLockedFromSlice<A: Zeroize + NewBytes + Lockable<A>> {
189    /// Returns a new locked region of memory from `src`.
190    fn from_slice_into_locked(
191        src: &[u8],
192    ) -> Result<Protected<A, traits::ReadWrite, traits::Locked>, crate::error::Error>;
193    /// Returns a new read-only locked region of memory from `src`.
194    fn from_slice_into_readonly_locked(
195        src: &[u8],
196    ) -> Result<Protected<A, traits::ReadOnly, traits::Locked>, crate::error::Error>;
197}
198
199/// Holds Protected region of memory. Does not implement traits such as
200/// [Copy], [Clone], or [std::fmt::Debug].
201pub struct Protected<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> {
202    i: Option<int::InternalData<A>>,
203    p: PhantomData<PM>,
204    l: PhantomData<LM>,
205}
206
207/// Short-hand type aliases for protected types.
208pub mod ptypes {
209    /// Locked, read-write, page-aligned memory region type alias
210    pub type Locked<T> = super::Protected<T, super::traits::ReadWrite, super::traits::Locked>;
211    /// Locked, read-only, page-aligned memory region type alias
212    pub type LockedRO<T> = super::Protected<T, super::traits::ReadOnly, super::traits::Locked>;
213    /// Unlocked, no-access, page-aligned memory region type alias
214    pub type NoAccess<T> = super::Protected<T, super::traits::NoAccess, super::traits::Unlocked>;
215    /// Unlocked, read-write, page-aligned memory region type alias
216    pub type Unlocked<T> = super::Protected<T, super::traits::ReadWrite, super::traits::Unlocked>;
217    /// Unlocked, read-only, page-aligned memory region type alias
218    pub type UnlockedRO<T> = super::Protected<T, super::traits::ReadOnly, super::traits::Unlocked>;
219    /// Locked, read-write, page-aligned bytes type alias
220    pub type LockedBytes = Locked<super::HeapBytes>;
221}
222
223impl<T: Zeroize + NewBytes + ResizableBytes + Lockable<T> + NewLocked<T>> Clone for Locked<T> {
224    fn clone(&self) -> Self {
225        let mut cloned = T::new_locked().expect("unable to create new locked instance");
226        cloned.resize(self.len(), 0);
227        cloned.as_mut_slice().copy_from_slice(self.as_slice());
228        cloned
229    }
230}
231
232impl<T: Zeroize + NewBytes + ResizableBytes + Lockable<T> + NewLocked<T>> Clone for LockedRO<T> {
233    fn clone(&self) -> Self {
234        let mut cloned = T::new_locked().expect("unable to create new locked instance");
235        cloned.resize(self.len(), 0);
236        cloned.as_mut_slice().copy_from_slice(self.as_slice());
237        cloned
238            .mprotect_readonly()
239            .expect("unable to protect readonly")
240    }
241}
242
243impl<T: Zeroize + Bytes + Clone> Clone for Unlocked<T> {
244    fn clone(&self) -> Self {
245        Self::new_with(self.i.as_ref().unwrap().a.clone())
246    }
247}
248
249impl<T: Zeroize + NewBytes + Clone> Clone for UnlockedRO<T> {
250    fn clone(&self) -> Self {
251        Unlocked::<T>::new_with(self.i.as_ref().unwrap().a.clone())
252            .mprotect_readonly()
253            .expect("unable to create new readonly instance")
254    }
255}
256
257pub use ptypes::*;
258
259fn dryoc_mlock(data: &[u8]) -> Result<(), std::io::Error> {
260    if data.is_empty() {
261        // no-op
262        return Ok(());
263    }
264    #[cfg(unix)]
265    {
266        #[cfg(target_os = "linux")]
267        {
268            // tell the kernel not to include this memory in a core dump
269            use libc::{MADV_DONTDUMP, madvise};
270            unsafe {
271                madvise(data.as_ptr() as *mut c_void, data.len(), MADV_DONTDUMP);
272            }
273        }
274
275        use libc::{c_void, mlock as c_mlock};
276        let ret = unsafe { c_mlock(data.as_ptr() as *const c_void, data.len()) };
277        match ret {
278            0 => Ok(()),
279            _ => Err(std::io::Error::last_os_error()),
280        }
281    }
282    #[cfg(windows)]
283    {
284        use winapi::shared::minwindef::LPVOID;
285        use winapi::um::memoryapi::VirtualLock;
286
287        let res = unsafe { VirtualLock(data.as_ptr() as LPVOID, data.len()) };
288        match res {
289            1 => Ok(()),
290            _ => Err(std::io::Error::last_os_error()),
291        }
292    }
293}
294
295fn dryoc_munlock(data: &[u8]) -> Result<(), std::io::Error> {
296    if data.is_empty() {
297        // no-op
298        return Ok(());
299    }
300    #[cfg(unix)]
301    {
302        #[cfg(target_os = "linux")]
303        {
304            // undo MADV_DONTDUMP
305            use libc::{MADV_DODUMP, madvise};
306            unsafe {
307                madvise(data.as_ptr() as *mut c_void, data.len(), MADV_DODUMP);
308            }
309        }
310
311        use libc::{c_void, munlock as c_munlock};
312        let ret = unsafe { c_munlock(data.as_ptr() as *const c_void, data.len()) };
313        match ret {
314            0 => Ok(()),
315            _ => Err(std::io::Error::last_os_error()),
316        }
317    }
318    #[cfg(windows)]
319    {
320        use winapi::shared::minwindef::LPVOID;
321        use winapi::um::memoryapi::VirtualUnlock;
322
323        let res = unsafe { VirtualUnlock(data.as_ptr() as LPVOID, data.len()) };
324        match res {
325            1 => Ok(()),
326            _ => Err(std::io::Error::last_os_error()),
327        }
328    }
329}
330
331fn dryoc_mprotect_readonly(data: &[u8]) -> Result<(), std::io::Error> {
332    if data.is_empty() {
333        // no-op
334        return Ok(());
335    }
336    #[cfg(unix)]
337    {
338        use libc::{PROT_READ, c_void, mprotect as c_mprotect};
339        let ret = unsafe { c_mprotect(data.as_ptr() as *mut c_void, data.len() - 1, PROT_READ) };
340        match ret {
341            0 => Ok(()),
342            _ => Err(std::io::Error::last_os_error()),
343        }
344    }
345    #[cfg(windows)]
346    {
347        use winapi::shared::minwindef::{DWORD, LPVOID};
348        use winapi::um::memoryapi::VirtualProtect;
349        use winapi::um::winnt::PAGE_READONLY;
350
351        let mut old: DWORD = 0;
352
353        let res = unsafe {
354            VirtualProtect(
355                data.as_ptr() as LPVOID,
356                data.len() - 1,
357                PAGE_READONLY,
358                &mut old,
359            )
360        };
361        match res {
362            1 => Ok(()),
363            _ => Err(std::io::Error::last_os_error()),
364        }
365    }
366}
367
368fn dryoc_mprotect_readwrite(data: &[u8]) -> Result<(), std::io::Error> {
369    if data.is_empty() {
370        // no-op
371        return Ok(());
372    }
373    #[cfg(unix)]
374    {
375        use libc::{PROT_READ, PROT_WRITE, c_void, mprotect as c_mprotect};
376        let ret = unsafe {
377            c_mprotect(
378                data.as_ptr() as *mut c_void,
379                data.len() - 1,
380                PROT_READ | PROT_WRITE,
381            )
382        };
383        match ret {
384            0 => Ok(()),
385            _ => Err(std::io::Error::last_os_error()),
386        }
387    }
388    #[cfg(windows)]
389    {
390        use winapi::shared::minwindef::{DWORD, LPVOID};
391        use winapi::um::memoryapi::VirtualProtect;
392        use winapi::um::winnt::PAGE_READWRITE;
393
394        let mut old: DWORD = 0;
395
396        let res = unsafe {
397            VirtualProtect(
398                data.as_ptr() as LPVOID,
399                data.len() - 1,
400                PAGE_READWRITE,
401                &mut old,
402            )
403        };
404        match res {
405            1 => Ok(()),
406            _ => Err(std::io::Error::last_os_error()),
407        }
408    }
409}
410
411fn dryoc_mprotect_noaccess(data: &[u8]) -> Result<(), std::io::Error> {
412    if data.is_empty() {
413        // no-op
414        return Ok(());
415    }
416    #[cfg(unix)]
417    {
418        use libc::{PROT_NONE, c_void, mprotect as c_mprotect};
419        let ret = unsafe { c_mprotect(data.as_ptr() as *mut c_void, data.len() - 1, PROT_NONE) };
420        match ret {
421            0 => Ok(()),
422            _ => Err(std::io::Error::last_os_error()),
423        }
424    }
425    #[cfg(windows)]
426    {
427        use winapi::shared::minwindef::{DWORD, LPVOID};
428        use winapi::um::memoryapi::VirtualProtect;
429        use winapi::um::winnt::PAGE_NOACCESS;
430
431        let mut old: DWORD = 0;
432
433        let res = unsafe {
434            VirtualProtect(
435                data.as_ptr() as LPVOID,
436                data.len() - 1,
437                PAGE_NOACCESS,
438                &mut old,
439            )
440        };
441        match res {
442            1 => Ok(()),
443            _ => Err(std::io::Error::last_os_error()),
444        }
445    }
446}
447
448impl<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> Protected<A, PM, LM> {
449    fn new() -> Self {
450        Self {
451            i: None,
452            p: PhantomData,
453            l: PhantomData,
454        }
455    }
456
457    fn new_with(a: A) -> Self {
458        Self {
459            i: Some(int::InternalData {
460                a,
461                lm: int::LockMode::Unlocked,
462                pm: int::ProtectMode::ReadWrite,
463            }),
464            p: PhantomData,
465            l: PhantomData,
466        }
467    }
468
469    fn swap_some_or_err<F, OPM: traits::ProtectMode, OLM: traits::LockMode>(
470        &mut self,
471        f: F,
472    ) -> Result<Protected<A, OPM, OLM>, std::io::Error>
473    where
474        F: Fn(&mut int::InternalData<A>) -> Result<Protected<A, OPM, OLM>, std::io::Error>,
475    {
476        match &mut self.i {
477            Some(d) => {
478                let mut new = f(d)?;
479                // swap into new struct
480                std::mem::swap(&mut new.i, &mut self.i);
481                Ok(new)
482            }
483            _ => Err(std::io::Error::new(
484                std::io::ErrorKind::InvalidData,
485                "unexpected empty internal struct",
486            )),
487        }
488    }
489}
490
491impl<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> Unlock<A, PM>
492    for Protected<A, PM, LM>
493{
494    fn munlock(mut self) -> Result<Protected<A, PM, traits::Unlocked>, std::io::Error> {
495        self.swap_some_or_err(|old| {
496            dryoc_munlock(old.a.as_slice())?;
497            // update internal state
498            old.lm = int::LockMode::Unlocked;
499            Ok(Protected::<A, PM, traits::Unlocked>::new())
500        })
501    }
502}
503
504impl<A: Zeroize + Bytes + Default, PM: traits::ProtectMode> Lock<A, PM>
505    for Protected<A, PM, traits::Unlocked>
506{
507    fn mlock(mut self) -> Result<Protected<A, PM, traits::Locked>, std::io::Error> {
508        self.swap_some_or_err(|old| {
509            dryoc_mlock(old.a.as_slice())?;
510            // update internal state
511            old.lm = int::LockMode::Locked;
512            Ok(Protected::<A, PM, traits::Locked>::new())
513        })
514    }
515}
516
517impl<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> ProtectReadOnly<A, PM, LM>
518    for Protected<A, PM, LM>
519{
520    fn mprotect_readonly(mut self) -> Result<Protected<A, traits::ReadOnly, LM>, std::io::Error> {
521        self.swap_some_or_err(|old| {
522            dryoc_mprotect_readonly(old.a.as_slice())?;
523            // update internal state
524            old.pm = int::ProtectMode::ReadOnly;
525            Ok(Protected::<A, traits::ReadOnly, LM>::new())
526        })
527    }
528}
529
530impl<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> ProtectReadWrite<A, PM, LM>
531    for Protected<A, PM, LM>
532{
533    fn mprotect_readwrite(mut self) -> Result<Protected<A, traits::ReadWrite, LM>, std::io::Error> {
534        self.swap_some_or_err(|old| {
535            dryoc_mprotect_readwrite(old.a.as_slice())?;
536            // update internal state
537            old.pm = int::ProtectMode::ReadWrite;
538            Ok(Protected::<A, traits::ReadWrite, LM>::new())
539        })
540    }
541}
542
543impl<A: Zeroize + Bytes, PM: traits::ProtectMode> ProtectNoAccess<A, PM>
544    for Protected<A, PM, traits::Unlocked>
545{
546    fn mprotect_noaccess(
547        mut self,
548    ) -> Result<Protected<A, traits::NoAccess, traits::Unlocked>, std::io::Error> {
549        self.swap_some_or_err(|old| {
550            dryoc_mprotect_noaccess(old.a.as_slice())?;
551            // update internal state
552            old.pm = int::ProtectMode::NoAccess;
553            Ok(Protected::<A, traits::NoAccess, traits::Unlocked>::new())
554        })
555    }
556}
557
558impl<A: Zeroize + Bytes + AsRef<[u8]>, LM: traits::LockMode> AsRef<[u8]>
559    for Protected<A, traits::ReadOnly, LM>
560{
561    fn as_ref(&self) -> &[u8] {
562        self.i.as_ref().unwrap().a.as_ref()
563    }
564}
565
566impl<A: Zeroize + Bytes + AsRef<[u8]>, LM: traits::LockMode> AsRef<[u8]>
567    for Protected<A, traits::ReadWrite, LM>
568{
569    fn as_ref(&self) -> &[u8] {
570        self.i.as_ref().unwrap().a.as_ref()
571    }
572}
573
574impl<A: Zeroize + MutBytes + AsMut<[u8]>, LM: traits::LockMode> AsMut<[u8]>
575    for Protected<A, traits::ReadWrite, LM>
576{
577    fn as_mut(&mut self) -> &mut [u8] {
578        self.i.as_mut().unwrap().a.as_mut()
579    }
580}
581
582impl<A: Zeroize + Bytes, LM: traits::LockMode> Bytes for Protected<A, traits::ReadOnly, LM> {
583    #[inline]
584    fn as_slice(&self) -> &[u8] {
585        self.i.as_ref().unwrap().a.as_slice()
586    }
587
588    #[inline]
589    fn len(&self) -> usize {
590        self.i.as_ref().unwrap().a.len()
591    }
592
593    #[inline]
594    fn is_empty(&self) -> bool {
595        self.i.as_ref().unwrap().a.is_empty()
596    }
597}
598
599impl<A: Zeroize + Bytes, LM: traits::LockMode> Bytes for Protected<A, traits::ReadWrite, LM> {
600    #[inline]
601    fn as_slice(&self) -> &[u8] {
602        self.i.as_ref().unwrap().a.as_slice()
603    }
604
605    #[inline]
606    fn len(&self) -> usize {
607        self.i.as_ref().unwrap().a.len()
608    }
609
610    #[inline]
611    fn is_empty(&self) -> bool {
612        self.i.as_ref().unwrap().a.is_empty()
613    }
614}
615
616impl<const LENGTH: usize> From<StackByteArray<LENGTH>> for HeapByteArray<LENGTH> {
617    fn from(other: StackByteArray<LENGTH>) -> Self {
618        let mut r = HeapByteArray::<LENGTH>::new_byte_array();
619        let mut s = other;
620        r.copy_from_slice(s.as_slice());
621        s.zeroize();
622        r
623    }
624}
625
626impl<const LENGTH: usize> StackByteArray<LENGTH> {
627    /// Locks a [StackByteArray], consuming it, and returning a [Protected]
628    /// wrapper.
629    pub fn mlock(
630        self,
631    ) -> Result<Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Locked>, std::io::Error>
632    {
633        Protected::<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Unlocked>::new_with(
634            self.into(),
635        )
636        .mlock()
637    }
638}
639
640impl<const LENGTH: usize> StackByteArray<LENGTH> {
641    /// Returns a readonly protected [StackByteArray].
642    pub fn mprotect_readonly(
643        self,
644    ) -> Result<Protected<HeapByteArray<LENGTH>, traits::ReadOnly, traits::Unlocked>, std::io::Error>
645    {
646        Protected::<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Unlocked>::new_with(
647            self.into(),
648        )
649        .mprotect_readonly()
650    }
651}
652
653impl<const LENGTH: usize> Lockable<HeapByteArray<LENGTH>> for HeapByteArray<LENGTH> {
654    /// Locks a [HeapByteArray], and returns a [Protected] wrapper.
655    fn mlock(
656        self,
657    ) -> Result<Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Locked>, std::io::Error>
658    {
659        Protected::<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Unlocked>::new_with(self)
660            .mlock()
661    }
662}
663
664impl Lockable<HeapBytes> for HeapBytes {
665    /// Locks a [HeapBytes], and returns a [Protected] wrapper.
666    fn mlock(
667        self,
668    ) -> Result<Protected<HeapBytes, traits::ReadWrite, traits::Locked>, std::io::Error> {
669        Protected::<HeapBytes, traits::ReadWrite, traits::Unlocked>::new_with(self).mlock()
670    }
671}
672
673#[derive(Clone)]
674/// Custom page-aligned allocator implementation. Creates blocks of page-aligned
675/// heap-allocated memory regions, with no-access pages before and after the
676/// allocated region of memory.
677pub struct PageAlignedAllocator;
678
679lazy_static! {
680    static ref PAGESIZE: usize = {
681        #[cfg(unix)]
682        {
683            use libc::{_SC_PAGE_SIZE, sysconf};
684            unsafe { sysconf(_SC_PAGE_SIZE) as usize }
685        }
686        #[cfg(windows)]
687        {
688            use winapi::um::sysinfoapi::{GetSystemInfo, SYSTEM_INFO};
689            let mut si = SYSTEM_INFO::default();
690            unsafe { GetSystemInfo(&mut si) };
691            si.dwPageSize as usize
692        }
693    };
694}
695
696fn _page_round(size: usize, pagesize: usize) -> usize {
697    size + (pagesize - size % pagesize)
698}
699
700unsafe impl Allocator for PageAlignedAllocator {
701    #[inline]
702    fn allocate(&self, layout: Layout) -> Result<ptr::NonNull<[u8]>, AllocError> {
703        let pagesize = *PAGESIZE;
704        let size = _page_round(layout.size(), pagesize) + 2 * pagesize;
705        #[cfg(unix)]
706        let out = {
707            use libc::posix_memalign;
708            let mut out = ptr::null_mut();
709
710            // allocate full pages, in addition to an extra page at the start and
711            // end which will remain locked with no access permitted.
712            let ret = unsafe { posix_memalign(&mut out, pagesize, size) };
713            if ret != 0 {
714                return Err(AllocError);
715            }
716
717            out
718        };
719        #[cfg(windows)]
720        let out = {
721            use winapi::um::memoryapi::VirtualAlloc;
722            use winapi::um::winnt::{MEM_COMMIT, MEM_RESERVE, PAGE_READWRITE};
723            unsafe {
724                VirtualAlloc(
725                    ptr::null_mut(),
726                    size,
727                    MEM_COMMIT | MEM_RESERVE,
728                    PAGE_READWRITE,
729                )
730            }
731        };
732
733        // lock the pages at the fore of the region
734        let fore_protected_region =
735            unsafe { std::slice::from_raw_parts_mut(out as *mut u8, pagesize) };
736        dryoc_mprotect_noaccess(fore_protected_region)
737            .map_err(|err| eprintln!("mprotect error = {:?}, in allocator", err))
738            .ok();
739
740        // lock the pages at the aft of the region
741        let aft_protected_region_offset = pagesize + _page_round(layout.size(), pagesize);
742        let aft_protected_region = unsafe {
743            std::slice::from_raw_parts_mut(
744                out.add(aft_protected_region_offset) as *mut u8,
745                pagesize,
746            )
747        };
748        dryoc_mprotect_noaccess(aft_protected_region)
749            .map_err(|err| eprintln!("mprotect error = {:?}, in allocator", err))
750            .ok();
751
752        let slice =
753            unsafe { std::slice::from_raw_parts_mut(out.add(pagesize) as *mut u8, layout.size()) };
754
755        dryoc_mprotect_readwrite(slice)
756            .map_err(|err| eprintln!("mprotect error = {:?}, in allocator", err))
757            .ok();
758
759        unsafe { Ok(ptr::NonNull::new_unchecked(slice)) }
760    }
761
762    #[inline]
763    unsafe fn deallocate(&self, ptr: ptr::NonNull<u8>, layout: Layout) {
764        let pagesize = *PAGESIZE;
765
766        let ptr = ptr.as_ptr().offset(-(pagesize as isize));
767
768        // unlock the fore protected region
769        let fore_protected_region = std::slice::from_raw_parts_mut(ptr, pagesize);
770        dryoc_mprotect_readwrite(fore_protected_region)
771            .map_err(|err| eprintln!("mprotect error = {:?}", err))
772            .ok();
773
774        // unlock the aft protected region
775        let aft_protected_region_offset = pagesize + _page_round(layout.size(), pagesize);
776        let aft_protected_region =
777            std::slice::from_raw_parts_mut(ptr.add(aft_protected_region_offset), pagesize);
778
779        dryoc_mprotect_readwrite(aft_protected_region)
780            .map_err(|err| eprintln!("mprotect error = {:?}", err))
781            .ok();
782
783        #[cfg(unix)]
784        {
785            libc::free(ptr as *mut libc::c_void);
786        }
787        #[cfg(windows)]
788        {
789            use winapi::shared::minwindef::LPVOID;
790            use winapi::um::memoryapi::VirtualFree;
791            use winapi::um::winnt::MEM_RELEASE;
792            VirtualFree(ptr as LPVOID, 0, MEM_RELEASE);
793        }
794    }
795}
796
797/// Provides a heap-allocated, fixed-length, page-aligned memory region.
798///
799/// This struct provides a heap-allocated fixed-length byte array, using the
800/// [page-aligned allocator](PageAlignedAllocator). Required for working with
801/// protected memory regions. Wraps a [`Vec`] with custom [`Allocator`]
802/// implementation.
803#[derive(Zeroize, ZeroizeOnDrop, Debug, PartialEq, Eq, Clone)]
804pub struct HeapByteArray<const LENGTH: usize>(Vec<u8, PageAlignedAllocator>);
805
806/// Provides a heap-allocated, resizable memory region.
807///
808/// This struct provides heap-allocated resizable byte array, using the
809/// [page-aligned allocator](PageAlignedAllocator). Required for working with
810/// protected memory regions. Wraps a [`Vec`] with custom [`Allocator`]
811/// implementation.
812#[derive(Zeroize, ZeroizeOnDrop, Debug, PartialEq, Eq, Clone)]
813pub struct HeapBytes(Vec<u8, PageAlignedAllocator>);
814
815impl<A: Zeroize + NewBytes + Lockable<A>> NewLocked<A> for A {
816    fn new_locked() -> Result<Protected<Self, traits::ReadWrite, traits::Locked>, std::io::Error> {
817        Self::new_bytes().mlock()
818    }
819
820    fn new_readonly_locked()
821    -> Result<Protected<Self, traits::ReadOnly, traits::Locked>, std::io::Error> {
822        Self::new_bytes()
823            .mlock()
824            .and_then(|p| p.mprotect_readonly())
825    }
826
827    fn gen_locked() -> Result<Protected<Self, traits::ReadWrite, traits::Locked>, std::io::Error> {
828        let mut res = Self::new_bytes().mlock()?;
829        copy_randombytes(res.as_mut_slice());
830        Ok(res)
831    }
832
833    fn gen_readonly_locked()
834    -> Result<Protected<Self, traits::ReadOnly, traits::Locked>, std::io::Error> {
835        Self::gen_locked().and_then(|s| s.mprotect_readonly())
836    }
837}
838
839impl<A: Zeroize + NewBytes + ResizableBytes + Lockable<A>> NewLockedFromSlice<A> for A {
840    /// Returns a new locked byte array from `other`. Panics if sizes do not
841    /// match.
842    fn from_slice_into_locked(
843        src: &[u8],
844    ) -> Result<Protected<Self, traits::ReadWrite, traits::Locked>, crate::error::Error> {
845        let mut res = Self::new_bytes().mlock()?;
846        res.resize(src.len(), 0);
847        res.as_mut_slice().copy_from_slice(src);
848        Ok(res)
849    }
850
851    /// Returns a new locked byte array from `other`. Panics if sizes do not
852    /// match.
853    fn from_slice_into_readonly_locked(
854        src: &[u8],
855    ) -> Result<Protected<Self, traits::ReadOnly, traits::Locked>, crate::error::Error> {
856        Self::from_slice_into_locked(src)
857            .and_then(|s| s.mprotect_readonly().map_err(|err| err.into()))
858    }
859}
860
861impl<const LENGTH: usize> NewLockedFromSlice<HeapByteArray<LENGTH>> for HeapByteArray<LENGTH> {
862    /// Returns a new locked byte array from `other`. Panics if sizes do not
863    /// match.
864    fn from_slice_into_locked(
865        other: &[u8],
866    ) -> Result<Protected<Self, traits::ReadWrite, traits::Locked>, crate::error::Error> {
867        if other.len() != LENGTH {
868            return Err(dryoc_error!(format!(
869                "slice length {} doesn't match expected {}",
870                other.len(),
871                LENGTH
872            )));
873        }
874        let mut res = Self::new_bytes().mlock()?;
875        res.as_mut_slice().copy_from_slice(other);
876        Ok(res)
877    }
878
879    fn from_slice_into_readonly_locked(
880        other: &[u8],
881    ) -> Result<Protected<Self, traits::ReadOnly, traits::Locked>, crate::error::Error> {
882        Self::from_slice_into_locked(other)
883            .and_then(|s| s.mprotect_readonly().map_err(|err| err.into()))
884    }
885}
886
887impl<const LENGTH: usize> Bytes for HeapByteArray<LENGTH> {
888    #[inline]
889    fn as_slice(&self) -> &[u8] {
890        &self.0
891    }
892
893    #[inline]
894    fn len(&self) -> usize {
895        self.0.len()
896    }
897
898    #[inline]
899    fn is_empty(&self) -> bool {
900        self.0.is_empty()
901    }
902}
903
904impl Bytes for HeapBytes {
905    #[inline]
906    fn as_slice(&self) -> &[u8] {
907        &self.0
908    }
909
910    #[inline]
911    fn len(&self) -> usize {
912        self.0.len()
913    }
914
915    #[inline]
916    fn is_empty(&self) -> bool {
917        self.0.is_empty()
918    }
919}
920
921impl<const LENGTH: usize> MutBytes for HeapByteArray<LENGTH> {
922    #[inline]
923    fn as_mut_slice(&mut self) -> &mut [u8] {
924        self.0.as_mut_slice()
925    }
926
927    fn copy_from_slice(&mut self, other: &[u8]) {
928        self.0.copy_from_slice(other)
929    }
930}
931
932impl NewBytes for HeapBytes {
933    fn new_bytes() -> Self {
934        Self::default()
935    }
936}
937
938impl MutBytes for HeapBytes {
939    #[inline]
940    fn as_mut_slice(&mut self) -> &mut [u8] {
941        self.0.as_mut_slice()
942    }
943
944    fn copy_from_slice(&mut self, other: &[u8]) {
945        self.0.copy_from_slice(other)
946    }
947}
948
949impl ResizableBytes for HeapBytes {
950    fn resize(&mut self, new_len: usize, value: u8) {
951        self.0.resize(new_len, value);
952    }
953}
954
955impl<A: Zeroize + NewBytes + ResizableBytes + Lockable<A>> ResizableBytes
956    for Protected<A, traits::ReadWrite, traits::Locked>
957{
958    fn resize(&mut self, new_len: usize, value: u8) {
959        match &mut self.i {
960            Some(d) => {
961                // because it's locked, we'll do a swaparoo here instead of a plain resize
962                let mut new = A::new_bytes();
963                // resize the new array
964                new.resize(new_len, value);
965                // need to actually lock the memory now, because it was previously locked
966                let mut locked = new.mlock().expect("unable to lock on resize");
967                let len_to_copy = std::cmp::min(new_len, d.a.as_slice().len());
968                locked.i.as_mut().unwrap().a.as_mut_slice()[..len_to_copy]
969                    .copy_from_slice(&d.a.as_slice()[..len_to_copy]);
970                std::mem::swap(&mut locked.i, &mut self.i);
971                // when dropped, the old region will unlock automatically in
972                // Drop
973            }
974            None => panic!("invalid array"),
975        }
976    }
977}
978
979impl<A: Zeroize + NewBytes + ResizableBytes + Lockable<A>> ResizableBytes
980    for Protected<A, traits::ReadWrite, traits::Unlocked>
981{
982    fn resize(&mut self, new_len: usize, value: u8) {
983        match &mut self.i {
984            Some(d) => d.a.resize(new_len, value),
985            None => panic!("invalid array"),
986        }
987    }
988}
989
990impl<A: Zeroize + MutBytes, LM: traits::LockMode> MutBytes for Protected<A, traits::ReadWrite, LM> {
991    #[inline]
992    fn as_mut_slice(&mut self) -> &mut [u8] {
993        match &mut self.i {
994            Some(d) => d.a.as_mut_slice(),
995            None => panic!("invalid array"),
996        }
997    }
998
999    fn copy_from_slice(&mut self, other: &[u8]) {
1000        match &mut self.i {
1001            Some(d) => d.a.copy_from_slice(other),
1002            None => panic!("invalid array"),
1003        }
1004    }
1005}
1006
1007impl<const LENGTH: usize> std::convert::AsRef<[u8; LENGTH]> for HeapByteArray<LENGTH> {
1008    fn as_ref(&self) -> &[u8; LENGTH] {
1009        let arr = self.0.as_ptr() as *const [u8; LENGTH];
1010        unsafe { &*arr }
1011    }
1012}
1013
1014impl<const LENGTH: usize> std::convert::AsMut<[u8; LENGTH]> for HeapByteArray<LENGTH> {
1015    fn as_mut(&mut self) -> &mut [u8; LENGTH] {
1016        let arr = self.0.as_mut_ptr() as *mut [u8; LENGTH];
1017        unsafe { &mut *arr }
1018    }
1019}
1020
1021impl<const LENGTH: usize> std::convert::AsRef<[u8]> for HeapByteArray<LENGTH> {
1022    fn as_ref(&self) -> &[u8] {
1023        self.0.as_ref()
1024    }
1025}
1026
1027impl std::convert::AsRef<[u8]> for HeapBytes {
1028    fn as_ref(&self) -> &[u8] {
1029        self.0.as_ref()
1030    }
1031}
1032
1033impl<const LENGTH: usize> std::convert::AsMut<[u8]> for HeapByteArray<LENGTH> {
1034    fn as_mut(&mut self) -> &mut [u8] {
1035        self.0.as_mut()
1036    }
1037}
1038
1039impl std::convert::AsMut<[u8]> for HeapBytes {
1040    fn as_mut(&mut self) -> &mut [u8] {
1041        self.0.as_mut()
1042    }
1043}
1044
1045impl<const LENGTH: usize> std::ops::Deref for HeapByteArray<LENGTH> {
1046    type Target = [u8];
1047
1048    fn deref(&self) -> &Self::Target {
1049        &self.0
1050    }
1051}
1052
1053impl<const LENGTH: usize> std::ops::DerefMut for HeapByteArray<LENGTH> {
1054    fn deref_mut(&mut self) -> &mut Self::Target {
1055        &mut self.0
1056    }
1057}
1058
1059impl std::ops::Deref for HeapBytes {
1060    type Target = [u8];
1061
1062    fn deref(&self) -> &Self::Target {
1063        &self.0
1064    }
1065}
1066
1067impl std::ops::DerefMut for HeapBytes {
1068    fn deref_mut(&mut self) -> &mut Self::Target {
1069        &mut self.0
1070    }
1071}
1072
1073impl<A: Bytes + Zeroize, LM: traits::LockMode> std::ops::Deref
1074    for Protected<A, traits::ReadOnly, LM>
1075{
1076    type Target = [u8];
1077
1078    fn deref(&self) -> &Self::Target {
1079        self.i.as_ref().unwrap().a.as_slice()
1080    }
1081}
1082
1083impl<A: Bytes + Zeroize, LM: traits::LockMode> std::ops::Deref
1084    for Protected<A, traits::ReadWrite, LM>
1085{
1086    type Target = [u8];
1087
1088    fn deref(&self) -> &Self::Target {
1089        self.i.as_ref().unwrap().a.as_slice()
1090    }
1091}
1092
1093impl<A: MutBytes + Zeroize, LM: traits::LockMode> std::ops::DerefMut
1094    for Protected<A, traits::ReadWrite, LM>
1095{
1096    fn deref_mut(&mut self) -> &mut Self::Target {
1097        self.i.as_mut().unwrap().a.as_mut_slice()
1098    }
1099}
1100
1101impl<const LENGTH: usize> std::ops::Index<usize> for HeapByteArray<LENGTH> {
1102    type Output = u8;
1103
1104    #[inline]
1105    fn index(&self, index: usize) -> &Self::Output {
1106        &self.0[index]
1107    }
1108}
1109impl<const LENGTH: usize> std::ops::IndexMut<usize> for HeapByteArray<LENGTH> {
1110    #[inline]
1111    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
1112        &mut self.0[index]
1113    }
1114}
1115
1116macro_rules! impl_index_heapbytearray {
1117    ($range:ty) => {
1118        impl<const LENGTH: usize> std::ops::Index<$range> for HeapByteArray<LENGTH> {
1119            type Output = [u8];
1120
1121            #[inline]
1122            fn index(&self, index: $range) -> &Self::Output {
1123                &self.0[index]
1124            }
1125        }
1126        impl<const LENGTH: usize> std::ops::IndexMut<$range> for HeapByteArray<LENGTH> {
1127            #[inline]
1128            fn index_mut(&mut self, index: $range) -> &mut Self::Output {
1129                &mut self.0[index]
1130            }
1131        }
1132    };
1133}
1134
1135impl_index_heapbytearray!(std::ops::Range<usize>);
1136impl_index_heapbytearray!(std::ops::RangeFull);
1137impl_index_heapbytearray!(std::ops::RangeFrom<usize>);
1138impl_index_heapbytearray!(std::ops::RangeInclusive<usize>);
1139impl_index_heapbytearray!(std::ops::RangeTo<usize>);
1140impl_index_heapbytearray!(std::ops::RangeToInclusive<usize>);
1141
1142impl<const LENGTH: usize> Default for HeapByteArray<LENGTH> {
1143    fn default() -> Self {
1144        let mut v = Vec::new_in(PageAlignedAllocator);
1145        v.resize(LENGTH, 0);
1146        Self(v)
1147    }
1148}
1149
1150impl<A: Zeroize + NewBytes + Lockable<A> + NewLocked<A>> Default
1151    for Protected<A, traits::ReadWrite, traits::Locked>
1152{
1153    fn default() -> Self {
1154        A::new_locked().expect("mlock failed")
1155    }
1156}
1157
1158impl std::ops::Index<usize> for HeapBytes {
1159    type Output = u8;
1160
1161    #[inline]
1162    fn index(&self, index: usize) -> &Self::Output {
1163        &self.0[index]
1164    }
1165}
1166impl std::ops::IndexMut<usize> for HeapBytes {
1167    #[inline]
1168    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
1169        &mut self.0[index]
1170    }
1171}
1172
1173macro_rules! impl_index_heapbytes {
1174    ($range:ty) => {
1175        impl std::ops::Index<$range> for HeapBytes {
1176            type Output = [u8];
1177
1178            #[inline]
1179            fn index(&self, index: $range) -> &Self::Output {
1180                &self.0[index]
1181            }
1182        }
1183        impl std::ops::IndexMut<$range> for HeapBytes {
1184            #[inline]
1185            fn index_mut(&mut self, index: $range) -> &mut Self::Output {
1186                &mut self.0[index]
1187            }
1188        }
1189    };
1190}
1191
1192impl_index_heapbytes!(std::ops::Range<usize>);
1193impl_index_heapbytes!(std::ops::RangeFull);
1194impl_index_heapbytes!(std::ops::RangeFrom<usize>);
1195impl_index_heapbytes!(std::ops::RangeInclusive<usize>);
1196impl_index_heapbytes!(std::ops::RangeTo<usize>);
1197impl_index_heapbytes!(std::ops::RangeToInclusive<usize>);
1198
1199impl Default for HeapBytes {
1200    fn default() -> Self {
1201        Self(Vec::new_in(PageAlignedAllocator))
1202    }
1203}
1204
1205impl<const LENGTH: usize> From<&[u8; LENGTH]> for HeapByteArray<LENGTH> {
1206    fn from(src: &[u8; LENGTH]) -> Self {
1207        let mut arr = Self::default();
1208        arr.0.copy_from_slice(src);
1209        arr
1210    }
1211}
1212
1213impl<const LENGTH: usize> From<[u8; LENGTH]> for HeapByteArray<LENGTH> {
1214    fn from(mut src: [u8; LENGTH]) -> Self {
1215        let ret = Self::from(&src);
1216        // need to zeroize this input
1217        src.zeroize();
1218        ret
1219    }
1220}
1221
1222impl<const LENGTH: usize> TryFrom<&[u8]> for HeapByteArray<LENGTH> {
1223    type Error = error::Error;
1224
1225    fn try_from(src: &[u8]) -> Result<Self, Self::Error> {
1226        if src.len() != LENGTH {
1227            Err(dryoc_error!(format!(
1228                "Invalid size: expected {} found {}",
1229                LENGTH,
1230                src.len()
1231            )))
1232        } else {
1233            let mut arr = Self::default();
1234            arr.0.copy_from_slice(src);
1235            Ok(arr)
1236        }
1237    }
1238}
1239
1240impl From<&[u8]> for HeapBytes {
1241    fn from(src: &[u8]) -> Self {
1242        let mut arr = Self::default();
1243        arr.0.copy_from_slice(src);
1244        arr
1245    }
1246}
1247
1248impl<const LENGTH: usize> ByteArray<LENGTH> for HeapByteArray<LENGTH> {
1249    #[inline]
1250    fn as_array(&self) -> &[u8; LENGTH] {
1251        // this is safe for fixed-length arrays
1252        let ptr = self.0.as_ptr() as *const [u8; LENGTH];
1253        unsafe { &*ptr }
1254    }
1255}
1256
1257impl<const LENGTH: usize> NewBytes for HeapByteArray<LENGTH> {
1258    fn new_bytes() -> Self {
1259        Self::default()
1260    }
1261}
1262
1263impl NewBytes for Protected<HeapBytes, traits::ReadWrite, traits::Locked> {
1264    fn new_bytes() -> Self {
1265        match HeapBytes::new_locked() {
1266            Ok(r) => r,
1267            Err(err) => panic!("Error creating locked bytes: {:?}", err),
1268        }
1269    }
1270}
1271
1272impl<const LENGTH: usize> NewBytes
1273    for Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Locked>
1274{
1275    fn new_bytes() -> Self {
1276        match HeapByteArray::<LENGTH>::new_locked() {
1277            Ok(r) => r,
1278            Err(err) => panic!("Error creating locked bytes: {:?}", err),
1279        }
1280    }
1281}
1282
1283impl<const LENGTH: usize> NewByteArray<LENGTH>
1284    for Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Locked>
1285{
1286    fn new_byte_array() -> Self {
1287        match HeapByteArray::<LENGTH>::new_locked() {
1288            Ok(r) => r,
1289            Err(err) => panic!("Error creating locked bytes: {:?}", err),
1290        }
1291    }
1292
1293    fn gen() -> Self {
1294        match HeapByteArray::<LENGTH>::new_locked() {
1295            Ok(mut r) => {
1296                copy_randombytes(r.as_mut_slice());
1297                r
1298            }
1299            Err(err) => panic!("Error creating locked bytes: {:?}", err),
1300        }
1301    }
1302}
1303
1304impl<const LENGTH: usize> NewByteArray<LENGTH> for HeapByteArray<LENGTH> {
1305    fn new_byte_array() -> Self {
1306        Self::default()
1307    }
1308
1309    /// Returns a new byte array filled with random data.
1310    fn gen() -> Self {
1311        let mut res = Self::default();
1312        copy_randombytes(&mut res.0);
1313        res
1314    }
1315}
1316
1317impl<const LENGTH: usize> MutByteArray<LENGTH> for HeapByteArray<LENGTH> {
1318    fn as_mut_array(&mut self) -> &mut [u8; LENGTH] {
1319        // this is safe for fixed-length arrays
1320        let ptr = self.0.as_ptr() as *mut [u8; LENGTH];
1321        unsafe { &mut *ptr }
1322    }
1323}
1324
1325impl<const LENGTH: usize> ByteArray<LENGTH>
1326    for Protected<HeapByteArray<LENGTH>, traits::ReadOnly, traits::Unlocked>
1327{
1328    #[inline]
1329    fn as_array(&self) -> &[u8; LENGTH] {
1330        match &self.i {
1331            Some(d) => d.a.as_array(),
1332            None => panic!("invalid array"),
1333        }
1334    }
1335}
1336
1337impl<const LENGTH: usize> ByteArray<LENGTH>
1338    for Protected<HeapByteArray<LENGTH>, traits::ReadOnly, traits::Locked>
1339{
1340    #[inline]
1341    fn as_array(&self) -> &[u8; LENGTH] {
1342        match &self.i {
1343            Some(d) => d.a.as_array(),
1344            None => panic!("invalid array"),
1345        }
1346    }
1347}
1348
1349impl<const LENGTH: usize> ByteArray<LENGTH>
1350    for Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Unlocked>
1351{
1352    #[inline]
1353    fn as_array(&self) -> &[u8; LENGTH] {
1354        match &self.i {
1355            Some(d) => d.a.as_array(),
1356            None => panic!("invalid array"),
1357        }
1358    }
1359}
1360
1361impl<const LENGTH: usize> ByteArray<LENGTH>
1362    for Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Locked>
1363{
1364    #[inline]
1365    fn as_array(&self) -> &[u8; LENGTH] {
1366        match &self.i {
1367            Some(d) => d.a.as_array(),
1368            None => panic!("invalid array"),
1369        }
1370    }
1371}
1372
1373impl<const LENGTH: usize> MutByteArray<LENGTH>
1374    for Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Locked>
1375{
1376    #[inline]
1377    fn as_mut_array(&mut self) -> &mut [u8; LENGTH] {
1378        match &mut self.i {
1379            Some(d) => d.a.as_mut_array(),
1380            None => panic!("invalid array"),
1381        }
1382    }
1383}
1384
1385impl<const LENGTH: usize> MutByteArray<LENGTH>
1386    for Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Unlocked>
1387{
1388    #[inline]
1389    fn as_mut_array(&mut self) -> &mut [u8; LENGTH] {
1390        match &mut self.i {
1391            Some(d) => d.a.as_mut_array(),
1392            None => panic!("invalid array"),
1393        }
1394    }
1395}
1396
1397impl<const LENGTH: usize> AsMut<[u8; LENGTH]>
1398    for Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Locked>
1399{
1400    fn as_mut(&mut self) -> &mut [u8; LENGTH] {
1401        match &mut self.i {
1402            Some(d) => d.a.as_mut(),
1403            None => panic!("invalid array"),
1404        }
1405    }
1406}
1407
1408impl<const LENGTH: usize> AsMut<[u8; LENGTH]>
1409    for Protected<HeapByteArray<LENGTH>, traits::ReadWrite, traits::Unlocked>
1410{
1411    fn as_mut(&mut self) -> &mut [u8; LENGTH] {
1412        match &mut self.i {
1413            Some(d) => d.a.as_mut(),
1414            None => panic!("invalid array"),
1415        }
1416    }
1417}
1418
1419impl<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> Drop
1420    for Protected<A, PM, LM>
1421{
1422    fn drop(&mut self) {
1423        self.zeroize()
1424    }
1425}
1426
1427impl<A: Zeroize + Bytes, PM: traits::ProtectMode, LM: traits::LockMode> Zeroize
1428    for Protected<A, PM, LM>
1429{
1430    fn zeroize(&mut self) {
1431        if let Some(d) = &mut self.i {
1432            if !d.a.as_slice().is_empty() {
1433                if d.pm != int::ProtectMode::ReadWrite {
1434                    dryoc_mprotect_readwrite(d.a.as_slice())
1435                        .map_err(|err| eprintln!("mprotect_readwrite error on drop = {:?}", err))
1436                        .ok();
1437                }
1438                d.a.zeroize();
1439                if d.lm == int::LockMode::Locked {
1440                    dryoc_munlock(d.a.as_slice())
1441                        .map_err(|err| eprintln!("dryoc_munlock error on drop = {:?}", err))
1442                        .ok();
1443                }
1444            }
1445        }
1446    }
1447}
1448
1449#[cfg(test)]
1450mod tests {
1451    use super::*;
1452
1453    #[test]
1454    fn test_lock_unlock() {
1455        use crate::dryocstream::Key;
1456
1457        let key = Key::gen();
1458        let key_clone = key.clone();
1459
1460        let locked_key = key.mlock().expect("lock failed");
1461
1462        let unlocked_key = locked_key.munlock().expect("unlock failed");
1463
1464        assert_eq!(unlocked_key.as_slice(), key_clone.as_slice());
1465    }
1466
1467    #[test]
1468    fn test_protect_unprotect() {
1469        use crate::dryocstream::Key;
1470
1471        let key = Key::gen();
1472        let key_clone = key.clone();
1473
1474        let readonly_key = key.mprotect_readonly().expect("mprotect failed");
1475        assert_eq!(readonly_key.as_slice(), key_clone.as_slice());
1476
1477        let mut readwrite_key = readonly_key.mprotect_readwrite().expect("mprotect failed");
1478        assert_eq!(readwrite_key.as_slice(), key_clone.as_slice());
1479
1480        // should be able to write now without blowing up
1481        readwrite_key.as_mut_slice()[0] = 0;
1482    }
1483
1484    #[test]
1485    fn test_allocator() {
1486        let mut vec: Vec<i32, _> = Vec::new_in(PageAlignedAllocator);
1487
1488        vec.push(1);
1489        vec.push(2);
1490        vec.push(3);
1491
1492        for i in 0..5000 {
1493            vec.push(i);
1494        }
1495
1496        vec.resize(5, 0);
1497
1498        assert_eq!([1, 2, 3, 0, 1], vec.as_slice());
1499    }
1500
1501    // #[test]
1502    // fn test_crash() {
1503    //     use crate::protected::*;
1504
1505    //     // Create a read-only, locked region of memory
1506    //     let readonly_locked =
1507    // HeapBytes::from_slice_into_readonly_locked(b"some locked bytes")
1508    //         .expect("failed to get locked bytes");
1509
1510    //     // Write to a protected region of memory, causing a crash.
1511    //     unsafe {
1512    //         ptr::write(readonly_locked.as_slice().as_ptr() as *mut u8, 0) //
1513    // <- crash happens here     };
1514    // }
1515}