{-# LANGUAGE CPP #-}
module Crypto.PubKey.ECIES.Conduit
  ( encrypt
  , decrypt
  ) where

import           Control.Monad.Catch                  (MonadThrow, throwM)
import           Control.Monad.Trans.Class            (lift)
import qualified Crypto.Cipher.ChaCha                 as ChaCha
import qualified Crypto.Cipher.ChaChaPoly1305.Conduit as ChaCha
import qualified Crypto.ECC                           as ECC
import qualified Crypto.Error                         as CE
import           Crypto.Hash                          (SHA512 (..), hashWith)
import           Crypto.PubKey.ECIES                  (deriveDecrypt,
                                                       deriveEncrypt)
import           Crypto.Random                        (MonadRandom)
import qualified Data.ByteArray                       as BA
import           Data.ByteString                      (ByteString)
import qualified Data.ByteString                      as B
import qualified Data.ByteString.Lazy                 as BL
import           Data.Conduit                         (ConduitM, yield)
import qualified Data.Conduit.Binary                  as CB
import           Data.Proxy                           (Proxy (..))
import           System.IO.Unsafe                     (unsafePerformIO)

getNonceKey :: ECC.SharedSecret -> (ByteString, ByteString)
getNonceKey :: SharedSecret -> (ByteString, ByteString)
getNonceKey shared :: SharedSecret
shared =
  let state1 :: StateSimple
state1 = ByteString -> StateSimple
forall seed. ByteArrayAccess seed => seed -> StateSimple
ChaCha.initializeSimple (ByteString -> StateSimple) -> ByteString -> StateSimple
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take 40 (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Digest SHA512 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Digest SHA512 -> ByteString) -> Digest SHA512 -> ByteString
forall a b. (a -> b) -> a -> b
$ SHA512 -> SharedSecret -> Digest SHA512
forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith SHA512
SHA512 SharedSecret
shared
      (nonce :: ByteString
nonce, state2 :: StateSimple
state2) = StateSimple -> Int -> (ByteString, StateSimple)
forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple)
ChaCha.generateSimple StateSimple
state1 12
      (key :: ByteString
key, _) = StateSimple -> Int -> (ByteString, StateSimple)
forall ba. ByteArray ba => StateSimple -> Int -> (ba, StateSimple)
ChaCha.generateSimple StateSimple
state2 32
   in (ByteString
nonce, ByteString
key)

type Curve = ECC.Curve_P256R1

proxy :: Proxy Curve
proxy :: Proxy Curve
proxy = Proxy Curve
forall k (t :: k). Proxy t
Proxy

pointBinarySize :: Int
pointBinarySize :: Int
pointBinarySize = ByteString -> Int
B.length (ByteString -> Int) -> ByteString -> Int
forall a b. (a -> b) -> a -> b
$ Proxy Curve -> Point Curve -> ByteString
forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Point curve -> bs
ECC.encodePoint Proxy Curve
proxy Point Curve
Point
point
  where
    point :: Point
point = IO Point -> Point
forall a. IO a -> a
unsafePerformIO (KeyPair Curve -> Point
forall curve. KeyPair curve -> Point curve
ECC.keypairGetPublic (KeyPair Curve -> Point) -> IO (KeyPair Curve) -> IO Point
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Proxy Curve -> IO (KeyPair Curve)
forall curve (randomly :: * -> *) (proxy :: * -> *).
(EllipticCurve curve, MonadRandom randomly) =>
proxy curve -> randomly (KeyPair curve)
ECC.curveGenerateKeyPair Proxy Curve
proxy)
{-# NOINLINE pointBinarySize #-}

throwOnFail :: MonadThrow m => CE.CryptoFailable a -> m a
throwOnFail :: CryptoFailable a -> m a
throwOnFail (CE.CryptoPassed a :: a
a) = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
throwOnFail (CE.CryptoFailed e :: CryptoError
e) = CryptoError -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM CryptoError
e


encrypt
  :: (MonadThrow m, MonadRandom m)
  => ECC.Point Curve
  -> ConduitM ByteString ByteString m ()
encrypt :: Point Curve -> ConduitM ByteString ByteString m ()
encrypt point :: Point Curve
point = do
  (point' :: Point
point', shared :: SharedSecret
shared) <- m (CryptoFailable (Point, SharedSecret))
-> ConduitT
     ByteString ByteString m (CryptoFailable (Point, SharedSecret))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Proxy Curve
-> Point Curve -> m (CryptoFailable (Point Curve, SharedSecret))
forall (randomly :: * -> *) curve (proxy :: * -> *).
(MonadRandom randomly, EllipticCurveDH curve) =>
proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncryptCompat Proxy Curve
proxy Point Curve
point) ConduitT
  ByteString ByteString m (CryptoFailable (Point, SharedSecret))
-> (CryptoFailable (Point, SharedSecret)
    -> ConduitT ByteString ByteString m (Point, SharedSecret))
-> ConduitT ByteString ByteString m (Point, SharedSecret)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CryptoFailable (Point, SharedSecret)
-> ConduitT ByteString ByteString m (Point, SharedSecret)
forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail
  let (nonce :: ByteString
nonce, key :: ByteString
key) = SharedSecret -> (ByteString, ByteString)
getNonceKey SharedSecret
shared
  ByteString -> ConduitM ByteString ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (ByteString -> ConduitM ByteString ByteString m ())
-> ByteString -> ConduitM ByteString ByteString m ()
forall a b. (a -> b) -> a -> b
$ Proxy Curve -> Point Curve -> ByteString
forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Point curve -> bs
ECC.encodePoint Proxy Curve
proxy Point Curve
Point
point'
  ByteString -> ByteString -> ConduitM ByteString ByteString m ()
forall (m :: * -> *).
MonadThrow m =>
ByteString -> ByteString -> ConduitM ByteString ByteString m ()
ChaCha.encrypt ByteString
nonce ByteString
key
  where
#if MIN_VERSION_cryptonite(0,23,999)
    deriveEncryptCompat :: proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncryptCompat prx :: proxy curve
prx p :: Point curve
p = proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
forall (randomly :: * -> *) curve (proxy :: * -> *).
(MonadRandom randomly, EllipticCurveDH curve) =>
proxy curve
-> Point curve
-> randomly (CryptoFailable (Point curve, SharedSecret))
deriveEncrypt proxy curve
prx Point curve
p
#else
    deriveEncryptCompat prx p = CE.CryptoPassed <$> deriveEncrypt prx p
#endif

decrypt
  :: (MonadThrow m)
  => ECC.Scalar Curve
  -> ConduitM ByteString ByteString m ()
decrypt :: Scalar Curve -> ConduitM ByteString ByteString m ()
decrypt scalar :: Scalar Curve
scalar = do
  ByteString
pointBS <- (ByteString -> ByteString)
-> ConduitT ByteString ByteString m ByteString
-> ConduitT ByteString ByteString m ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
BL.toStrict (ConduitT ByteString ByteString m ByteString
 -> ConduitT ByteString ByteString m ByteString)
-> ConduitT ByteString ByteString m ByteString
-> ConduitT ByteString ByteString m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ConduitT ByteString ByteString m ByteString
forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take Int
pointBinarySize
  Point
point   <- CryptoFailable Point -> ConduitT ByteString ByteString m Point
forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail (Proxy Curve -> ByteString -> CryptoFailable (Point Curve)
forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> bs -> CryptoFailable (Point curve)
ECC.decodePoint Proxy Curve
proxy ByteString
pointBS)
  SharedSecret
shared  <- CryptoFailable SharedSecret
-> ConduitT ByteString ByteString m SharedSecret
forall (m :: * -> *) a. MonadThrow m => CryptoFailable a -> m a
throwOnFail (Proxy Curve
-> Point Curve -> Scalar Curve -> CryptoFailable SharedSecret
forall curve (proxy :: * -> *).
EllipticCurveDH curve =>
proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
deriveDecryptCompat Proxy Curve
proxy Point Curve
Point
point Scalar Curve
scalar)
  let (_nonce :: ByteString
_nonce, key :: ByteString
key) = SharedSecret -> (ByteString, ByteString)
getNonceKey SharedSecret
shared
  ByteString -> ConduitM ByteString ByteString m ()
forall (m :: * -> *).
MonadThrow m =>
ByteString -> ConduitM ByteString ByteString m ()
ChaCha.decrypt ByteString
key
  where
#if MIN_VERSION_cryptonite(0,23,999)
    deriveDecryptCompat :: proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
deriveDecryptCompat prx :: proxy curve
prx p :: Point curve
p s :: Scalar curve
s = proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
forall curve (proxy :: * -> *).
EllipticCurveDH curve =>
proxy curve
-> Point curve -> Scalar curve -> CryptoFailable SharedSecret
deriveDecrypt proxy curve
prx Point curve
p Scalar curve
s
#else
    deriveDecryptCompat prx p s = CE.CryptoPassed (deriveDecrypt prx p s)
#endif