Skip to content

Commit

Permalink
Make qubits a memoized property of Moment (#6894)
Browse files Browse the repository at this point in the history
Moment already has a _qubit_to_op dict. Maintaining a separate _qubits frozenset is inefficient when constructing circuits, as it needs copied each time an op is added to a moment.

In particular, for wide moments, this change seems to result in about a 3x speedup. On my laptop, creating a moment with X gates on 10_000 qubits takes 4s before this change, and 1.3s after.
  • Loading branch information
daxfohl authored Jan 14, 2025
1 parent 7f66b42 commit f596c43
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""A simplified time-slice of operations within a sequenced circuit."""

import itertools
from functools import cached_property
from types import NotImplementedType
from typing import (
AbstractSet,
Expand Down Expand Up @@ -113,7 +114,6 @@ def __init__(self, *contents: 'cirq.OP_TREE', _flatten_contents: bool = True) ->
raise ValueError(f'Overlapping operations: {self.operations}')
self._qubit_to_op[q] = op

self._qubits = frozenset(self._qubit_to_op.keys())
self._measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None

Expand All @@ -135,9 +135,9 @@ def from_ops(cls, *ops: 'cirq.Operation') -> 'cirq.Moment':
def operations(self) -> Tuple['cirq.Operation', ...]:
return self._operations

@property
@cached_property
def qubits(self) -> FrozenSet['cirq.Qid']:
return self._qubits
return frozenset(self._qubit_to_op)

def operates_on_single_qubit(self, qubit: 'cirq.Qid') -> bool:
"""Determines if the moment has operations touching the given qubit.
Expand All @@ -157,7 +157,7 @@ def operates_on(self, qubits: Iterable['cirq.Qid']) -> bool:
Returns:
Whether this moment has operations involving the qubits.
"""
return not self._qubits.isdisjoint(qubits)
return not self._qubit_to_op.keys().isdisjoint(qubits)

def operation_at(self, qubit: raw_types.Qid) -> Optional['cirq.Operation']:
"""Returns the operation on a certain qubit for the moment.
Expand Down Expand Up @@ -185,14 +185,13 @@ def with_operation(self, operation: 'cirq.Operation') -> 'cirq.Moment':
Raises:
ValueError: If the operation given overlaps a current operation in the moment.
"""
if any(q in self._qubits for q in operation.qubits):
if any(q in self._qubit_to_op for q in operation.qubits):
raise ValueError(f'Overlapping operations: {operation}')

# Use private variables to facilitate a quick copy.
m = Moment(_flatten_contents=False)
m._operations = self._operations + (operation,)
m._sorted_operations = None
m._qubits = self._qubits.union(operation.qubits)
m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}}

m._measurement_key_objs = self._measurement_key_objs_().union(
Expand Down Expand Up @@ -222,14 +221,11 @@ def with_operations(self, *contents: 'cirq.OP_TREE') -> 'cirq.Moment':
m = Moment(_flatten_contents=False)
# Use private variables to facilitate a quick copy.
m._qubit_to_op = self._qubit_to_op.copy()
qubits = set(self._qubits)
for op in flattened_contents:
if any(q in qubits for q in op.qubits):
if any(q in m._qubit_to_op for q in op.qubits):
raise ValueError(f'Overlapping operations: {op}')
qubits.update(op.qubits)
for q in op.qubits:
m._qubit_to_op[q] = op
m._qubits = frozenset(qubits)

m._operations = self._operations + flattened_contents
m._sorted_operations = None
Expand Down Expand Up @@ -450,7 +446,9 @@ def expand_to(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Moment':
@_compat.cached_method()
def _has_kraus_(self) -> bool:
"""Returns True if self has a Kraus representation and self uses <= 10 qubits."""
return all(protocols.has_kraus(op) for op in self.operations) and len(self.qubits) <= 10
return (
all(protocols.has_kraus(op) for op in self.operations) and len(self._qubit_to_op) <= 10
)

def _kraus_(self) -> Sequence[np.ndarray]:
r"""Returns Kraus representation of self.
Expand All @@ -475,7 +473,7 @@ def _kraus_(self) -> Sequence[np.ndarray]:
if not self._has_kraus_():
return NotImplemented

qubits = sorted(self.qubits)
qubits = sorted(self._qubit_to_op)
n = len(qubits)
if n < 1:
return (np.array([[1 + 0j]]),)
Expand Down Expand Up @@ -602,7 +600,7 @@ def to_text_diagram(
"""

# Figure out where to place everything.
qs = set(self.qubits) | set(extra_qubits)
qs = self._qubit_to_op.keys() | set(extra_qubits)
points = {xy_breakdown_func(q) for q in qs}
x_keys = sorted({pt[0] for pt in points}, key=_SortByValFallbackToType)
y_keys = sorted({pt[1] for pt in points}, key=_SortByValFallbackToType)
Expand Down

0 comments on commit f596c43

Please sign in to comment.