Skip to content

Commit

Permalink
Add a StuckTabApp case to Stuck
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Oct 23, 2023
1 parent de88bf8 commit d80f318
Show file tree
Hide file tree
Showing 13 changed files with 86 additions and 26 deletions.
12 changes: 12 additions & 0 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ buildScopedAssumeNoDecls cont = do
_ -> error "Expected no decl emissions"
{-# INLINE buildScopedAssumeNoDecls #-}

withReducibleEmissions
:: (ScopableBuilder r m, Builder r m, HasNamesE e, SubstE AtomSubstVal e)
=> String
-> (forall o' . (Emits o', DExt o o') => m o' (e o'))
-> m o (e o)
withReducibleEmissions msg cont = do
withDecls <- buildScoped cont
reduceWithDecls withDecls >>= \case
Just t -> return t
_ -> throw TypeErr msg
{-# INLINE withReducibleEmissions #-}

-- === "Hoisting" top-level builder class ===

-- `emitHoistedEnv` lets you emit top env fragments, like cache entries or
Expand Down
24 changes: 23 additions & 1 deletion src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,14 @@ reduceExprM = \case
case (ty, val) of
(BaseTy (Scalar Word32Type), Con (Lit (Word64Lit v))) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v
_ -> empty
TabApp ty tab xs -> do
ty' <- substM ty
xs' <- mapM substM xs
tab' <- substM tab
case tab' of
Stuck tab'' -> return $ Stuck $ StuckTabApp ty' tab'' xs'
_ -> error "not a table" -- what about RepVal?
TopApp _ _ _ -> empty
TabApp _ _ _ -> empty
Case _ _ _ -> empty
TabCon _ _ _ -> empty
PrimOp _ -> empty
Expand Down Expand Up @@ -188,6 +194,11 @@ typeOfApp (Pi piTy) xs = withSubstReaderT $
withInstantiated piTy xs \(EffTy _ ty) -> substM ty
typeOfApp _ _ = error "expected a pi type"

typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n)
typeOfTabApp (TabPi piTy) xs = withSubstReaderT $
withInstantiated piTy xs \ty -> substM ty
typeOfTabApp _ _ = error "expected a TabPi type"

repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n)
repValAtom (RepVal ty tree) = case ty of
ProdTy ts -> case tree of
Expand Down Expand Up @@ -220,6 +231,13 @@ reduceUnwrapM = \case
_ -> error "expected a newtype"
_ -> empty

reduceTabAppM :: IRRep r => Atom r o -> [Atom r o] -> ReducerM i o (Atom r o)
reduceTabAppM tab xs = case tab of
Stuck tab' -> do
ty <- typeOfTabApp (getType tab') xs
return $ Stuck $ StuckTabApp ty tab' xs
_ -> error $ "not a table" ++ pprint tab

unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n)
unwrapNewtypeType = \case
Nat -> return (NatCon, IdxRepTy)
Expand Down Expand Up @@ -616,6 +634,10 @@ reduceStuck = \case
StuckUnwrap _ x -> do
x' <- reduceStuck x
dropSubst $ reduceUnwrapM x'
StuckTabApp _ f xs -> do
f' <- reduceStuck f
xs' <- mapM substM xs
dropSubst $ reduceTabAppM f' xs'
InstantiatedGiven _ f xs -> do
xs' <- mapM substM xs
f' <- reduceStuck f
Expand Down
7 changes: 7 additions & 0 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ instance IRRep r => CheckableE r (Stuck r) where
StuckProject resultTy i x -> do
Project resultTy' i' (Stuck x') <- checkWithEffects Pure $ Project resultTy i (Stuck x)
return $ StuckProject resultTy' i' x'
StuckTabApp reqTy f xs -> do
reqTy' <- reqTy |: TyKind
(f', tabTy) <- checkAndGetType f
xs' <- mapM checkE xs
ty' <- checkTabApp tabTy xs'
checkTypesEq reqTy' ty'
return $ StuckTabApp reqTy' f' xs'
InstantiatedGiven resultTy given args -> do
resultTy' <- resultTy |: TyKind
(given', Pi piTy) <- checkAndGetType given
Expand Down
6 changes: 6 additions & 0 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -865,9 +865,15 @@ atomToRepVal x = RepVal (getType x) <$> go x where
Stuck (StuckVar v) -> lookupAtomName (atomVarName v) >>= \case
TopDataBound (RepVal _ tree) -> return tree
_ -> error "should only have pointer and data atom names left"
-- TODO: I think we want to be able to rule this one out by insisting that
-- RepValAtom is itself part of Stuck and it can't represent a product.
Stuck (StuckProject _ i val) -> do
Branch ts <- go $ Stuck val
return $ ts !! i
Stuck (StuckTabApp _ f xs) -> do
f' <- atomToRepVal $ Stuck f
RepVal _ t <- naryIndexRepVal f' (toList xs)
return t

-- XXX: We used to have a function called `destToAtom` which loaded the value
-- from the dest. This version is not that. It just lifts a dest into an atom of
Expand Down
11 changes: 0 additions & 11 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1089,17 +1089,6 @@ checkSigmaDependent e@(WithSrcE ctx _) ty = addSrcContext ctx $
"Dependent functions can only be applied to fully evaluated expressions. " ++
"Bind the argument to a name before you apply the function."

withReducibleEmissions
:: Zonkable e
=> String
-> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o'))
-> InfererM i o (e o)
withReducibleEmissions msg cont = do
withDecls <- buildScoped cont
reduceWithDecls withDecls >>= \case
Just t -> return t
_ -> throw TypeErr msg

-- === sorting case alternatives ===

data IndexedAlt n = IndexedAlt CaseAltIndex (Alt CoreIR n)
Expand Down
2 changes: 2 additions & 0 deletions src/lib/Linearize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ linearizeAtom atom = case atom of
activePrimalIdx v' >>= \case
Nothing -> withZeroT $ return (Var v')
Just idx -> return $ WithTangent (Var v') $ getTangentArg idx
-- TODO: buildScoped and reduce the results so we keep expression in non-ANF for type checking purposes
Stuck (StuckProject ty i x) -> linearizeExpr $ Project ty i (Stuck x)
Stuck (StuckTabApp t f xs) -> linearizeExpr $ TabApp t (Stuck f) xs
RepValAtom _ -> emitZeroT
where emitZeroT = withZeroT $ renameM atom

Expand Down
16 changes: 11 additions & 5 deletions src/lib/OccAnalysis.hs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ instance HasOCC SStuck where
ty' <- occTy ty
return $ StuckVar (AtomVar n ty')
StuckProject t i x -> StuckProject <$> occ a t <*> pure i <*> occ a x
StuckTabApp t array ixs -> do
t' <- occTy t
(a', ixs') <- occIdxs a ixs
array' <- occ a' array
return $ StuckTabApp t' array' ixs'

instance HasOCC SType where
occ a ty = runOCCMVisitor a $ visitTypePartial ty
Expand Down Expand Up @@ -360,7 +365,7 @@ instance HasOCC SExpr where
return $ Block effTy' (Abs decls' ans')
TabApp t array ixs -> do
t' <- occTy t
(a', ixs') <- go a ixs
(a', ixs') <- occIdxs a ixs
array' <- occ a' array
return $ TabApp t' array' ixs'
Case scrut alts (EffTy effs ty) -> do
Expand All @@ -376,10 +381,11 @@ instance HasOCC SExpr where
ref' <- occ a ref
PrimOp . RefOp ref' <$> occ a op
expr -> occGeneric a expr
where
go acc [] = return (acc, [])
go acc (ix:ixs) = do
(acc', ixs') <- go acc ixs

occIdxs :: Access n -> [SAtom n] -> OCCM n (Access n, [SAtom n])
occIdxs acc [] = return (acc, [])
occIdxs acc (ix:ixs) = do
(acc', ixs') <- occIdxs acc ixs
(summ, ix') <- occurrenceAndSummary ix
return (location summ acc', ix':ixs')

Expand Down
1 change: 1 addition & 0 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ instance IRRep r => PrettyPrec (Stuck r n) where
prettyPrec = \case
StuckVar v -> atPrec ArgPrec $ p v
StuckProject _ i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v
StuckTabApp _ f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs
StuckUnwrap _ v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v
InstantiatedGiven _ v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args)
SuperclassProj _ d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i
Expand Down
1 change: 1 addition & 0 deletions src/lib/QueryTypePure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ instance IRRep r => HasType r (Stuck r) where
getType = \case
StuckVar (AtomVar _ t) -> t
StuckProject t _ _ -> t
StuckTabApp t _ _ -> t
StuckUnwrap t _ -> t
InstantiatedGiven t _ _ -> t
SuperclassProj t _ _ -> t
Expand Down
8 changes: 8 additions & 0 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ forceStuck stuck cont = withDistinct case stuck of
DepPair l r _ -> forceConstructor ([l, r]!!i) cont
_ -> error "Can't project stuck term"
_ -> error "Can't project stuck term"
StuckTabApp _ f xs -> do
ty <- substM $ getType stuck
xs' <- forM xs \x -> toDataAtomIgnoreRecon =<< substM x
forceStuck f \case
CCSimpInCore (LiftSimp _ f') -> do
result <- naryTabApp f' (sink<$>xs')
cont $ CCSimpInCore $ LiftSimp (sink ty) result
_ -> error "not a table" -- what about table lambda?
StuckUnwrap _ x -> forceStuck x \case
CCCon (WithSubst s con) -> withSubst s case con of
NewtypeCon _ x' -> forceConstructor x' cont
Expand Down
4 changes: 2 additions & 2 deletions src/lib/Transpose.hs
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ transposeAtom atom ct = case atom of
return ()
LinRef ref -> emitCTToRef ref ct
LinTrivial -> return ()
Stuck (StuckProject _ _ _) -> undefined
-- Stuck (StuckProject _ i' x') -> do
Stuck (StuckProject _ _ _) -> error "not implemented"
Stuck (StuckTabApp _ _ _) -> error "not implemented"
-- let (idxs, v) = asNaryProj i' x'
-- lookupSubstM (atomVarName v) >>= \case
-- RenameNonlin _ -> error "an error, probably"
Expand Down
18 changes: 11 additions & 7 deletions src/lib/Types/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ data Type (r::IR) (n::S) where
data Stuck (r::IR) (n::S) where
StuckVar :: AtomVar r n -> Stuck r n
StuckProject :: Type r n -> Int -> Stuck r n -> Stuck r n
StuckTabApp :: Type r n -> Stuck r n -> [Atom r n] -> Stuck r n
StuckUnwrap :: CType n -> CStuck n -> Stuck CoreIR n
InstantiatedGiven :: CType n -> CStuck n -> [CAtom n] -> Stuck CoreIR n
SuperclassProj :: CType n -> Int -> CStuck n -> Stuck CoreIR n
Expand Down Expand Up @@ -1552,26 +1553,29 @@ instance IRRep r => AlphaHashableE (Atom r)
instance IRRep r => RenameE (Atom r)

instance IRRep r => GenericE (Stuck r) where
type RepE (Stuck r) = EitherE5
type RepE (Stuck r) = EitherE6
{- StuckVar -} (AtomVar r)
{- StuckProject -} (Type r `PairE` LiftE Int `PairE` Stuck r)
{- StuckTabApp -} (Type r `PairE` Stuck r `PairE` ListE (Atom r))
{- StuckUnwrap -} (WhenCore r (CType `PairE` CStuck))
{- InstantiatedGiven -} (WhenCore r (CType `PairE` CStuck `PairE` ListE CAtom))
{- SuperclassProj -} (WhenCore r (CType `PairE` LiftE Int `PairE` CStuck))
fromE = \case
StuckVar v -> Case0 v
StuckProject t i e -> Case1 $ t `PairE` LiftE i `PairE` e
StuckUnwrap t e -> Case2 $ WhenIRE $ t `PairE` e
InstantiatedGiven t e xs -> Case3 $ WhenIRE $ t `PairE` e `PairE` ListE xs
SuperclassProj t i e -> Case4 $ WhenIRE $ t `PairE` LiftE i `PairE` e
StuckTabApp t f x -> Case2 $ t `PairE` f `PairE` ListE x
StuckUnwrap t e -> Case3 $ WhenIRE $ t `PairE` e
InstantiatedGiven t e xs -> Case4 $ WhenIRE $ t `PairE` e `PairE` ListE xs
SuperclassProj t i e -> Case5 $ WhenIRE $ t `PairE` LiftE i `PairE` e
{-# INLINE fromE #-}

toE = \case
Case0 v -> StuckVar v
Case1 (t `PairE` LiftE i `PairE` e) -> StuckProject t i e
Case2 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e
Case3 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs
Case4 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e
Case2 (t `PairE` f `PairE` ListE x) -> StuckTabApp t f x
Case3 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e
Case4 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs
Case5 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e
_ -> error "impossible"
{-# INLINE toE #-}

Expand Down
2 changes: 2 additions & 0 deletions src/lib/Vectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do
_ -> throwVectErr "Invalid projection"
x'' <- reduceProj i x'
return $ VVal ov x''
-- TODO: think about this case
StuckTabApp _ _ _ -> throwVectErr $ "Cannot vectorize atom: " ++ pprint atom
Con (Lit l) -> return $ VVal Uniform $ Con $ Lit l
_ -> do
subst <- getSubst
Expand Down

0 comments on commit d80f318

Please sign in to comment.