Skip to content

Commit 1b2d252

Browse files
committed
Add some missing linearity annotations.
We really need to build the linearity checker.
1 parent 9672aa3 commit 1b2d252

File tree

5 files changed

+28
-31
lines changed

5 files changed

+28
-31
lines changed

src/lib/Builder.hs

+11-8
Original file line numberDiff line numberDiff line change
@@ -765,23 +765,23 @@ mkTypedHof hof = do
765765
effTy <- effTyOfHof hof
766766
return $ TypedHof effTy hof
767767

768-
buildForAnn
769-
:: (Emits n, ScopableBuilder r m)
768+
mkFor
769+
:: (ScopableBuilder r m)
770770
=> NameHint -> ForAnn -> IxType r n
771771
-> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l))
772-
-> m n (Atom r n)
773-
buildForAnn hint ann (IxType iTy ixDict) body = do
772+
-> m n (Expr r n)
773+
mkFor hint ann (IxType iTy ixDict) body = do
774774
lam <- withFreshBinder hint iTy \b -> do
775775
let v = binderVar b
776776
body' <- buildBlock $ body $ sink v
777777
return $ LamExpr (UnaryNest b) body'
778-
emitHof $ For ann (IxType iTy ixDict) lam
778+
liftM toExpr $ mkTypedHof $ For ann (IxType iTy ixDict) lam
779779

780780
buildFor :: (Emits n, ScopableBuilder r m)
781781
=> NameHint -> Direction -> IxType r n
782782
-> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l))
783783
-> m n (Atom r n)
784-
buildFor hint dir ty body = buildForAnn hint dir ty body
784+
buildFor hint ann ty body = mkFor hint ann ty body >>= emit
785785

786786
buildMap :: (Emits n, ScopableBuilder SimpIR m)
787787
=> SAtom n
@@ -853,6 +853,10 @@ emitLin e = case toExpr e of
853853
expr -> liftM toAtom $ emitDecl noHint LinearLet $ peepholeExpr expr
854854
{-# INLINE emitLin #-}
855855

856+
emitHofLin :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n)
857+
emitHofLin hof = mkTypedHof hof >>= emitLin
858+
{-# INLINE emitHofLin #-}
859+
856860
zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n)
857861
zeroAt ty = liftEmitBuilder $ go ty where
858862
go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n)
@@ -1100,9 +1104,8 @@ mkApplyMethod d i xs = do
11001104
mkInstanceDict :: EnvReader m => InstanceName n -> [CAtom n] -> m n (CDict n)
11011105
mkInstanceDict instanceName args = do
11021106
instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName
1103-
sourceName <- getSourceName <$> lookupClassDef className
11041107
PairE (ListE params) _ <- instantiate instanceDef args
1105-
let ty = toType $ DictType sourceName className params
1108+
ty <- toType <$> dictType className params
11061109
return $ toDict $ InstanceDict ty instanceName args
11071110

11081111
mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n)

src/lib/Inference.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2013,7 +2013,7 @@ generalizeDict ty dict = do
20132013
result <- liftEnvReaderT $ liftInfererM $ generalizeDictRec ty dict
20142014
case result of
20152015
Failure e -> error $ "Failed to generalize " ++ pprint dict
2016-
++ " to " ++ pprint ty ++ " because " ++ pprint e
2016+
++ " to " ++ show ty ++ " because " ++ pprint e
20172017
Success ans -> return ans
20182018

20192019
generalizeDictRec :: CType n -> CDict n -> InfererM i n (CDict n)

src/lib/Linearize.hs

+9-9
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom
270270
applyLinLam (LamExpr bs body) = do
271271
TangentArgs args <- liftSubstReaderT $ getTangentArgs
272272
extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do
273-
substM body >>= emit
273+
substM body >>= emitLin
274274

275275
-- === actual linearization passs ===
276276

@@ -299,7 +299,7 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do
299299
ts <- getUnpacked $ toAtom $ sink $ binderVar bTangent
300300
let substFrag = bsRecon @@> map (SubstVal . sink) xs
301301
<.> bsTangent @@> map (SubstVal . sink) ts
302-
emit =<< applySubst substFrag tangentBody
302+
emitLin =<< applySubst substFrag tangentBody
303303
return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody'
304304
return (primalFun, tangentFun)
305305
(,) <$> asTopLam primalFun <*> asTopLam tangentFun
@@ -358,7 +358,7 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do
358358
WithTangent pRest tfRest <- linearizeDecls rest cont
359359
return $ WithTangent pRest do
360360
t <- tf
361-
vt <- emitDecl (getNameHint b) ann (Atom t)
361+
vt <- emitDecl (getNameHint b) LinearLet (Atom t)
362362
extendTangentArgs vt $
363363
tfRest
364364

@@ -410,7 +410,7 @@ linearizeExpr expr = case expr of
410410
(primal, residualss) <- fromPair result
411411
resultTangentType <- tangentType resultTy'
412412
return $ WithTangent primal do
413-
buildCase (sink residualss) (sink resultTangentType) \i residuals -> do
413+
emitLin =<< buildCase' (sink residualss) (sink resultTangentType) \i residuals -> do
414414
ObligateRecon _ (Abs bs linLam) <- return $ sinkList recons !! i
415415
residuals' <- unpackTelescope bs residuals
416416
withSubstReaderT $ extendSubst (bs @@> (SubstVal <$> residuals')) do
@@ -613,13 +613,13 @@ linearizeHof hof = case hof of
613613
TrivialRecon linLam' ->
614614
return $ WithTangent primalsAux do
615615
Abs ib'' linLam'' <- sinkM (Abs ib' linLam')
616-
withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do
616+
withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do
617617
extendSubst (ib''@>Rename (atomVarName i')) $ applyLinLam linLam''
618618
ReconWithData reconAbs -> do
619619
primals <- buildMap primalsAux getFst
620620
return $ WithTangent primals do
621621
Abs ib'' (Abs bs linLam') <- sinkM (Abs ib' reconAbs)
622-
withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do
622+
withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do
623623
extendSubst (ib''@> Rename (atomVarName i')) do
624624
residuals' <- tabApp (sink primalsAux) (toAtom i') >>= getSnd >>= unpackTelescope bs
625625
extendSubst (bs @@> (SubstVal <$> residuals')) $
@@ -636,7 +636,7 @@ linearizeHof hof = case hof of
636636
tanEffLam <- buildEffLam noHint tt \h ref ->
637637
extendTangentArgss [h, ref] do
638638
withSubstReaderT $ applyLinLam $ sink linLam
639-
emitHof $ RunReader rLin' tanEffLam
639+
emitHofLin $ RunReader rLin' tanEffLam
640640
RunState Nothing sInit lam -> do
641641
WithTangent sInit' sLin <- linearizeAtom sInit
642642
(lam', recon) <- linearizeEffectFun State lam
@@ -649,7 +649,7 @@ linearizeHof hof = case hof of
649649
tanEffLam <- buildEffLam noHint tt \h ref ->
650650
extendTangentArgss [h, ref] do
651651
withSubstReaderT $ applyLinLam $ sink linLam
652-
emitHof $ RunState Nothing sLin' tanEffLam
652+
emitHofLin $ RunState Nothing sLin' tanEffLam
653653
RunWriter Nothing bm lam -> do
654654
-- TODO: check it's actually the 0/+ monoid (or should we just build that in?)
655655
bm' <- renameM bm
@@ -663,7 +663,7 @@ linearizeHof hof = case hof of
663663
tanEffLam <- buildEffLam noHint tt \h ref ->
664664
extendTangentArgss [h, ref] do
665665
withSubstReaderT $ applyLinLam $ sink linLam
666-
emitHof $ RunWriter Nothing bm'' tanEffLam
666+
emitHofLin $ RunWriter Nothing bm'' tanEffLam
667667
RunIO body -> do
668668
(body', recon) <- linearizeExprDefunc body
669669
primalAux <- emitHof $ RunIO body'

src/lib/Simplify.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,7 @@ exceptToMaybeExpr expr = case expr of
10561056
return $ JustAtom ty x'
10571057
PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do
10581058
ixTy <- substM ixTy'
1059-
maybes <- buildForAnn (getNameHint b) ann ixTy \i -> do
1059+
maybes <- buildFor (getNameHint b) ann ixTy \i -> do
10601060
extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeExpr body
10611061
catMaybesE maybes
10621062
PrimOp (MiscOp (ThrowException _)) -> do

src/lib/Transpose.hs

+6-12
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,12 @@ data TransposeSubstVal c n where
8080

8181
type TransposeM a = SubstReaderT TransposeSubstVal (BuilderM SimpIR) a
8282

83-
-- TODO: it might make sense to replace substNonlin/isLin
84-
-- with a single `trySubtNonlin :: e i -> Maybe (e o)`.
85-
-- But for that we need a way to traverse names, like a monadic
86-
-- version of `substE`.
87-
substNonlin :: (SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o)
83+
substNonlin :: (PrettyE e, SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o)
8884
substNonlin e = do
8985
subst <- getSubst
9086
fmapRenamingM (\v -> case subst ! v of
9187
RenameNonlin v' -> v'
92-
_ -> error "not a nonlinear expression") e
88+
_ -> error $ "not a nonlinear expression: " ++ pprint e) e
9389

9490
withAccumulator
9591
:: Emits o
@@ -113,7 +109,7 @@ withAccumulator ty cont = do
113109
emitCTToRef :: (Emits n, Builder SimpIR m) => SAtom n -> SAtom n -> m n ()
114110
emitCTToRef ref ct = do
115111
baseMonoid <- tangentBaseMonoidFor (getType ct)
116-
void $ emit $ RefOp ref $ MExtend baseMonoid ct
112+
void $ emitLin $ RefOp ref $ MExtend baseMonoid ct
117113

118114
-- === actual pass ===
119115

@@ -190,7 +186,7 @@ transposeOp op ct = case op of
190186
DAMOp _ -> error "unreachable" -- TODO: rule out statically
191187
RefOp refArg m -> do
192188
refArg' <- substNonlin refArg
193-
let emitEff = emit . RefOp refArg'
189+
let emitEff = emitLin . RefOp refArg'
194190
case m of
195191
MAsk -> do
196192
baseMonoid <- tangentBaseMonoidFor (getType ct)
@@ -251,9 +247,7 @@ transposeAtom atom ct = case atom of
251247
PtrVar _ _ -> notTangent
252248
Var v -> do
253249
lookupSubstM (atomVarName v) >>= \case
254-
RenameNonlin _ ->
255-
-- XXX: we seem to need this case, but it feels like it should be an error!
256-
return ()
250+
RenameNonlin _ -> error "nonlinear"
257251
LinRef ref -> emitCTToRef ref ct
258252
LinTrivial -> return ()
259253
StuckProject _ _ -> error "not linear"
@@ -266,7 +260,7 @@ transposeHof hof ct = case hof of
266260
For ann ixTy' lam -> do
267261
UnaryLamExpr b body <- return lam
268262
ixTy <- substNonlin ixTy'
269-
void $ buildForAnn (getNameHint b) (flipDir ann) ixTy \i -> do
263+
void $ emitLin =<< mkFor (getNameHint b) (flipDir ann) ixTy \i -> do
270264
ctElt <- tabApp (sink ct) (toAtom i)
271265
extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeExpr body ctElt
272266
return UnitVal

0 commit comments

Comments
 (0)