-----------------------------------------------------------------------------
--
-- Module      :  Test.Tasty.TH
-- Copyright   :  Oscar Finnsson, Benno Fünfstück
-- License     :  BSD3
--
-- Maintainer  :  Benno Fünfstück
-- Stability   :
-- Portability :
--
--
-----------------------------------------------------------------------------
{-# LANGUAGE TemplateHaskell #-}

-- | This module provides TemplateHaskell functions to automatically generate
-- tasty TestTrees from specially named functions. See the README of the package
-- for examples.
--
-- Important: due to to the GHC staging restriction, you must put any uses of these
-- functions at the end of the file, or you may get errors due to missing definitions.
module Test.Tasty.TH
  ( testGroupGenerator
  , defaultMainGenerator
  , testGroupGeneratorFor
  , defaultMainGeneratorFor
  , extractTestFunctions
  , locationModule
  ) where

import Control.Monad (join)
import Control.Applicative
import Language.Haskell.Exts (parseFileContentsWithMode)
import Language.Haskell.Exts.Parser (ParseResult(..), defaultParseMode, parseFilename)
import qualified Language.Haskell.Exts.Syntax as S
import Language.Haskell.TH
import Data.Maybe
import Data.Data (gmapQ, Data)
import Data.Typeable (cast)
import Data.List (nub, isPrefixOf, find)
import qualified Data.Foldable as F

import Test.Tasty
import Prelude

-- | Convenience function that directly generates an `IO` action that may be used as the
-- main function. It's just a wrapper that applies 'defaultMain' to the 'TestTree' generated
-- by 'testGroupGenerator'.
--
-- Example usage:
--
-- @
-- -- properties, test cases, ....
--
-- main :: IO ()
-- main = $('defaultMainGenerator')
-- @
defaultMainGenerator :: ExpQ
defaultMainGenerator :: ExpQ
defaultMainGenerator = [| defaultMain $(testGroupGenerator) |]

-- | This function generates a 'TestTree' from functions in the current module. 
-- The test tree is named after the current module.
--
-- The following definitions are collected by `testGroupGenerator`:
--
-- * a test_something definition in the current module creates a sub-testGroup with the name "something"
-- * a prop_something definition in the current module is added as a QuickCheck property named "something"
-- * a case_something definition leads to a HUnit-Assertion test with the name "something"
--
-- Example usage:
--
-- @
-- prop_example :: Int -> Int -> Bool
-- prop_example a b = a + b == b + a
--
-- tests :: 'TestTree'
-- tests = $('testGroupGenerator')
-- @
testGroupGenerator :: ExpQ
testGroupGenerator :: ExpQ
testGroupGenerator = Q ExpQ -> ExpQ
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Q ExpQ -> ExpQ) -> Q ExpQ -> ExpQ
forall a b. (a -> b) -> a -> b
$ String -> [String] -> ExpQ
testGroupGeneratorFor (String -> [String] -> ExpQ) -> Q String -> Q ([String] -> ExpQ)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Loc -> String) -> Q Loc -> Q String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Loc -> String
loc_module Q Loc
location Q ([String] -> ExpQ) -> Q [String] -> Q ExpQ
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Q [String]
testFunctions
 where
  testFunctions :: Q [String]
testFunctions = Q Loc
location Q Loc -> (Loc -> Q [String]) -> Q [String]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO [String] -> Q [String]
forall a. IO a -> Q a
runIO (IO [String] -> Q [String])
-> (Loc -> IO [String]) -> Loc -> Q [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO [String]
extractTestFunctions (String -> IO [String]) -> (Loc -> String) -> Loc -> IO [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Loc -> String
loc_filename

-- | Retrieves all function names from the given file that would be discovered by 'testGroupGenerator'.
extractTestFunctions :: FilePath -> IO [String]
extractTestFunctions :: String -> IO [String]
extractTestFunctions filePath :: String
filePath = do
  String
file <- String -> IO String
readFile String
filePath
  -- we first try to parse the file using haskell-src-exts
  -- if that fails, we fallback to lexing each line, which is less
  -- accurate but is more reliable (haskell-src-exts sometimes struggles
  -- with less-common GHC extensions).
  let functions :: [String]
functions = [String] -> Maybe [String] -> [String]
forall a. a -> Maybe a -> a
fromMaybe (String -> [String]
lexed String
file) (String -> Maybe [String]
parsed String
file)
      filtered :: String -> [String]
filtered pat :: String
pat = (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (String
pat String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf`) [String]
functions
  [String] -> IO [String]
forall (m :: * -> *) a. Monad m => a -> m a
return ([String] -> IO [String])
-> ([String] -> [String]) -> [String] -> IO [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> IO [String]) -> [String] -> IO [String]
forall a b. (a -> b) -> a -> b
$ [[String]] -> [String]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String -> [String]
filtered "prop_", String -> [String]
filtered "case_", String -> [String]
filtered "test_"]
 where
  lexed :: String -> [String]
lexed = ((String, String) -> String) -> [(String, String)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (String, String) -> String
forall a b. (a, b) -> a
fst ([(String, String)] -> [String])
-> (String -> [(String, String)]) -> String -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> [(String, String)]) -> [String] -> [(String, String)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap String -> [(String, String)]
lex ([String] -> [(String, String)])
-> (String -> [String]) -> String -> [(String, String)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String]
lines
  
  parsed :: String -> Maybe [String]
parsed file :: String
file = case ParseMode -> String -> ParseResult (Module SrcSpanInfo)
parseFileContentsWithMode (ParseMode
defaultParseMode { parseFilename :: String
parseFilename = String
filePath }) String
file of
    ParseOk parsedModule :: Module SrcSpanInfo
parsedModule -> [String] -> Maybe [String]
forall a. a -> Maybe a
Just (Module SrcSpanInfo -> [String]
forall l. Data l => Module l -> [String]
declarations Module SrcSpanInfo
parsedModule)
    ParseFailed _ _ -> Maybe [String]
forall a. Maybe a
Nothing
  declarations :: Module l -> [String]
declarations (S.Module _ _ _ _ decls :: [Decl l]
decls) = (Decl l -> [String]) -> [Decl l] -> [String]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Decl l -> [String]
forall l. Data l => Decl l -> [String]
testFunName [Decl l]
decls
  declarations _ = []
  testFunName :: Decl l -> [String]
testFunName (S.PatBind _ pat :: Pat l
pat _ _) = Pat l -> [String]
forall l. Data l => Pat l -> [String]
patternVariables Pat l
pat
  testFunName (S.FunBind _ clauses :: [Match l]
clauses) = [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ((Match l -> String) -> [Match l] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Match l -> String
forall l. Match l -> String
clauseName [Match l]
clauses)
  testFunName _ = []
  clauseName :: Match l -> String
clauseName (S.Match _ name :: Name l
name _ _ _) = Name l -> String
forall l. Name l -> String
nameString Name l
name
  clauseName (S.InfixMatch _ _ name :: Name l
name _ _ _) = Name l -> String
forall l. Name l -> String
nameString Name l
name

-- | Convert a 'Name' to a 'String'
nameString :: S.Name l -> String
nameString :: Name l -> String
nameString (S.Ident _ n :: String
n) = String
n
nameString (S.Symbol _ n :: String
n) = String
n

-- | Find all variables that are bound in the given pattern.
patternVariables :: Data l => S.Pat l -> [String]
patternVariables :: Pat l -> [String]
patternVariables = Pat l -> [String]
forall l. Data l => Pat l -> [String]
go
 where
  go :: Pat l -> [String]
go (S.PVar _ name :: Name l
name) = [Name l -> String
forall l. Name l -> String
nameString Name l
name]
  go pat :: Pat l
pat = [[String]] -> [String]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[String]] -> [String]) -> [[String]] -> [String]
forall a b. (a -> b) -> a -> b
$ (forall d. Data d => d -> [String]) -> Pat l -> [[String]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
gmapQ ((Pat l -> [String]) -> Maybe (Pat l) -> [String]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
F.foldMap Pat l -> [String]
go (Maybe (Pat l) -> [String])
-> (d -> Maybe (Pat l)) -> d -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. d -> Maybe (Pat l)
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast) Pat l
pat

-- | Extract the name of the current module.
locationModule :: ExpQ
locationModule :: ExpQ
locationModule = do
  Loc
loc <- Q Loc
location
  Exp -> ExpQ
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ExpQ) -> Exp -> ExpQ
forall a b. (a -> b) -> a -> b
$ Lit -> Exp
LitE (Lit -> Exp) -> Lit -> Exp
forall a b. (a -> b) -> a -> b
$ String -> Lit
StringL (String -> Lit) -> String -> Lit
forall a b. (a -> b) -> a -> b
$ Loc -> String
loc_module Loc
loc

-- | Like 'testGroupGenerator', but generates a test group only including the specified function names.
-- The function names still need to follow the pattern of starting with one of @prop_@, @case_@ or @test_@.
testGroupGeneratorFor
  :: String   -- ^ The name of the test group itself
  -> [String] -- ^ The names of the functions which should be included in the test group
  -> ExpQ
testGroupGeneratorFor :: String -> [String] -> ExpQ
testGroupGeneratorFor name :: String
name functionNames :: [String]
functionNames = [| testGroup name $(listE (mapMaybe test functionNames)) |]
 where
  testFunctions :: [(String, String)]
testFunctions = [("prop_", "testProperty"), ("case_", "testCase"), ("test_", "testGroup")]
  getTestFunction :: String -> Maybe String
getTestFunction fname :: String
fname = (String, String) -> String
forall a b. (a, b) -> b
snd ((String, String) -> String)
-> Maybe (String, String) -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((String, String) -> Bool)
-> [(String, String)] -> Maybe (String, String)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
fname) (String -> Bool)
-> ((String, String) -> String) -> (String, String) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String, String) -> String
forall a b. (a, b) -> a
fst) [(String, String)]
testFunctions
  test :: String -> Maybe ExpQ
test fname :: String
fname = do
    String
fn <- String -> Maybe String
getTestFunction String
fname
    ExpQ -> Maybe ExpQ
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpQ -> Maybe ExpQ) -> ExpQ -> Maybe ExpQ
forall a b. (a -> b) -> a -> b
$ ExpQ -> ExpQ -> ExpQ
appE (ExpQ -> ExpQ -> ExpQ
appE (Name -> ExpQ
varE (Name -> ExpQ) -> Name -> ExpQ
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
fn) (String -> ExpQ
stringE (String -> String
fixName String
fname))) (Name -> ExpQ
varE (String -> Name
mkName String
fname))

-- | Like 'defaultMainGenerator', but only includes the specific function names in the test group.
-- The function names still need to follow the pattern of starting with one of @prop_@, @case_@ or @test_@.
defaultMainGeneratorFor
  :: String   -- ^ The name of the top-level test group
  -> [String] -- ^ The names of the functions which should be included in the test group
  -> ExpQ
defaultMainGeneratorFor :: String -> [String] -> ExpQ
defaultMainGeneratorFor name :: String
name fns :: [String]
fns = [| defaultMain $(testGroupGeneratorFor name fns) |]

fixName :: String -> String
fixName :: String -> String
fixName = Char -> Char -> String -> String
forall a. Eq a => a -> a -> [a] -> [a]
replace '_' ' ' (String -> String) -> (String -> String) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
forall a. [a] -> [a]
tail (String -> String) -> (String -> String) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= '_')

replace :: Eq a => a -> a -> [a] -> [a]
replace :: a -> a -> [a] -> [a]
replace b :: a
b v :: a
v = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (\i :: a
i -> if a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
i then a
v else a
i)