{-# LANGUAGE DeriveDataTypeable #-}
module Crypto.Cipher.ChaChaPoly1305.Conduit
  ( encrypt
  , decrypt
  , ChaChaException (..)
  ) where

import           Control.Exception            (assert)
import           Control.Monad.Catch          (Exception, MonadThrow, throwM)
import qualified Crypto.Cipher.ChaChaPoly1305 as Cha
import qualified Crypto.Error                 as CE
import qualified Crypto.MAC.Poly1305          as Poly1305
import qualified Data.ByteArray               as BA
import           Data.ByteString              (ByteString)
import qualified Data.ByteString              as B
import qualified Data.ByteString.Lazy         as BL
import           Data.Conduit                 (ConduitM, await, leftover, yield)
import qualified Data.Conduit.Binary          as CB
import           Data.Typeable                (Typeable)

cf :: MonadThrow m
   => (CE.CryptoError -> ChaChaException)
   -> CE.CryptoFailable a
   -> m a
cf :: (CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf _ (CE.CryptoPassed x :: a
x) = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
cf f :: CryptoError -> ChaChaException
f (CE.CryptoFailed e :: CryptoError
e) = ChaChaException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (CryptoError -> ChaChaException
f CryptoError
e)

data ChaChaException
  = EncryptNonceException !CE.CryptoError
  | EncryptKeyException !CE.CryptoError
  | DecryptNonceException !CE.CryptoError
  | DecryptKeyException !CE.CryptoError
  | MismatchedAuth
  deriving (Int -> ChaChaException -> ShowS
[ChaChaException] -> ShowS
ChaChaException -> String
(Int -> ChaChaException -> ShowS)
-> (ChaChaException -> String)
-> ([ChaChaException] -> ShowS)
-> Show ChaChaException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ChaChaException] -> ShowS
$cshowList :: [ChaChaException] -> ShowS
show :: ChaChaException -> String
$cshow :: ChaChaException -> String
showsPrec :: Int -> ChaChaException -> ShowS
$cshowsPrec :: Int -> ChaChaException -> ShowS
Show, Typeable)
instance Exception ChaChaException

encrypt
  :: MonadThrow m
  => ByteString -- ^ nonce (12 random bytes)
  -> ByteString -- ^ symmetric key (32 bytes)
  -> ConduitM ByteString ByteString m ()
encrypt :: ByteString -> ByteString -> ConduitM ByteString ByteString m ()
encrypt nonceBS :: ByteString
nonceBS key :: ByteString
key = do
  Nonce
nonce <- (CryptoError -> ChaChaException)
-> CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce
forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
EncryptNonceException (CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce)
-> CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce
forall a b. (a -> b) -> a -> b
$ ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
Cha.nonce12 ByteString
nonceBS
  State
state0 <- (CryptoError -> ChaChaException)
-> CryptoFailable State -> ConduitT ByteString ByteString m State
forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
EncryptKeyException (CryptoFailable State -> ConduitT ByteString ByteString m State)
-> CryptoFailable State -> ConduitT ByteString ByteString m State
forall a b. (a -> b) -> a -> b
$ ByteString -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
Cha.initialize ByteString
key Nonce
nonce
  ByteString -> ConduitM ByteString ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
nonceBS
  let loop :: State -> ConduitT o o m ()
loop state1 :: State
state1 = do
        Maybe o
mbs <- ConduitT o o m (Maybe o)
forall (m :: * -> *) i. Monad m => Consumer i m (Maybe i)
await
        case Maybe o
mbs of
          Nothing -> o -> ConduitT o o m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (o -> ConduitT o o m ()) -> o -> ConduitT o o m ()
forall a b. (a -> b) -> a -> b
$ Auth -> o
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Auth -> o) -> Auth -> o
forall a b. (a -> b) -> a -> b
$ State -> Auth
Cha.finalize State
state1
          Just bs :: o
bs -> do
            let (bs' :: o
bs', state2 :: State
state2) = o -> State -> (o, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
Cha.encrypt o
bs State
state1
            o -> ConduitT o o m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield o
bs'
            State -> ConduitT o o m ()
loop State
state2
  State -> ConduitM ByteString ByteString m ()
forall (m :: * -> *) o.
(Monad m, ByteArray o) =>
State -> ConduitT o o m ()
loop (State -> ConduitM ByteString ByteString m ())
-> State -> ConduitM ByteString ByteString m ()
forall a b. (a -> b) -> a -> b
$ State -> State
Cha.finalizeAAD State
state0

decrypt
  :: MonadThrow m
  => ByteString -- ^ symmetric key (32 bytes)
  -> ConduitM ByteString ByteString m ()
decrypt :: ByteString -> ConduitM ByteString ByteString m ()
decrypt key :: ByteString
key = do
  ByteString
nonceBS <- Int -> ConduitT ByteString ByteString m ByteString
forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take 12
  Nonce
nonce <- (CryptoError -> ChaChaException)
-> CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce
forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
DecryptNonceException (CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce)
-> CryptoFailable Nonce -> ConduitT ByteString ByteString m Nonce
forall a b. (a -> b) -> a -> b
$ ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
Cha.nonce12 (ByteString -> CryptoFailable Nonce)
-> ByteString -> CryptoFailable Nonce
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict ByteString
nonceBS
  State
state0 <- (CryptoError -> ChaChaException)
-> CryptoFailable State -> ConduitT ByteString ByteString m State
forall (m :: * -> *) a.
MonadThrow m =>
(CryptoError -> ChaChaException) -> CryptoFailable a -> m a
cf CryptoError -> ChaChaException
DecryptKeyException (CryptoFailable State -> ConduitT ByteString ByteString m State)
-> CryptoFailable State -> ConduitT ByteString ByteString m State
forall a b. (a -> b) -> a -> b
$ ByteString -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
Cha.initialize ByteString
key Nonce
nonce
  let loop :: State -> ConduitT ByteString ByteString m ()
loop state1 :: State
state1 = do
        Either ByteString ByteString
ebs <- (ByteString -> ByteString)
-> ConduitT ByteString ByteString m (Either ByteString ByteString)
forall (m :: * -> *) o.
Monad m =>
(ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 ByteString -> ByteString
forall a. a -> a
id
        case Either ByteString ByteString
ebs of
          Left final :: ByteString
final ->
            case ByteString -> CryptoFailable Auth
forall b. ByteArrayAccess b => b -> CryptoFailable Auth
Poly1305.authTag ByteString
final of
              CE.CryptoPassed final' :: Auth
final' | State -> Auth
Cha.finalize State
state1 Auth -> Auth -> Bool
forall a. Eq a => a -> a -> Bool
== Auth
final' -> () -> ConduitT ByteString ByteString m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
              _ -> ChaChaException -> ConduitT ByteString ByteString m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ChaChaException
MismatchedAuth
          Right bs :: ByteString
bs -> do
            let (bs' :: ByteString
bs', state2 :: State
state2) = ByteString -> State -> (ByteString, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
Cha.decrypt ByteString
bs State
state1
            ByteString -> ConduitT ByteString ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs'
            State -> ConduitT ByteString ByteString m ()
loop State
state2
  State -> ConduitM ByteString ByteString m ()
forall (m :: * -> *).
MonadThrow m =>
State -> ConduitT ByteString ByteString m ()
loop (State -> ConduitM ByteString ByteString m ())
-> State -> ConduitM ByteString ByteString m ()
forall a b. (a -> b) -> a -> b
$ State -> State
Cha.finalizeAAD State
state0
  where
    awaitExcept16 :: (ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 front :: ByteString -> ByteString
front = do
      Maybe ByteString
mbs <- ConduitT ByteString o m (Maybe ByteString)
forall (m :: * -> *) i. Monad m => Consumer i m (Maybe i)
await
      case Maybe ByteString
mbs of
        Nothing -> Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ByteString ByteString
 -> ConduitT ByteString o m (Either ByteString ByteString))
-> Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either ByteString ByteString
forall a b. a -> Either a b
Left (ByteString -> Either ByteString ByteString)
-> ByteString -> Either ByteString ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
front ByteString
B.empty
        Just bs :: ByteString
bs -> do
          let bs' :: ByteString
bs' = ByteString -> ByteString
front ByteString
bs
          if ByteString -> Int
B.length ByteString
bs' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 16
            then do
              let (x :: ByteString
x, y :: ByteString
y) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteString -> Int
B.length ByteString
bs' Int -> Int -> Int
forall a. Num a => a -> a -> a
- 16) ByteString
bs'
              Bool
-> (ByteString -> ConduitT ByteString o m ())
-> ByteString
-> ConduitT ByteString o m ()
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ByteString -> Int
B.length ByteString
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 16) ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
y
              Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ByteString ByteString
 -> ConduitT ByteString o m (Either ByteString ByteString))
-> Either ByteString ByteString
-> ConduitT ByteString o m (Either ByteString ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either ByteString ByteString
forall a b. b -> Either a b
Right ByteString
x
            else (ByteString -> ByteString)
-> ConduitT ByteString o m (Either ByteString ByteString)
awaitExcept16 (ByteString -> ByteString -> ByteString
B.append ByteString
bs')