{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module Internal
(
newWanted
, newGiven
, newDerived
, evByFiat
, lookupModule
, lookupName
, tracePlugin
, flattenGivens
, mkSubst
, mkSubst'
, substType
, substCt
)
where
import GHC.Tc.Plugin (TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPlugin
(newDerived, newWanted, getTopEnv, tcPluginIO, findImportedModule)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult(..))
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import GHC.Tc.Utils.TcType (TcType)
import Data.Maybe (mapMaybe)
import GhcApi.Constraint (Ct(..), CtEvidence(..), CtLoc)
import GhcApi.GhcPlugins
import Internal.Type (substType)
import Internal.Constraint (newGiven, flatToCt, mkSubst, overEvidencePredType)
import Internal.Evidence (evByFiat)
{-# ANN fr_mod "HLint: ignore Use camelCase" #-}
pattern FoundModule :: Module -> FindResult
pattern $mFoundModule :: forall {r}. FindResult -> (Module -> r) -> (Void# -> r) -> r
FoundModule a <- Found _ a
fr_mod :: a -> a
fr_mod :: forall a. a -> a
fr_mod = a -> a
forall a. a -> a
id
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted = CtLoc -> PredType -> TcPluginM CtEvidence
TcPlugin.newWanted
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived = CtLoc -> PredType -> TcPluginM CtEvidence
TcPlugin.newDerived
lookupModule :: ModuleName
-> FastString
-> TcPluginM Module
lookupModule :: ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
mod_nm FastString
_pkg = do
HscEnv
hsc_env <- TcPluginM HscEnv
TcPlugin.getTopEnv
FindResult
found_module <- IO FindResult -> TcPluginM FindResult
forall a. IO a -> TcPluginM a
TcPlugin.tcPluginIO (IO FindResult -> TcPluginM FindResult)
-> IO FindResult -> TcPluginM FindResult
forall a b. (a -> b) -> a -> b
$ HscEnv -> ModuleName -> IO FindResult
findPluginModule HscEnv
hsc_env ModuleName
mod_nm
case FindResult
found_module of
FoundModule Module
h -> Module -> TcPluginM Module
forall (m :: * -> *) a. Monad m => a -> m a
return (Module -> Module
forall a. a -> a
fr_mod Module
h)
FindResult
_ -> do
FindResult
found_module' <- ModuleName -> Maybe FastString -> TcPluginM FindResult
TcPlugin.findImportedModule ModuleName
mod_nm (Maybe FastString -> TcPluginM FindResult)
-> Maybe FastString -> TcPluginM FindResult
forall a b. (a -> b) -> a -> b
$ FastString -> Maybe FastString
forall a. a -> Maybe a
Just (FastString -> Maybe FastString) -> FastString -> Maybe FastString
forall a b. (a -> b) -> a -> b
$ String -> FastString
fsLit String
"this"
case FindResult
found_module' of
FoundModule Module
h -> Module -> TcPluginM Module
forall (m :: * -> *) a. Monad m => a -> m a
return (Module -> Module
forall a. a -> a
fr_mod Module
h)
FindResult
_ -> String -> SDoc -> TcPluginM Module
forall a. String -> SDoc -> a
panicDoc String
"Couldn't find module" (ModuleName -> SDoc
forall a. Outputable a => a -> SDoc
ppr ModuleName
mod_nm)
lookupName :: Module -> OccName -> TcPluginM Name
lookupName :: Module -> OccName -> TcPluginM Name
lookupName = Module -> OccName -> TcPluginM Name
lookupOrig
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin String
s TcPlugin{TcPluginM s
s -> TcPluginM ()
s -> TcPluginSolver
tcPluginInit :: ()
tcPluginSolve :: ()
tcPluginStop :: ()
tcPluginStop :: s -> TcPluginM ()
tcPluginSolve :: s -> TcPluginSolver
tcPluginInit :: TcPluginM s
..} = TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin { tcPluginInit :: TcPluginM s
tcPluginInit = TcPluginM s
traceInit
, tcPluginSolve :: s -> TcPluginSolver
tcPluginSolve = s -> TcPluginSolver
traceSolve
, tcPluginStop :: s -> TcPluginM ()
tcPluginStop = s -> TcPluginM ()
traceStop
}
where
traceInit :: TcPluginM s
traceInit = do
TcPluginM ()
initializeStaticFlags
String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginInit " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty TcPluginM () -> TcPluginM s -> TcPluginM s
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TcPluginM s
tcPluginInit
traceStop :: s -> TcPluginM ()
traceStop s
z = String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginStop " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty TcPluginM () -> TcPluginM () -> TcPluginM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> s -> TcPluginM ()
tcPluginStop s
z
traceSolve :: s -> TcPluginSolver
traceSolve s
z [Ct]
given [Ct]
derived [Ct]
wanted = do
String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve start " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"given =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
given
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"derived =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
derived
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"wanted =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
wanted)
TcPluginResult
r <- s -> TcPluginSolver
tcPluginSolve s
z [Ct]
given [Ct]
derived [Ct]
wanted
case TcPluginResult
r of
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new -> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve ok " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"solved =" SDoc -> SDoc -> SDoc
<+> [(EvTerm, Ct)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
solved
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"new =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
new)
TcPluginContradiction [Ct]
bad -> String -> SDoc -> TcPluginM ()
tcPluginTrace
(String
"tcPluginSolve contradiction " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"bad =" SDoc -> SDoc -> SDoc
<+> [Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
bad)
TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return TcPluginResult
r
initializeStaticFlags :: TcPluginM ()
initializeStaticFlags :: TcPluginM ()
initializeStaticFlags = () -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
flattenGivens :: [Ct] -> [Ct]
flattenGivens :: [Ct] -> [Ct]
flattenGivens [Ct]
givens =
([((TcTyVar, PredType), Ct)] -> Maybe Ct)
-> [[((TcTyVar, PredType), Ct)]] -> [Ct]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [((TcTyVar, PredType), Ct)] -> Maybe Ct
flatToCt [[((TcTyVar, PredType), Ct)]]
flat [Ct] -> [Ct] -> [Ct]
forall a. [a] -> [a] -> [a]
++ (Ct -> Ct) -> [Ct] -> [Ct]
forall a b. (a -> b) -> [a] -> [b]
map ([(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subst') [Ct]
givens
where
subst :: [((TcTyVar, PredType), Ct)]
subst = [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' [Ct]
givens
([[((TcTyVar, PredType), Ct)]]
flat,[(TcTyVar, PredType)]
subst')
= ([[((TcTyVar, PredType), Ct)]] -> [(TcTyVar, PredType)])
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
-> ([[((TcTyVar, PredType), Ct)]], [(TcTyVar, PredType)])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ((((TcTyVar, PredType), Ct) -> (TcTyVar, PredType))
-> [((TcTyVar, PredType), Ct)] -> [(TcTyVar, PredType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, PredType), Ct) -> (TcTyVar, PredType)
forall a b. (a, b) -> a
fst ([((TcTyVar, PredType), Ct)] -> [(TcTyVar, PredType)])
-> ([[((TcTyVar, PredType), Ct)]] -> [((TcTyVar, PredType), Ct)])
-> [[((TcTyVar, PredType), Ct)]]
-> [(TcTyVar, PredType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[((TcTyVar, PredType), Ct)]] -> [((TcTyVar, PredType), Ct)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat)
(([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
-> ([[((TcTyVar, PredType), Ct)]], [(TcTyVar, PredType)]))
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
-> ([[((TcTyVar, PredType), Ct)]], [(TcTyVar, PredType)])
forall a b. (a -> b) -> a -> b
$ ([((TcTyVar, PredType), Ct)] -> Bool)
-> [[((TcTyVar, PredType), Ct)]]
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
2) (Int -> Bool)
-> ([((TcTyVar, PredType), Ct)] -> Int)
-> [((TcTyVar, PredType), Ct)]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [((TcTyVar, PredType), Ct)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length)
([[((TcTyVar, PredType), Ct)]]
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]]))
-> [[((TcTyVar, PredType), Ct)]]
-> ([[((TcTyVar, PredType), Ct)]], [[((TcTyVar, PredType), Ct)]])
forall a b. (a -> b) -> a -> b
$ (((TcTyVar, PredType), Ct) -> ((TcTyVar, PredType), Ct) -> Bool)
-> [((TcTyVar, PredType), Ct)] -> [[((TcTyVar, PredType), Ct)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (TcTyVar -> TcTyVar -> Bool
forall a. Eq a => a -> a -> Bool
(==) (TcTyVar -> TcTyVar -> Bool)
-> (((TcTyVar, PredType), Ct) -> TcTyVar)
-> ((TcTyVar, PredType), Ct)
-> ((TcTyVar, PredType), Ct)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ((TcTyVar, PredType) -> TcTyVar
forall a b. (a, b) -> a
fst((TcTyVar, PredType) -> TcTyVar)
-> (((TcTyVar, PredType), Ct) -> (TcTyVar, PredType))
-> ((TcTyVar, PredType), Ct)
-> TcTyVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((TcTyVar, PredType), Ct) -> (TcTyVar, PredType)
forall a b. (a, b) -> a
fst))
([((TcTyVar, PredType), Ct)] -> [[((TcTyVar, PredType), Ct)]])
-> [((TcTyVar, PredType), Ct)] -> [[((TcTyVar, PredType), Ct)]]
forall a b. (a -> b) -> a -> b
$ (((TcTyVar, PredType), Ct) -> TcTyVar)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn ((TcTyVar, PredType) -> TcTyVar
forall a b. (a, b) -> a
fst((TcTyVar, PredType) -> TcTyVar)
-> (((TcTyVar, PredType), Ct) -> (TcTyVar, PredType))
-> ((TcTyVar, PredType), Ct)
-> TcTyVar
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((TcTyVar, PredType), Ct) -> (TcTyVar, PredType)
forall a b. (a, b) -> a
fst) [((TcTyVar, PredType), Ct)]
subst
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' :: [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' = (((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)])
-> [((TcTyVar, PredType), Ct)]
-> [((TcTyVar, PredType), Ct)]
-> [((TcTyVar, PredType), Ct)]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
substSubst [] ([((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)])
-> ([Ct] -> [((TcTyVar, PredType), Ct)])
-> [Ct]
-> [((TcTyVar, PredType), Ct)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ct -> Maybe ((TcTyVar, PredType), Ct))
-> [Ct] -> [((TcTyVar, PredType), Ct)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ct -> Maybe ((TcTyVar, PredType), Ct)
mkSubst
where
substSubst :: ((TcTyVar,TcType),Ct)
-> [((TcTyVar,TcType),Ct)]
-> [((TcTyVar,TcType),Ct)]
substSubst :: ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
substSubst ((TcTyVar
tv,PredType
t),Ct
ct) [((TcTyVar, PredType), Ct)]
s = ((TcTyVar
tv,[(TcTyVar, PredType)] -> PredType -> PredType
substType ((((TcTyVar, PredType), Ct) -> (TcTyVar, PredType))
-> [((TcTyVar, PredType), Ct)] -> [(TcTyVar, PredType)]
forall a b. (a -> b) -> [a] -> [b]
map ((TcTyVar, PredType), Ct) -> (TcTyVar, PredType)
forall a b. (a, b) -> a
fst [((TcTyVar, PredType), Ct)]
s) PredType
t),Ct
ct)
((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
forall a. a -> [a] -> [a]
: (((TcTyVar, PredType), Ct) -> ((TcTyVar, PredType), Ct))
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
forall a b. (a -> b) -> [a] -> [b]
map (((TcTyVar, PredType) -> (TcTyVar, PredType))
-> ((TcTyVar, PredType), Ct) -> ((TcTyVar, PredType), Ct)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ((PredType -> PredType)
-> (TcTyVar, PredType) -> (TcTyVar, PredType)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar
tv,PredType
t)]))) [((TcTyVar, PredType), Ct)]
s
substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt :: [(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subst = (PredType -> PredType) -> Ct -> Ct
overEvidencePredType ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst)