@@ -29,12 +29,13 @@ import IRVariants
29
29
import MTL1
30
30
import Subst
31
31
import Name
32
+ import PeepholeOptimize
32
33
import QueryType
33
34
import Types.Core
34
35
import Types.Imp
35
36
import Types.Primitives
36
37
import Types.Source
37
- import Util (enumerate , transitiveClosureM , bindM2 , toSnocList , (...) )
38
+ import Util (enumerate , transitiveClosureM , bindM2 , toSnocList )
38
39
39
40
-- === Ordinary (local) builder class ===
40
41
@@ -66,50 +67,31 @@ emitDecl _ _ (Atom (Stuck _ (Var n))) = return n
66
67
emitDecl hint ann expr = rawEmitDecl hint ann expr
67
68
{-# INLINE emitDecl #-}
68
69
69
- emitInline :: (Builder r m , Emits n ) => Atom r n -> m n (AtomVar r n )
70
- emitInline atom = emitDecl noHint InlineLet $ Atom atom
71
- {-# INLINE emitInline #-}
72
-
73
- emitHinted :: (Builder r m , Emits n ) => NameHint -> Expr r n -> m n (AtomVar r n )
74
- emitHinted hint expr = emitDecl hint PlainLet expr
75
- {-# INLINE emitHinted #-}
76
-
77
70
emit :: (Builder r m , ToExpr e r , Emits n ) => e n -> m n (Atom r n )
78
71
emit e = case toExpr e of
79
72
Atom x -> return x
80
73
Block _ block -> emitDecls block >>= emit
81
- expr -> toAtom <$> emitToVar expr
74
+ expr -> do
75
+ v <- emitDecl noHint PlainLet $ peepholeExpr expr
76
+ return $ toAtom v
82
77
{-# INLINE emit #-}
83
78
84
79
emitToVar :: (Builder r m , ToExpr e r , Emits n ) => e n -> m n (AtomVar r n )
85
- emitToVar e = case toExpr e of
86
- Atom ( Stuck _ (Var v) ) -> return v
87
- expr -> emitDecl noHint PlainLet expr
80
+ emitToVar expr = emit expr >>= \ case
81
+ Stuck _ (Var v) -> return v
82
+ atom -> emitDecl noHint PlainLet (toExpr atom)
88
83
{-# INLINE emitToVar #-}
89
84
90
- emitHof :: (Builder r m , Emits n ) => Hof r n -> m n (Atom r n )
91
- emitHof hof = mkTypedHof hof >>= emit
92
-
93
- mkTypedHof :: (EnvReader m , IRRep r ) => Hof r n -> m n (TypedHof r n )
94
- mkTypedHof hof = do
95
- effTy <- effTyOfHof hof
96
- return $ TypedHof effTy hof
97
-
98
- emitUnOp :: (Builder r m , Emits n ) => UnOp -> Atom r n -> m n (Atom r n )
99
- emitUnOp op x = emit $ UnOp op x
100
- {-# INLINE emitUnOp #-}
101
-
102
85
emitDecls :: (Builder r m , Emits n , RenameE e , SinkableE e )
103
86
=> WithDecls r e n -> m n (e n )
104
- emitDecls (Abs decls result) = runSubstReaderT idSubst $ emitDecls' decls result
105
-
106
- emitDecls' :: (Builder r m , Emits o , RenameE e , SinkableE e )
107
- => Nest (Decl r ) i i' -> e i' -> SubstReaderT Name m i o (e o )
108
- emitDecls' Empty e = renameM e
109
- emitDecls' (Nest (Let b (DeclBinding ann expr)) rest) e = do
110
- expr' <- renameM expr
111
- AtomVar v _ <- emitDecl (getNameHint b) ann expr'
112
- extendSubst (b @> v) $ emitDecls' rest e
87
+ emitDecls (Abs decls result) = runSubstReaderT idSubst $ go decls result where
88
+ go :: (Builder r m , Emits o , RenameE e , SinkableE e )
89
+ => Nest (Decl r ) i i' -> e i' -> SubstReaderT Name m i o (e o )
90
+ go Empty e = renameM e
91
+ go (Nest (Let b (DeclBinding ann expr)) rest) e = do
92
+ expr' <- renameM expr
93
+ AtomVar v _ <- emitDecl (getNameHint b) ann expr'
94
+ extendSubst (b @> v) $ go rest e
113
95
114
96
buildScopedAssumeNoDecls :: (SinkableE e , ScopableBuilder r m )
115
97
=> (forall l . (Emits l , DExt n l ) => m l (e l ))
@@ -775,6 +757,14 @@ buildEffLam hint ty body = do
775
757
body' <- buildBlock $ body (sink hVar) $ sink ref
776
758
return $ LamExpr (BinaryNest h b) body'
777
759
760
+ emitHof :: (Builder r m , Emits n ) => Hof r n -> m n (Atom r n )
761
+ emitHof hof = mkTypedHof hof >>= emit
762
+
763
+ mkTypedHof :: (EnvReader m , IRRep r ) => Hof r n -> m n (TypedHof r n )
764
+ mkTypedHof hof = do
765
+ effTy <- effTyOfHof hof
766
+ return $ TypedHof effTy hof
767
+
778
768
buildForAnn
779
769
:: (Emits n , ScopableBuilder r m )
780
770
=> NameHint -> ForAnn -> IxType r n
@@ -940,70 +930,38 @@ symbolicTangentNonZero val = do
940
930
941
931
-- === builder versions of common local ops ===
942
932
943
- fLitLike :: (SBuilder m , Emits n ) => Double -> SAtom n -> m n (SAtom n )
944
- fLitLike x t = case getTyCon t of
945
- BaseType (Scalar Float64Type ) -> return $ toAtom $ Lit $ Float64Lit x
946
- BaseType (Scalar Float32Type ) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x
947
- _ -> error " Expected a floating point scalar"
948
-
949
933
neg :: (Builder r m , Emits n ) => Atom r n -> m n (Atom r n )
950
934
neg x = emit $ UnOp FNeg x
951
935
952
936
add :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
953
937
add x y = emit $ BinOp FAdd x y
954
938
955
- -- TODO: Implement constant folding for fixed-width integer types as well!
956
- iadd :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
957
- iadd (Con (Lit l)) y | getIntLit l == 0 = return y
958
- iadd x (Con (Lit l)) | getIntLit l == 0 = return x
959
- iadd x@ (Con (Lit _)) y@ (Con (Lit _)) = return $ applyIntBinOp (+) x y
960
- iadd x y = emit $ BinOp IAdd x y
961
-
962
939
mul :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
963
940
mul x y = emit $ BinOp FMul x y
964
941
942
+ iadd :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
943
+ iadd x y = emit $ BinOp IAdd x y
944
+
965
945
imul :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
966
- imul (Con (Lit l)) y | getIntLit l == 1 = return y
967
- imul x (Con (Lit l)) | getIntLit l == 1 = return x
968
- imul x@ (Con (Lit _)) y@ (Con (Lit _)) = return $ applyIntBinOp (*) x y
969
946
imul x y = emit $ BinOp IMul x y
970
947
971
- sub :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
972
- sub x y = emit $ BinOp FSub x y
973
-
974
- isub :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
975
- isub x (Con (Lit l)) | getIntLit l == 0 = return x
976
- isub x@ (Con (Lit _)) y@ (Con (Lit _)) = return $ applyIntBinOp (-) x y
977
- isub x y = emit $ BinOp ISub x y
978
-
979
- select :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> Atom r n -> m n (Atom r n )
980
- select (Con (Lit (Word8Lit p))) x y = return $ if p /= 0 then x else y
981
- select p x y = emit $ MiscOp $ Select p x y
982
-
983
948
div' :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
984
949
div' x y = emit $ BinOp FDiv x y
985
950
986
- idiv :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
987
- idiv x (Con (Lit l)) | getIntLit l == 1 = return x
988
- idiv x@ (Con (Lit _)) y@ (Con (Lit _)) = return $ applyIntBinOp div x y
989
- idiv x y = emit $ BinOp IDiv x y
990
-
991
- irem :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
992
- irem x y = emit $ BinOp IRem x y
993
-
994
951
fpow :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
995
952
fpow x y = emit $ BinOp FPow x y
996
953
954
+ sub :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
955
+ sub x y = emit $ BinOp FSub x y
956
+
997
957
flog :: (Builder r m , Emits n ) => Atom r n -> m n (Atom r n )
998
958
flog x = emit $ UnOp Log x
999
959
1000
- ilt :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
1001
- ilt x@ (Con (Lit _)) y@ (Con (Lit _)) = return $ applyIntCmpOp (<) x y
1002
- ilt x y = emit $ BinOp (ICmp Less ) x y
1003
-
1004
- ieq :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
1005
- ieq x@ (Con (Lit _)) y@ (Con (Lit _)) = return $ applyIntCmpOp (==) x y
1006
- ieq x y = emit $ BinOp (ICmp Equal ) x y
960
+ fLitLike :: (SBuilder m , Emits n ) => Double -> SAtom n -> m n (SAtom n )
961
+ fLitLike x t = case getTyCon t of
962
+ BaseType (Scalar Float64Type ) -> return $ toAtom $ Lit $ Float64Lit x
963
+ BaseType (Scalar Float32Type ) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x
964
+ _ -> error " Expected a floating point scalar"
1007
965
1008
966
fromPair :: (Builder r m , Emits n ) => Atom r n -> m n (Atom r n , Atom r n )
1009
967
fromPair pair = do
@@ -1160,7 +1118,7 @@ app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n)
1160
1118
app x i = mkApp x [i] >>= emit
1161
1119
1162
1120
naryApp :: (CBuilder m , Emits n ) => CAtom n -> [CAtom n ] -> m n (CAtom n )
1163
- naryApp = naryAppHinted noHint
1121
+ naryApp f xs = mkApp f xs >>= emit
1164
1122
{-# INLINE naryApp #-}
1165
1123
1166
1124
naryTopApp :: (Builder SimpIR m , Emits n ) => TopFunName n -> [SAtom n ] -> m n (SAtom n )
@@ -1175,10 +1133,6 @@ naryTopAppInlined f xs = do
1175
1133
_ -> naryTopApp f xs
1176
1134
{-# INLINE naryTopAppInlined #-}
1177
1135
1178
- naryAppHinted :: (CBuilder m , Emits n )
1179
- => NameHint -> CAtom n -> [CAtom n ] -> m n (CAtom n )
1180
- naryAppHinted hint f xs = toAtom <$> (mkApp f xs >>= emitHinted hint)
1181
-
1182
1136
tabApp :: (Builder r m , Emits n ) => Atom r n -> Atom r n -> m n (Atom r n )
1183
1137
tabApp x i = mkTabApp x i >>= emit
1184
1138
@@ -1581,30 +1535,3 @@ visitDeclsEmits (Nest (Let b (DeclBinding _ expr)) decls) cont = do
1581
1535
x <- visitExprEmits expr
1582
1536
extendSubst (b@> SubstVal x) do
1583
1537
visitDeclsEmits decls cont
1584
-
1585
- -- === Helpers for function evaluation over fixed-width types ===
1586
-
1587
- applyIntBinOp' :: (forall a . (Eq a , Ord a , Num a , Integral a )
1588
- => (a -> Atom r n ) -> a -> a -> Atom r n ) -> Atom r n -> Atom r n -> Atom r n
1589
- applyIntBinOp' f x y = case (x, y) of
1590
- (Con (Lit (Int64Lit xv)), Con (Lit (Int64Lit yv))) -> f (Con . Lit . Int64Lit ) xv yv
1591
- (Con (Lit (Int32Lit xv)), Con (Lit (Int32Lit yv))) -> f (Con . Lit . Int32Lit ) xv yv
1592
- (Con (Lit (Word8Lit xv)), Con (Lit (Word8Lit yv))) -> f (Con . Lit . Word8Lit ) xv yv
1593
- (Con (Lit (Word32Lit xv)), Con (Lit (Word32Lit yv))) -> f (Con . Lit . Word32Lit ) xv yv
1594
- (Con (Lit (Word64Lit xv)), Con (Lit (Word64Lit yv))) -> f (Con . Lit . Word64Lit ) xv yv
1595
- _ -> error " Expected integer atoms"
1596
-
1597
- applyIntBinOp :: (forall a . (Num a , Integral a ) => a -> a -> a ) -> Atom r n -> Atom r n -> Atom r n
1598
- applyIntBinOp f x y = applyIntBinOp' (\ w -> w ... f) x y
1599
-
1600
- applyIntCmpOp :: (forall a . (Eq a , Ord a ) => a -> a -> Bool ) -> Atom r n -> Atom r n -> Atom r n
1601
- applyIntCmpOp f x y = applyIntBinOp' (\ _ -> (Con . Lit . Word8Lit . fromIntegral . fromEnum ) ... f) x y
1602
-
1603
- applyFloatBinOp :: (forall a . (Num a , Fractional a ) => a -> a -> a ) -> Atom r n -> Atom r n -> Atom r n
1604
- applyFloatBinOp f x y = case (x, y) of
1605
- (Con (Lit (Float64Lit xv)), Con (Lit (Float64Lit yv))) -> Con $ Lit $ Float64Lit $ f xv yv
1606
- (Con (Lit (Float32Lit xv)), Con (Lit (Float32Lit yv))) -> Con $ Lit $ Float32Lit $ f xv yv
1607
- _ -> error " Expected float atoms"
1608
-
1609
- _applyFloatUnOp :: (forall a . (Num a , Fractional a ) => a -> a ) -> Atom r n -> Atom r n
1610
- _applyFloatUnOp f x = applyFloatBinOp (\ _ -> f) (error " shouldn't be needed" ) x
0 commit comments