diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index 92d070594..d04bbbad6 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -17,6 +17,7 @@ "FragmentTransformer", "TransformedElaboratable", "DomainCollector", "DomainRenamer", "DomainLowerer", + "LHSMaskCollector", "ResetInserter", "EnableInserter"] @@ -601,6 +602,71 @@ def on_fragment(self, fragment): return super().on_fragment(fragment) +class LHSMaskCollector: + def __init__(self): + self.lhs = SignalDict() + + def visit_stmt(self, stmt): + if type(stmt) is Assign: + self.visit_value(stmt.lhs, ~0) + elif type(stmt) is Switch: + for (_, substmt, _) in stmt.cases: + self.visit_stmt(substmt) + elif type(stmt) in (Property, Print): + pass + elif isinstance(stmt, Iterable): + for substmt in stmt: + self.visit_stmt(substmt) + else: + assert False # :nocov: + + def visit_value(self, value, mask): + if type(value) in (Signal, ClockSignal, ResetSignal): + mask &= (1 << len(value)) - 1 + self.lhs.setdefault(value, 0) + self.lhs[value] |= mask + elif type(value) is Operator: + assert value.operator in ("s", "u") + self.visit_value(value.operands[0], mask) + elif type(value) is Slice: + slice_mask = (1 << value.stop) - (1 << value.start) + mask <<= value.start + mask &= slice_mask + self.visit_value(value.value, mask) + elif type(value) is Part: + # Could be more accurate, but if you're relying on such details, you're not seeing + # the Light of Heaven. + self.visit_value(value.value, ~0) + elif type(value) is Concat: + for part in value.parts: + self.visit_value(part, mask) + mask >>= len(part) + elif type(value) is SwitchValue: + for (_, subvalue) in value.cases: + self.visit_value(subvalue, mask) + else: + assert False # :nocov: + + def chunks(self): + for signal, mask in self.lhs.items(): + if mask == (1 << len(signal)) - 1: + yield signal, 0, None + else: + start = 0 + while start < len(signal): + if ((mask >> start) & 1) == 0: + start += 1 + else: + stop = start + while stop < len(signal) and ((mask >> stop) & 1) == 1: + stop += 1 + yield (signal, start, stop) + start = stop + + def masks(self): + yield from self.lhs.items() + + class _ControlInserter(FragmentTransformer): def __init__(self, controls): self.src_loc = None @@ -615,10 +681,9 @@ def on_fragment(self, fragment): for domain, statements in fragment.statements.items(): if domain == "comb" or domain not in self.controls: continue - signals = SignalSet() - for stmt in statements: - signals |= stmt._lhs_signals() - self._insert_control(new_fragment, domain, signals) + lhs_masks = LHSMaskCollector() + lhs_masks.visit_stmt(statements) + self._insert_control(new_fragment, domain, lhs_masks) return new_fragment def _insert_control(self, fragment, domain, signals): @@ -630,13 +695,20 @@ def __call__(self, value, *, src_loc_at=0): class ResetInserter(_ControlInserter): - def _insert_control(self, fragment, domain, signals): - stmts = [s.eq(Const(s.init, s.shape())) for s in signals if not s.reset_less] + def _insert_control(self, fragment, domain, lhs_masks): + stmts = [] + for signal, start, stop in lhs_masks.chunks(): + if signal.reset_less: + continue + if start == 0 and stop is None: + stmts.append(signal.eq(Const(signal.init, signal.shape()))) + else: + stmts.append(signal[start:stop].eq(Const(signal.init, signal.shape())[start:stop])) fragment.add_statements(domain, Switch(self.controls[domain], [(1, stmts, None)], src_loc=self.src_loc)) class EnableInserter(_ControlInserter): - def _insert_control(self, fragment, domain, signals): + def _insert_control(self, fragment, domain, _lhs_masks): if domain in fragment.statements: fragment.statements[domain] = _StatementList([Switch( self.controls[domain], diff --git a/amaranth/sim/_base.py b/amaranth/sim/_base.py index c63b95d81..7e58112a4 100644 --- a/amaranth/sim/_base.py +++ b/amaranth/sim/_base.py @@ -23,7 +23,7 @@ class BaseSignalState: curr = NotImplemented next = NotImplemented - def update(self, value): + def update(self, value, mask=~0): raise NotImplementedError # :nocov: diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index be624fa6b..0c94421ee 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -5,7 +5,7 @@ from ..hdl import * from ..hdl._ast import SignalSet, _StatementList, Property -from ..hdl._xfrm import ValueVisitor, StatementVisitor +from ..hdl._xfrm import ValueVisitor, StatementVisitor, LHSMaskCollector from ..hdl._mem import MemoryInstance from ._base import BaseProcess from ._pyeval import value_to_string @@ -487,19 +487,20 @@ def __call__(self, fragment): for domain_name in domains: domain_stmts = fragment.statements.get(domain_name, _StatementList()) domain_process = PyRTLProcess(is_comb=domain_name == "comb") - domain_signals = domain_stmts._lhs_signals() + lhs_masks = LHSMaskCollector() + lhs_masks.visit_stmt(domain_stmts) if isinstance(fragment, MemoryInstance): for port in fragment._read_ports: if port._domain == domain_name: - domain_signals.update(port._data._lhs_signals()) + lhs_masks.visit_value(port._data, ~0) emitter = _PythonEmitter() emitter.append(f"def run():") emitter._level += 1 if domain_name == "comb": - for signal in domain_signals: + for (signal, _) in lhs_masks.masks(): signal_index = self.state.get_signal(signal) self.state.slots[signal_index].is_comb = True emitter.append(f"next_{signal_index} = {signal.init}") @@ -533,7 +534,7 @@ def __call__(self, fragment): if domain.async_reset and domain.rst is not None: self.state.add_signal_waker(domain.rst, edge_waker(domain_process, 1)) - for signal in domain_signals: + for (signal, _) in lhs_masks.masks(): signal_index = self.state.get_signal(signal) emitter.append(f"next_{signal_index} = slots[{signal_index}].next") @@ -546,7 +547,7 @@ def __call__(self, fragment): emitter.append(f"if {rst}:") with emitter.indent(): emitter.append("pass") - for signal in domain_signals: + for (signal, _) in lhs_masks.masks(): if not signal.reset_less: signal_index = self.state.get_signal(signal) emitter.append(f"next_{signal_index} = {signal.init}") @@ -592,9 +593,11 @@ def __call__(self, fragment): lhs(port._data)(data) - for signal in domain_signals: + for (signal, mask) in lhs_masks.masks(): + if signal.shape().signed and (mask & 1 << (len(signal) - 1)): + mask |= -1 << len(signal) signal_index = self.state.get_signal(signal) - emitter.append(f"slots[{signal_index}].update(next_{signal_index})") + emitter.append(f"slots[{signal_index}].update(next_{signal_index}, {mask})") # There shouldn't be any exceptions raised by the generated code, but if there are # (almost certainly due to a bug in the code generator), use this environment variable diff --git a/amaranth/sim/pysim.py b/amaranth/sim/pysim.py index 2cc646e6e..d89ddca82 100644 --- a/amaranth/sim/pysim.py +++ b/amaranth/sim/pysim.py @@ -369,7 +369,8 @@ def add_waker(self, waker): assert waker not in self.wakers self.wakers.append(waker) - def update(self, value): + def update(self, value, mask=~0): + value = (self.next & ~mask) | (value & mask) if self.next != value: self.next = value self.pending.add(self) diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index 0ab2eeeaa..389629655 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -227,6 +227,7 @@ def setUp(self): self.s1 = Signal() self.s2 = Signal(init=1) self.s3 = Signal(init=1, reset_less=True) + self.s4 = Signal(8, init=0x3a) self.c1 = Signal() def test_reset_default(self): @@ -281,6 +282,40 @@ def test_reset_value(self): ) """) + def test_reset_mask(self): + f = Fragment() + f.add_statements("sync", self.s4[2:4].eq(0)) + + f = ResetInserter(self.c1)(f) + self.assertRepr(f.statements["sync"], """ + ( + (eq (slice (sig s4) 2:4) (const 1'd0)) + (switch (sig c1) + (case 1 (eq (slice (sig s4) 2:4) (slice (const 8'd58) 2:4))) + ) + ) + """) + + f = Fragment() + f.add_statements("sync", self.s4[2:4].eq(0)) + f.add_statements("sync", self.s4[3:5].eq(0)) + f.add_statements("sync", self.s4[6:10].eq(0)) + + f = ResetInserter(self.c1)(f) + self.assertRepr(f.statements["sync"], """ + ( + (eq (slice (sig s4) 2:4) (const 1'd0)) + (eq (slice (sig s4) 3:5) (const 1'd0)) + (eq (slice (sig s4) 6:8) (const 1'd0)) + (switch (sig c1) + (case 1 + (eq (slice (sig s4) 2:5) (slice (const 8'd58) 2:5)) + (eq (slice (sig s4) 6:8) (slice (const 8'd58) 6:8)) + ) + ) + ) + """) + def test_reset_less(self): f = Fragment() f.add_statements("sync", self.s3.eq(0)) @@ -423,3 +458,31 @@ def test_composition(self): ) ) """) + +class LHSMaskCollectorTestCase(FHDLTestCase): + def test_slice(self): + s = Signal(8) + lhs = LHSMaskCollector() + lhs.visit_value(s[2:5], ~0) + self.assertEqual(lhs.lhs[s], 0x1c) + + def test_slice_slice(self): + s = Signal(8) + lhs = LHSMaskCollector() + lhs.visit_value(s[2:7][1:3], ~0) + self.assertEqual(lhs.lhs[s], 0x18) + + def test_slice_concat(self): + s1 = Signal(8) + s2 = Signal(8) + lhs = LHSMaskCollector() + lhs.visit_value(Cat(s1, s2)[4:11], ~0) + self.assertEqual(lhs.lhs[s1], 0xf0) + self.assertEqual(lhs.lhs[s2], 0x07) + + def test_slice_part(self): + s = Signal(8) + idx = Signal(8) + lhs = LHSMaskCollector() + lhs.visit_value(s.bit_select(idx, 5)[1:3], ~0) + self.assertEqual(lhs.lhs[s], 0xff) diff --git a/tests/test_sim.py b/tests/test_sim.py index db9ea8723..e3a6a121d 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -1413,6 +1413,28 @@ async def testbench(ctx): with self.assertSimulation(Module(), traces=[mem1, mem2, mem3]) as sim: sim.add_testbench(testbench) + def test_multiple_modules(self): + m = Module() + m.submodules.m1 = m1 = Module() + m.submodules.m2 = m2 = Module() + a = Signal(8) + b = Signal(8) + m1.d.comb += b[0:2].eq(a[0:2]) + m1.d.comb += b[4:6].eq(a[4:6]) + m2.d.comb += b[2:4].eq(a[2:4]) + m2.d.comb += b[6:8].eq(a[6:8]) + with self.assertSimulation(m) as sim: + async def testbench(ctx): + ctx.set(a, 0) + self.assertEqual(ctx.get(b), 0) + ctx.set(a, 0x12) + self.assertEqual(ctx.get(b), 0x12) + ctx.set(a, 0x34) + self.assertEqual(ctx.get(b), 0x34) + ctx.set(a, 0xdb) + self.assertEqual(ctx.get(b), 0xdb) + sim.add_testbench(testbench) + class SimulatorTracesTestCase(FHDLTestCase): def assertDef(self, traces, flat_traces):