use super::{hash::Digest, password::PasswordSalt};
use argon2::Argon2;
use chacha20::{
cipher::{KeyIvInit, StreamCipher},
ChaCha20,
};
use generic_array::{sequence::GenericSequence, typenum::Unsigned};
use hex;
use rand::{rngs::OsRng, CryptoRng, Rng};
use serde::{de::Error as _, Deserialize, Deserializer, Serialize, Serializer};
use std::{fmt, sync::Arc};
use subtle::ConstantTimeEq;
use thiserror::Error;
use zeroize::{Zeroize, Zeroizing};
pub(crate) type Nonce = [u8; NONCE_SIZE];
pub(crate) const NONCE_SIZE: usize =
<<chacha20::Nonce as GenericSequence<_>>::Length as Unsigned>::USIZE;
#[derive(Clone)]
pub struct SecretKey(Arc<Zeroizing<[u8; Self::SIZE]>>);
impl SecretKey {
pub const SIZE: usize = <<chacha20::Key as GenericSequence<_>>::Length as Unsigned>::USIZE;
pub fn parse_hex(hex_str: &str) -> Result<Self, hex::FromHexError> {
let mut bytes = [0; Self::SIZE];
hex::decode_to_slice(hex_str, &mut bytes)?;
let mut key = Self::zero();
key.as_mut().copy_from_slice(&bytes);
bytes.zeroize();
Ok(key)
}
pub fn generate<R: Rng + CryptoRng + ?Sized>(rng: &mut R) -> Self {
let mut key = Self::zero();
rng.fill(key.as_mut());
key
}
pub fn random() -> Self {
Self::generate(&mut OsRng)
}
pub fn derive_from_key(master_key: &[u8; Self::SIZE], nonce: &[u8]) -> Self {
let mut sub_key = Self::zero();
let mut hasher = blake3::Hasher::new_keyed(master_key);
hasher.update(nonce);
hasher.finalize_into(sub_key.as_mut().into());
sub_key
}
pub fn random_salt() -> PasswordSalt {
OsRng.gen()
}
pub fn derive_from_password(user_password: &str, salt: &PasswordSalt) -> Self {
let mut result = Self::zero();
Argon2::default()
.hash_password_into(user_password.as_ref(), salt.as_ref(), result.as_mut())
.expect("failed to hash password");
result
}
pub(crate) fn encrypt_no_aead(&self, nonce: &Nonce, buffer: &mut [u8]) {
let mut cipher = ChaCha20::new(self.as_ref().into(), nonce.into());
cipher.apply_keystream(buffer)
}
pub(crate) fn decrypt_no_aead(&self, nonce: &Nonce, buffer: &mut [u8]) {
let mut cipher = ChaCha20::new(self.as_ref().into(), nonce.into());
cipher.apply_keystream(buffer)
}
pub fn as_array(&self) -> &[u8; Self::SIZE] {
&self.0
}
fn zero() -> Self {
Self(Arc::new(Zeroizing::new([0; Self::SIZE])))
}
fn as_mut(&mut self) -> &mut [u8] {
&mut **Arc::get_mut(&mut self.0).unwrap()
}
}
impl TryFrom<&[u8]> for SecretKey {
type Error = SecretKeyLengthError;
fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
if slice.len() >= Self::SIZE {
let mut key = Self::zero();
key.as_mut().copy_from_slice(slice);
Ok(key)
} else {
Err(SecretKeyLengthError)
}
}
}
impl AsRef<[u8]> for SecretKey {
fn as_ref(&self) -> &[u8] {
&self.0[..]
}
}
impl PartialEq for SecretKey {
fn eq(&self, other: &Self) -> bool {
self.as_array().ct_eq(other.as_array()).into()
}
}
impl Eq for SecretKey {}
impl fmt::Debug for SecretKey {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "****")
}
}
impl Serialize for SecretKey {
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serde_bytes::Bytes::new(self.as_ref()).serialize(s)
}
}
impl<'de> Deserialize<'de> for SecretKey {
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bytes: &serde_bytes::Bytes = Deserialize::deserialize(d)?;
if bytes.len() != Self::SIZE {
return Err(D::Error::invalid_length(
bytes.len(),
&format!("{}", Self::SIZE).as_str(),
));
}
let mut key = Self::zero();
key.as_mut().copy_from_slice(bytes);
Ok(key)
}
}
#[derive(Debug, Error)]
#[error("invalid secret key length")]
pub struct SecretKeyLengthError;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn serialize_deserialize_bincode() {
let orig = SecretKey::try_from(&b"abcdefghijklmnopqrstuvwxyz012345"[..]).unwrap();
let expected_serialized_hex =
"20000000000000006162636465666768696a6b6c6d6e6f707172737475767778797a303132333435";
let serialized = bincode::serialize(&orig).unwrap();
assert_eq!(hex::encode(&serialized), expected_serialized_hex);
let deserialized: SecretKey = bincode::deserialize(&serialized).unwrap();
assert_eq!(deserialized.as_ref(), orig.as_ref());
}
#[test]
fn serialize_deserialize_msgpack() {
let orig = SecretKey::try_from(&b"abcdefghijklmnopqrstuvwxyz012345"[..]).unwrap();
let expected_serialized_hex =
"c4206162636465666768696a6b6c6d6e6f707172737475767778797a303132333435";
let serialized = rmp_serde::to_vec(&orig).unwrap();
assert_eq!(hex::encode(&serialized), expected_serialized_hex);
let deserialized: SecretKey = rmp_serde::from_slice(&serialized).unwrap();
assert_eq!(deserialized.as_ref(), orig.as_ref());
}
}