Skip to content

Commit 7bf7ad9

Browse files
committed
Printing LLVM
1 parent f534427 commit 7bf7ad9

File tree

8 files changed

+337
-48
lines changed

8 files changed

+337
-48
lines changed

src/lib/LLVMFFI.hs

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
-- Copyright 2025 Google LLC
2+
--
3+
-- Use of this source code is governed by a BSD-style
4+
-- license that can be found in the LICENSE file or at
5+
-- https://developers.google.com/open-source/licenses/bsd
6+
7+
module LLVMFFI (LLVMContext, initializeLLVM, compileLLVM, getFunctionPtr,
8+
callEntryFun) where
9+
10+
import Data.Int
11+
import Util (BString)
12+
13+
foreign import ccall "doit_cpp" doit_cpp :: Int64 -> IO Int64
14+
15+
type FunctionPtr = ()
16+
type LLVMContext = ()
17+
type DataPtr = ()
18+
type DataListPtr = ()
19+
20+
initializeLLVM :: IO LLVMContext
21+
initializeLLVM = return undefined
22+
23+
compileLLVM :: LLVMContext -> BString -> IO ()
24+
compileLLVM _ _ = return undefined
25+
26+
getFunctionPtr :: LLVMContext -> BString -> IO FunctionPtr
27+
getFunctionPtr _ _ = return undefined
28+
29+
callEntryFun :: FunctionPtr -> [DataPtr] -> IO DataPtr
30+
callEntryFun _ _ = return undefined

src/lib/PPrint.hs

+8-4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
{-# LANGUAGE NoFieldSelectors #-}
88

99
module PPrint (
10-
Pretty (..), indent, emitLine, hcat, hlist, pprint, app,
10+
Pretty (..), indent, emitLine, hcat, hlist, pprint, app, pprintStr,
1111
(<+>), BSBuilder, forceOneLine) where
1212

13+
import Data.ByteString.Internal (w2c)
1314
import Data.Int
1415
import Data.Word
1516
import Data.List (intersperse)
@@ -22,6 +23,12 @@ pprint :: Pretty a => a -> BString
2223
pprint x = runPrinter $ prLines x
2324
{-# SCC pprint #-}
2425

26+
pprintStr :: Pretty a => a -> String
27+
pprintStr x = bs2str $ pprint x
28+
29+
bs2str :: BString -> String
30+
bs2str s = map w2c $ BS.unpack s
31+
2532
-- === printing doc ===
2633

2734
type BString = BS.ByteString
@@ -32,11 +39,8 @@ data PrinterState = PrinterState {indent :: Indent, curString :: BS.Builder }
3239
newtype PrinterM a = PrinterM { inner :: State PrinterState a }
3340
deriving (Functor, Applicative, Monad)
3441

35-
-- Instances should define either `pr` (if they're expected to be one-liners
36-
-- most of the time) or `prLines`.
3742
class Pretty a where
3843
pr :: a -> BSBuilder
39-
pr x = forceOneLine $ prLines x
4044

4145
prLines :: a -> PrinterM ()
4246
prLines x = emitLine $ pr x

src/lib/ToLLVM.hs

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
-- Copyright 2025 Google LLC
2+
--
3+
-- Use of this source code is governed by a BSD-style
4+
-- license that can be found in the LICENSE file or at
5+
-- https://developers.google.com/open-source/licenses/bsd
6+
7+
{-# LANGUAGE NoFieldSelectors #-}
8+
9+
module ToLLVM where
10+
11+
import Name
12+
import Control.Monad
13+
import Control.Monad.State.Strict hiding (state)
14+
import Data.String (fromString)
15+
import qualified Data.ByteString as BS
16+
import qualified Types.LLVM as L
17+
import Types.Simple
18+
import Types.Primitives
19+
import PPrint
20+
21+
import Debug.Trace
22+
import QueryTypePure
23+
import Util
24+
25+
-- === entrypoint ===
26+
27+
toLLVMEntryFun :: Monad m => L.Name -> TopLamExpr -> m L.Function
28+
toLLVMEntryFun fname fun = do
29+
finalState <- runTranslateM do
30+
toLLVMEntryFun' fun
31+
startNewBlock $ L.Name "__unused__"
32+
let blocks = reverse finalState.basicBlocks
33+
return $ L.Function fname [] blocks
34+
35+
-- === monad for the translation ===
36+
37+
data TranslateState i = TranslateState
38+
{ basicBlocks :: [L.BasicBlock] -- reverse order
39+
, instructions :: [L.Decl] -- reverse order
40+
, curBlockName :: L.Name
41+
, nameGen :: Int
42+
, subst :: TranslateSubst i}
43+
type TranslateSubst i = Subst (LiftE L.Operand) i VoidS
44+
45+
newtype TranslateM (i::S) (a:: *) =
46+
TranslateM { inner :: State (TranslateState i) a }
47+
deriving (Functor, Applicative, Monad)
48+
49+
runTranslateM :: Monad m => TranslateM VoidS a -> m (TranslateState VoidS)
50+
runTranslateM cont = do
51+
let initState = TranslateState [] [] (L.Name "__entry__") 0 voidSubst
52+
return $ execState cont.inner initState
53+
54+
emitInstr :: L.Type -> L.Instruction -> TranslateM i L.Operand
55+
emitInstr resultTy instr = do
56+
v <- newLName ""
57+
let decl = (Just v, resultTy, instr)
58+
TranslateM $ modify \s -> s {instructions = decl : s.instructions}
59+
return $ L.Operand (L.LocalOcc v) resultTy
60+
61+
emitStatement :: L.Instruction -> TranslateM i ()
62+
emitStatement instr = do
63+
let decl = (Nothing, L.VoidType, instr)
64+
TranslateM $ modify \s -> s {instructions = decl : s.instructions}
65+
66+
extendEnv :: NameBinder i i' -> L.Operand -> TranslateM i' a -> TranslateM i a
67+
extendEnv b x cont = TranslateM do
68+
prevState <- get
69+
let subst' = prevState.subst <>> (b @> LiftE x)
70+
let (ans, newState) = runState (cont.inner) $ updateSubst prevState subst'
71+
put $ updateSubst newState prevState.subst
72+
return ans
73+
74+
lookupEnv :: Name i -> TranslateM i L.Operand
75+
lookupEnv v = TranslateM do
76+
env <- gets (.subst)
77+
let LiftE x = env ! v
78+
return x
79+
80+
updateSubst :: TranslateState i -> TranslateSubst i' -> TranslateState i'
81+
updateSubst (TranslateState a b c d _) subst = TranslateState a b c d subst
82+
83+
newLName :: BString -> TranslateM i L.Name
84+
newLName hint = TranslateM do
85+
c <- gets (.nameGen)
86+
modify \s -> s {nameGen = s.nameGen + 1}
87+
return $ L.Name $ hint <> "_" <> fromString (show c)
88+
89+
startNewBlock :: L.Name -> TranslateM i ()
90+
startNewBlock blockName = TranslateM $ modify \state -> do
91+
let newBlock = L.BasicBlock state.curBlockName (reverse state.instructions)
92+
state {
93+
basicBlocks = newBlock : state.basicBlocks,
94+
curBlockName = blockName,
95+
instructions = []}
96+
97+
-- === translation itself ===
98+
99+
toLLVMEntryFun' :: TopLamExpr -> TranslateM VoidS ()
100+
toLLVMEntryFun' (TopLamExpr (Abs Empty body)) = do
101+
trExpr body
102+
return ()
103+
104+
trExpr :: Expr i -> TranslateM i L.Operand
105+
trExpr = \case
106+
Block resultTy block -> trBlock block
107+
PrimOp resultTy op -> do
108+
resultTy' <- trType resultTy
109+
op' <- forM op trAtom
110+
trPrimOp resultTy' op'
111+
112+
trType :: Type i -> TranslateM i L.Type
113+
trType = \case
114+
BaseType b -> return $ L.BaseType b
115+
ProdType [] -> return L.VoidType
116+
t -> error $ "not implemented: " ++ pprintStr t
117+
118+
trAtom :: Atom i -> TranslateM i L.Operand
119+
trAtom = \case
120+
Var v _ -> do
121+
val <- lookupEnv v
122+
return val
123+
Lit v -> return $ L.Operand (L.Lit v) (L.BaseType (litType v))
124+
125+
trBlock :: Block i -> TranslateM i L.Operand
126+
trBlock (Abs decls result) = case decls of
127+
Empty -> trExpr result
128+
Nest (Let b expr) rest -> do
129+
val <- trExpr expr
130+
extendEnv b val $ trBlock $ Abs rest result
131+
132+
trPrimOp :: L.Type -> PrimOp L.Operand -> TranslateM i L.Operand
133+
trPrimOp resultTy op = case op of
134+
BinOp b x y -> case b of
135+
FAdd -> emitInstr resultTy $ L.FAdd x y
136+
MiscOp op' -> case op' of
137+
DebugPrintInt x -> undefined

src/lib/TopLevel.hs

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import PPrint
2525
import Simplify
2626
import LLVMFFI
2727
import ToLLVM
28-
import Types.LLVM
28+
import qualified Types.LLVM as L
2929
import Types.Complicated
3030
import Types.Primitives
3131
import Types.Source hiding (CTopDecl)
@@ -104,7 +104,7 @@ execUDecl decl = do
104104
CTopLet Nothing expr <- checkPass TypePass $ inferTopUDecl renamed
105105
simpFun <- simplifyTopFun (exprAsNullaryFun expr)
106106
logPass SimpPass simpFun
107-
let tempFunName = "main" -- TODO: need to get a name
107+
let tempFunName = L.Name "main" -- TODO: need to get a name
108108
llvmContext <- TopperM $ asks topperLLVMContext
109109
llvmFun <- toLLVMEntryFun tempFunName simpFun
110110
logPass LLVMPass llvmFun

src/lib/Types/LLVM.hs

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
-- Copyright 2025 Google LLC
2+
--
3+
-- Use of this source code is governed by a BSD-style
4+
-- license that can be found in the LICENSE file or at
5+
-- https://developers.google.com/open-source/licenses/bsd
6+
7+
{-# LANGUAGE DuplicateRecordFields #-}
8+
9+
module Types.LLVM where
10+
11+
import Control.Monad
12+
import Data.ByteString (ByteString)
13+
import qualified Data.ByteString as BS
14+
import qualified Data.ByteString.Builder as BS
15+
16+
import qualified Types.Primitives as P
17+
import PPrint
18+
import Util (bs2str)
19+
20+
-- this string doesn't include the `@` or `%` prefixes
21+
newtype Name = Name { val :: ByteString }
22+
type Binder = (Name, Type)
23+
24+
data Module = Module { functions :: [Function] }
25+
26+
data Function = Function
27+
{ name :: Name
28+
, params :: [Binder]
29+
, body :: [BasicBlock] }
30+
31+
data BasicBlock = BasicBlock
32+
{ name :: Name
33+
, instructions :: [Decl]}
34+
35+
type Decl = (Maybe Name, Type, Instruction)
36+
data Instruction =
37+
FAdd Operand Operand
38+
| Return Operand
39+
40+
data Operand = Operand { val :: UntypedOperand, ty :: Type }
41+
data UntypedOperand =
42+
LocalOcc Name
43+
| Lit P.LitVal
44+
45+
data Type =
46+
BaseType P.BaseType
47+
| VoidType
48+
49+
-- === LLVM printing ===
50+
51+
-- This is load-bearing! We have to generate correct LLVM textual representation.
52+
53+
instance Pretty Function where
54+
prLines (Function name [] body) = do
55+
emitLine $ "define i32" <+> prTopName name <> "() {"
56+
forM_ body \block -> do
57+
emitLine ""
58+
prLines block
59+
emitLine "}"
60+
61+
prTopName :: Name -> BS.Builder
62+
prTopName name = "@" <> BS.byteString name.val
63+
64+
prLocalName :: Name -> BS.Builder
65+
prLocalName name = "%" <> BS.byteString name.val
66+
67+
prDecl :: Decl -> BS.Builder
68+
prDecl (Just v, resultTy, instr) = prLocalName v <> " = " <> prInstr resultTy instr
69+
70+
prInstr :: Type -> Instruction -> BS.Builder
71+
prInstr resultTy = \case
72+
FAdd x y -> "fadd " <> pr resultTy <+> pr x.val <> ", " <> pr y.val
73+
74+
instance Pretty BasicBlock where
75+
prLines (BasicBlock name decls) = do
76+
emitLine $ pr name <> ":"
77+
indent do
78+
forM_ decls \decl -> emitLine $ prDecl decl
79+
80+
instance Pretty Name where
81+
pr name = BS.byteString name.val
82+
83+
instance Pretty UntypedOperand where
84+
pr = \case
85+
LocalOcc v -> prLocalName v
86+
Lit v -> pr v
87+
88+
instance Pretty Type where
89+
pr = \case
90+
BaseType (P.Scalar b) -> case b of
91+
P.Float32Type -> "f32"
92+
VoidType -> "void"
93+
94+
95+
-- instance LLVMSer Operand where
96+
-- lpr x = cat [lpr (getType x), ", ", printOperandWithoutType x]
97+
98+
-- instance Pretty Type where
99+
-- pr = undefined
100+
101+

src/lib/Types/Simple.hs

+9-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,15 @@ instance GenericE Type where
146146
DepPairTy p -> Case4 $ p
147147
TabPi t -> Case5 $ t
148148

149-
instance Pretty (Type n)
149+
instance Pretty (Type n) where
150+
pr = \case
151+
BaseType b -> pr b
152+
ProdType _ -> undefined
153+
SumType _ -> undefined
154+
RefType _ -> undefined
155+
DepPairTy _ -> undefined
156+
TabPi _ -> undefined
157+
150158
instance SinkableE Type
151159
instance HoistableE Type
152160
instance RenameE Type

src/lib/Types/Source.hs

+3-3
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ data PrintBackend =
588588

589589
data OutFormat = Printed (Maybe PrintBackend) | RenderHtml deriving (Show, Eq, Generic)
590590

591-
data PassName = Parse | RenamePass | TypePass | SimpPass | ImpPass | JitPass | LLVMPass
591+
data PassName = Parse | RenamePass | TypePass | SimpPass | ImpPass | LLVMPass
592592
| LLVMOpt | AsmPass | JAXPass | JAXSimpPass | LLVMEval | LowerOptPass | LowerPass
593593
| ResultPass | JaxprAndHLO | EarlyOptPass | OptPass | VectPass | OccAnalysisPass
594594
| InlinePass
@@ -597,13 +597,13 @@ data PassName = Parse | RenamePass | TypePass | SimpPass | ImpPass | JitPass | L
597597
instance Show PassName where
598598
show p = case p of
599599
Parse -> "parse" ; RenamePass -> "rename"; TypePass -> "typed"
600-
SimpPass -> "simp" ; ImpPass -> "imp" ; JitPass -> "llvm"
600+
SimpPass -> "simp" ; ImpPass -> "imp"
601601
LLVMOpt -> "llvmopt" ; AsmPass -> "asm"
602602
JAXPass -> "jax" ; JAXSimpPass -> "jsimp"; ResultPass -> "result"
603603
LLVMEval -> "llvmeval" ; JaxprAndHLO -> "jaxprhlo";
604604
LowerOptPass -> "lower-opt"; LowerPass -> "lower"
605605
EarlyOptPass -> "early-opt"; OptPass -> "opt"; OccAnalysisPass -> "occ-analysis"
606-
VectPass -> "vect"; InlinePass -> "inline"
606+
VectPass -> "vect"; InlinePass -> "inline"; LLVMPass -> "llvm"
607607

608608
data EnvQuery =
609609
DumpSubst

0 commit comments

Comments
 (0)