@@ -270,7 +270,7 @@ applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom
270
270
applyLinLam (LamExpr bs body) = do
271
271
TangentArgs args <- liftSubstReaderT $ getTangentArgs
272
272
extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do
273
- substM body >>= emit
273
+ substM body >>= emitLin
274
274
275
275
-- === actual linearization passs ===
276
276
@@ -299,7 +299,7 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do
299
299
ts <- getUnpacked $ toAtom $ sink $ binderVar bTangent
300
300
let substFrag = bsRecon @@> map (SubstVal . sink) xs
301
301
<.> bsTangent @@> map (SubstVal . sink) ts
302
- emit =<< applySubst substFrag tangentBody
302
+ emitLin =<< applySubst substFrag tangentBody
303
303
return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody'
304
304
return (primalFun, tangentFun)
305
305
(,) <$> asTopLam primalFun <*> asTopLam tangentFun
@@ -358,7 +358,7 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do
358
358
WithTangent pRest tfRest <- linearizeDecls rest cont
359
359
return $ WithTangent pRest do
360
360
t <- tf
361
- vt <- emitDecl (getNameHint b) ann (Atom t)
361
+ vt <- emitDecl (getNameHint b) LinearLet (Atom t)
362
362
extendTangentArgs vt $
363
363
tfRest
364
364
@@ -410,7 +410,7 @@ linearizeExpr expr = case expr of
410
410
(primal, residualss) <- fromPair result
411
411
resultTangentType <- tangentType resultTy'
412
412
return $ WithTangent primal do
413
- buildCase (sink residualss) (sink resultTangentType) \ i residuals -> do
413
+ emitLin =<< buildCase' (sink residualss) (sink resultTangentType) \ i residuals -> do
414
414
ObligateRecon _ (Abs bs linLam) <- return $ sinkList recons !! i
415
415
residuals' <- unpackTelescope bs residuals
416
416
withSubstReaderT $ extendSubst (bs @@> (SubstVal <$> residuals')) do
@@ -613,13 +613,13 @@ linearizeHof hof = case hof of
613
613
TrivialRecon linLam' ->
614
614
return $ WithTangent primalsAux do
615
615
Abs ib'' linLam'' <- sinkM (Abs ib' linLam')
616
- withSubstReaderT $ buildFor noHint d (sink ixTy) \ i' -> do
616
+ withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \ i' -> do
617
617
extendSubst (ib''@> Rename (atomVarName i')) $ applyLinLam linLam''
618
618
ReconWithData reconAbs -> do
619
619
primals <- buildMap primalsAux getFst
620
620
return $ WithTangent primals do
621
621
Abs ib'' (Abs bs linLam') <- sinkM (Abs ib' reconAbs)
622
- withSubstReaderT $ buildFor noHint d (sink ixTy) \ i' -> do
622
+ withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \ i' -> do
623
623
extendSubst (ib''@> Rename (atomVarName i')) do
624
624
residuals' <- tabApp (sink primalsAux) (toAtom i') >>= getSnd >>= unpackTelescope bs
625
625
extendSubst (bs @@> (SubstVal <$> residuals')) $
@@ -636,7 +636,7 @@ linearizeHof hof = case hof of
636
636
tanEffLam <- buildEffLam noHint tt \ h ref ->
637
637
extendTangentArgss [h, ref] do
638
638
withSubstReaderT $ applyLinLam $ sink linLam
639
- emitHof $ RunReader rLin' tanEffLam
639
+ emitHofLin $ RunReader rLin' tanEffLam
640
640
RunState Nothing sInit lam -> do
641
641
WithTangent sInit' sLin <- linearizeAtom sInit
642
642
(lam', recon) <- linearizeEffectFun State lam
@@ -649,7 +649,7 @@ linearizeHof hof = case hof of
649
649
tanEffLam <- buildEffLam noHint tt \ h ref ->
650
650
extendTangentArgss [h, ref] do
651
651
withSubstReaderT $ applyLinLam $ sink linLam
652
- emitHof $ RunState Nothing sLin' tanEffLam
652
+ emitHofLin $ RunState Nothing sLin' tanEffLam
653
653
RunWriter Nothing bm lam -> do
654
654
-- TODO: check it's actually the 0/+ monoid (or should we just build that in?)
655
655
bm' <- renameM bm
@@ -663,7 +663,7 @@ linearizeHof hof = case hof of
663
663
tanEffLam <- buildEffLam noHint tt \ h ref ->
664
664
extendTangentArgss [h, ref] do
665
665
withSubstReaderT $ applyLinLam $ sink linLam
666
- emitHof $ RunWriter Nothing bm'' tanEffLam
666
+ emitHofLin $ RunWriter Nothing bm'' tanEffLam
667
667
RunIO body -> do
668
668
(body', recon) <- linearizeExprDefunc body
669
669
primalAux <- emitHof $ RunIO body'
0 commit comments