-- Decrypt.hs: OpenPGP (RFC4880) recursive packet decryption
-- Copyright © 2013-2018  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).

{-# LANGUAGE FlexibleContexts #-}

module Data.Conduit.OpenPGP.Decrypt (
   conduitDecrypt
) where

import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Trans.Resource (MonadResource, MonadThrow)
import qualified Control.Monad.Trans.State.Lazy as S
import qualified Crypto.Hash as CH
import qualified Crypto.Hash.Algorithms as CHA
import qualified Data.ByteArray as BA
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import Data.Conduit.Serialization.Binary (conduitGet)
import Data.Conduit.OpenPGP.Compression (conduitDecompress)
import qualified Data.Conduit.List as CL
import Data.Maybe (fromJust, isNothing)
import Data.Binary (get)

import Codec.Encryption.OpenPGP.S2K (skesk2Key)
import Codec.Encryption.OpenPGP.CFB (decrypt, decryptOpenPGPCfb)
import Codec.Encryption.OpenPGP.Types

data RecursorState = RecursorState {
     _depth    :: Int
  , _lastPKESK :: Maybe PKESK
  , _lastSKESK :: Maybe SKESK
  , _lastLDP   :: Maybe LiteralData
} deriving (Eq, Show)

def :: RecursorState
def = RecursorState 0 Nothing Nothing Nothing

type InputCallback m = String -> m BL.ByteString

conduitDecrypt :: (MonadUnliftIO m, MonadResource m, MonadThrow m) => InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt = conduitDecrypt' 0

conduitDecrypt' :: (MonadUnliftIO m, MonadResource m, MonadThrow m) => Int -> InputCallback IO -> ConduitT Pkt Pkt m ()
conduitDecrypt' depth cb = CL.concatMapAccumM push def { _depth = depth }  -- FIXME: this depth stuff is convoluted
    where
        push :: (MonadUnliftIO m, MonadResource m, MonadThrow m) => Pkt -> RecursorState -> m (RecursorState, [Pkt])
        push i s
            | _depth s > 42 = fail "I think we've been quine-attacked"
            | otherwise = case i of
                       SKESKPkt{} -> return (s { _lastSKESK = Just (fromPkt i) }, [])
                       (SymEncDataPkt bs) -> do d <- decryptSEDP (_depth s) cb (fromJust . _lastSKESK $ s) bs
                                                return (processLDPs s d, d)
                       (SymEncIntegrityProtectedDataPkt _ bs) -> do d <- decryptSEIPDP (_depth s) cb (fromJust . _lastSKESK $ s) bs
                                                                    return (processLDPs s d, d)
                       m@(ModificationDetectionCodePkt mdc) -> do when (isNothing (_lastLDP s)) $ fail "MDC with no referent"
                                                                  when (fmap (BL.fromStrict . BA.convert . (CH.hashlazy :: BL.ByteString -> CH.Digest CHA.SHA1) . _literalDataPayload) (_lastLDP s) /= Just mdc) $ fail "MDC indicates tampering"
                                                                  return (s, [m])
                       p -> return (s, [p])
        processLDPs s ds = S.execState (mapM_ ldpCheck ds) s
        ldpCheck l@LiteralDataPkt{} = S.get >>= \o -> S.put o { _lastLDP = Just . fromPkt $ l }
        ldpCheck _ = return ()

decryptSEDP :: (MonadUnliftIO m, MonadIO m, MonadThrow m) => Int -> InputCallback IO -> SKESK -> BL.ByteString -> m [Pkt]
decryptSEDP depth cb skesk bs = do -- FIXME: this shouldn't pass the whole SKESK
    passphrase <- liftIO $ cb "Input the passphrase I want"
    let key = skesk2Key skesk passphrase
        decrypted = case decryptOpenPGPCfb (_skeskSymmetricAlgorithm skesk) (BL.toStrict bs) key of
                        Left e -> error e
                        Right x -> x
    runConduitRes $ CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .| conduitDecompress .| conduitDecrypt' depth cb .| CL.consume

decryptSEIPDP :: (MonadUnliftIO m, MonadIO m, MonadThrow m) => Int -> InputCallback IO -> SKESK -> BL.ByteString -> m [Pkt]
decryptSEIPDP depth cb skesk bs = do -- FIXME: this shouldn't pass the whole SKESK
    passphrase <- liftIO $ cb "Input the passphrase I want"
    let key = skesk2Key skesk passphrase
        decrypted = case decrypt (_skeskSymmetricAlgorithm skesk) (BL.toStrict bs) key of
                        Left e -> error e
                        Right x -> x
    runConduitRes $ CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .| conduitDecompress .| conduitDecrypt' depth cb .| CL.consume