-- | /WARNING:/ Signature operations may leak the private key. Signature verification
-- should be safe.
module Crypto.PubKey.ECC.ECDSA
    ( module Crypto.Types.PubKey.ECDSA
    , signWith
    , sign
    , verify
    ) where

import Control.Monad
import Crypto.Random
import Data.Bits (shiftR)
import Data.ByteString (ByteString)
import Crypto.Number.ModArithmetic (inverse)
import Crypto.Number.Serialize
import Crypto.Number.Generate
import Crypto.Types.PubKey.ECDSA
import Crypto.Types.PubKey.ECC
import Crypto.PubKey.HashDescr
import Crypto.PubKey.ECC.Prim

-- | Sign message using the private key and an explicit k number.
--
-- /WARNING:/ Vulnerable to timing attacks.
signWith :: Integer         -- ^ k random number
         -> PrivateKey      -- ^ private key
         -> HashFunction    -- ^ hash function
         -> ByteString      -- ^ message to sign
         -> Maybe Signature
signWith :: Integer
-> PrivateKey -> HashFunction -> ByteString -> Maybe Signature
signWith k :: Integer
k (PrivateKey curve :: Curve
curve d :: Integer
d) hash :: HashFunction
hash msg :: ByteString
msg = do
    let z :: Integer
z = HashFunction -> ByteString -> Integer -> Integer
tHash HashFunction
hash ByteString
msg Integer
n
        CurveCommon _ _ g :: Point
g n :: Integer
n _ = Curve -> CurveCommon
common_curve Curve
curve
    let point :: Point
point = Curve -> Integer -> Point -> Point
pointMul Curve
curve Integer
k Point
g
    Integer
r <- case Point
point of
              PointO    -> Maybe Integer
forall a. Maybe a
Nothing
              Point x :: Integer
x _ -> Integer -> Maybe Integer
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Maybe Integer) -> Integer -> Maybe Integer
forall a b. (a -> b) -> a -> b
$ Integer
x Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
    Integer
kInv <- Integer -> Integer -> Maybe Integer
inverse Integer
k Integer
n
    let s :: Integer
s = Integer
kInv Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
z Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
d) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
    Bool -> Maybe () -> Maybe ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
r Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0 Bool -> Bool -> Bool
|| Integer
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== 0) Maybe ()
forall a. Maybe a
Nothing
    Signature -> Maybe Signature
forall (m :: * -> *) a. Monad m => a -> m a
return (Signature -> Maybe Signature) -> Signature -> Maybe Signature
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Signature
Signature Integer
r Integer
s

-- | Sign message using the private key.
--
-- /WARNING:/ Vulnerable to timing attacks.
sign :: CPRG g => g -> PrivateKey -> HashFunction -> ByteString -> (Signature, g)
sign :: g -> PrivateKey -> HashFunction -> ByteString -> (Signature, g)
sign rng :: g
rng pk :: PrivateKey
pk hash :: HashFunction
hash msg :: ByteString
msg =
    case Integer
-> PrivateKey -> HashFunction -> ByteString -> Maybe Signature
signWith Integer
k PrivateKey
pk HashFunction
hash ByteString
msg of
         Nothing  -> g -> PrivateKey -> HashFunction -> ByteString -> (Signature, g)
forall g.
CPRG g =>
g -> PrivateKey -> HashFunction -> ByteString -> (Signature, g)
sign g
rng' PrivateKey
pk HashFunction
hash ByteString
msg
         Just sig :: Signature
sig -> (Signature
sig, g
rng')
  where n :: Integer
n = CurveCommon -> Integer
ecc_n (CurveCommon -> Integer)
-> (Curve -> CurveCommon) -> Curve -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Curve -> CurveCommon
common_curve (Curve -> Integer) -> Curve -> Integer
forall a b. (a -> b) -> a -> b
$ PrivateKey -> Curve
private_curve PrivateKey
pk
        (k :: Integer
k, rng' :: g
rng') = g -> Integer -> Integer -> (Integer, g)
forall g. CPRG g => g -> Integer -> Integer -> (Integer, g)
generateBetween g
rng 1 (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- 1)

-- | Verify a bytestring using the public key.
verify :: HashFunction -> PublicKey -> Signature -> ByteString -> Bool
verify :: HashFunction -> PublicKey -> Signature -> ByteString -> Bool
verify _ (PublicKey _ PointO) _ _ = Bool
False
verify hash :: HashFunction
hash pk :: PublicKey
pk@(PublicKey curve :: Curve
curve q :: Point
q) (Signature r :: Integer
r s :: Integer
s) msg :: ByteString
msg
    | Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 1 Bool -> Bool -> Bool
|| Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n Bool -> Bool -> Bool
|| Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< 1 Bool -> Bool -> Bool
|| Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n = Bool
False
    | Bool
otherwise = Bool -> (Integer -> Bool) -> Maybe Integer -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Integer
r Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
==) (Maybe Integer -> Bool) -> Maybe Integer -> Bool
forall a b. (a -> b) -> a -> b
$ do
        Integer
w <- Integer -> Integer -> Maybe Integer
inverse Integer
s Integer
n
        let z :: Integer
z  = HashFunction -> ByteString -> Integer -> Integer
tHash HashFunction
hash ByteString
msg Integer
n
            u1 :: Integer
u1 = Integer
z Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
w Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
            u2 :: Integer
u2 = Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
w Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
            -- TODO: Use Shamir's trick
            g' :: Point
g' = Curve -> Integer -> Point -> Point
pointMul Curve
curve Integer
u1 Point
g
            q' :: Point
q' = Curve -> Integer -> Point -> Point
pointMul Curve
curve Integer
u2 Point
q
            x :: Point
x  = Curve -> Point -> Point -> Point
pointAdd Curve
curve Point
g' Point
q'
        case Point
x of
             PointO     -> Maybe Integer
forall a. Maybe a
Nothing
             Point x1 :: Integer
x1 _ -> Integer -> Maybe Integer
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Maybe Integer) -> Integer -> Maybe Integer
forall a b. (a -> b) -> a -> b
$ Integer
x1 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
  where n :: Integer
n = CurveCommon -> Integer
ecc_n CurveCommon
cc
        g :: Point
g = CurveCommon -> Point
ecc_g CurveCommon
cc
        cc :: CurveCommon
cc = Curve -> CurveCommon
common_curve (Curve -> CurveCommon) -> Curve -> CurveCommon
forall a b. (a -> b) -> a -> b
$ PublicKey -> Curve
public_curve PublicKey
pk

-- | Truncate and hash.
tHash ::  HashFunction -> ByteString -> Integer -> Integer
tHash :: HashFunction -> ByteString -> Integer -> Integer
tHash hash :: HashFunction
hash m :: ByteString
m n :: Integer
n
    | Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 0 = Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
shiftR Integer
e Int
d
    | Bool
otherwise = Integer
e
  where e :: Integer
e = ByteString -> Integer
os2ip (ByteString -> Integer) -> ByteString -> Integer
forall a b. (a -> b) -> a -> b
$ HashFunction
hash ByteString
m
        d :: Int
d = Integer -> Int
log2 Integer
e Int -> Int -> Int
forall a. Num a => a -> a -> a
- Integer -> Int
log2 Integer
n
        log2 :: Integer -> Int
log2 = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Double -> Int) -> (Integer -> Double) -> Integer -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase (2 :: Double) (Double -> Double) -> (Integer -> Double) -> Integer -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral