Skip to content

Commit d10cfc5

Browse files
committed
Put peephole optimizations in one place and use them from Builder.
The reason to do this now is that I want to make AD linearity explicit for correctness reasons. I started doing that but realized I was going to need linear version of each of the helper functions in Builder `add`, `mul`. This cuts down on that boilerplate and it's a good idea anyway.
1 parent be61893 commit d10cfc5

10 files changed

+346
-323
lines changed

dex.cabal

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ library
7777
, Occurrence
7878
, OccAnalysis
7979
, Optimize
80+
, PeepholeOptimize
8081
, PPrint
8182
, RawName
8283
, Runtime

src/lib/Algebra.hs

+8-5
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ newtype Polynomial (n::S) =
5050
-- us compute sums in closed form. This tries to compute
5151
-- `\sum_{i=0}^(lim-1) body`. `i`, `lim`, and `body` should all have type `Nat`.
5252
sumUsingPolys :: Emits n
53-
=> Atom SimpIR n -> Abs (Binder SimpIR) (Expr SimpIR) n -> BuilderM SimpIR n (Atom SimpIR n)
53+
=> SAtom n -> Abs (Binder SimpIR) (Expr SimpIR) n -> BuilderM SimpIR n (SAtom n)
5454
sumUsingPolys lim (Abs i body) = do
5555
sumAbs <- refreshAbs (Abs i body) \(i':>_) body' -> do
5656
exprAsPoly body' >>= \case
@@ -138,7 +138,7 @@ type BlockTraverserM i o a = SubstReaderT PolySubstVal (MaybeT1 (BuilderM SimpIR
138138
exprAsPoly :: (EnvExtender m, EnvReader m) => SExpr n -> m n (Maybe (Polynomial n))
139139
exprAsPoly expr = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ exprAsPolyRec expr
140140

141-
atomAsPoly :: Atom SimpIR i -> BlockTraverserM i o (Polynomial o)
141+
atomAsPoly :: SAtom i -> BlockTraverserM i o (Polynomial o)
142142
atomAsPoly = \case
143143
Stuck _ (Var v) -> atomVarAsPoly v
144144
Stuck _ (RepValAtom (RepVal _ (Leaf (IVar v' _)))) -> impNameAsPoly v'
@@ -190,7 +190,7 @@ blockAsPoly (Abs decls result) = case decls of
190190
-- coefficients. This is why we have to find the least common multiples and do the
191191
-- accumulation over numbers multiplied by that LCM. We essentially do fixed point
192192
-- fractional math here.
193-
emitPolynomial :: Emits n => Polynomial n -> BuilderM SimpIR n (Atom SimpIR n)
193+
emitPolynomial :: Emits n => Polynomial n -> BuilderM SimpIR n (SAtom n)
194194
emitPolynomial (Polynomial p) = do
195195
let constLCM = asAtom $ foldl lcm 1 $ fmap (denominator . snd) $ toList p
196196
monoAtoms <- flip traverse (toList p) $ \(m, c) -> do
@@ -204,7 +204,7 @@ emitPolynomial (Polynomial p) = do
204204
-- because it might be causing overflows due to all arithmetic being shifted.
205205
asAtom = IdxRepVal . fromInteger
206206

207-
emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (Atom SimpIR n)
207+
emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (SAtom n)
208208
emitMonomial (Monomial m) = do
209209
varAtoms <- forM (toList m) \(v, e) -> case v of
210210
LeftE v' -> do
@@ -215,9 +215,12 @@ emitMonomial (Monomial m) = do
215215
ipow atom e
216216
foldM imul (IdxRepVal 1) varAtoms
217217

218-
ipow :: Emits n => Atom SimpIR n -> Int -> BuilderM SimpIR n (Atom SimpIR n)
218+
ipow :: Emits n => SAtom n -> Int -> BuilderM SimpIR n (SAtom n)
219219
ipow x i = foldM imul (IdxRepVal 1) (replicate i x)
220220

221+
idiv :: Emits n => SAtom n -> SAtom n -> BuilderM SimpIR n (SAtom n)
222+
idiv = undefined
223+
221224
-- === instances ===
222225

223226
instance GenericE Monomial where

src/lib/Builder.hs

+36-109
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ import IRVariants
2929
import MTL1
3030
import Subst
3131
import Name
32+
import PeepholeOptimize
3233
import QueryType
3334
import Types.Core
3435
import Types.Imp
3536
import Types.Primitives
3637
import Types.Source
37-
import Util (enumerate, transitiveClosureM, bindM2, toSnocList, (...))
38+
import Util (enumerate, transitiveClosureM, bindM2, toSnocList)
3839

3940
-- === Ordinary (local) builder class ===
4041

@@ -66,50 +67,31 @@ emitDecl _ _ (Atom (Stuck _ (Var n))) = return n
6667
emitDecl hint ann expr = rawEmitDecl hint ann expr
6768
{-# INLINE emitDecl #-}
6869

69-
emitInline :: (Builder r m, Emits n) => Atom r n -> m n (AtomVar r n)
70-
emitInline atom = emitDecl noHint InlineLet $ Atom atom
71-
{-# INLINE emitInline #-}
72-
73-
emitHinted :: (Builder r m, Emits n) => NameHint -> Expr r n -> m n (AtomVar r n)
74-
emitHinted hint expr = emitDecl hint PlainLet expr
75-
{-# INLINE emitHinted #-}
76-
7770
emit :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n)
7871
emit e = case toExpr e of
7972
Atom x -> return x
8073
Block _ block -> emitDecls block >>= emit
81-
expr -> toAtom <$> emitToVar expr
74+
expr -> do
75+
v <- emitDecl noHint PlainLet $ peepholeExpr expr
76+
return $ toAtom v
8277
{-# INLINE emit #-}
8378

8479
emitToVar :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (AtomVar r n)
85-
emitToVar e = case toExpr e of
86-
Atom (Stuck _ (Var v)) -> return v
87-
expr -> emitDecl noHint PlainLet expr
80+
emitToVar expr = emit expr >>= \case
81+
Stuck _ (Var v) -> return v
82+
atom -> emitDecl noHint PlainLet (toExpr atom)
8883
{-# INLINE emitToVar #-}
8984

90-
emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n)
91-
emitHof hof = mkTypedHof hof >>= emit
92-
93-
mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n)
94-
mkTypedHof hof = do
95-
effTy <- effTyOfHof hof
96-
return $ TypedHof effTy hof
97-
98-
emitUnOp :: (Builder r m, Emits n) => UnOp -> Atom r n -> m n (Atom r n)
99-
emitUnOp op x = emit $ UnOp op x
100-
{-# INLINE emitUnOp #-}
101-
10285
emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e)
10386
=> WithDecls r e n -> m n (e n)
104-
emitDecls (Abs decls result) = runSubstReaderT idSubst $ emitDecls' decls result
105-
106-
emitDecls' :: (Builder r m, Emits o, RenameE e, SinkableE e)
107-
=> Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o)
108-
emitDecls' Empty e = renameM e
109-
emitDecls' (Nest (Let b (DeclBinding ann expr)) rest) e = do
110-
expr' <- renameM expr
111-
AtomVar v _ <- emitDecl (getNameHint b) ann expr'
112-
extendSubst (b @> v) $ emitDecls' rest e
87+
emitDecls (Abs decls result) = runSubstReaderT idSubst $ go decls result where
88+
go :: (Builder r m, Emits o, RenameE e, SinkableE e)
89+
=> Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o)
90+
go Empty e = renameM e
91+
go (Nest (Let b (DeclBinding ann expr)) rest) e = do
92+
expr' <- renameM expr
93+
AtomVar v _ <- emitDecl (getNameHint b) ann expr'
94+
extendSubst (b @> v) $ go rest e
11395

11496
buildScopedAssumeNoDecls :: (SinkableE e, ScopableBuilder r m)
11597
=> (forall l. (Emits l, DExt n l) => m l (e l))
@@ -775,6 +757,14 @@ buildEffLam hint ty body = do
775757
body' <- buildBlock $ body (sink hVar) $ sink ref
776758
return $ LamExpr (BinaryNest h b) body'
777759

760+
emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n)
761+
emitHof hof = mkTypedHof hof >>= emit
762+
763+
mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n)
764+
mkTypedHof hof = do
765+
effTy <- effTyOfHof hof
766+
return $ TypedHof effTy hof
767+
778768
buildForAnn
779769
:: (Emits n, ScopableBuilder r m)
780770
=> NameHint -> ForAnn -> IxType r n
@@ -940,70 +930,38 @@ symbolicTangentNonZero val = do
940930

941931
-- === builder versions of common local ops ===
942932

943-
fLitLike :: (SBuilder m, Emits n) => Double -> SAtom n -> m n (SAtom n)
944-
fLitLike x t = case getTyCon t of
945-
BaseType (Scalar Float64Type) -> return $ toAtom $ Lit $ Float64Lit x
946-
BaseType (Scalar Float32Type) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x
947-
_ -> error "Expected a floating point scalar"
948-
949933
neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n)
950934
neg x = emit $ UnOp FNeg x
951935

952936
add :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
953937
add x y = emit $ BinOp FAdd x y
954938

955-
-- TODO: Implement constant folding for fixed-width integer types as well!
956-
iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
957-
iadd (Con (Lit l)) y | getIntLit l == 0 = return y
958-
iadd x (Con (Lit l)) | getIntLit l == 0 = return x
959-
iadd x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (+) x y
960-
iadd x y = emit $ BinOp IAdd x y
961-
962939
mul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
963940
mul x y = emit $ BinOp FMul x y
964941

942+
iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
943+
iadd x y = emit $ BinOp IAdd x y
944+
965945
imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
966-
imul (Con (Lit l)) y | getIntLit l == 1 = return y
967-
imul x (Con (Lit l)) | getIntLit l == 1 = return x
968-
imul x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (*) x y
969946
imul x y = emit $ BinOp IMul x y
970947

971-
sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
972-
sub x y = emit $ BinOp FSub x y
973-
974-
isub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
975-
isub x (Con (Lit l)) | getIntLit l == 0 = return x
976-
isub x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (-) x y
977-
isub x y = emit $ BinOp ISub x y
978-
979-
select :: (Builder r m, Emits n) => Atom r n -> Atom r n -> Atom r n -> m n (Atom r n)
980-
select (Con (Lit (Word8Lit p))) x y = return $ if p /= 0 then x else y
981-
select p x y = emit $ MiscOp $ Select p x y
982-
983948
div' :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
984949
div' x y = emit $ BinOp FDiv x y
985950

986-
idiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
987-
idiv x (Con (Lit l)) | getIntLit l == 1 = return x
988-
idiv x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp div x y
989-
idiv x y = emit $ BinOp IDiv x y
990-
991-
irem :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
992-
irem x y = emit $ BinOp IRem x y
993-
994951
fpow :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
995952
fpow x y = emit $ BinOp FPow x y
996953

954+
sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
955+
sub x y = emit $ BinOp FSub x y
956+
997957
flog :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n)
998958
flog x = emit $ UnOp Log x
999959

1000-
ilt :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
1001-
ilt x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (<) x y
1002-
ilt x y = emit $ BinOp (ICmp Less) x y
1003-
1004-
ieq :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
1005-
ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y
1006-
ieq x y = emit $ BinOp (ICmp Equal) x y
960+
fLitLike :: (SBuilder m, Emits n) => Double -> SAtom n -> m n (SAtom n)
961+
fLitLike x t = case getTyCon t of
962+
BaseType (Scalar Float64Type) -> return $ toAtom $ Lit $ Float64Lit x
963+
BaseType (Scalar Float32Type) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x
964+
_ -> error "Expected a floating point scalar"
1007965

1008966
fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n)
1009967
fromPair pair = do
@@ -1160,7 +1118,7 @@ app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n)
11601118
app x i = mkApp x [i] >>= emit
11611119

11621120
naryApp :: (CBuilder m, Emits n) => CAtom n -> [CAtom n] -> m n (CAtom n)
1163-
naryApp = naryAppHinted noHint
1121+
naryApp f xs= mkApp f xs >>= emit
11641122
{-# INLINE naryApp #-}
11651123

11661124
naryTopApp :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n)
@@ -1175,10 +1133,6 @@ naryTopAppInlined f xs = do
11751133
_ -> naryTopApp f xs
11761134
{-# INLINE naryTopAppInlined #-}
11771135

1178-
naryAppHinted :: (CBuilder m, Emits n)
1179-
=> NameHint -> CAtom n -> [CAtom n] -> m n (CAtom n)
1180-
naryAppHinted hint f xs = toAtom <$> (mkApp f xs >>= emitHinted hint)
1181-
11821136
tabApp :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
11831137
tabApp x i = mkTabApp x i >>= emit
11841138

@@ -1581,30 +1535,3 @@ visitDeclsEmits (Nest (Let b (DeclBinding _ expr)) decls) cont = do
15811535
x <- visitExprEmits expr
15821536
extendSubst (b@>SubstVal x) do
15831537
visitDeclsEmits decls cont
1584-
1585-
-- === Helpers for function evaluation over fixed-width types ===
1586-
1587-
applyIntBinOp' :: (forall a. (Eq a, Ord a, Num a, Integral a)
1588-
=> (a -> Atom r n) -> a -> a -> Atom r n) -> Atom r n -> Atom r n -> Atom r n
1589-
applyIntBinOp' f x y = case (x, y) of
1590-
(Con (Lit (Int64Lit xv)), Con (Lit (Int64Lit yv))) -> f (Con . Lit . Int64Lit) xv yv
1591-
(Con (Lit (Int32Lit xv)), Con (Lit (Int32Lit yv))) -> f (Con . Lit . Int32Lit) xv yv
1592-
(Con (Lit (Word8Lit xv)), Con (Lit (Word8Lit yv))) -> f (Con . Lit . Word8Lit) xv yv
1593-
(Con (Lit (Word32Lit xv)), Con (Lit (Word32Lit yv))) -> f (Con . Lit . Word32Lit) xv yv
1594-
(Con (Lit (Word64Lit xv)), Con (Lit (Word64Lit yv))) -> f (Con . Lit . Word64Lit) xv yv
1595-
_ -> error "Expected integer atoms"
1596-
1597-
applyIntBinOp :: (forall a. (Num a, Integral a) => a -> a -> a) -> Atom r n -> Atom r n -> Atom r n
1598-
applyIntBinOp f x y = applyIntBinOp' (\w -> w ... f) x y
1599-
1600-
applyIntCmpOp :: (forall a. (Eq a, Ord a) => a -> a -> Bool) -> Atom r n -> Atom r n -> Atom r n
1601-
applyIntCmpOp f x y = applyIntBinOp' (\_ -> (Con . Lit . Word8Lit . fromIntegral . fromEnum) ... f) x y
1602-
1603-
applyFloatBinOp :: (forall a. (Num a, Fractional a) => a -> a -> a) -> Atom r n -> Atom r n -> Atom r n
1604-
applyFloatBinOp f x y = case (x, y) of
1605-
(Con (Lit (Float64Lit xv)), Con (Lit (Float64Lit yv))) -> Con $ Lit $ Float64Lit $ f xv yv
1606-
(Con (Lit (Float32Lit xv)), Con (Lit (Float32Lit yv))) -> Con $ Lit $ Float32Lit $ f xv yv
1607-
_ -> error "Expected float atoms"
1608-
1609-
_applyFloatUnOp :: (forall a. (Num a, Fractional a) => a -> a) -> Atom r n -> Atom r n
1610-
_applyFloatUnOp f x = applyFloatBinOp (\_ -> f) (error "shouldn't be needed") x

src/lib/Inference.hs

+4-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d
122122
_ -> do
123123
PairE block recon <- liftInfererM $ buildBlockInfWithRecon do
124124
val <- checkMaybeAnnExpr tyAnn rhs
125-
v <- emitHinted (getNameHint p) $ Atom val
125+
v <- emitDecl (getNameHint p) PlainLet $ Atom val
126126
bindLetPat p v do
127127
renameM result
128128
(topBlock, _) <- asTopBlock block
@@ -1597,6 +1597,9 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of
15971597
xs <- forM [0 .. n - 1] \i -> do
15981598
emitToVar =<< mkTabApp (toAtom v) (toAtom $ NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i))
15991599
bindLetPats ps xs cont
1600+
where
1601+
emitInline :: Emits n => CAtom n -> InfererM i n (AtomVar CoreIR n)
1602+
emitInline atom = emitDecl noHint InlineLet $ Atom atom
16001603

16011604
checkUType :: UType i -> InfererM i o (CType o)
16021605
checkUType t = do

src/lib/Inline.hs

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import IRVariants
1414
import Name
1515
import Subst
1616
import Occurrence hiding (Var)
17-
import Optimize
17+
import PeepholeOptimize
1818
import Types.Core
1919
import Types.Primitives
2020

@@ -80,7 +80,7 @@ inlineDeclsSubst = \case
8080
s <- getSubst
8181
extendSubst (b @> SubstVal (SuspEx expr s)) $ inlineDeclsSubst rest
8282
else do
83-
expr' <- inlineExpr Stop expr >>= (liftEnvReaderM . peepholeExpr)
83+
expr' <- peepholeExpr <$> inlineExpr Stop expr
8484
-- If the inliner starts moving effectful expressions, it may become
8585
-- necessary to query the effects of the new expression here.
8686
let presInfo = resolveWorkConservation ann expr'

src/lib/Linearize.hs

+5-5
Original file line numberDiff line numberDiff line change
@@ -495,18 +495,18 @@ linearizeUnOp op x' = do
495495
let emitZeroT = withZeroT $ emit $ UnOp op x
496496
case op of
497497
Exp -> do
498-
y <- emitUnOp Exp x
498+
y <- emit $ UnOp Exp x
499499
return $ WithTangent y (bindM2 mul tx (sinkM y))
500500
Exp2 -> notImplemented
501-
Log -> withT (emitUnOp Log x) $ (tx >>= (`div'` sink x))
501+
Log -> withT (emit $ UnOp Log x) $ (tx >>= (`div'` sink x))
502502
Log2 -> notImplemented
503503
Log10 -> notImplemented
504504
Log1p -> notImplemented
505-
Sin -> withT (emitUnOp Sin x) $ bindM2 mul tx (emitUnOp Cos (sink x))
506-
Cos -> withT (emitUnOp Cos x) $ bindM2 mul tx (neg =<< emitUnOp Sin (sink x))
505+
Sin -> withT (emit $ UnOp Sin x) $ bindM2 mul tx (emit $ UnOp Cos (sink x))
506+
Cos -> withT (emit $ UnOp Cos x) $ bindM2 mul tx (neg =<< emit (UnOp Sin (sink x)))
507507
Tan -> notImplemented
508508
Sqrt -> do
509-
y <- emitUnOp Sqrt x
509+
y <- emit $ UnOp Sqrt x
510510
return $ WithTangent y do
511511
denominator <- bindM2 mul (2 `fLitLike` sink y) (sinkM y)
512512
bindM2 div' tx (pure denominator)

0 commit comments

Comments
 (0)