Skip to content

Commit

Permalink
Factor out the way Simplify handles ACase.
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Oct 23, 2023
1 parent 0ff3233 commit de88bf8
Showing 1 changed file with 98 additions and 111 deletions.
209 changes: 98 additions & 111 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,75 @@ tryAsDataAtom atom = do
where
notData = error $ "Not runtime-representable data: " ++ pprint atom

data WithSubst (e::E) (o::S) where
WithSubst :: Subst AtomSubstVal i o -> e i -> WithSubst e o

data ConcreteCAtom (n::S) =
CCCon (WithSubst CAtom n) -- can't be Stuck or SimpInCore
| CCSimpInCore (SimpInCore n) -- can't be ACase
| CCNoInlineFun (CAtomVar n) (CType n) (CAtom n)
| CCFFIFun (CorePiType n) (TopFunName n)

-- Yields to the continuation a term with a concrete CoreIR constructor,
-- or LiftSimpFun, liftSimp, or TabLam.
forceConstructor
:: Emits o
=> CAtom i
-> (forall o' i'. (DExt o o', Emits o') => ConcreteCAtom o'-> SimplifyM i' o' (CAtom o'))
-> SimplifyM i o (CAtom o)
forceConstructor atom cont = withDistinct case atom of
Stuck stuck -> forceStuck stuck cont
SimpInCore lifted -> case lifted of
ACase e alts resultTy -> do
e' <- substM e
resultTy' <- substM resultTy
defuncCase e' resultTy' \i x -> do
Abs b body <- return $ alts !! i
extendSubst (b@>SubstVal x) do
forceConstructor body cont
_ -> do
lifted' <- substM lifted
cont $ CCSimpInCore lifted'
_ -> do
Distinct <- getDistinct
subst <- getSubst
cont $ CCCon $ WithSubst subst atom

forceStuck
:: Emits o
=> CStuck i
-> (forall o' i'. (DExt o o', Emits o') => ConcreteCAtom o'-> SimplifyM i' o' (CAtom o'))
-> SimplifyM i o (CAtom o)
forceStuck stuck cont = withDistinct case stuck of
StuckVar v -> lookupSubstM (atomVarName v) >>= \case
SubstVal x -> dropSubst $ forceConstructor x cont
Rename v' -> lookupAtomName v' >>= \case
LetBound (DeclBinding _ (Atom x)) -> dropSubst $ forceConstructor x cont
NoinlineFun t f -> do
v'' <- toAtomVar v'
cont $ CCNoInlineFun v'' t f
FFIFunBound t f -> cont $ CCFFIFun t f
_ -> error "shouldn't have other CVars left"
-- TODO: figure out how to de-dup these cases with their Expr counterpart
StuckProject _ i x -> do
ty <- substM $ getType stuck
forceStuck x \case
CCSimpInCore (LiftSimp _ x') -> do
x'' <- proj i x'
cont $ CCSimpInCore $ LiftSimp (sink ty) x''
CCCon (WithSubst s con) -> withSubst s case con of
ProdVal xs -> forceConstructor (xs!!i) cont
DepPair l r _ -> forceConstructor ([l, r]!!i) cont
_ -> error "Can't project stuck term"
_ -> error "Can't project stuck term"
StuckUnwrap _ x -> forceStuck x \case
CCCon (WithSubst s con) -> withSubst s case con of
NewtypeCon _ x' -> forceConstructor x' cont
_ -> error "can't unwrap stuck term"
_ -> error "can't unwrap stuck term"
InstantiatedGiven _ _ _ -> error "shouldn't have this left"
SuperclassProj _ _ _ -> error "shouldn't have this left"

forceTabLam :: Emits n => TabLamExpr n -> SimplifyM i n (SAtom n)
forceTabLam (PairE ixTy (Abs b ab)) =
buildFor (getNameHint b) Fwd ixTy \v -> do
Expand Down Expand Up @@ -315,8 +384,7 @@ simplifyExpr expr = confuseGHC >>= \_ -> case expr of
simplifyApp ty' f xs'
TabApp _ f xs -> do
xs' <- mapM simplifyAtom xs
f' <- simplifyAtom f
simplifyTabApp f' xs'
simplifyTabApp f xs'
Atom x -> simplifyAtom x
PrimOp op -> simplifyOp op
ApplyMethod (EffTy _ ty) dict i xs -> do
Expand Down Expand Up @@ -379,6 +447,7 @@ defuncCaseCore scrut resultTy cont = do
let xCoreTy = altBinderTys !! i
x' <- liftSimpAtom (sink xCoreTy) x
cont i x'
-- TODO: we should use forceConstructor here
Nothing -> case trySelectBranch scrut of
Just (i, arg) -> getDistinct >>= \Distinct -> cont i arg
Nothing -> go scrut where
Expand Down Expand Up @@ -449,61 +518,21 @@ simplifyAlt split ty cont = do

simplifyApp :: forall i o. Emits o
=> CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o)
simplifyApp resultTy f xs = case f of
Lam (CoreLamExpr _ lam) -> fast lam
_ -> slow =<< simplifyAtomAndInline f
where
fast :: LamExpr CoreIR i' -> SimplifyM i' o (CAtom o)
fast lam = withInstantiated lam xs \body -> simplifyExpr body

slow :: CAtom o -> SimplifyM i o (CAtom o)
slow = \case
Lam (CoreLamExpr _ lam) -> dropSubst $ fast lam
SimpInCore (ACase e alts _) -> dropSubst do
defuncCase e resultTy \i x -> do
Abs b body <- return $ alts !! i
extendSubst (b@>SubstVal x) do
xs' <- mapM sinkM xs
simplifyApp (sink resultTy) body xs'
SimpInCore (LiftSimpFun _ lam) -> do
xs' <- mapM toDataAtomIgnoreRecon xs
result <- instantiate lam xs' >>= emitExpr
liftSimpAtom resultTy result
Var v -> do
lookupAtomName (atomVarName v) >>= \case
NoinlineFun _ _ -> simplifyTopFunApp v xs
FFIFunBound _ f' -> do
xs' <- mapM toDataAtomIgnoreRecon xs
liftSimpAtom resultTy =<< naryTopApp f' xs'
b -> error $ "Should only have noinline functions left " ++ pprint b
atom -> error $ "Unexpected function: " ++ pprint atom

-- | Like `simplifyAtom`, but will try to inline function definitions found
-- in the environment. The only exception is when we're going to differentiate
-- and the function has a custom derivative rule defined.
-- TODO(dougalm): do we still need this?
simplifyAtomAndInline :: CAtom i -> SimplifyM i o (CAtom o)
simplifyAtomAndInline atom = confuseGHC >>= \_ -> case atom of
Var v -> do
env <- getSubst
case env ! atomVarName v of
Rename v' -> doInline =<< toAtomVar v'
SubstVal (Var v') -> doInline v'
SubstVal x -> return x
-- This is a hack because we weren't normalize the unwrapping of
-- `unit_type_scale` in `plot.dx`. We need a better system for deciding how to
-- normalize and inline.
Stuck (StuckProject _ i x) -> do
x' <- simplifyStuck x >>= reduceProj i
dropSubst $ simplifyAtomAndInline x'
_ -> simplifyAtom atom >>= \case
Var v -> doInline v
ans -> return ans
where
doInline v = do
lookupAtomName (atomVarName v) >>= \case
LetBound (DeclBinding _ (Atom x)) -> dropSubst $ simplifyAtomAndInline x
_ -> return $ Var v
simplifyApp resultTy f xs = forceConstructor f \f' -> do
xs' <- mapM sinkM xs
case f' of
CCCon (WithSubst s (Lam (CoreLamExpr _ lam))) ->
withSubst s $ withInstantiated lam xs' \body ->
simplifyExpr body
CCSimpInCore (LiftSimpFun _ lam) -> do
xs'' <- mapM toDataAtomIgnoreRecon xs'
result <- instantiate lam xs'' >>= emitExpr
liftSimpAtom (sink resultTy) result
CCNoInlineFun v _ _ -> simplifyTopFunApp v xs'
CCFFIFun _ f'' -> do
xs'' <- mapM toDataAtomIgnoreRecon xs'
liftSimpAtom (sink resultTy) =<< naryTopApp f'' xs''
_ -> error "not a function"

simplifyTopFunApp :: Emits n => CAtomVar n -> [CAtom n] -> SimplifyM i n (CAtom n)
simplifyTopFunApp fName xs = do
Expand Down Expand Up @@ -547,33 +576,23 @@ specializedFunCoreDefinition (AppSpecialization f (Abs bs staticArgs)) = do
naryApp f' staticArgs'

simplifyTabApp :: forall i o. Emits o
=> CAtom o -> [CAtom o] -> SimplifyM i o (CAtom o)
simplifyTabApp f [] = return f
simplifyTabApp f@(SimpInCore sic) xs = case sic of
TabLam _ _ -> do
case fromNaryTabLam (length xs) f of
=> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o)
simplifyTabApp f [] = simplifyAtom f
simplifyTabApp f xs = forceConstructor f \case
CCSimpInCore sic@(TabLam _ _) -> do
case fromNaryTabLam (length xs) (SimpInCore sic) of
Just (bsCount, ab) -> do
let (xsPref, xsRest) = splitAt bsCount xs
(xsPref, xsRest) <- splitAt bsCount <$> mapM sinkM xs
xsPref' <- mapM toDataAtomIgnoreRecon xsPref
block' <- instantiate ab xsPref'
atom <- emitDecls block'
simplifyTabApp atom xsRest
dropSubst $ simplifyTabApp atom xsRest
Nothing -> error "should never happen"
ACase e alts ty -> dropSubst do
resultTy <- typeOfTabApp ty xs
defuncCase e resultTy \i x -> do
Abs b body <- return $ alts !! i
extendSubst (b@>SubstVal x) do
xs' <- mapM sinkM xs
body' <- substM body
simplifyTabApp body' xs'
LiftSimp _ f' -> do
fTy <- return $ getType f
resultTy <- typeOfTabApp fTy xs
xs' <- mapM toDataAtomIgnoreRecon xs
CCSimpInCore (LiftSimp fTy f') -> do
resultTy <- typeOfTabApp fTy (sink<$>xs)
xs' <- mapM (toDataAtomIgnoreRecon . sink) xs
liftSimpAtom resultTy =<< naryTabApp f' xs'
LiftSimpFun _ _ -> error "not implemented"
simplifyTabApp f _ = error $ "Unexpected table: " ++ pprint f
_ -> error "not a table"

simplifyIxType :: IxType CoreIR o -> SimplifyM i o (IxType SimpIR o)
simplifyIxType (IxType t ixDict) = do
Expand Down Expand Up @@ -625,40 +644,8 @@ ixMethodType method absDict = do
let allBs = extraArgBs >>> methodArgs
return $ PiType allBs (EffTy Pure resultTy)

-- TODO: do we even need this, or is it just a glorified `SubstM`?
simplifyAtom :: CAtom i -> SimplifyM i o (CAtom o)
simplifyAtom atom = confuseGHC >>= \_ -> case atom of
Stuck e -> simplifyStuck e
Lam _ -> substM atom
DepPair x y ty -> DepPair <$> simplifyAtom x <*> simplifyAtom y <*> substM ty
Con con -> Con <$> traverseOp con substM simplifyAtom (error "unexpected lambda")
Eff eff -> Eff <$> substM eff
PtrVar t v -> PtrVar t <$> substM v
DictCon _ -> substM atom
NewtypeCon _ _ -> substM atom
SimpInCore _ -> substM atom
TypeAsAtom _ -> substM atom

simplifyStuck :: CStuck i -> SimplifyM i o (CAtom o)
simplifyStuck = \case
StuckVar v -> simplifyVar v
StuckProject _ i x -> reduceProj i =<< simplifyStuck x
stuck -> substM (Stuck stuck)

simplifyVar :: AtomVar CoreIR i -> SimplifyM i o (CAtom o)
simplifyVar v = do
env <- getSubst
case env ! atomVarName v of
SubstVal x -> return x
Rename v' -> do
AtomNameBinding bindingInfo <- lookupEnv v'
let ty = getType bindingInfo
case bindingInfo of
-- Functions get inlined only at application sites
LetBound (DeclBinding _ _) | isFun -> return $ Var $ AtomVar v' ty
where isFun = case ty of Pi _ -> True; _ -> False
LetBound (DeclBinding _ (Atom x)) -> dropSubst $ simplifyAtom x
_ -> return $ Var $ AtomVar v' ty
simplifyAtom = substM

-- Assumes first order (args/results are "data", allowing newtypes), monormophic
simplifyLam
Expand Down

0 comments on commit de88bf8

Please sign in to comment.