-
Notifications
You must be signed in to change notification settings - Fork 4
/
RewriterPlugin.hs
165 lines (147 loc) · 6.84 KB
/
RewriterPlugin.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RecordWildCards #-}
module RewriterPlugin ( plugin ) where
-- ghc
import qualified GHC.Plugins as GHC
( Plugin(..), defaultPlugin, purePlugin )
-- ghc-tcplugin-api
import qualified GHC.TcPlugin.API as API
import GHC.TcPlugin.API
( TcPluginErrorMessage(..) )
--------------------------------------------------------------------------------
-- Plugin definition and setup/finalisation.
-- N.B. The qualified imports here are for clarity of exposition only.
-- In practice, I would recommend importing 'GHC.TcPlugin.API' unqualified.
-- Plugins must define "plugin :: GHC.Plugin", much like executables
-- must define "main :: IO ()".
plugin :: GHC.Plugin
plugin =
GHC.defaultPlugin
{ GHC.tcPlugin = \ _args -> Just $ API.mkTcPlugin tcPlugin
, GHC.pluginRecompile = GHC.purePlugin
}
-- The type-checking plugin itself: specify the four stages.
tcPlugin :: API.TcPlugin
tcPlugin =
API.TcPlugin
{ API.tcPluginInit = tcPluginInit
, API.tcPluginSolve = tcPluginSolve
, API.tcPluginRewrite = tcPluginRewrite
, API.tcPluginStop = tcPluginStop
}
-- Definitions used by the plugin.
data PluginDefs =
PluginDefs
{ natType :: !API.TcType
, zeroTyCon :: !API.TyCon
, succTyCon :: !API.TyCon
, badNatTyCon :: !API.TyCon
, addTyCon :: !API.TyCon
, cancellableClass :: !API.Class
}
-- Look-up a module in a package, using their names.
findModule :: API.MonadTcPlugin m => String -> m API.Module
findModule modName = do
let modlName = API.mkModuleName modName
pkgQual <- API.resolveImport modlName Nothing
findResult <- API.findImportedModule modlName pkgQual
case findResult of
API.Found _ res -> pure res
API.FoundMultiple _ -> error $ "RewriterPlugin: found multiple modules named " <> modName <> "."
_ -> error $ "RewriterPlugin: could not find any module named " <> modName <> "."
-- Initialise plugin state.
tcPluginInit :: API.TcPluginM API.Init PluginDefs
tcPluginInit = do
defsModule <- findModule "RewriterPlugin.Definitions"
natType <- fmap ( `API.mkTyConApp` [] ) . API.tcLookupTyCon =<< API.lookupOrig defsModule ( API.mkTcOcc "Nat" )
zeroTyCon <- fmap API.promoteDataCon . API.tcLookupDataCon =<< API.lookupOrig defsModule ( API.mkDataOcc "Zero" )
succTyCon <- fmap API.promoteDataCon . API.tcLookupDataCon =<< API.lookupOrig defsModule ( API.mkDataOcc "Succ" )
badNatTyCon <- fmap API.promoteDataCon . API.tcLookupDataCon =<< API.lookupOrig defsModule ( API.mkDataOcc "BadNat" )
addTyCon <- API.tcLookupTyCon =<< API.lookupOrig defsModule ( API.mkTcOcc "Add" )
cancellableClass <- API.tcLookupClass =<< API.lookupOrig defsModule ( API.mkClsOcc "Cancellable" )
pure ( PluginDefs { .. } )
-- The plugin does no constraint-solving, only type-family rewriting.
tcPluginSolve :: PluginDefs -> [ API.Ct ] -> [ API.Ct ] -> API.TcPluginM API.Solve API.TcPluginSolveResult
tcPluginSolve _ _ _ = pure $ API.TcPluginOk [] []
-- Nothing to shutdown.
tcPluginStop :: PluginDefs -> API.TcPluginM API.Stop ()
tcPluginStop _ = pure ()
--------------------------------------------------------------------------------
-- Simplification of type family applications.
-- Rewriting: we are only rewriting the 'Add' type family.
tcPluginRewrite :: PluginDefs -> API.UniqFM API.TyCon API.TcPluginRewriter
tcPluginRewrite defs@( PluginDefs { addTyCon } ) =
API.listToUFM
[ ( addTyCon, rewrite_add defs ) ]
-- Each type family has its own rewriting function.
-- Here we pass the rewrite_add function to rewrite the 'Add' type family.
-- Rewrite 'Add a b'.
rewrite_add :: PluginDefs -> [ API.Ct ] -> [ API.TcType ] -> API.TcPluginM API.Rewrite API.TcPluginRewriteResult
rewrite_add pluginDefs@( PluginDefs { .. } ) _givens tys
| [a,b] <- tys
= if
-- Cancelling zero: "Add Zero b = b", emitting a "Cancellable b" Wanted constraint.
| Just ( zero, [] ) <- API.splitTyConApp_maybe a
, zero == zeroTyCon
-> do
wanted <- mkCancellableWanted pluginDefs b
pure $ API.TcPluginRewriteTo
( API.mkTyFamAppReduction "RewriterPlugin" API.Nominal addTyCon tys b )
[ wanted ]
-- "Add a Zero = a", emitting a "Cancellable a" Wanted constraint.
| Just ( zero, [] ) <- API.splitTyConApp_maybe b
, zero == zeroTyCon
-> do
wanted <- mkCancellableWanted pluginDefs a
pure $ API.TcPluginRewriteTo
( API.mkTyFamAppReduction "RewriterPlugin" API.Nominal addTyCon tys a )
[ wanted ]
-- Erroring on 'BadNat'.
-- Add "BadNat b = BadNat", throwing an extra type error.
| Just ( badNat, [] ) <- API.splitTyConApp_maybe a
, badNat == badNatTyCon
-> throwTypeError badRedn $
Txt "RewriterPlugin detected a BadNat in the first argument of (+):"
:-:
PrintType a
-- "Add a BadNat = BadNat", throwing an extra type error.
| Just ( badNat, [] ) <- API.splitTyConApp_maybe b
, badNat == badNatTyCon
-> throwTypeError badRedn $
Txt "RewriterPlugin detected a BadNat in the second argument of (+):"
:-:
PrintType b
-- No rewriting otherwise.
| otherwise
-> pure API.TcPluginNoRewrite
| otherwise
= pure API.TcPluginNoRewrite
where
badRedn :: API.Reduction
badRedn = API.mkTyFamAppReduction "RewriterPlugin" API.Nominal
addTyCon tys (API.mkTyConApp badNatTyCon [])
-- Given the type "a", constructs a "Cancellable a" constraint
-- which has the source location information obtained from the rewriter environment.
mkCancellableWanted :: PluginDefs -> API.TcType -> API.TcPluginM API.Rewrite API.Ct
mkCancellableWanted ( PluginDefs { .. } ) ty = do
env <- API.askRewriteEnv
let
ctLoc :: API.CtLoc
ctLoc = API.bumpCtLocDepth $ API.rewriteEnvCtLoc env
ctPredTy :: API.PredType
ctPredTy = API.mkTyConApp ( API.classTyCon cancellableClass ) [ ty ]
ctEv <- API.setCtLocM ctLoc $ API.newWanted ctLoc ctPredTy
pure ( API.mkNonCanonical ctEv )
-- Return the given type family reduction, while emitting an additional type error with the given message.
throwTypeError :: API.Reduction -> API.TcPluginErrorMessage -> API.TcPluginM API.Rewrite API.TcPluginRewriteResult
throwTypeError badRedn msg = do
env <- API.askRewriteEnv
errorTy <- API.mkTcPluginErrorTy msg
let
errorCtLoc :: API.CtLoc
errorCtLoc = API.bumpCtLocDepth $ API.rewriteEnvCtLoc env
errorCtEv <- API.setCtLocM errorCtLoc $ API.newWanted errorCtLoc errorTy
pure $ API.TcPluginRewriteTo badRedn [ API.mkNonCanonical errorCtEv ]