{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Crypto.Nettle.ChaChaPoly1305
-- Copyright   :  (c) 2013 Stefan Bühler
-- License     :  MIT-style (see the file COPYING)
-- 
-- Maintainer  :  stbuehler@web.de
-- Stability   :  experimental
-- Portability :  portable
--
-- This module exports the ChaCha-Poly1305 AEAD cipher supported by nettle:
--   <http://www.lysator.liu.se/~nisse/nettle/>
--
-- Both ChaCha (the underlying cipher) and Poly1305 (the keyed hash) were
-- designed by D. J. Bernstein.
--
-----------------------------------------------------------------------------

module Crypto.Nettle.ChaChaPoly1305 (
	-- * ChaCha-Poly1305
	--
	-- No streaming interface is provided, as this basically violates the
	-- spirit of the "AEAD-should-be-simple-to-use" concept - you only can
	-- use the decrypted data after it got successfully verified.

	  chaChaPoly1305Encrypt
	, chaChaPoly1305Decrypt
	) where

import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
import Data.SecureMem

import Crypto.Nettle.Ciphers.ForeignImports
import Nettle.Utils

{-|
Encrypt plain text and create a verification tag for the encrypted text and some additional data.
@key@ and @nonce@ must not be reused together.
The returned tag is 16 bytes long, but may be shortened for verification (losing security).
-}
chaChaPoly1305Encrypt
	:: B.ByteString                 -- ^ @key@ (must be 32 bytes)
	-> B.ByteString                 -- ^ @nonce@ (must be 12 bytes)
	-> B.ByteString                 -- ^ @aad@ additional data to be verified
	-> B.ByteString                 -- ^ @plain@ data to encrypt
	-> (B.ByteString, B.ByteString) -- ^ returns (@cipher@, @tag@) ciphertext and verification tag
chaChaPoly1305Encrypt :: ByteString
-> ByteString
-> ByteString
-> ByteString
-> (ByteString, ByteString)
chaChaPoly1305Encrypt key :: ByteString
key nonce :: ByteString
nonce aad :: ByteString
aad plain :: ByteString
plain = IO (ByteString, ByteString) -> (ByteString, ByteString)
forall a. IO a -> a
unsafeDupablePerformIO (IO (ByteString, ByteString) -> (ByteString, ByteString))
-> IO (ByteString, ByteString) -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ do
	SecureMem
ctx <- Int -> IO SecureMem
allocateSecureMem Int
c_chacha_poly1305_ctx_size
	ByteString
tag <- Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create 16 (\_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
	ByteString
cipher <- Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create (ByteString -> Int
B.length ByteString
plain) (\_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
	ByteString -> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a. ByteString -> (Word -> Ptr Word8 -> IO a) -> IO a
withByteStringPtr ByteString
plain ((Word -> Ptr Word8 -> IO ()) -> IO ())
-> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \psize :: Word
psize pptr :: Ptr Word8
pptr ->
		ByteString -> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a. ByteString -> (Word -> Ptr Word8 -> IO a) -> IO a
withByteStringPtr ByteString
aad ((Word -> Ptr Word8 -> IO ()) -> IO ())
-> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \aadsize :: Word
aadsize aadptr :: Ptr Word8
aadptr ->
		ByteString -> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a. ByteString -> (Word -> Ptr Word8 -> IO a) -> IO a
withByteStringPtr ByteString
cipher ((Word -> Ptr Word8 -> IO ()) -> IO ())
-> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \_ cipherptr :: Ptr Word8
cipherptr ->
		ByteString -> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a. ByteString -> (Word -> Ptr Word8 -> IO a) -> IO a
withByteStringPtr ByteString
tag ((Word -> Ptr Word8 -> IO ()) -> IO ())
-> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \_ tagptr :: Ptr Word8
tagptr ->
		SecureMem -> (Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr SecureMem
ctx ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ctxptr :: Ptr Word8
ctxptr ->
		SecureMem -> (Int -> Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz (ByteString -> SecureMem
forall a. ToSecureMem a => a -> SecureMem
toSecureMem ByteString
key) ((Int -> Ptr Word8 -> IO ()) -> IO ())
-> (Int -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ksize :: Int
ksize kptr :: Ptr Word8
kptr -> if Int
ksize Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 32 then [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error "Invalid key length" else
		SecureMem -> (Int -> Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz (ByteString -> SecureMem
forall a. ToSecureMem a => a -> SecureMem
toSecureMem ByteString
nonce) ((Int -> Ptr Word8 -> IO ()) -> IO ())
-> (Int -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \nsize :: Int
nsize nptr :: Ptr Word8
nptr -> if Int
nsize Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 12 then [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error "Invalid nonce length" else do
		Ptr Word8 -> Ptr Word8 -> IO ()
c_chacha_poly1305_set_key Ptr Word8
ctxptr Ptr Word8
kptr
		Ptr Word8 -> Ptr Word8 -> IO ()
c_chacha_poly1305_set_nonce Ptr Word8
ctxptr Ptr Word8
nptr
		Ptr Word8 -> Word -> Ptr Word8 -> IO ()
c_chacha_poly1305_update Ptr Word8
ctxptr Word
aadsize Ptr Word8
aadptr
		NettleCryptFunc
c_chacha_poly1305_encrypt Ptr Word8
ctxptr Word
psize Ptr Word8
cipherptr Ptr Word8
pptr
		Ptr Word8 -> Word -> Ptr Word8 -> IO ()
c_chacha_poly1305_digest Ptr Word8
ctxptr 16 Ptr Word8
tagptr
	(ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
cipher, ByteString
tag)

{-|
Decrypt cipher text and verify a (possible shortened) tag for the encrypted text and some additional data.
@key@ and @nonce@ must not be reused together.
-}
chaChaPoly1305Decrypt :: B.ByteString -> B.ByteString -> B.ByteString -> B.ByteString -> B.ByteString -> Maybe B.ByteString
chaChaPoly1305Decrypt :: ByteString
-> ByteString
-> ByteString
-> ByteString
-> ByteString
-> Maybe ByteString
chaChaPoly1305Decrypt key :: ByteString
key nonce :: ByteString
nonce aad :: ByteString
aad cipher :: ByteString
cipher verifytag :: ByteString
verifytag = IO (Maybe ByteString) -> Maybe ByteString
forall a. IO a -> a
unsafeDupablePerformIO (IO (Maybe ByteString) -> Maybe ByteString)
-> IO (Maybe ByteString) -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ do
	SecureMem
ctx <- Int -> IO SecureMem
allocateSecureMem Int
c_chacha_poly1305_ctx_size
	ByteString
tag <- Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create 16 (\_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
	ByteString
plain <- Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create (ByteString -> Int
B.length ByteString
cipher) (\_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
	ByteString -> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a. ByteString -> (Word -> Ptr Word8 -> IO a) -> IO a
withByteStringPtr ByteString
cipher ((Word -> Ptr Word8 -> IO ()) -> IO ())
-> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \psize :: Word
psize pptr :: Ptr Word8
pptr ->
		ByteString -> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a. ByteString -> (Word -> Ptr Word8 -> IO a) -> IO a
withByteStringPtr ByteString
aad ((Word -> Ptr Word8 -> IO ()) -> IO ())
-> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \aadsize :: Word
aadsize aadptr :: Ptr Word8
aadptr ->
		ByteString -> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a. ByteString -> (Word -> Ptr Word8 -> IO a) -> IO a
withByteStringPtr ByteString
plain ((Word -> Ptr Word8 -> IO ()) -> IO ())
-> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \_ plainptr :: Ptr Word8
plainptr ->
		ByteString -> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a. ByteString -> (Word -> Ptr Word8 -> IO a) -> IO a
withByteStringPtr ByteString
tag ((Word -> Ptr Word8 -> IO ()) -> IO ())
-> (Word -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \_ tagptr :: Ptr Word8
tagptr ->
		SecureMem -> (Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr SecureMem
ctx ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ctxptr :: Ptr Word8
ctxptr ->
		SecureMem -> (Int -> Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz (ByteString -> SecureMem
forall a. ToSecureMem a => a -> SecureMem
toSecureMem ByteString
key) ((Int -> Ptr Word8 -> IO ()) -> IO ())
-> (Int -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ksize :: Int
ksize kptr :: Ptr Word8
kptr -> if Int
ksize Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 32 then [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error "Invalid key length" else
		SecureMem -> (Int -> Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz (ByteString -> SecureMem
forall a. ToSecureMem a => a -> SecureMem
toSecureMem ByteString
nonce) ((Int -> Ptr Word8 -> IO ()) -> IO ())
-> (Int -> Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \nsize :: Int
nsize nptr :: Ptr Word8
nptr -> if Int
nsize Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 12 then [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error "Invalid nonce length" else do
		Ptr Word8 -> Ptr Word8 -> IO ()
c_chacha_poly1305_set_key Ptr Word8
ctxptr Ptr Word8
kptr
		Ptr Word8 -> Ptr Word8 -> IO ()
c_chacha_poly1305_set_nonce Ptr Word8
ctxptr Ptr Word8
nptr
		Ptr Word8 -> Word -> Ptr Word8 -> IO ()
c_chacha_poly1305_update Ptr Word8
ctxptr Word
aadsize Ptr Word8
aadptr
		NettleCryptFunc
c_chacha_poly1305_decrypt Ptr Word8
ctxptr Word
psize Ptr Word8
plainptr Ptr Word8
pptr
		Ptr Word8 -> Word -> Ptr Word8 -> IO ()
c_chacha_poly1305_digest Ptr Word8
ctxptr 16 Ptr Word8
tagptr
	if Int -> ByteString -> ByteString
B.take (ByteString -> Int
B.length ByteString
verifytag) ByteString
tag ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
verifytag then Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
plain else Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing