{-# LANGUAGE CPP #-}
#ifndef MIN_VERSION_integer_gmp
#define MIN_VERSION_integer_gmp(a,b,c) 0
#endif
#if MIN_VERSION_integer_gmp(0,5,1)
{-# LANGUAGE MagicHash, UnboxedTuples, BangPatterns #-}
#endif
-- |
-- Module      : Crypto.Number.Serialize
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
-- fast serialization primitives for integer
module Crypto.Number.Serialize
    ( i2osp
    , os2ip
    , i2ospOf
    , i2ospOf_
    , lengthBytes
    ) where

import Data.ByteString (ByteString)
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString as B
import Foreign.Ptr

#if MIN_VERSION_integer_gmp(0,5,1)
#if __GLASGOW_HASKELL__ >= 710
import Control.Monad (void)
#endif
import GHC.Integer.GMP.Internals
import GHC.Base
import GHC.Ptr
import System.IO.Unsafe
import Foreign.ForeignPtr
#else
import Foreign.Storable
import Data.Bits
#endif

#if !MIN_VERSION_integer_gmp(0,5,1)
{-# INLINE divMod256 #-}
divMod256 :: Integer -> (Integer, Integer)
divMod256 n = (n `shiftR` 8, n .&. 0xff)
#endif

-- | os2ip converts a byte string into a positive integer
os2ip :: ByteString -> Integer
#if MIN_VERSION_integer_gmp(0,5,1)
os2ip :: ByteString -> Integer
os2ip bs :: ByteString
bs = IO Integer -> Integer
forall a. IO a -> a
unsafePerformIO (IO Integer -> Integer) -> IO Integer -> Integer
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8 -> (Ptr Word8 -> IO Integer) -> IO Integer
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO Integer) -> IO Integer)
-> (Ptr Word8 -> IO Integer) -> IO Integer
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr Word8
ptr ->
    let !(Ptr ad :: Addr#
ad) = (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
ofs)
#if __GLASGOW_HASKELL__ >= 710
     in Addr# -> Word# -> Int# -> IO Integer
importIntegerFromAddr Addr#
ad (Int# -> Word#
int2Word# Int#
n) 1#
#else
     in IO $ \s -> importIntegerFromAddr ad (int2Word# n) 1# s
#endif
  where !(fptr :: ForeignPtr Word8
fptr, ofs :: Int
ofs, !(I# n :: Int#
n)) = ByteString -> (ForeignPtr Word8, Int, Int)
B.toForeignPtr ByteString
bs
{-# NOINLINE os2ip #-}
#else
os2ip = B.foldl' (\a b -> (256 * a) .|. (fromIntegral b)) 0
{-# INLINE os2ip #-}
#endif

-- | i2osp converts a positive integer into a byte string
i2osp :: Integer -> ByteString
#if MIN_VERSION_integer_gmp(0,5,1)
i2osp :: Integer -> ByteString
i2osp 0 = Word8 -> ByteString
B.singleton 0
i2osp m :: Integer
m = Int -> (Ptr Word8 -> IO ()) -> ByteString
B.unsafeCreate (Int# -> Int
I# (Word# -> Int#
word2Int# Word#
sz)) Ptr Word8 -> IO ()
forall a. Ptr a -> IO ()
fillPtr
  where !sz :: Word#
sz = Integer -> Int# -> Word#
sizeInBaseInteger Integer
m 256#
#if __GLASGOW_HASKELL__ >= 710
        fillPtr :: Ptr a -> IO ()
fillPtr (Ptr srcAddr :: Addr#
srcAddr) = IO Word -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Word -> IO ()) -> IO Word -> IO ()
forall a b. (a -> b) -> a -> b
$ Integer -> Addr# -> Int# -> IO Word
exportIntegerToAddr Integer
m Addr#
srcAddr 1#
#else
        fillPtr (Ptr srcAddr) = IO $ \s -> case exportIntegerToAddr m srcAddr 1# s of
                                                (# s2, _ #) -> (# s2, () #)
#endif
{-# NOINLINE i2osp #-}
#else
i2osp m
    | m < 0     = error "i2osp: cannot convert a negative integer to a bytestring"
    | otherwise = B.reverse $ B.unfoldr fdivMod256 m
    where fdivMod256 0 = Nothing
          fdivMod256 n = Just (fromIntegral a,b) where (b,a) = divMod256 n
#endif


-- | just like i2osp, but take an extra parameter for size.
-- if the number is too big to fit in @len bytes, nothing is returned
-- otherwise the number is padded with 0 to fit the @len required.
--
-- FIXME: use unsafeCreate to fill the bytestring
i2ospOf :: Int -> Integer -> Maybe ByteString
#if MIN_VERSION_integer_gmp(0,5,1)
i2ospOf :: Int -> Integer -> Maybe ByteString
i2ospOf len :: Int
len m :: Integer
m
    | Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
len = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Integer -> ByteString
i2ospOf_ Int
len Integer
m
    | Bool
otherwise = Maybe ByteString
forall a. Maybe a
Nothing
  where !sz :: Int
sz = Int# -> Int
I# (Word# -> Int#
word2Int# (Integer -> Int# -> Word#
sizeInBaseInteger Integer
m 256#))
#else
i2ospOf len m
    | lenbytes < len  = Just $ B.replicate (len - lenbytes) 0 `B.append` bytes
    | lenbytes == len = Just bytes
    | otherwise       = Nothing
  where lenbytes = B.length bytes
        bytes    = i2osp m
#endif

-- | just like i2ospOf except that it doesn't expect a failure: i.e.
-- an integer larger than the number of output bytes requested
--
-- for example if you just took a modulo of the number that represent
-- the size (example the RSA modulo n).
i2ospOf_ :: Int -> Integer -> ByteString
#if MIN_VERSION_integer_gmp(0,5,1)
i2ospOf_ :: Int -> Integer -> ByteString
i2ospOf_ len :: Int
len m :: Integer
m = IO ByteString -> ByteString
forall a. IO a -> a
unsafePerformIO (IO ByteString -> ByteString) -> IO ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create Int
len Ptr Word8 -> IO ()
fillPtr
  where !sz :: Word#
sz = (Integer -> Int# -> Word#
sizeInBaseInteger Integer
m 256#)
        isz :: Int
isz = Int# -> Int
I# (Word# -> Int#
word2Int# Word#
sz)
        fillPtr :: Ptr Word8 -> IO ()
fillPtr ptr :: Ptr Word8
ptr
            | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
isz  = [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error "cannot compute i2ospOf_ with integer larger than output bytes"
            | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
isz =
                let !(Ptr srcAddr :: Addr#
srcAddr) = Ptr Word8
ptr in
#if __GLASGOW_HASKELL__ >= 710
                IO Word -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Integer -> Addr# -> Int# -> IO Word
exportIntegerToAddr Integer
m Addr#
srcAddr 1#)
#else
                IO $ \s -> case exportIntegerToAddr m srcAddr 1# s of
                                (# s2, _ #) -> (# s2, () #)
#endif
            | Bool
otherwise = do
                let z :: Int
z = Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
isz
                Ptr Word8
_ <- Ptr Word8 -> Word8 -> CSize -> IO (Ptr Word8)
B.memset Ptr Word8
ptr 0 (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
                let !(Ptr addr :: Addr#
addr) = Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
z
#if __GLASGOW_HASKELL__ >= 710
                IO Word -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Integer -> Addr# -> Int# -> IO Word
exportIntegerToAddr Integer
m Addr#
addr 1#)
#else
                IO $ \s -> case exportIntegerToAddr m addr 1# s of
                                (# s2, _ #) -> (# s2, () #)
#endif
{-# NOINLINE i2ospOf_ #-}
#else
i2ospOf_ len m = B.unsafeCreate len fillPtr
    where fillPtr srcPtr = loop m (srcPtr `plusPtr` (len-1))
            where loop n ptr = do
                      let (nn,a) = divMod256 n
                      poke ptr (fromIntegral a)
                      if ptr == srcPtr
                          then return ()
                          else (if nn == 0 then fillerLoop else loop nn) (ptr `plusPtr` (-1))
                  fillerLoop ptr = do
                      poke ptr 0
                      if ptr == srcPtr
                          then return ()
                          else fillerLoop (ptr `plusPtr` (-1))
{-# INLINE i2ospOf_ #-}
#endif

-- | returns the number of bytes to store an integer with i2osp
--
-- with integer-simple, this function is really slow.
lengthBytes :: Integer -> Int
#if MIN_VERSION_integer_gmp(0,5,1)
lengthBytes :: Integer -> Int
lengthBytes n :: Integer
n = Int# -> Int
I# (Word# -> Int#
word2Int# (Integer -> Int# -> Word#
sizeInBaseInteger Integer
n 256#))
#else
lengthBytes n
    | n < 256   = 1
    | otherwise = 1 + lengthBytes (n `shiftR` 8)
#endif