diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 814dd659322..110856c2993 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -604,6 +604,7 @@ with_key_path, with_key_path_prefix, with_measurement_key_mapping, + with_rescoped_keys, ) from cirq.ion import ( diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index d61fb3b5be1..06a634e2291 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -913,6 +913,18 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): [protocols.with_key_path_prefix(moment, prefix) for moment in self.moments] ) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ): + moments = [] + for moment in self.moments: + new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys) + moments.append(new_moment) + bindable_keys |= protocols.measurement_key_objs(new_moment) + return self._with_sliced_moments(moments) + def _qid_shape_(self) -> Tuple[int, ...]: return self.qid_shape() @@ -1171,7 +1183,8 @@ def to_text_diagram_drawer( qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits()) cbits = tuple( sorted( - (key for op in self.all_operations() for key in protocols.control_keys(op)), key=str + set(key for op in self.all_operations() for key in protocols.control_keys(op)), + key=str, ) ) labels = qubits + cbits @@ -1524,6 +1537,10 @@ def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]: self._with_sliced_moments([m[qubits] for m in self.moments]) for qubits in qubit_factors ) + def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: + controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op)) + return controls - protocols.measurement_key_objs(self) + def _overlap_collision_time( c1: Sequence['cirq.Moment'], c2: Sequence['cirq.Moment'], align: 'cirq.Alignment' diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 5dc54b68dae..15c9ef4af10 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -22,11 +22,12 @@ AbstractSet, Callable, Dict, + FrozenSet, + Iterator, List, Optional, Tuple, Union, - Iterator, ) import dataclasses @@ -77,11 +78,18 @@ class CircuitOperation(ops.Operation): The keys and values should be unindexed (i.e. without repetition_ids). The values cannot contain the `MEASUREMENT_KEY_SEPARATOR`. param_resolver: Resolved values for parameters in the circuit. - parent_path: A tuple of identifiers for any parent CircuitOperations containing this one. repetition_ids: List of identifiers for each repetition of the CircuitOperation. If populated, the length should be equal to the repetitions. If not populated and abs(`repetitions`) > 1, it is initialized to strings for numbers in `range(repetitions)`. + parent_path: A tuple of identifiers for any parent CircuitOperations + containing this one. + extern_keys: The set of measurement keys defined at extern scope. The + values here are used by decomposition and simulation routines to + cache which external measurement keys exist as possible binding + targets for unbound `ClassicallyControlledOperation` keys. This + field is not intended to be set or changed manually, and should be + empty in circuits that aren't in the middle of decomposition. """ _hash: Optional[int] = dataclasses.field(default=None, init=False) @@ -96,6 +104,7 @@ class CircuitOperation(ops.Operation): param_resolver: study.ParamResolver = study.ParamResolver() repetition_ids: Optional[List[str]] = dataclasses.field(default=None) parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -184,9 +193,7 @@ def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: for repetition_id in self.repetition_ids for key in circuit_keys } - circuit_keys = { - protocols.with_key_path_prefix(key, self.parent_path) for key in circuit_keys - } + circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys} object.__setattr__( self, '_cached_measurement_key_objs', @@ -200,6 +207,11 @@ def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: def _measurement_key_names_(self) -> AbstractSet[str]: return {str(key) for key in self._measurement_key_objs_()} + def _control_keys_(self) -> AbstractSet[value.MeasurementKey]: + if not protocols.control_keys(self.circuit): + return frozenset() + return protocols.control_keys(self.mapped_circuit()) + def _parameter_names_(self) -> AbstractSet[str]: return { name @@ -222,26 +234,28 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': like `cirq.decompose(self)`, but preserving moment structure. """ circuit = self.circuit.unfreeze() - circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) + if self.qubit_map: + circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) if self.repetitions < 0: circuit = circuit ** -1 - has_measurements = protocols.is_measurement(circuit) - if has_measurements: + if self.measurement_key_map: circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) - circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) - if deep: - circuit = circuit.map_operations( - lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op - ) + if self.param_resolver: + circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) if self.repetition_ids: - if not has_measurements: + if not protocols.is_measurement(circuit): circuit = circuit * abs(self.repetitions) else: circuit = circuits.Circuit( - protocols.with_key_path_prefix(circuit, (rep,)) for rep in self.repetition_ids + protocols.with_rescoped_keys(circuit, (rep,)) for rep in self.repetition_ids ) - if self.parent_path: - circuit = protocols.with_key_path_prefix(circuit, self.parent_path) + circuit = protocols.with_rescoped_keys( + circuit, self.parent_path, bindable_keys=self.extern_keys + ) + if deep: + circuit = circuit.map_operations( + lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op + ) return circuit def mapped_op(self, deep: bool = False) -> 'cirq.CircuitOperation': @@ -430,6 +444,21 @@ def _with_key_path_(self, path: Tuple[str, ...]): def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): return dataclasses.replace(self, parent_path=prefix + self.parent_path) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ): + # The following line prevents binding to measurement keys in previous repeated subcircuits + # "just because their repetition ids matched". If we eventually decide to change that + # requirement and allow binding across subcircuits (possibly conditionally upon the key or + # the subcircuit having some 'allow_cross_circuit_binding' field set), this is the line to + # change or remove. + bindable_keys = frozenset(k for k in bindable_keys if len(k.path) <= len(path)) + bindable_keys |= {k.with_key_path_prefix(*path) for k in self.extern_keys} + path += self.parent_path + return dataclasses.replace(self, parent_path=path, extern_keys=bindable_keys) + def with_key_path(self, path: Tuple[str, ...]): return self._with_key_path_(path) @@ -518,14 +547,16 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera keys than this operation. """ new_map = {} - for k_obj in self.circuit.all_measurement_key_objs(): + for k_obj in protocols.measurement_keys_touched(self.circuit): k = k_obj.name k_new = self.measurement_key_map.get(k, k) k_new = key_map.get(k_new, k_new) if k_new != k: new_map[k] = k_new new_op = self.replace(measurement_key_map=new_map) - if len(new_op._measurement_key_objs_()) != len(self._measurement_key_objs_()): + if len(protocols.measurement_keys_touched(new_op)) != len( + protocols.measurement_keys_touched(self) + ): raise ValueError( f'Collision in measurement key map composition. Original map:\n' f'{self.measurement_key_map}\nApplied changes: {key_map}' diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index b2417e84b52..69b8a342407 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -647,8 +647,16 @@ def test_decompose_nested(): op2 = cirq.CircuitOperation(circuit2) circuit3 = cirq.FrozenCircuit( op2.with_params({exp1: exp_half}), - op2.with_params({exp1: exp_one}), - op2.with_params({exp1: exp_two}), + op2.with_params({exp1: exp_one}) + .with_measurement_key_mapping({'ma': 'ma1'}) + .with_measurement_key_mapping({'mb': 'mb1'}) + .with_measurement_key_mapping({'mc': 'mc1'}) + .with_measurement_key_mapping({'md': 'md1'}), + op2.with_params({exp1: exp_two}) + .with_measurement_key_mapping({'ma': 'ma2'}) + .with_measurement_key_mapping({'mb': 'mb2'}) + .with_measurement_key_mapping({'mc': 'mc2'}) + .with_measurement_key_mapping({'md': 'md2'}), ) op3 = cirq.CircuitOperation(circuit3) @@ -656,8 +664,16 @@ def test_decompose_nested(): expected_circuit1 = cirq.Circuit( op2.with_params({exp1: 0.5, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}), - op2.with_params({exp1: 1.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}), - op2.with_params({exp1: 2.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}), + op2.with_params({exp1: 1.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}) + .with_measurement_key_mapping({'ma': 'ma1'}) + .with_measurement_key_mapping({'mb': 'mb1'}) + .with_measurement_key_mapping({'mc': 'mc1'}) + .with_measurement_key_mapping({'md': 'md1'}), + op2.with_params({exp1: 2.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}) + .with_measurement_key_mapping({'ma': 'ma2'}) + .with_measurement_key_mapping({'mb': 'mb2'}) + .with_measurement_key_mapping({'mc': 'mc2'}) + .with_measurement_key_mapping({'md': 'md2'}), ) result_ops1 = cirq.decompose_once(final_op) @@ -673,21 +689,21 @@ def test_decompose_nested(): cirq.X(d) ** 0.5, cirq.measure(d, key='md'), cirq.X(a) ** 1.0, - cirq.measure(a, key='ma'), + cirq.measure(a, key='ma1'), cirq.X(b) ** 1.0, - cirq.measure(b, key='mb'), + cirq.measure(b, key='mb1'), cirq.X(c) ** 1.0, - cirq.measure(c, key='mc'), + cirq.measure(c, key='mc1'), cirq.X(d) ** 1.0, - cirq.measure(d, key='md'), + cirq.measure(d, key='md1'), cirq.X(a) ** 2.0, - cirq.measure(a, key='ma'), + cirq.measure(a, key='ma2'), cirq.X(b) ** 2.0, - cirq.measure(b, key='mb'), + cirq.measure(b, key='mb2'), cirq.X(c) ** 2.0, - cirq.measure(c, key='mc'), + cirq.measure(c, key='mc2'), cirq.X(d) ** 2.0, - cirq.measure(d, key='md'), + cirq.measure(d, key='md2'), ) assert cirq.Circuit(cirq.decompose(final_op)) == expected_circuit # Verify that mapped_circuit gives the same operations. @@ -816,4 +832,24 @@ def test_mapped_circuit_keeps_keys_under_parent_path(): assert cirq.measurement_key_names(op2.mapped_circuit()) == {'X:A', 'X:B', 'X:C', 'X:D'} +def test_keys_conflict_no_repetitions(): + q = cirq.LineQubit(0) + op1 = cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.measure(q, key='A'), + ) + ) + op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1)) + with pytest.raises(ValueError, match='Conflicting measurement keys found: A'): + _ = op2.mapped_circuit(deep=True) + + +def test_keys_conflict_locally(): + q = cirq.LineQubit(0) + op1 = cirq.measure(q, key='A') + op2 = cirq.CircuitOperation(cirq.FrozenCircuit(op1, op1)) + with pytest.raises(ValueError, match='Conflicting measurement keys found: A'): + _ = op2.mapped_circuit() + + # TODO: Operation has a "gate" property. What is this for a CircuitOperation? diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 1d1994cb14c..81ff00e6292 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -267,6 +267,72 @@ def test_append_multiple(): ) +def test_append_control_key_subcircuit(): + q0, q1 = cirq.LineQubit.range(2) + + c = cirq.Circuit() + c.append(cirq.measure(q0, key='a')) + c.append( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'a')) + ) + ) + assert len(c) == 2 + + c = cirq.Circuit() + c.append(cirq.measure(q0, key='a')) + c.append( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b')) + ) + ) + assert len(c) == 1 + + c = cirq.Circuit() + c.append(cirq.measure(q0, key='a')) + c.append( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b')) + ).with_measurement_key_mapping({'b': 'a'}) + ) + assert len(c) == 2 + + c = cirq.Circuit() + c.append(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0, key='a')))) + c.append( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b')) + ).with_measurement_key_mapping({'b': 'a'}) + ) + assert len(c) == 2 + + c = cirq.Circuit() + c.append( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.measure(q0, key='a')) + ).with_measurement_key_mapping({'a': 'c'}) + ) + c.append( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b')) + ).with_measurement_key_mapping({'b': 'c'}) + ) + assert len(c) == 2 + + c = cirq.Circuit() + c.append( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.measure(q0, key='a')) + ).with_measurement_key_mapping({'a': 'b'}) + ) + c.append( + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.ClassicallyControlledOperation(cirq.X(q1), 'b')) + ).with_measurement_key_mapping({'b': 'a'}) + ) + assert len(c) == 1 + + def test_append_moments(): a = cirq.NamedQubit('a') b = cirq.NamedQubit('b') diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index 7deefe2f7e7..50f2fe9b4e2 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -76,6 +76,7 @@ def __init__( self._has_measurements: Optional[bool] = None self._all_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None self._are_all_measurements_terminal: Optional[bool] = None + self._control_keys: Optional[FrozenSet[value.MeasurementKey]] = None @property def moments(self) -> Sequence['cirq.Moment']: @@ -133,6 +134,11 @@ def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]: def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: return self.all_measurement_key_objs() + def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: + if self._control_keys is None: + self._control_keys = super()._control_keys_() + return self._control_keys + def are_all_measurements_terminal(self) -> bool: if self._are_all_measurements_terminal is None: self._are_all_measurements_terminal = super().are_all_measurements_terminal() diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 94114e742f1..d4da88fdcc2 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -169,13 +169,32 @@ def not_zero(measurement): def _with_measurement_key_mapping_( self, key_map: Dict[str, str] ) -> 'ClassicallyControlledOperation': - keys = [protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys] - return self._sub_operation.with_classical_controls(*keys) + sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map) + sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation + return sub_operation.with_classical_controls( + *[protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys] + ) def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': keys = [protocols.with_key_path_prefix(k, path) for k in self._control_keys] return self._sub_operation.with_classical_controls(*keys) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ) -> 'ClassicallyControlledOperation': + def map_key(key: value.MeasurementKey) -> value.MeasurementKey: + for i in range(len(path) + 1): + back_path = path[: len(path) - i] + new_key = key.with_key_path_prefix(*back_path) + if new_key in bindable_keys: + return new_key + return key + + sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys) + return sub_operation.with_classical_controls(*[map_key(k) for k in self._control_keys]) + def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation)) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index ff46dccb5fb..a9896dbed0d 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -390,6 +390,215 @@ def test_str(): assert str(op) == 'X(0).with_classical_controls(a)' +def test_scope_local(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:0:a', '0:1:a', '1:0:a', '1:1:a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────── ]──────────── + [ [ a: ═══@═══^═══ ](loops=2) ](loops=2) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───────M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +0:0:a: ═══@═══^═══╬═══╬═══╬═══╬═══╬═══╬═══ + ║ ║ ║ ║ ║ ║ +0:1:a: ═══════════@═══^═══╬═══╬═══╬═══╬═══ + ║ ║ ║ ║ +1:0:a: ═══════════════════@═══^═══╬═══╬═══ + ║ ║ +1:1:a: ═══════════════════════════@═══^═══ +""", + use_unicode_characters=True, + ) + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + +def test_scope_extern(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('b'), + ) + middle = cirq.Circuit( + cirq.measure(q, key=cirq.MeasurementKey('b')), + cirq.CircuitOperation(inner.freeze(), repetitions=2), + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:b', '0:b', '1:b', '1:b'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M('a')───X─── ] ] + [ 0: ───M───[ ║ ]──────────── ] +0: ───[ ║ [ b: ════════════^═══ ](loops=2) ]──────────── + [ ║ ║ ] + [ b: ═══@═══╩══════════════════════════════════ ](loops=2) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ─────M───M('0:0:a')───X───M('0:1:a')───X───M───M('1:0:a')───X───M('1:1:a')───X─── + ║ ║ ║ ║ ║ ║ +0:b: ═══@════════════════^════════════════^═══╬════════════════╬════════════════╬═══ + ║ ║ ║ +1:b: ═════════════════════════════════════════@════════════════^════════════════^═══ +""", + use_unicode_characters=True, + ) + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + +def test_scope_extern_wrapping_with_non_repeating_subcircuits(): + def wrap(*ops): + return cirq.CircuitOperation(cirq.FrozenCircuit(*ops)) + + def wrap_frozen(*ops): + return cirq.FrozenCircuit(wrap(*ops)) + + q = cirq.LineQubit(0) + inner = wrap_frozen( + wrap(cirq.measure(q, key='a')), + wrap(cirq.X(q).with_classical_controls('b')), + ) + middle = wrap_frozen( + wrap(cirq.measure(q, key=cirq.MeasurementKey('b'))), + wrap(cirq.CircuitOperation(inner, repetitions=2)), + ) + outer_subcircuit = cirq.CircuitOperation(middle, repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:b', '0:b', '1:b', '1:b'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ─────M───M('0:0:a')───X───M('0:1:a')───X───M───M('1:0:a')───X───M('1:1:a')───X─── + ║ ║ ║ ║ ║ ║ +0:b: ═══@════════════════^════════════════^═══╬════════════════╬════════════════╬═══ + ║ ║ ║ +1:b: ═════════════════════════════════════════@════════════════^════════════════^═══ +""", + use_unicode_characters=True, + ) + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + +def test_scope_root(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('b'), + ) + middle = cirq.Circuit( + cirq.measure(q, key=cirq.MeasurementKey('c')), + cirq.CircuitOperation(inner.freeze(), repetitions=2), + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['b', 'b', 'b', 'b'] + assert cirq.control_keys(outer_subcircuit) == {cirq.MeasurementKey('b')} + assert cirq.control_keys(circuit) == {cirq.MeasurementKey('b')} + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M('a')───X─── ] ] + [ 0: ───M('c')───[ ║ ]──────────── ] +0: ───[ [ b: ════════════^═══ ](loops=2) ]──────────── + [ ║ ] + [ b: ════════════╩══════════════════════════════════ ](loops=2) + ║ +b: ═══╩═════════════════════════════════════════════════════════════════ +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M('0:c')───M('0:0:a')───X───M('0:1:a')───X───M('1:c')───M('1:0:a')───X───M('1:1:a')───X─── + ║ ║ ║ ║ +b: ═══════════════════════════^════════════════^═══════════════════════════^════════════════^═══ +""", + use_unicode_characters=True, + ) + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + +def test_scope_extern_mismatch(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('b'), + ) + middle = cirq.Circuit( + cirq.measure(q, key=cirq.MeasurementKey('b', ('0',))), + cirq.CircuitOperation(inner.freeze(), repetitions=2), + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['b', 'b', 'b', 'b'] + assert cirq.control_keys(outer_subcircuit) == {cirq.MeasurementKey('b')} + assert cirq.control_keys(circuit) == {cirq.MeasurementKey('b')} + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M('a')───X─── ] ] + [ 0: ───M('0:b')───[ ║ ]──────────── ] +0: ───[ [ b: ════════════^═══ ](loops=2) ]──────────── + [ ║ ] + [ b: ══════════════╩══════════════════════════════════ ](loops=2) + ║ +b: ═══╩═══════════════════════════════════════════════════════════════════ +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M('0:0:b')───M('0:0:a')───X───M('0:1:a')───X───M('1:0:b')───M('1:0:a')───X───M('1:1:a')───X─── + ║ ║ ║ ║ +b: ═════════════════════════════^════════════════^═════════════════════════════^════════════════^═══ +""", + use_unicode_characters=True, + ) + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + + def test_repr(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a') @@ -416,3 +625,47 @@ def test_unmeasured_condition(): ), ): _ = cirq.Simulator().simulate(bad_circuit) + + +def test_layered_circuit_operations_with_controls_in_between(): + q = cirq.LineQubit(0) + outer_subcircuit = cirq.CircuitOperation( + cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q), + cirq.Y(q), + ) + ).with_classical_controls('m') + ).freeze() + ) + circuit = outer_subcircuit.mapped_circuit(deep=True) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ 0: ───[ 0: ───X───Y─── ].with_classical_controls(m)─── ] +0: ───[ ║ ]─── + [ m: ═══╩═══════════════════════════════════════════════ ] + ║ +m: ═══╩════════════════════════════════════════════════════════════ +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───[ 0: ───X───Y─── ].with_classical_controls(m)─── + ║ +m: ═══╩═══════════════════════════════════════════════ +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + cirq.Circuit(cirq.decompose(outer_subcircuit)), + """ +0: ───X───Y─── + ║ ║ +m: ═══^═══^═══ +""", + use_unicode_characters=True, + ) diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index f981fc015aa..706d91d2f0c 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -108,6 +108,16 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): return self return new_gate.on(*self.qubits) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ): + new_gate = protocols.with_rescoped_keys(self.gate, path, bindable_keys) + if new_gate is self.gate: + return self + return new_gate.on(*self.qubits) + def __repr__(self): if hasattr(self.gate, '_op_repr_'): result = self.gate._op_repr_(self.qubits) diff --git a/cirq-core/cirq/ops/kraus_channel.py b/cirq-core/cirq/ops/kraus_channel.py index 7b425e9f69e..402606e61c5 100644 --- a/cirq-core/cirq/ops/kraus_channel.py +++ b/cirq-core/cirq/ops/kraus_channel.py @@ -1,5 +1,5 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice -from typing import Any, Dict, Iterable, Tuple, Union +from typing import Any, Dict, FrozenSet, Iterable, Tuple, Union import numpy as np from cirq import linalg, protocols, value @@ -96,6 +96,16 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): kraus_ops=self._kraus_ops, key=protocols.with_key_path_prefix(self._key, prefix) ) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet[value.MeasurementKey], + ): + return KrausChannel( + kraus_ops=self._kraus_ops, + key=protocols.with_rescoped_keys(self._key, path, bindable_keys), + ) + def __str__(self): if self._key is not None: return f'KrausChannel({self._kraus_ops}, key={self._key})' diff --git a/cirq-core/cirq/ops/measurement_gate.py b/cirq-core/cirq/ops/measurement_gate.py index ffde4399ae0..58eecea0d3f 100644 --- a/cirq-core/cirq/ops/measurement_gate.py +++ b/cirq-core/cirq/ops/measurement_gate.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, FrozenSet, Iterable, Optional, Tuple, Sequence, TYPE_CHECKING, Union import numpy as np @@ -98,6 +98,13 @@ def _with_key_path_(self, path: Tuple[str, ...]): def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): return self.with_key(self.mkey._with_key_path_prefix_(prefix)) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ): + return self.with_key(protocols.with_rescoped_keys(self.mkey, path, bindable_keys)) + def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map)) diff --git a/cirq-core/cirq/ops/mixed_unitary_channel.py b/cirq-core/cirq/ops/mixed_unitary_channel.py index 6eb9fb4ede4..eb03e78f5d6 100644 --- a/cirq-core/cirq/ops/mixed_unitary_channel.py +++ b/cirq-core/cirq/ops/mixed_unitary_channel.py @@ -1,5 +1,5 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice -from typing import Any, Dict, Iterable, Tuple, Union +from typing import Any, Dict, FrozenSet, Iterable, Tuple, Union import numpy as np from cirq import linalg, protocols, value @@ -107,6 +107,16 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): mixture=self._mixture, key=protocols.with_key_path_prefix(self._key, prefix) ) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet[value.MeasurementKey], + ): + return MixedUnitaryChannel( + mixture=self._mixture, + key=protocols.with_rescoped_keys(self._key, path, bindable_keys), + ) + def __str__(self): if self._key is not None: return f'MixedUnitaryChannel({self._mixture}, key={self._key})' diff --git a/cirq-core/cirq/ops/moment.py b/cirq-core/cirq/ops/moment.py index 69f1e22a5fb..92e4de47bef 100644 --- a/cirq-core/cirq/ops/moment.py +++ b/cirq-core/cirq/ops/moment.py @@ -254,6 +254,15 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): for op in self.operations ) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ): + return Moment( + protocols.with_rescoped_keys(op, path, bindable_keys) for op in self.operations + ) + def __copy__(self): return type(self)(self.operations) diff --git a/cirq-core/cirq/ops/pauli_measurement_gate.py b/cirq-core/cirq/ops/pauli_measurement_gate.py index 43aa751f9a4..c383ff20030 100644 --- a/cirq-core/cirq/ops/pauli_measurement_gate.py +++ b/cirq-core/cirq/ops/pauli_measurement_gate.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, Tuple, Sequence, TYPE_CHECKING, Union - +from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union from cirq import protocols, value from cirq.ops import ( @@ -85,6 +84,13 @@ def _with_key_path_(self, path: Tuple[str, ...]) -> 'PauliMeasurementGate': def _with_key_path_prefix_(self, prefix: Tuple[str, ...]) -> 'PauliMeasurementGate': return self.with_key(self.mkey._with_key_path_prefix_(prefix)) + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ) -> 'PauliMeasurementGate': + return self.with_key(protocols.with_rescoped_keys(self.mkey, path, bindable_keys)) + def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'PauliMeasurementGate': return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map)) diff --git a/cirq-core/cirq/protocols/__init__.py b/cirq-core/cirq/protocols/__init__.py index 3595ca5107b..f51d36370e3 100644 --- a/cirq-core/cirq/protocols/__init__.py +++ b/cirq-core/cirq/protocols/__init__.py @@ -108,6 +108,7 @@ with_key_path, with_key_path_prefix, with_measurement_key_mapping, + with_rescoped_keys, SupportsMeasurementKey, ) from cirq.protocols.mixture_protocol import ( diff --git a/cirq-core/cirq/protocols/apply_mixture_protocol.py b/cirq-core/cirq/protocols/apply_mixture_protocol.py index 5db8744acac..c6f4b915fef 100644 --- a/cirq-core/cirq/protocols/apply_mixture_protocol.py +++ b/cirq-core/cirq/protocols/apply_mixture_protocol.py @@ -267,7 +267,6 @@ def err_str(buf_num_str): # Don't know how to apply mixture. Fallback to specified default behavior. # (STEP D) if default is not RaiseTypeErrorIfNotProvided: - print('HERE!') return default raise TypeError( "object of type '{}' has no _apply_mixture_, _apply_unitary_, " diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr index bfc25256cff..bbc3a1dc22b 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr @@ -1 +1 @@ -cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.MeasurementKey('a'), cirq.MeasurementKey('b')]) \ No newline at end of file +cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.MeasurementKey('a'), cirq.MeasurementKey('b')]) diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 09d8b60bdf3..fc398d2cf61 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -13,7 +13,7 @@ # limitations under the License. """Protocol for object that have measurement keys.""" -from typing import AbstractSet, Any, Dict, Optional, Tuple +from typing import AbstractSet, Any, Dict, FrozenSet, Optional, Tuple from typing_extensions import Protocol @@ -305,3 +305,26 @@ def with_key_path_prefix(val: Any, prefix: Tuple[str, ...]): """ getter = getattr(val, '_with_key_path_prefix_', None) return NotImplemented if getter is None else getter(prefix) + + +def with_rescoped_keys( + val: Any, + path: Tuple[str, ...], + bindable_keys: FrozenSet[value.MeasurementKey] = None, +): + """Rescopes any measurement and control keys to the provided path, given the existing keys. + + The path usually refers to an identifier or a list of identifiers from a subcircuit that + used to contain the target. Since a subcircuit can be repeated and reused, these paths help + differentiate the actual measurement keys. + + This function is generally for internal use in decomposing or iterating subcircuits. + + Args: + val: The value to rescope. + path: The prefix to apply to the value's path. + bindable_keys: The keys that can be bound to at the current scope. + """ + getter = getattr(val, '_with_rescoped_keys_', None) + result = NotImplemented if getter is None else getter(path, bindable_keys or frozenset()) + return result if result is not NotImplemented else val diff --git a/cirq-core/cirq/value/measurement_key.py b/cirq-core/cirq/value/measurement_key.py index 2454c1a5dbe..ee4c12bb051 100644 --- a/cirq-core/cirq/value/measurement_key.py +++ b/cirq-core/cirq/value/measurement_key.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple +from typing import Dict, FrozenSet, Optional, Tuple import dataclasses @@ -107,13 +107,23 @@ def _with_key_path_(self, path: Tuple[str, ...]): def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): return self._with_key_path_(path=prefix + self.path) - def with_key_path_prefix(self, path_component: str): + def with_key_path_prefix(self, *path_component: str): """Adds the input path component to the start of the path. Useful when constructing the path from inside to out (in case of nested subcircuits), recursively. """ - return self._with_key_path_prefix_((path_component,)) + return self.replace(path=path_component + self.path) + + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['MeasurementKey'], + ): + new_key = self.replace(path=path + self.path) + if new_key in bindable_keys: + raise ValueError(f'Conflicting measurement keys found: {new_key}') + return new_key def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): if self.name not in key_map: