Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Educational: built-in map primitive #1093

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ expr = propagateSrcE expr' where
UApp (mkApp (ns $ fromString rangeName) (ns UHole)) lim
expr' (CLambda args body) =
dropSrcE <$> liftM2 buildLam (concat <$> mapM argument args) (block body)
expr' (CMap fun array) = do
fun' <- expr fun
array' <- expr array
return $ UMap fun' array'
expr' (CFor KView indices body) =
dropSrcE <$> (buildTabLam <$> mapM patOptAnn indices <*> block body)
expr' (CFor kind indices body) = do
Expand Down
8 changes: 8 additions & 0 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,14 @@ typeCheckPrimOp op = case op of

typeCheckPrimHof :: Typer m => PrimHof (Atom i) -> m i o (Type o)
typeCheckPrimHof hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of
Map fun array -> do
Pi (PiType (PiBinder b argTy PlainArrow) Pure resEltTy) <- getTypeE fun
let resEltTy' = ignoreHoistFailure $ hoist b resEltTy
TabPi (TabPiType binder argEltTy) <- getTypeE array
let argEltTy' = ignoreHoistFailure $ hoist binder argEltTy
checkAlphaEq argTy argEltTy'
refreshAbs (Abs binder UnitE) \binder' _ ->
return $ TabPi $ TabPiType binder' (sink resEltTy')
For _ ixDict f -> do
ixTy <- ixTyFromDict =<< substM ixDict
Pi (PiType (PiBinder b argTy PlainArrow) eff eltTy) <- getTypeE f
Expand Down
10 changes: 10 additions & 0 deletions src/lib/ConcreteSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ data Group'
| CPrefix SourceName Group -- covers unary - and unary + among others
| CPostfix SourceName Group
| CLambda [Group] CBlock -- The arguments do not have Juxtapose at the top level
| CMap Group Group -- unary fun, array
| CFor ForKind [Group] CBlock -- also for_, rof, rof_, view
| CCase Group [(Group, CBlock)] -- scrutinee, alternatives
| CIf Group CBlock (Maybe CBlock)
Expand Down Expand Up @@ -559,6 +560,14 @@ cLam = do
body <- cBlock
return $ CLambda bs body

cMap :: Parser Group'
cMap = do
keyWord MapKW
fun <- cGroupNoJuxt
keyWord OverKW
array <- cGroup
return $ CMap fun array

cFor :: Parser Group'
cFor = do
kw <- forKW
Expand Down Expand Up @@ -704,6 +713,7 @@ leafGroupNoBrackets = do
_ | isDigit next -> ( CNat <$> natLit
<|> CFloat <$> doubleLit)
'\\' -> cLam
'm' -> cMap <|> CIdentifier <$> anyName
-- For exprs include view, for, rof, for_, rof_
'v' -> cFor <|> CIdentifier <$> anyName
'f' -> cFor <|> CIdentifier <$> anyName
Expand Down
13 changes: 13 additions & 0 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,19 @@ toImpHof :: Emits o => Maybe (Dest o) -> PrimHof (Atom i) -> SubstImpM i o (Atom
toImpHof maybeDest hof = do
resultTy <- getTypeSubst (Hof hof)
case hof of
Map (Lam (LamExpr b body)) array -> do
rDest <- allocDest maybeDest resultTy
TabPi (TabPiType (_:>ixTy) _) <- getTypeSubst array
array' <- substM array
n <- indexSetSizeImp ixTy
emitLoop noHint Fwd n \i -> do
idx <- unsafeFromOrdinalImp (sink ixTy) i
ithArg <- dropSubst $ translateExpr Nothing $
TabApp (sink array') $ idx :| []
ithDest <- destGet (sink rDest) idx
void $ extendSubst (b @> SubstVal ithArg) $
translateBlock (Just ithDest) body
destToAtom rDest
For d ixDict (Lam (LamExpr b body)) -> do
ixTy <- ixTyFromDict =<< substM ixDict
n <- indexSetSizeImp ixTy
Expand Down
31 changes: 31 additions & 0 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,20 @@ getImplicitArg (PiBinder _ argTy arr) = case arr of
return $ Just $ Con $ DictHole (AlwaysEqual ctx) argTy
_ -> return Nothing

etaExpand :: EmitsInf n => Atom n -> InfererM i n (Atom n)
etaExpand fun = do
ty <- getType fun
case ty of
Pi (PiType (PiBinder b argTy arr) eff _) -> do
case fun of
Lam _ -> pure fun
_ -> buildLamInf noHint arr argTy
(\b' -> applySubst (b @> b') eff)
(\x -> do
Distinct <- getDistinct
app (sink fun) (Var x))
_ -> error "atom must have pi type"

checkOrInferRho :: forall i o. EmitsBoth o
=> UExpr i -> RequiredTy RhoType o -> InfererM i o (Atom o)
checkOrInferRho (WithSrcE pos expr) reqTy = do
Expand All @@ -943,6 +957,23 @@ checkOrInferRho (WithSrcE pos expr) reqTy = do
Infer -> inferULam Pure uLamExpr
ixTy <- asIxType ty'
matchRequirement $ TabLam $ TabLamExpr (b':>ixTy) body'
UMap fun array -> do
array' <- inferRho array
arrayTy <- getType array'
case arrayTy of
TabPi (TabPiType (b:>_) argElemTy) -> do
argElemTy' <- case hoist b argElemTy of
HoistSuccess ty -> return ty
HoistFailure _ -> throw TypeErr "expected non-dependent array type"
resElemVar <- liftM Var $ freshInferenceName (TC TypeKind)
funTy <- naryNonDepPiType PlainArrow Pure [argElemTy'] resElemVar
fun' <- checkOrInferRho fun (Check funTy)
-- Eta-expand `fun'` into a `Lam`. Later on we make use of the invariant
-- that the first argument of `Map` is a `Lam`.
fun'' <- etaExpand fun'
result <- liftM Var $ emit $ Hof $ Map fun'' array'
matchRequirement result
_ -> throw TypeErr "expected array type"
UFor dir (UForExpr b body) -> do
allowedEff <- getAllowedEffects
let uLamExpr = ULamExpr PlainArrow b body
Expand Down
3 changes: 3 additions & 0 deletions src/lib/Lexing.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ data KeyWord = DefKW | ForKW | For_KW | RofKW | Rof_KW | CaseKW | OfKW
| ViewKW | ImportKW | ForeignKW | NamedInstanceKW
| EffectKW | HandlerKW | JmpKW | CtlKW | ReturnKW | ResumeKW
| CustomLinearizationKW | CustomLinearizationSymbolicKW
| MapKW | OverKW
deriving (Enum)

keyWordToken :: KeyWord -> String
keyWordToken = \case
DefKW -> "def"
MapKW -> "map_"
OverKW -> "over_"
ForKW -> "for"
RofKW -> "rof"
For_KW -> "for_"
Expand Down
2 changes: 2 additions & 0 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,8 @@ instance PrettyPrec (UExpr' n) where
<+> nest 2 (pLowest body)
where kw = case dir of Fwd -> "for"
Rev -> "rof"
UMap fun array -> atPrec LowestPrec $ "map_" <+> nest 2 (pLowest fun)
<+> "over_" <+> nest 2 (pLowest array)
UPi piType -> prettyPrec piType
UTabPi piType -> prettyPrec piType
UDecl declExpr -> prettyPrec declExpr
Expand Down
7 changes: 7 additions & 0 deletions src/lib/QueryType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,12 @@ getTypePrimHof hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of
Pi (PiType (PiBinder b _ _) _ eltTy) <- getTypeE f
ixTy <- ixTyFromDict =<< substM dict
return $ TabTy (b:>ixTy) eltTy
Map fun array -> do
Pi (PiType (PiBinder b _ _) _ resEltTy) <- getTypeE fun
let resEltTy' = ignoreHoistFailure $ hoist b resEltTy
TabPi (TabPiType binder _) <- getTypeE array
refreshAbs (Abs binder UnitE) \binder' _ ->
return $ TabPi $ TabPiType binder' (sink resEltTy')
While _ -> return UnitTy
Linearize f -> do
Pi (PiType (PiBinder binder a PlainArrow) Pure b) <- getTypeE f
Expand Down Expand Up @@ -797,6 +803,7 @@ exprEffects expr = case expr of
_ -> return Pure
Hof hof -> case hof of
For _ _ f -> functionEffs f
Map _ _ -> return Pure
While body -> functionEffs body
Linearize _ -> return Pure -- Body has to be a pure function
Transpose _ -> return Pure -- Body has to be a pure function
Expand Down
13 changes: 13 additions & 0 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,19 @@ projectDictMethod d i = do

simplifyHof :: Emits o => Hof i -> SimplifyM i o (Atom o)
simplifyHof hof = case hof of
Map fun array -> do
(fun', Abs b recon) <- simplifyLam fun
array' <- simplifyAtom array
ans <- liftM Var $ emit $ Hof $ Map fun' array'
case recon of
IdentityRecon -> return ans
LamRecon reconAbs -> do
TabPi (TabPiType (_:>ixTy) _) <- getType array'
buildTabLam noHint ixTy \i -> do
locals <- tabApp (sink ans) $ Var i
ithArg <- emitAtomToName =<< (tabApp (sink array') $ Var i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of emitting to name, you could b @> SubstVal ithArg below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

reconAbs' <- applySubst (b @> ithArg) reconAbs
applyReconAbs reconAbs' locals
For d ixDict lam -> do
ixTy@(IxType _ ixDict') <- ixTyFromDict =<< substM ixDict
(lam', Abs b recon) <- simplifyLam lam
Expand Down
2 changes: 2 additions & 0 deletions src/lib/SourceRename.hs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ instance SourceRenamableE UExpr' where
UDecl (UDeclExpr decl rest) ->
sourceRenameB decl \decl' ->
UDecl <$> UDeclExpr decl' <$> sourceRenameE rest
UMap fun array -> UMap <$> sourceRenameE fun
<*> sourceRenameE array
UFor d (UForExpr pat body) ->
sourceRenameB pat \pat' ->
UFor d <$> UForExpr pat' <$> sourceRenameE body
Expand Down
1 change: 1 addition & 0 deletions src/lib/Types/Primitives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ traversePrimOp = inline traverse

data PrimHof e =
For ForAnn e e -- ix dict, body lambda
| Map e e -- lambda, array
| While e
| RunReader e e
| RunWriter (Maybe e) (BaseMonoidP e) e
Expand Down
1 change: 1 addition & 0 deletions src/lib/Types/Source.hs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ data UExpr' (n::S) =
| UTabPi (UTabPiExpr n)
| UTabApp (UExpr n) (UExpr n)
| UDecl (UDeclExpr n)
| UMap (UExpr n) (UExpr n)
| UFor Direction (UForExpr n)
| UCase (UExpr n) [UAlt n]
| UHole
Expand Down
22 changes: 22 additions & 0 deletions tests/mymap.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
def my_map {a:Type} {b:Type} {n:Type} [Ix n]
(f:a -> b)
(x:n => a) : n => b =
for i:n. f x.i

x0 = [1, 2, 3, 4, 5]

my_map (\x. x+x) x0
map_ (\x. x+x) over_ x0

my_map (\x. 2*x) x0
map_ (\x. 2*x) over_ x0

my_map (\x. 2*x) (x0 + x0)
map_ (\x. 2*x) over_ (x0 + x0)
-- The following is also parsed as `map_ (\x. 2*x) over_ (x0 + x0)` ... not
-- intentionally so.
map_ (\x. 2*x) over_ x0 + x0

my_map (\x. 2*x) x0 + x0
(map_ (\x. 2*x) over_ x0) + x0