diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 724e579d4d89..c5dc1ba8b041 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -18,15 +18,16 @@ component operations in order, including any nested CircuitOperations. """ from typing import ( - TYPE_CHECKING, AbstractSet, Callable, + cast, Dict, FrozenSet, Iterator, List, Optional, Tuple, + TYPE_CHECKING, Union, ) @@ -94,6 +95,12 @@ class CircuitOperation(ops.Operation): will have its path prepended with the repetition id for each repetition. When False, this will not happen and the measurement key will be repeated. + repeat_until: A condition that will be tested after each iteration of + the subcircuit. The subcircuit will repeat until condition returns + True, but will always run at least once, and the measurement key + need not be defined prior to the subcircuit (but must be defined in + a measurement within the subcircuit). This field is incompatible + with repetitions or repetition_ids. """ _hash: Optional[int] = dataclasses.field(default=None, init=False) @@ -103,6 +110,9 @@ class CircuitOperation(ops.Operation): _cached_control_keys: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field( default=None, init=False ) + _cached_mapped_single_loop: Optional['cirq.Circuit'] = dataclasses.field( + default=None, init=False + ) circuit: 'cirq.FrozenCircuit' repetitions: int = 1 @@ -113,6 +123,7 @@ class CircuitOperation(ops.Operation): parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) use_repetition_ids: bool = True + repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None) def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -148,6 +159,14 @@ def __post_init__(self): if q_new.dimension != q.dimension: raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}') + if self.repeat_until: + if self.use_repetition_ids or self.repetitions != 1: + raise ValueError('Cannot use repetitions with repeat_until') + if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint( + self.repeat_until.keys + ): + raise ValueError('Infinite loop: condition is not modified in subcircuit.') + # Ensure that param_resolver is converted to an actual ParamResolver. object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver)) @@ -174,6 +193,7 @@ def __eq__(self, other) -> bool: and self.repetition_ids == other.repetition_ids and self.parent_path == other.parent_path and self.use_repetition_ids == other.use_repetition_ids + and self.repeat_until == other.repeat_until ) # Methods for getting post-mapping properties of the contained circuit. @@ -223,6 +243,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: if not protocols.control_keys(self.circuit) else protocols.control_keys(self.mapped_circuit()) ) + if self.repeat_until is not None: + keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_() object.__setattr__(self, '_cached_control_keys', keys) return self._cached_control_keys # type: ignore @@ -235,6 +257,27 @@ def _parameter_names_(self) -> AbstractSet[str]: ) } + def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit': + if self._cached_mapped_single_loop is None: + circuit = self.circuit.unfreeze() + if self.qubit_map: + circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) + if self.repetitions < 0: + circuit = circuit ** -1 + if self.measurement_key_map: + circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) + if self.param_resolver: + circuit = protocols.resolve_parameters( + circuit, self.param_resolver, recursive=False + ) + object.__setattr__(self, '_cached_mapped_single_loop', circuit) + circuit = cast(circuits.Circuit, self._cached_mapped_single_loop) + if repetition_id: + circuit = protocols.with_rescoped_keys(circuit, (repetition_id,)) + return protocols.with_rescoped_keys( + circuit, self.parent_path, bindable_keys=self.extern_keys + ) + def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': """Applies all maps to the contained circuit and returns the result. @@ -249,24 +292,12 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': """ if self.repetitions == 0: return circuits.Circuit() - circuit = self.circuit.unfreeze() - if self.qubit_map: - circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) - if self.repetitions < 0: - circuit = circuit ** -1 - if self.measurement_key_map: - circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) - if self.param_resolver: - circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) - if self.repetition_ids is not None: - if not self.use_repetition_ids or not protocols.is_measurement(circuit): - circuit = circuit * abs(self.repetitions) - else: - circuit = circuits.Circuit( - protocols.with_rescoped_keys(circuit, (rep,)) for rep in self.repetition_ids - ) - circuit = protocols.with_rescoped_keys( - circuit, self.parent_path, bindable_keys=self.extern_keys + circuit = ( + circuits.Circuit(self._mapped_single_loop(rep) for rep in self.repetition_ids) + if self.repetition_ids is not None + and self.use_repetition_ids + and protocols.is_measurement(self.circuit) + else self._mapped_single_loop() * abs(self.repetitions) ) if deep: circuit = circuit.map_operations( @@ -282,8 +313,16 @@ def _decompose_(self) -> Iterator['cirq.Operation']: return self.mapped_circuit(deep=False).all_operations() def _act_on_(self, args: 'cirq.OperationTarget') -> bool: - for op in self._decompose_(): - protocols.act_on(op, args) + if self.repeat_until: + circuit = self._mapped_single_loop() + while True: + for op in circuit.all_operations(): + protocols.act_on(op, args) + if self.repeat_until.resolve(args.classical_data): + break + else: + for op in self._decompose_(): + protocols.act_on(op, args) return True # Methods for string representation of the operation. @@ -305,6 +344,8 @@ def __repr__(self): args += f'repetition_ids={proper_repr(self.repetition_ids)},\n' if not self.use_repetition_ids: args += 'use_repetition_ids=False,\n' + if self.repeat_until: + args += f'repeat_until={self.repeat_until!r},\n' indented_args = args.replace('\n', '\n ') return f'cirq.CircuitOperation({indented_args[:-4]})' @@ -337,6 +378,8 @@ def dict_str(d: Dict) -> str: args.append(f'loops={self.repetitions}') if not self.use_repetition_ids: args.append('no_rep_ids') + if self.repeat_until: + args.append(f'until={self.repeat_until}') if not args: return circuit_msg return f'{circuit_msg}({", ".join(args)})' @@ -375,6 +418,8 @@ def _json_dict_(self): } if not self.use_repetition_ids: resp['use_repetition_ids'] = False + if self.repeat_until: + resp['repeat_until'] = self.repeat_until return resp @classmethod @@ -388,10 +433,11 @@ def _from_json_dict_( repetition_ids, parent_path=(), use_repetition_ids=True, + repeat_until=None, **kwargs, ): return ( - cls(circuit, use_repetition_ids=use_repetition_ids) + cls(circuit, use_repetition_ids=use_repetition_ids, repeat_until=repeat_until) .with_qubit_mapping(dict(qubit_map)) .with_measurement_key_mapping(measurement_key_map) .with_params(param_resolver) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 6999bfbd8d3c..6f94e706beea 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -533,6 +533,24 @@ def test_string_format(): use_repetition_ids=False, )""" ) + op7 = cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.measure(x, key='a')), + use_repetition_ids=False, + repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')), + ) + assert ( + repr(op7) + == """\ +cirq.CircuitOperation( + circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')), + ), + ]), + use_repetition_ids=False, + repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')), +)""" + ) def test_json_dict(): @@ -977,4 +995,107 @@ def test_simulate_no_repetition_ids_inner(sim): assert result.records['1:a'].shape == (1, 2, 1) +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_repeat_until(sim): + q = cirq.LineQubit(0) + key = cirq.MeasurementKey('m') + c = cirq.Circuit( + cirq.X(q), + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q), + cirq.measure(q, key=key), + ), + use_repetition_ids=False, + repeat_until=cirq.KeyCondition(key), + ), + ) + measurements = sim.run(c).records['m'][0] + assert len(measurements) == 2 + assert measurements[0] == (0,) + assert measurements[1] == (1,) + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_repeat_until_sympy(sim): + q1, q2 = cirq.LineQubit.range(2) + circuitop = cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q2), + cirq.measure(q2, key='b'), + ), + use_repetition_ids=False, + repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))), + ) + c = cirq.Circuit( + cirq.measure(q1, key='a'), + circuitop, + ) + # Validate commutation + assert len(c) == 2 + assert cirq.control_keys(circuitop) == {cirq.MeasurementKey('a')} + measurements = sim.run(c).records['b'][0] + assert len(measurements) == 2 + assert measurements[0] == (1,) + assert measurements[1] == (0,) + + +@pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()]) +def test_post_selection(sim): + q = cirq.LineQubit(0) + key = cirq.MeasurementKey('m') + c = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q) ** 0.2, + cirq.measure(q, key=key), + ), + use_repetition_ids=False, + repeat_until=cirq.KeyCondition(key), + ), + ) + result = sim.run(c) + assert result.records['m'][0][-1] == (1,) + for i in range(len(result.records['m'][0]) - 1): + assert result.records['m'][0][i] == (0,) + + +def test_repeat_until_diagram(): + q = cirq.LineQubit(0) + key = cirq.MeasurementKey('m') + c = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q) ** 0.2, + cirq.measure(q, key=key), + ), + use_repetition_ids=False, + repeat_until=cirq.KeyCondition(key), + ), + ) + cirq.testing.assert_has_diagram( + c, + """ +0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)─── +""", + use_unicode_characters=True, + ) + + +def test_repeat_until_error(): + q = cirq.LineQubit(0) + with pytest.raises(ValueError, match='Cannot use repetitions with repeat_until'): + cirq.CircuitOperation( + cirq.FrozenCircuit(), + use_repetition_ids=True, + repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')), + ) + with pytest.raises(ValueError, match='Infinite loop'): + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.measure(q, key='m')), + use_repetition_ids=False, + repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')), + ) + + # TODO: Operation has a "gate" property. What is this for a CircuitOperation? diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index fe87b08ad679..7fbb4421dbef 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -294,6 +294,32 @@ "parent_path": [], "repetition_ids": null, "use_repetition_ids": false + }, + { + "cirq_type": "CircuitOperation", + "circuit": { + "cirq_type": "_SerializedKey", + "key": 1 + }, + "repetitions": 1, + "qubit_map": [], + "measurement_key_map": {}, + "param_resolver": { + "cirq_type": "ParamResolver", + "param_dict": [] + }, + "parent_path": [], + "repetition_ids": null, + "use_repetition_ids": false, + "repeat_until": { + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "0,1,2,3,4", + "path": [] + }, + "index": -1 + } } ] ] diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index 791ebda07e8e..ee527416d26a 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -36,4 +36,19 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ ), ]), param_resolver={sympy.Symbol('theta'): 1.5}, -use_repetition_ids=False)] \ No newline at end of file +use_repetition_ids=False), +cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.H(cirq.LineQubit(0)), + cirq.H(cirq.LineQubit(1)), + cirq.H(cirq.LineQubit(2)), + cirq.H(cirq.LineQubit(3)), + cirq.H(cirq.LineQubit(4)), + ), + cirq.Moment( + cirq.MeasurementGate(5, '0,1,2,3,4', ()).on(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), cirq.LineQubit(3), cirq.LineQubit(4)), + ), +]), +use_repetition_ids=False, +repeat_until=cirq.KeyCondition(key=cirq.MeasurementKey('0,1,2,3,4')), +)] \ No newline at end of file