{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Handshake.State13
( getTxState
, getRxState
, setTxState
, setRxState
, clearTxState
, clearRxState
, setHelloParameters13
, transcriptHash
, wrapAsMessageHash13
, PendingAction(..)
, setPendingActions
, popPendingAction
) where
import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Imports
import Network.TLS.Util
getTxState :: Context -> IO (Hash, Cipher, ByteString)
getTxState :: Context -> IO (Hash, Cipher, ByteString)
getTxState ctx :: Context
ctx = Context
-> (Context -> MVar RecordState) -> IO (Hash, Cipher, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxTxState
getRxState :: Context -> IO (Hash, Cipher, ByteString)
getRxState :: Context -> IO (Hash, Cipher, ByteString)
getRxState ctx :: Context
ctx = Context
-> (Context -> MVar RecordState) -> IO (Hash, Cipher, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxRxState
getXState :: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, ByteString)
getXState :: Context
-> (Context -> MVar RecordState) -> IO (Hash, Cipher, ByteString)
getXState ctx :: Context
ctx func :: Context -> MVar RecordState
func = do
RecordState
tx <- MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
func Context
ctx)
let Just usedCipher :: Cipher
usedCipher = RecordState -> Maybe Cipher
stCipher RecordState
tx
usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
secret :: ByteString
secret = CryptState -> ByteString
cstMacSecret (CryptState -> ByteString) -> CryptState -> ByteString
forall a b. (a -> b) -> a -> b
$ RecordState -> CryptState
stCryptState RecordState
tx
(Hash, Cipher, ByteString) -> IO (Hash, Cipher, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Hash
usedHash, Cipher
usedCipher, ByteString
secret)
setTxState :: Context -> Hash -> Cipher -> ByteString -> IO ()
setTxState :: Context -> Hash -> Cipher -> ByteString -> IO ()
setTxState = (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> ByteString
-> IO ()
setXState Context -> MVar RecordState
ctxTxState BulkDirection
BulkEncrypt
setRxState :: Context -> Hash -> Cipher -> ByteString -> IO ()
setRxState :: Context -> Hash -> Cipher -> ByteString -> IO ()
setRxState = (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> ByteString
-> IO ()
setXState Context -> MVar RecordState
ctxRxState BulkDirection
BulkDecrypt
setXState :: (Context -> MVar RecordState) -> BulkDirection
-> Context -> Hash -> Cipher -> ByteString
-> IO ()
setXState :: (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> ByteString
-> IO ()
setXState func :: Context -> MVar RecordState
func encOrDec :: BulkDirection
encOrDec ctx :: Context
ctx h :: Hash
h cipher :: Cipher
cipher secret :: ByteString
secret =
MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\_ -> RecordState -> IO RecordState
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt)
where
bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher
keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
ivSize :: Int
ivSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max 8 (Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkExplicitIV Bulk
bulk)
key :: ByteString
key = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret "key" "" Int
keySize
iv :: ByteString
iv = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret "iv" "" Int
ivSize
cst :: CryptState
cst = $WCryptState :: BulkState -> ByteString -> ByteString -> CryptState
CryptState {
cstKey :: BulkState
cstKey = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk BulkDirection
encOrDec ByteString
key
, cstIV :: ByteString
cstIV = ByteString
iv
, cstMacSecret :: ByteString
cstMacSecret = ByteString
secret
}
rt :: RecordState
rt = $WRecordState :: Maybe Cipher
-> Compression -> CryptState -> MacState -> RecordState
RecordState {
stCryptState :: CryptState
stCryptState = CryptState
cst
, stMacState :: MacState
stMacState = MacState :: Word64 -> MacState
MacState { msSequence :: Word64
msSequence = 0 }
, stCipher :: Maybe Cipher
stCipher = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
, stCompression :: Compression
stCompression = Compression
nullCompression
}
clearTxState :: Context -> IO ()
clearTxState :: Context -> IO ()
clearTxState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxTxState
clearRxState :: Context -> IO ()
clearRxState :: Context -> IO ()
clearRxState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxRxState
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState func :: Context -> MVar RecordState
func ctx :: Context
ctx =
MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\rt :: RecordState
rt -> RecordState -> IO RecordState
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt { stCipher :: Maybe Cipher
stCipher = Maybe Cipher
forall a. Maybe a
Nothing })
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 cipher :: Cipher
cipher = do
HandshakeState
hst <- HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
case HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst of
Nothing -> do
HandshakeState -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put HandshakeState
hst {
hstPendingCipher :: Maybe Cipher
hstPendingCipher = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
, hstPendingCompression :: Compression
hstPendingCompression = Compression
nullCompression
, hstHandshakeDigest :: HandshakeDigest
hstHandshakeDigest = HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeDigest -> HandshakeDigest)
-> HandshakeDigest -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst
}
Either TLSError () -> HandshakeM (Either TLSError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
Just oldcipher :: Cipher
oldcipher
| Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
oldcipher -> Either TLSError () -> HandshakeM (Either TLSError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
| Bool
otherwise -> Either TLSError () -> HandshakeM (Either TLSError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError ()
forall a b. a -> Either a b
Left (TLSError -> Either TLSError ()) -> TLSError -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol ("TLS 1.3 cipher changed after hello retry", Bool
True, AlertDescription
IllegalParameter)
where
hashAlg :: Hash
hashAlg = Cipher -> Hash
cipherHash Cipher
cipher
updateDigest :: HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeMessages bytes :: [ByteString]
bytes) = HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
updateDigest (HandshakeDigestContext _) = String -> HandshakeDigest
forall a. HasCallStack => String -> a
error "cannot initialize digest with another digest"
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
Cipher
cipher <- HandshakeM Cipher
getPendingCipher
Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest (Cipher -> Hash
cipherHash Cipher
cipher) ByteString -> ByteString
foldFunc
where
foldFunc :: ByteString -> ByteString
foldFunc dig :: ByteString
dig = [ByteString] -> ByteString
B.concat [ "\254\0\0"
, Word8 -> ByteString
B.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
dig)
, ByteString
dig
]
transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash :: Context -> m ByteString
transcriptHash ctx :: Context
ctx = do
HandshakeState
hst <- String -> Maybe HandshakeState -> HandshakeState
forall a. String -> Maybe a -> a
fromJust "HState" (Maybe HandshakeState -> HandshakeState)
-> m (Maybe HandshakeState) -> m HandshakeState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> m (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
HandshakeDigestContext hashCtx :: HashCtx
hashCtx -> ByteString -> m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> m ByteString) -> ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString
hashFinal HashCtx
hashCtx
HandshakeMessages _ -> String -> m ByteString
forall a. HasCallStack => String -> a
error "un-initialized handshake digest"
setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions ctx :: Context
ctx = IORef [PendingAction] -> [PendingAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef [PendingAction]
ctxPendingActions Context
ctx)
popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction ctx :: Context
ctx = do
let ref :: IORef [PendingAction]
ref = Context -> IORef [PendingAction]
ctxPendingActions Context
ctx
[PendingAction]
actions <- IORef [PendingAction] -> IO [PendingAction]
forall a. IORef a -> IO a
readIORef IORef [PendingAction]
ref
case [PendingAction]
actions of
bs :: PendingAction
bs:bss :: [PendingAction]
bss -> IORef [PendingAction] -> [PendingAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [PendingAction]
ref [PendingAction]
bss IO () -> IO (Maybe PendingAction) -> IO (Maybe PendingAction)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe PendingAction -> IO (Maybe PendingAction)
forall (m :: * -> *) a. Monad m => a -> m a
return (PendingAction -> Maybe PendingAction
forall a. a -> Maybe a
Just PendingAction
bs)
[] -> Maybe PendingAction -> IO (Maybe PendingAction)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PendingAction
forall a. Maybe a
Nothing