diff --git a/amaranth/sim/_async.py b/amaranth/sim/_async.py index dd38e1206..2f41dd3d5 100644 --- a/amaranth/sim/_async.py +++ b/amaranth/sim/_async.py @@ -749,10 +749,18 @@ def reset(self): self.critical = not self.background self.waits_on = None self.coroutine = self.constructor(self.context) + self.first_await = True def run(self): try: self.waits_on = self.coroutine.send(None) + # Special case to make combination logic replacement work correctly: ensure that + # a process looping over `changed()` always gets awakened at least once at time 0, + # to see the initial values. + if self.first_await and self.waits_on.initial_eligible(): + self.waits_on.compute_result() + self.waits_on = self.coroutine.send(None) + self.first_await = False except StopIteration: self.critical = False self.waits_on = None diff --git a/amaranth/sim/pysim.py b/amaranth/sim/pysim.py index 3c1275a77..2cc646e6e 100644 --- a/amaranth/sim/pysim.py +++ b/amaranth/sim/pysim.py @@ -555,7 +555,7 @@ def activate(self): else: self._broken = True - def run(self): + def compute_result(self): result = [] for trigger in self._combination._triggers: if isinstance(trigger, (SampleTrigger, ChangedTrigger)): @@ -570,12 +570,20 @@ def run(self): assert False # :nocov: self._result = tuple(result) + def run(self): + self.compute_result() self._combination._process.runnable = True self._combination._process.waits_on = None self._triggers_hit.clear() for waker, interval_fs in self._delay_wakers.items(): self._engine.state.set_delay_waker(interval_fs, waker) + def initial_eligible(self): + return not self._oneshot and any( + isinstance(trigger, ChangedTrigger) + for trigger in self._combination._triggers + ) + def __await__(self): self._result = None if self._broken: diff --git a/tests/test_sim.py b/tests/test_sim.py index b68168380..612068d65 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -1506,6 +1506,22 @@ def test_comb_clock_conflict(self): r"^Clock signal is already driven by combinational logic$"): sim.add_clock(1e-6) + def test_initial(self): + a = Signal(4, init=3) + m = Module() + sim = Simulator(m) + fired = 0 + + async def process(ctx): + nonlocal fired + async for val_a, in ctx.changed(a): + self.assertEqual(val_a, 3) + fired += 1 + + sim.add_process(process) + sim.run() + self.assertEqual(fired, 1) + def test_sample(self): m = Module() m.domains.sync = cd_sync = ClockDomain()