-- |A small selection of utilities that might be of use to others working with bytestring/number combinations.
module Crypto.Util where

import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeIndex, unsafeUseAsCStringLen)
import Data.Bits (shiftL, shiftR)
import Data.Bits (xor, setBit, shiftR, shiftL)
import Control.Exception (Exception, throw)
import Data.Tagged
import System.IO.Unsafe
import Foreign.C.Types
import Foreign.Ptr

-- |@incBS bs@ inefficiently computes the value @i2bs (8 * B.length bs) (bs2i bs + 1)@
incBS :: B.ByteString -> B.ByteString
incBS :: ByteString -> ByteString
incBS bs :: ByteString
bs = [ByteString] -> ByteString
B.concat (ByteString -> Int -> [ByteString]
go ByteString
bs (ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1))
  where
  go :: ByteString -> Int -> [ByteString]
go bs :: ByteString
bs i :: Int
i
        | ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0     = []
        | ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
i Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== 0xFF = (ByteString -> Int -> [ByteString]
go (ByteString -> ByteString
B.init ByteString
bs) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-1)) [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton 0]
        | Bool
otherwise            = [ByteString -> ByteString
B.init ByteString
bs] [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton (Word8 -> ByteString) -> Word8 -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
i) Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ 1]
{-# INLINE incBS #-}


-- |@i2bs bitLen i@ converts @i@ to a 'ByteString' of @bitLen@ bits (must be a multiple of 8).
i2bs :: Int -> Integer -> B.ByteString
i2bs :: Int -> Integer -> ByteString
i2bs l :: Int
l i :: Integer
i = (Int -> Maybe (Word8, Int)) -> Int -> ByteString
forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
B.unfoldr (\l' :: Int
l' -> if Int
l' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 0 then Maybe (Word8, Int)
forall a. Maybe a
Nothing else (Word8, Int) -> Maybe (Word8, Int)
forall a. a -> Maybe a
Just (Integer -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
i Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` Int
l'), Int
l' Int -> Int -> Int
forall a. Num a => a -> a -> a
- 8)) (Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-8)
{-# INLINE i2bs #-}

-- |@i2bs_unsized i@ converts @i@ to a 'ByteString' of sufficient bytes to express the integer.
-- The integer must be non-negative and a zero will be encoded in one byte.
i2bs_unsized :: Integer -> B.ByteString
i2bs_unsized :: Integer -> ByteString
i2bs_unsized 0 = Word8 -> ByteString
B.singleton 0
i2bs_unsized i :: Integer
i = ByteString -> ByteString
B.reverse (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ (Integer -> Maybe (Word8, Integer)) -> Integer -> ByteString
forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
B.unfoldr (\i' :: Integer
i' -> if Integer
i' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= 0 then Maybe (Word8, Integer)
forall a. Maybe a
Nothing else (Word8, Integer) -> Maybe (Word8, Integer)
forall a. a -> Maybe a
Just (Integer -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
i', (Integer
i' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` 8))) Integer
i
{-# INLINE i2bs_unsized #-}

-- | Useful utility to extract the result of a generator operation
-- and translate error results to exceptions.
throwLeft :: Exception e => Either e a -> a
throwLeft :: Either e a -> a
throwLeft (Left e :: e
e)  = e -> a
forall a e. Exception e => e -> a
throw e
e
throwLeft (Right a :: a
a) = a
a

-- |Obtain a tagged value for a particular instantiated type.
for :: Tagged a b -> a -> b
for :: Tagged a b -> a -> b
for t :: Tagged a b
t _ = Tagged a b -> b
forall k (s :: k) b. Tagged s b -> b
unTagged Tagged a b
t

-- |Infix `for` operator
(.::.) :: Tagged a b -> a -> b
.::. :: Tagged a b -> a -> b
(.::.) = Tagged a b -> a -> b
forall a b. Tagged a b -> a -> b
for

-- | Checks two bytestrings for equality without breaches for
-- timing attacks.
--
-- Semantically, @constTimeEq = (==)@.  However, @x == y@ takes less
-- time when the first byte is different than when the first byte
-- is equal.  This side channel allows an attacker to mount a
-- timing attack.  On the other hand, @constTimeEq@ always takes the
-- same time regardless of the bytestrings' contents, unless they are
-- of difference size.
--
-- You should always use @constTimeEq@ when comparing secrets,
-- otherwise you may leave a significant security hole
-- (cf. <http://codahale.com/a-lesson-in-timing-attacks/>).
constTimeEq :: B.ByteString -> B.ByteString -> Bool
constTimeEq :: ByteString -> ByteString -> Bool
constTimeEq s1 :: ByteString
s1 s2 :: ByteString
s2 =
    IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    ByteString -> (CStringLen -> IO Bool) -> IO Bool
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
s1 ((CStringLen -> IO Bool) -> IO Bool)
-> (CStringLen -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \(s1_ptr :: Ptr CChar
s1_ptr, s1_len :: Int
s1_len) ->
    ByteString -> (CStringLen -> IO Bool) -> IO Bool
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
s2 ((CStringLen -> IO Bool) -> IO Bool)
-> (CStringLen -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \(s2_ptr :: Ptr CChar
s2_ptr, s2_len :: Int
s2_len) ->
    if Int
s1_len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
s2_len
      then Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      else (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== 0) (CInt -> Bool) -> IO CInt -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr CChar -> Ptr CChar -> CInt -> IO CInt
c_constTimeEq Ptr CChar
s1_ptr Ptr CChar
s2_ptr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
s1_len)

foreign import ccall unsafe
   c_constTimeEq :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt

-- |Helper function to convert bytestrings to integers
bs2i :: B.ByteString -> Integer
bs2i :: ByteString -> Integer
bs2i bs :: ByteString
bs = (Integer -> Word8 -> Integer) -> Integer -> ByteString -> Integer
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
B.foldl' (\i :: Integer
i b :: Word8
b -> (Integer
i Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` 8) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Word8 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) 0 ByteString
bs
{-# INLINE bs2i #-}

-- |zipWith xor + Pack
-- As a result of rewrite rules, this should automatically be
-- optimized (at compile time). to use the bytestring libraries
-- 'zipWith'' function.
zwp' :: B.ByteString -> B.ByteString -> B.ByteString
zwp' :: ByteString -> ByteString -> ByteString
zwp' a :: ByteString
a = [Word8] -> ByteString
B.pack ([Word8] -> ByteString)
-> (ByteString -> [Word8]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
a
{-# INLINE zwp' #-}

-- |zipWith xor + Pack
--
-- This is written intentionally to take advantage
-- of the bytestring libraries 'zipWith'' rewrite rule but at the
-- extra cost of the resulting lazy bytestring being more fragmented
-- than either of the two inputs.
zwp :: L.ByteString -> L.ByteString -> L.ByteString
zwp :: ByteString -> ByteString -> ByteString
zwp  a :: ByteString
a b :: ByteString
b = 
        let as :: [ByteString]
as = ByteString -> [ByteString]
L.toChunks ByteString
a
            bs :: [ByteString]
bs = ByteString -> [ByteString]
L.toChunks ByteString
b
        in [ByteString] -> ByteString
L.fromChunks ([ByteString] -> [ByteString] -> [ByteString]
go [ByteString]
as [ByteString]
bs)
  where
  go :: [ByteString] -> [ByteString] -> [ByteString]
go [] _ = []
  go _ [] = []
  go (a :: ByteString
a:as :: [ByteString]
as) (b :: ByteString
b:bs :: [ByteString]
bs) =
        let l :: Int
l = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (ByteString -> Int
B.length ByteString
a) (ByteString -> Int
B.length ByteString
b)
            (a' :: ByteString
a',ar :: ByteString
ar) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
l ByteString
a
            (b' :: ByteString
b',br :: ByteString
br) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
l ByteString
b
            as' :: [ByteString]
as' = if ByteString -> Int
B.length ByteString
ar Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then [ByteString]
as else ByteString
ar ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
as
            bs' :: [ByteString]
bs' = if ByteString -> Int
B.length ByteString
br Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 then [ByteString]
bs else ByteString
br ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bs
        in (ByteString -> ByteString -> ByteString
zwp' ByteString
a' ByteString
b') ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString] -> [ByteString] -> [ByteString]
go [ByteString]
as' [ByteString]
bs'
{-# INLINEABLE zwp #-}

-- gather a specified number of bytes from the list of bytestrings
collect :: Int -> [B.ByteString] -> [B.ByteString]
collect :: Int -> [ByteString] -> [ByteString]
collect 0 _ = []
collect _ [] = []
collect i :: Int
i (b :: ByteString
b:bs :: [ByteString]
bs)
        | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i  = ByteString
b ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Int -> [ByteString] -> [ByteString]
collect (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len) [ByteString]
bs
        | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i = [Int -> ByteString -> ByteString
B.take Int
i ByteString
b]
  where
  len :: Int
len = ByteString -> Int
B.length ByteString
b
{-# INLINE collect #-}