module E.Subst( doSubst, doSubst', eAp, litSMapM, subst, subst', substMap, substMap', substMap'', typeSubst, typeSubst' ) where -- This is a little tricky. {- Consider the following example. fn = \x0 -> let x1 = 10+x0 -- x1 is only used once, let's inline it. in (\x0 -> x1+x0) -- x0 from the outer lambda isn't used. Simply inlining x1 will give this errornous result: fn = \x0 -> (\x0 -> (10+x0)+x0) We solve this by renaming variable whenever they clash with the current scope: fn = \x0 -> (\x1 -> (10+x0)+x1) Another solution would be to assign a globally unique id to each variable. However, in a pure and lazy language like Haskell, renaming variables on the fly is easier and quite fast. New ids are currently generated by selecting psuedo random numbers and checking if they're free. Another posibility would be to select the highest known id number + 1. See Name.Id.newId for more information. -} import Control.Monad.Reader import Data.Monoid import qualified Data.Traversable as T import List hiding(union,insert,delete) import E.E import E.FreeVars() import Name.Id import Name.Names (tc_Arrow) import {-# SOURCE #-} E.Show import Support.FreeVars import GenUtil import Util.SetLike as S import Util.HasSize import qualified Data.Set as Set eLetRec :: [(TVr,E)] -> E -> E eLetRec ds e = f (filter ((/= 0) . tvrIdent . fst) ds) where f [] = e f ds = ELetRec ds e -- | Basic substitution routine subst :: TVr -- ^ Variable to substitute -> E -- ^ What to substitute with -> E -- ^ input term -> E -- ^ output term subst (TVr { tvrIdent = 0 }) _ e = e subst (TVr { tvrIdent = i }) w e = doSubst' False False (msingleton i w) (\n -> n `member` (freeVars w `union` freeVars e :: IdSet)) e -- | Identitcal to 'subst' except that it substitutes inside the local types -- for variables in expressions. This should not be used because it breaks the -- sharing of types between a binding site of a variable and its uses and can -- lead to inconsistant terms. However, it is sometimes useful to create -- transient terms for typechecking. subst' :: TVr -> E -> E -> E subst' (TVr { tvrIdent = 0 }) _ e = e subst' (TVr { tvrIdent = (i) }) w e = doSubst' True False (msingleton i w) (\n -> n `member` (freeVars w `union` freeVars e :: IdSet)) e litSMapM f LitCons { litName = s, litArgs = es, litType = t, litAliasFor = af } = do t' <- f t es' <- mapM f es return $ LitCons s es' t' af litSMapM f (LitInt n t) = do t' <- f t return $ LitInt n t' substMap :: IdMap E -> E -> E substMap im e = doSubst' False False im (\n -> n `member` (unions $ (freeVars e :: IdSet):map freeVars (melems im))) e -- | doesn't seed with free variables. substMap' :: IdMap E -> E -> E substMap' im = doSubst' False False im (`mmember` im) -- | doesn't seed with free variables. substMap'' :: IdMap (Maybe E) -> E -> E substMap'' im = doSubst' False False (mapMaybeIdMap id im) (`mmember` im) -- Monadic code is so much nicer doSubst :: Bool -> Bool -> IdMap (Maybe E) -> E -> E doSubst substInVars allShadow bm e = doSubst' substInVars allShadow (mapMaybeIdMap id bm) (`mmember` bm) e doSubst' :: Bool -> Bool -> IdMap E -> (Id -> Bool) -> E -> E doSubst' substInVars allShadow bm check e = f e (Set.empty, bm) where f :: E -> (Set.Set Id, IdMap E) -> E f eo@(EVar tvr@(TVr { tvrIdent = i, tvrType = t })) = do (_,mp) <- ask case mlookup i mp of Just v -> return v _ | substInVars -> f t >>= \t' -> return $ EVar (tvr { tvrType = t'}) | otherwise -> return eo f (ELam tvr e) = lp ELam tvr e f (EPi tvr e) = lp EPi tvr e f (EAp a b) = liftM2 eAp (f a) (f b) f (EError x e) = liftM (EError x) (f e) f (EPrim x es e) = liftM2 (EPrim x) (mapM f es) (f e) f ELetRec { eDefs = dl, eBody = e } = do (as,rs) <- mapMntvr (fsts dl) local (foldr (.) id rs) $ do ds <- mapM f (snds dl) e' <- f e return $ ELetRec (zip as ds) e' f (ELit l) = liftM ELit $ litSMapM f l f Unknown = return Unknown f e@(ESort {}) = return e f ec@(ECase {}) = do e' <- f $ eCaseScrutinee ec (b',r) <- ntvr Set.empty $ eCaseBind ec d <- local r $ T.mapM f $ eCaseDefault ec let da (Alt lc@LitCons { litName = s, litArgs = vs, litType = t } e) = do t' <- f t (as,rs) <- mapMntvr vs e' <- local (foldr (.) id rs) $ f e return $ Alt lc { litArgs = as, litType = t' } e' da (Alt l e) = do l' <- T.mapM f l e' <- f e return $ Alt l' e' alts <- local r (mapM da $ eCaseAlts ec) nty <- f (eCaseType ec) return $ caseUpdate ec { eCaseScrutinee = e', eCaseDefault = d, eCaseBind = b', eCaseAlts = alts, eCaseType = nty } lp lam tvr@(TVr { tvrIdent = n, tvrType = t}) e | n == 0 || (allShadow && n `notElem` freeVars e) = do t' <- f t e' <- local (\(s,m) -> (Set.insert n s, mdelete n m)) $ f e return $ lam (tvr { tvrIdent = 0, tvrType = t'}) e' lp lam tvr e = do (tv,r) <- ntvr Set.empty tvr e' <- local r $ f e return $ lam tv e' mapMntvr ts = f ts [] where f [] xs = return $ unzip $ reverse xs f (t:ts) rs = do (t',r) <- ntvr vs t local r $ f ts ((t',r):rs) vs = Set.fromList [ tvrIdent x | x <- ts ] ntvr xs tvr@(TVr { tvrIdent = 0, tvrType = t}) = do t' <- f t let nvr = (tvr { tvrType = t'}) return (nvr,id) ntvr xs tvr@(TVr {tvrIdent = i, tvrType = t}) = do t' <- f t (s,ss) <- ask let i' = mnv allShadow xs i check s ss let nvr = (tvr { tvrIdent = i', tvrType = t'}) return (nvr,\(s,m) -> (Set.insert i' . Set.insert i $ s, minsert i (EVar nvr) . mdelete i' $ m)) mnv :: Bool -> Set.Set Id -> Id -> (Id -> Bool) -> Set.Set Id -> IdMap a -> Id mnv allShadow xs i checkTaken s ss | allShadow = newId (Set.size xs + Set.size s + size ss) (not . scheck) | isInvalidId i || scheck i = newId (Set.size xs + Set.size s + size ss) (not . check) -- It is very important that we don't check for 'xs' membership in the guard above. | otherwise = i where scheck n = n `mmember` ss || n `member` s || checkTaken n check n = scheck n || n `member` xs eAp (EPi t b) e = if tvrIdent t == 0 then b else subst t e b eAp (ELam t b) e = if tvrIdent t == 0 then b else subst t e b --eAp (EPrim n es t@(EPi _ _)) b = EPrim n (es ++ [b]) (eAp t b) -- only apply if type is pi-like eAp (ELit lc@LitCons { litArgs = es, litType = (EPi t r) }) b = ELit lc { litArgs = es ++ [b], litType = subst t b r } eAp (ELit LitCons { litArgs = es, litAliasFor = Just af }) b = foldl eAp af (es ++ [b]) --eAp a@ELit {} b = error $ "very strange application: (" ++ prettyE a ++ ") (" ++ prettyE b ++ ")" eAp (EError s t) b = EError s (eAp t b) eAp a b = EAp a b typeSubst' :: IdMap E -> IdMap E -> E -> E typeSubst' termSub typeSub e | isEmpty termSub && isEmpty typeSub = e --typeSubst' termSub typeSub e = typeSubst (Map.map Just termSub `Map.union` Map.fromAscList [ (x,Map.lookup x termSub) | x <- fvs]) typeSub e where -- fvs = Set.toAscList (freeVars e `Set.union` fvmap termSub `Set.union` fvmap typeSub) -- fvmap m = Set.unions (map freeVars (Map.elems m)) typeSubst' termSub typeSub e = typeSubst (fmap Just termSub `union` fmap ((`mlookup` termSub) . tvrIdent) fvs) typeSub e where fvs :: IdMap TVr fvs = (freeVars e `union` fvmap termSub `union` fvmap typeSub) fvmap m = unions (map freeVars (melems m)) substType t e e' = typeSubst (freeVars e `union` freeVars e') (msingleton t e) e' -- | substitution routine that can substitute different values at the term and type level. -- this is useful to enforce the invarient that let-bound variables must not occur at the type level, yet -- non-atomic values (even typelike ones) cannot appear in argument positions at the term level. typeSubst :: IdMap (Maybe E) -- ^ substitution to carry out at term level as well as a list of in-scope variables -> IdMap E -- ^ substitution to carry out at type level -> (E -> E) -- ^ the substitution function typeSubst termSubst typeSubst e | isEmpty termSubst && isEmpty typeSubst = e typeSubst termSubst typeSubst e = f e (False,termSubst',typeSubst) where termSubst' = termSubst `union` fmap (const Nothing) typeSubst f :: E -> (Bool,IdMap (Maybe E),IdMap E) -> E f eo@(EVar tvr@(TVr { tvrIdent = i, tvrType = t })) = do (wh,trm,tp) <- ask case (wh,mlookup i trm, mlookup i tp) of (False,(Just (Just v)),_) -> return v (True,_,(Just v)) -> return v _ -> return eo f (ELam tvr e) = lp ELam tvr e f (EPi tvr e) = lp EPi tvr e f (EAp a b) = liftM2 eAp (f a) (f b) f (EError x e) = liftM (EError x) (inType $ f e) f (EPrim x es e) = liftM2 (EPrim x) (mapM f es) (inType $ f e) f ELetRec { eDefs = dl, eBody = e } = do (as,rs) <- liftM unzip $ mapMntvr (fsts dl) local (foldr (.) id rs) $ do ds <- mapM f (snds dl) e' <- f e return $ ELetRec (zip as ds) e' f (ELit l) = liftM ELit $ litSMapM l f Unknown = return Unknown f e@(ESort {}) = return e f ec@(ECase {}) = do e' <- f $ eCaseScrutinee ec (b',r) <- ntvr Set.empty $ eCaseBind ec d <- local r $ T.mapM f $ eCaseDefault ec let da (Alt lc@LitCons { litName = s, litArgs = vs, litType = t } e) = do t' <- inType $ f t (as,rs) <- liftM unzip $ mapMntvr vs e' <- local (foldr (.) id rs) $ f e return $ Alt lc { litArgs = as, litType = t' } e' da (Alt (LitInt n t) e) = do t' <- inType (f t) e' <- f e return $ Alt (LitInt n t') e' alts <- (mapM da $ eCaseAlts ec) nty <- inType (f $ eCaseType ec) return $ caseUpdate ec { eCaseScrutinee = e', eCaseDefault = d, eCaseBind = b', eCaseAlts = alts, eCaseType = nty } lp lam tvr@(TVr { tvrIdent = 0, tvrType = t}) e = do t' <- inType (f t) e' <- f e return $ lam (tvr { tvrIdent = 0, tvrType = t'}) e' lp lam tvr e = do (tv,r) <- ntvr Set.empty tvr e' <- local r $ f e return $ lam tv e' mapMntvr ts = f ts [] where f [] xs = return $ reverse xs f (t:ts) rs = do (t',r) <- ntvr vs t local r $ f ts ((t',r):rs) vs = Set.fromList [ tvrIdent x | x <- ts ] inType = local (\ (_,trm,typ) -> (True,trm,typ) ) addMap i (Just e) (b,trm,typ) = (b,minsert i (Just e) trm, minsert i e typ) addMap i Nothing (b,trm,typ) = (b,minsert i Nothing trm, typ) litSMapM lc@LitCons { litName = s, litArgs = es, litType = t } = do t' <- inType $ f t es' <- mapM f es return $ lc { litArgs = es', litType = t' } litSMapM (LitInt n t) = do t' <- inType $ f t return $ LitInt n t' ntvr xs tvr@(TVr { tvrIdent = 0, tvrType = t}) = do t' <- inType (f t) let nvr = (tvr { tvrType = t'}) return (nvr,id) ntvr xs tvr@(TVr {tvrIdent = i, tvrType = t}) = do t' <- inType (f t) (_,map,_) <- ask let i' = mnv False xs i (\_ -> False) Set.empty map let nvr = (tvr { tvrIdent = i', tvrType = t'}) case i == i' of True -> return (nvr,addMap i (Just $ EVar nvr)) False -> return (nvr,addMap i (Just $ EVar nvr) . addMap i' Nothing)