Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/irx/builders/llvmliteir.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(self) -> None:
self.named_values: dict[str, Any] = {}
self.function_protos: dict[str, astx.FunctionPrototype] = {}
self.result_stack: list[ir.Value | ir.Function] = []
self._fast_math_enabled = False

self.initialize()

Expand Down Expand Up @@ -420,7 +421,36 @@ def _emit_fma(
return builder.fma(lhs, rhs, addend, name="vfma")

fma_fn = self._get_fma_function(lhs.type)
return builder.call(fma_fn, [lhs, rhs, addend], name="vfma")
inst = builder.call(fma_fn, [lhs, rhs, addend], name="vfma")
self._apply_fast_math(inst)
return inst

def set_fast_math(self, enabled: bool) -> None:
"""Enable/disable fast-math flags for subsequent FP instructions."""
self._fast_math_enabled = enabled

def _apply_fast_math(self, inst: ir.Instruction) -> None:
"""Attach fast-math flags when enabled and applicable."""
if not self._fast_math_enabled:
return
ty = inst.type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inst.flags might be None, immutable or absent, so it can cause attributeError or TypeError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback @yuvimittal. and sorry for the delay .Should I handle it with a try-except block?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added defensive handling using getattr and try/except so it won't crash if flags is None or immutable.

if isinstance(ty, ir.VectorType):
if not is_fp_type(ty.element):
return
elif not is_fp_type(ty):
return

flags = getattr(inst, "flags", None)
if flags is None:
return

if "fast" in flags:
return

try:
flags.append("fast")
except (AttributeError, TypeError):
return

@dispatch.abstract
def visit(self, node: astx.AST) -> None:
Expand Down Expand Up @@ -616,6 +646,7 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.fadd(
llvm_lhs, llvm_rhs, name="vfaddtmp"
)
self._apply_fast_math(result)
else:
result = self._llvm.ir_builder.add(
llvm_lhs, llvm_rhs, name="vaddtmp"
Expand All @@ -625,6 +656,7 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.fsub(
llvm_lhs, llvm_rhs, name="vfsubtmp"
)
self._apply_fast_math(result)
else:
result = self._llvm.ir_builder.sub(
llvm_lhs, llvm_rhs, name="vsubtmp"
Expand All @@ -634,6 +666,7 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.fmul(
llvm_lhs, llvm_rhs, name="vfmultmp"
)
self._apply_fast_math(result)
else:
result = self._llvm.ir_builder.mul(
llvm_lhs, llvm_rhs, name="vmultmp"
Expand All @@ -643,6 +676,7 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.fdiv(
llvm_lhs, llvm_rhs, name="vfdivtmp"
)
self._apply_fast_math(result)
else:
unsigned = getattr(node, "unsigned", None)
if unsigned is None:
Expand Down Expand Up @@ -690,6 +724,7 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.fadd(
llvm_lhs, llvm_rhs, "addtmp"
)
self._apply_fast_math(result)
else:
# there's more conditions to be handled
result = self._llvm.ir_builder.add(
Expand All @@ -703,6 +738,7 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.fsub(
llvm_lhs, llvm_rhs, "subtmp"
)
self._apply_fast_math(result)
else:
# note: be careful you should handle this as INT32
result = self._llvm.ir_builder.sub(
Expand All @@ -717,6 +753,7 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.fmul(
llvm_lhs, llvm_rhs, "multmp"
)
self._apply_fast_math(result)
else:
# note: be careful you should handle this as INT32
result = self._llvm.ir_builder.mul(
Expand Down Expand Up @@ -782,6 +819,7 @@ def visit(self, node: astx.BinaryOp) -> None:
result = self._llvm.ir_builder.fdiv(
llvm_lhs, llvm_rhs, "divtmp"
)
self._apply_fast_math(result)
else:
# Assuming the division is signed by default. Use `udiv` for
# unsigned division.
Expand Down
20 changes: 20 additions & 0 deletions tests/test_llvmlite_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,23 @@ def test_emit_int_div_signed_and_unsigned() -> None:

assert getattr(signed, "opname", "") == "sdiv"
assert getattr(unsigned, "opname", "") == "udiv"


def test_set_fast_math_marks_float_ops() -> None:
"""set_fast_math should add fast flag to floating instructions."""
visitor = LLVMLiteIRVisitor()
_prime_builder(visitor)

float_ty = visitor._llvm.FLOAT_TYPE
lhs = ir.Constant(float_ty, 1.0)
rhs = ir.Constant(float_ty, 2.0)

visitor.set_fast_math(True)
inst_fast = visitor._llvm.ir_builder.fadd(lhs, rhs)
visitor._apply_fast_math(inst_fast)
assert "fast" in inst_fast.flags

visitor.set_fast_math(False)
inst_normal = visitor._llvm.ir_builder.fadd(lhs, rhs)
visitor._apply_fast_math(inst_normal)
assert "fast" not in inst_normal.flags
Loading