{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}

module Network.DNS.Transport (
    Resolver(..)
  , resolve
  ) where

import Control.Concurrent.Async (async, waitAnyCancel)
import Control.Exception as E
import qualified Data.ByteString.Char8 as BS
import qualified Data.List.NonEmpty as NE
import Network.Socket (AddrInfo(..), SockAddr(..), Family(AF_INET, AF_INET6), Socket, SocketType(Stream), close, socket, connect, defaultProtocol)
import System.IO.Error (annotateIOError)
import System.Timeout (timeout)

import Network.DNS.IO
import Network.DNS.Imports
import Network.DNS.Types
import Network.DNS.Types.Internal

-- | Check response for a matching identifier and question.  If we ever do
-- pipelined TCP, we'll need to handle out of order responses.  See:
-- https://tools.ietf.org/html/rfc7766#section-7
checkResp :: [Question] -> Identifier -> DNSMessage -> Bool
checkResp q seqno resp =
   (identifier (header resp) == seqno) && (q == (question resp))

----------------------------------------------------------------

data TCPFallback = TCPFallback deriving (Show, Typeable)
instance Exception TCPFallback

type Rslv0 = Bool -> (Socket -> IO DNSMessage)
           -> IO (Either DNSError DNSMessage)

type Rslv1 = [Question]
          -> [ResourceRecord]
          -> Int -- Timeout
          -> Int -- Retry
          -> Rslv0

type TcpRslv = Identifier -> AddrInfo -> [Question] -> Int -- Timeout
            -> Bool -> IO DNSMessage

type UdpRslv = [ResourceRecord] -> Int -- Retry
            -> (Socket -> IO DNSMessage) -> TcpRslv

-- In lookup loop, we try UDP until we get a response.  If the response
-- is truncated, we try TCP once, with no further UDP retries.
--
-- For now, we optimize for low latency high-availability caches
-- (e.g.  running on a loopback interface), where TCP is cheap
-- enough.  We could attempt to complete the TCP lookup within the
-- original time budget of the truncated UDP query, by wrapping both
-- within a a single 'timeout' thereby staying within the original
-- time budget, but it seems saner to give TCP a full opportunity to
-- return results.  TCP latency after a truncated UDP reply will be
-- atypical.
--
-- Future improvements might also include support for TCP on the
-- initial query.
resolve :: Domain -> TYPE -> Resolver -> Rslv0
resolve dom typ rlv ad rcv
  | isIllegal dom = return $ Left IllegalDomain
  | onlyOne       = resolveOne        (head nss) (head gens) q edns tm retry ad rcv
  | concurrent    = resolveConcurrent nss        gens        q edns tm retry ad rcv
  | otherwise     = resolveSequential nss        gens        q edns tm retry ad rcv
  where
    q = case BS.last dom of
          '.' -> [Question dom typ]
          _   -> [Question (dom <> ".") typ]

    gens = NE.toList $ genIds rlv

    seed    = resolvseed rlv
    nss     = NE.toList $ nameservers seed
    onlyOne = length nss == 1

    conf       = resolvconf seed
    concurrent = resolvConcurrent conf
    tm         = resolvTimeout conf
    retry      = resolvRetry conf
    edns       = resolvEDNS conf

resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential nss gs q edns tm retry ad rcv = loop nss gs
  where
    loop [ai]     [gen] = resolveOne ai gen q edns tm retry ad rcv
    loop (ai:ais) (gen:gens) = do
        eres <- resolveOne ai gen q edns tm retry ad rcv
        case eres of
          Left  _ -> loop ais gens
          res     -> return res
    loop _  _     = error "resolveSequential:loop"

resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent nss gens q edns tm retry ad rcv = do
    asyncs <- mapM mkAsync $ zip nss gens
    snd <$> waitAnyCancel asyncs
  where
    mkAsync (ai,gen) = async $ resolveOne ai gen q edns tm retry ad rcv

resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne ai gen q edns tm retry ad rcv = do
    ident <- gen
    E.try $ udpTcpLookup edns retry rcv ident ai q tm ad

----------------------------------------------------------------

udpTcpLookup :: UdpRslv
udpTcpLookup edns retry rcv ident ai q tm ad =
    udpLookup edns retry rcv ident ai q tm ad `E.catch` \TCPFallback ->
        tcpLookup ident ai q tm ad

----------------------------------------------------------------

ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError ai tag ioe = throwIO $ NetworkFailure aioe
  where
    aioe = annotateIOError ioe (show ai) Nothing $ Just tag

----------------------------------------------------------------

udpOpen :: AddrInfo -> IO Socket
udpOpen ai = do
    sock <- socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai)
    connect sock (addrAddress ai)
    return sock

-- This throws DNSError or TCPFallback.
udpLookup :: UdpRslv
udpLookup edns retry rcv ident ai q tm ad = do
    let qry = encodeQuestions ident q edns ad
        ednsRetry = not $ null edns
    E.handle (ioErrorToDNSError ai "UDP") $
      bracket (udpOpen ai) close (loop qry ednsRetry 0 RetryLimitExceeded)
  where
    loop qry ednsRetry cnt err sock
      | cnt == retry = E.throwIO err
      | otherwise    = do
          mres <- timeout tm (send sock qry >> getAns sock)
          case mres of
              Nothing  -> loop qry ednsRetry (cnt + 1) RetryLimitExceeded sock
              Just res -> do
                      let flgs = flags$ header res
                          truncated = trunCation flgs
                          rc = rcode flgs
                      if truncated then
                          E.throwIO TCPFallback
                      else if ednsRetry && rc == FormatErr then
                          let nonednsQuery = encodeQuestions ident q [] ad
                          in loop nonednsQuery False cnt RetryLimitExceeded sock
                      else
                          return res

    -- | Closed UDP ports are occasionally re-used for a new query, with
    -- the nameserver returning an unexpected answer to the wrong socket.
    -- Such answers should be simply dropped, with the client continuing
    -- to wait for the right answer, without resending the question.
    -- Note, this eliminates sequence mismatch as a UDP error condition,
    -- instead we'll time out if no matching answer arrives.
    --
    getAns sock = do
        mres <- rcv sock
        if checkResp q ident mres
        then return mres
        else getAns sock

----------------------------------------------------------------

-- Create a TCP socket with the given socket address.
tcpOpen :: SockAddr -> IO Socket
tcpOpen peer = case peer of
    SockAddrInet{}  -> socket AF_INET  Stream defaultProtocol
    SockAddrInet6{} -> socket AF_INET6 Stream defaultProtocol
    _               -> E.throwIO ServerFailure

-- Perform a DNS query over TCP, if we were successful in creating
-- the TCP socket.
-- This throws DNSError only.
tcpLookup :: TcpRslv
tcpLookup ident ai q tm ad =
    E.handle (ioErrorToDNSError ai "TCP") $ bracket (tcpOpen addr) close perform
  where
    addr = addrAddress ai
    perform vc = do
        let qry = encodeQuestions ident q [] ad
        mres <- timeout tm $ do
            connect vc addr
            sendVC vc qry
            receiveVC vc
        case mres of
            Nothing                     -> E.throwIO TimeoutExpired
            Just res
                | checkResp q ident res -> return res
                | otherwise             -> E.throwIO SequenceNumberMismatch

----------------------------------------------------------------

badLength :: Domain -> Bool
badLength dom
    | BS.null dom        = True
    | BS.last dom == '.' = BS.length dom > 254
    | otherwise          = BS.length dom > 253

isIllegal :: Domain -> Bool
isIllegal dom
  | badLength dom               = True
  | '.' `BS.notElem` dom        = True
  | ':' `BS.elem` dom           = True
  | '/' `BS.elem` dom           = True
  | any (\x -> BS.length x > 63)
        (BS.split '.' dom)      = True
  | otherwise                   = False