Skip to content

Commit

Permalink
sim: make driving parts of a signal from distinct modules possible.
Browse files Browse the repository at this point in the history
Fixes (part of) #1454.
  • Loading branch information
wanda-phi authored and whitequark committed Jul 22, 2024
1 parent a154b64 commit c78806f
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 17 deletions.
86 changes: 79 additions & 7 deletions amaranth/hdl/_xfrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"FragmentTransformer",
"TransformedElaboratable",
"DomainCollector", "DomainRenamer", "DomainLowerer",
"LHSMaskCollector",
"ResetInserter", "EnableInserter"]


Expand Down Expand Up @@ -601,6 +602,71 @@ def on_fragment(self, fragment):
return super().on_fragment(fragment)


class LHSMaskCollector:
def __init__(self):
self.lhs = SignalDict()

def visit_stmt(self, stmt):
if type(stmt) is Assign:
self.visit_value(stmt.lhs, ~0)
elif type(stmt) is Switch:
for (_, substmt, _) in stmt.cases:
self.visit_stmt(substmt)
elif type(stmt) in (Property, Print):
pass
elif isinstance(stmt, Iterable):
for substmt in stmt:
self.visit_stmt(substmt)
else:
assert False # :nocov:

def visit_value(self, value, mask):
if type(value) in (Signal, ClockSignal, ResetSignal):
mask &= (1 << len(value)) - 1
self.lhs.setdefault(value, 0)
self.lhs[value] |= mask
elif type(value) is Operator:
assert value.operator in ("s", "u")
self.visit_value(value.operands[0], mask)
elif type(value) is Slice:
slice_mask = (1 << value.stop) - (1 << value.start)
mask <<= value.start
mask &= slice_mask
self.visit_value(value.value, mask)
elif type(value) is Part:
# Could be more accurate, but if you're relying on such details, you're not seeing
# the Light of Heaven.
self.visit_value(value.value, ~0)
elif type(value) is Concat:
for part in value.parts:
self.visit_value(part, mask)
mask >>= len(part)
elif type(value) is SwitchValue:
for (_, subvalue) in value.cases:
self.visit_value(subvalue, mask)
else:
assert False # :nocov:

def chunks(self):
for signal, mask in self.lhs.items():
if mask == (1 << len(signal)) - 1:
yield signal, 0, None
else:
start = 0
while start < len(signal):
if ((mask >> start) & 1) == 0:
start += 1
else:
stop = start
while stop < len(signal) and ((mask >> stop) & 1) == 1:
stop += 1
yield (signal, start, stop)
start = stop

def masks(self):
yield from self.lhs.items()


class _ControlInserter(FragmentTransformer):
def __init__(self, controls):
self.src_loc = None
Expand All @@ -615,10 +681,9 @@ def on_fragment(self, fragment):
for domain, statements in fragment.statements.items():
if domain == "comb" or domain not in self.controls:
continue
signals = SignalSet()
for stmt in statements:
signals |= stmt._lhs_signals()
self._insert_control(new_fragment, domain, signals)
lhs_masks = LHSMaskCollector()
lhs_masks.visit_stmt(statements)
self._insert_control(new_fragment, domain, lhs_masks)
return new_fragment

def _insert_control(self, fragment, domain, signals):
Expand All @@ -630,13 +695,20 @@ def __call__(self, value, *, src_loc_at=0):


class ResetInserter(_ControlInserter):
def _insert_control(self, fragment, domain, signals):
stmts = [s.eq(Const(s.init, s.shape())) for s in signals if not s.reset_less]
def _insert_control(self, fragment, domain, lhs_masks):
stmts = []
for signal, start, stop in lhs_masks.chunks():
if signal.reset_less:
continue
if start == 0 and stop is None:
stmts.append(signal.eq(Const(signal.init, signal.shape())))
else:
stmts.append(signal[start:stop].eq(Const(signal.init, signal.shape())[start:stop]))
fragment.add_statements(domain, Switch(self.controls[domain], [(1, stmts, None)], src_loc=self.src_loc))


class EnableInserter(_ControlInserter):
def _insert_control(self, fragment, domain, signals):
def _insert_control(self, fragment, domain, _lhs_masks):
if domain in fragment.statements:
fragment.statements[domain] = _StatementList([Switch(
self.controls[domain],
Expand Down
2 changes: 1 addition & 1 deletion amaranth/sim/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseSignalState:
curr = NotImplemented
next = NotImplemented

def update(self, value):
def update(self, value, mask=~0):
raise NotImplementedError # :nocov:


Expand Down
19 changes: 11 additions & 8 deletions amaranth/sim/_pyrtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..hdl import *
from ..hdl._ast import SignalSet, _StatementList, Property
from ..hdl._xfrm import ValueVisitor, StatementVisitor
from ..hdl._xfrm import ValueVisitor, StatementVisitor, LHSMaskCollector
from ..hdl._mem import MemoryInstance
from ._base import BaseProcess
from ._pyeval import value_to_string
Expand Down Expand Up @@ -487,19 +487,20 @@ def __call__(self, fragment):
for domain_name in domains:
domain_stmts = fragment.statements.get(domain_name, _StatementList())
domain_process = PyRTLProcess(is_comb=domain_name == "comb")
domain_signals = domain_stmts._lhs_signals()
lhs_masks = LHSMaskCollector()
lhs_masks.visit_stmt(domain_stmts)

if isinstance(fragment, MemoryInstance):
for port in fragment._read_ports:
if port._domain == domain_name:
domain_signals.update(port._data._lhs_signals())
lhs_masks.visit_value(port._data, ~0)

emitter = _PythonEmitter()
emitter.append(f"def run():")
emitter._level += 1

if domain_name == "comb":
for signal in domain_signals:
for (signal, _) in lhs_masks.masks():
signal_index = self.state.get_signal(signal)
self.state.slots[signal_index].is_comb = True
emitter.append(f"next_{signal_index} = {signal.init}")
Expand Down Expand Up @@ -533,7 +534,7 @@ def __call__(self, fragment):
if domain.async_reset and domain.rst is not None:
self.state.add_signal_waker(domain.rst, edge_waker(domain_process, 1))

for signal in domain_signals:
for (signal, _) in lhs_masks.masks():
signal_index = self.state.get_signal(signal)
emitter.append(f"next_{signal_index} = slots[{signal_index}].next")

Expand All @@ -546,7 +547,7 @@ def __call__(self, fragment):
emitter.append(f"if {rst}:")
with emitter.indent():
emitter.append("pass")
for signal in domain_signals:
for (signal, _) in lhs_masks.masks():
if not signal.reset_less:
signal_index = self.state.get_signal(signal)
emitter.append(f"next_{signal_index} = {signal.init}")
Expand Down Expand Up @@ -592,9 +593,11 @@ def __call__(self, fragment):

lhs(port._data)(data)

for signal in domain_signals:
for (signal, mask) in lhs_masks.masks():
if signal.shape().signed and (mask & 1 << (len(signal) - 1)):
mask |= -1 << len(signal)
signal_index = self.state.get_signal(signal)
emitter.append(f"slots[{signal_index}].update(next_{signal_index})")
emitter.append(f"slots[{signal_index}].update(next_{signal_index}, {mask})")

# There shouldn't be any exceptions raised by the generated code, but if there are
# (almost certainly due to a bug in the code generator), use this environment variable
Expand Down
3 changes: 2 additions & 1 deletion amaranth/sim/pysim.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ def add_waker(self, waker):
assert waker not in self.wakers
self.wakers.append(waker)

def update(self, value):
def update(self, value, mask=~0):
value = (self.next & ~mask) | (value & mask)
if self.next != value:
self.next = value
self.pending.add(self)
Expand Down
63 changes: 63 additions & 0 deletions tests/test_hdl_xfrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def setUp(self):
self.s1 = Signal()
self.s2 = Signal(init=1)
self.s3 = Signal(init=1, reset_less=True)
self.s4 = Signal(8, init=0x3a)
self.c1 = Signal()

def test_reset_default(self):
Expand Down Expand Up @@ -281,6 +282,40 @@ def test_reset_value(self):
)
""")

def test_reset_mask(self):
f = Fragment()
f.add_statements("sync", self.s4[2:4].eq(0))

f = ResetInserter(self.c1)(f)
self.assertRepr(f.statements["sync"], """
(
(eq (slice (sig s4) 2:4) (const 1'd0))
(switch (sig c1)
(case 1 (eq (slice (sig s4) 2:4) (slice (const 8'd58) 2:4)))
)
)
""")

f = Fragment()
f.add_statements("sync", self.s4[2:4].eq(0))
f.add_statements("sync", self.s4[3:5].eq(0))
f.add_statements("sync", self.s4[6:10].eq(0))

f = ResetInserter(self.c1)(f)
self.assertRepr(f.statements["sync"], """
(
(eq (slice (sig s4) 2:4) (const 1'd0))
(eq (slice (sig s4) 3:5) (const 1'd0))
(eq (slice (sig s4) 6:8) (const 1'd0))
(switch (sig c1)
(case 1
(eq (slice (sig s4) 2:5) (slice (const 8'd58) 2:5))
(eq (slice (sig s4) 6:8) (slice (const 8'd58) 6:8))
)
)
)
""")

def test_reset_less(self):
f = Fragment()
f.add_statements("sync", self.s3.eq(0))
Expand Down Expand Up @@ -423,3 +458,31 @@ def test_composition(self):
)
)
""")

class LHSMaskCollectorTestCase(FHDLTestCase):
def test_slice(self):
s = Signal(8)
lhs = LHSMaskCollector()
lhs.visit_value(s[2:5], ~0)
self.assertEqual(lhs.lhs[s], 0x1c)

def test_slice_slice(self):
s = Signal(8)
lhs = LHSMaskCollector()
lhs.visit_value(s[2:7][1:3], ~0)
self.assertEqual(lhs.lhs[s], 0x18)

def test_slice_concat(self):
s1 = Signal(8)
s2 = Signal(8)
lhs = LHSMaskCollector()
lhs.visit_value(Cat(s1, s2)[4:11], ~0)
self.assertEqual(lhs.lhs[s1], 0xf0)
self.assertEqual(lhs.lhs[s2], 0x07)

def test_slice_part(self):
s = Signal(8)
idx = Signal(8)
lhs = LHSMaskCollector()
lhs.visit_value(s.bit_select(idx, 5)[1:3], ~0)
self.assertEqual(lhs.lhs[s], 0xff)
22 changes: 22 additions & 0 deletions tests/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,28 @@ async def testbench(ctx):
with self.assertSimulation(Module(), traces=[mem1, mem2, mem3]) as sim:
sim.add_testbench(testbench)

def test_multiple_modules(self):
m = Module()
m.submodules.m1 = m1 = Module()
m.submodules.m2 = m2 = Module()
a = Signal(8)
b = Signal(8)
m1.d.comb += b[0:2].eq(a[0:2])
m1.d.comb += b[4:6].eq(a[4:6])
m2.d.comb += b[2:4].eq(a[2:4])
m2.d.comb += b[6:8].eq(a[6:8])
with self.assertSimulation(m) as sim:
async def testbench(ctx):
ctx.set(a, 0)
self.assertEqual(ctx.get(b), 0)
ctx.set(a, 0x12)
self.assertEqual(ctx.get(b), 0x12)
ctx.set(a, 0x34)
self.assertEqual(ctx.get(b), 0x34)
ctx.set(a, 0xdb)
self.assertEqual(ctx.get(b), 0xdb)
sim.add_testbench(testbench)


class SimulatorTracesTestCase(FHDLTestCase):
def assertDef(self, traces, flat_traces):
Expand Down

0 comments on commit c78806f

Please sign in to comment.