Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat-until functionality to subcircuits #5018

Merged
merged 21 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 64 additions & 19 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
component operations in order, including any nested CircuitOperations.
"""
from typing import (
TYPE_CHECKING,
AbstractSet,
Callable,
cast,
Dict,
FrozenSet,
Iterator,
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -94,6 +95,8 @@ class CircuitOperation(ops.Operation):
will have its path prepended with the repetition id for each
repetition. When False, this will not happen and the measurement
key will be repeated.
repeat_until: A condition that will be tested after each iteration of
the circuit.
"""

_hash: Optional[int] = dataclasses.field(default=None, init=False)
Expand All @@ -103,6 +106,9 @@ class CircuitOperation(ops.Operation):
_cached_control_keys: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field(
default=None, init=False
)
_cached_mapped_single_loop: Optional['cirq.Circuit'] = dataclasses.field(
default=None, init=False
)

circuit: 'cirq.FrozenCircuit'
repetitions: int = 1
Expand All @@ -113,6 +119,7 @@ class CircuitOperation(ops.Operation):
parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset)
use_repetition_ids: bool = True
repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None)

def __post_init__(self):
if not isinstance(self.circuit, circuits.FrozenCircuit):
Expand Down Expand Up @@ -148,6 +155,14 @@ def __post_init__(self):
if q_new.dimension != q.dimension:
raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}')

if self.repeat_until:
if self.use_repetition_ids or self.repetitions != 1:
raise ValueError('Cannot use repetitions with repeat_until')
if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint(
self.repeat_until.keys
):
raise ValueError('Infinite loop: condition is not modified in subcircuit.')

# Ensure that param_resolver is converted to an actual ParamResolver.
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))

Expand All @@ -174,6 +189,7 @@ def __eq__(self, other) -> bool:
and self.repetition_ids == other.repetition_ids
and self.parent_path == other.parent_path
and self.use_repetition_ids == other.use_repetition_ids
and self.repeat_until == other.repeat_until
)

# Methods for getting post-mapping properties of the contained circuit.
Expand Down Expand Up @@ -223,6 +239,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
if not protocols.control_keys(self.circuit)
else protocols.control_keys(self.mapped_circuit())
)
if self.repeat_until is not None:
keys |= frozenset(self.repeat_until.keys)
object.__setattr__(self, '_cached_control_keys', keys)
return self._cached_control_keys # type: ignore

Expand All @@ -235,6 +253,28 @@ def _parameter_names_(self) -> AbstractSet[str]:
)
}

def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit':
if self._cached_mapped_single_loop is None:
circuit = self.circuit.unfreeze()
if self.qubit_map:
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.repetitions < 0:
circuit = circuit ** -1
if self.measurement_key_map:
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
if self.param_resolver:
circuit = protocols.resolve_parameters(
circuit, self.param_resolver, recursive=False
)
object.__setattr__(self, '_cached_mapped_single_loop', circuit)
circuit = cast(circuits.Circuit, self._cached_mapped_single_loop)
if repetition_id:
circuit = protocols.with_rescoped_keys(circuit, (repetition_id,))
circuit = protocols.with_rescoped_keys(
circuit, self.parent_path, bindable_keys=self.extern_keys
)
return circuit

def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
"""Applies all maps to the contained circuit and returns the result.

Expand All @@ -247,25 +287,15 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
qubit mapping, parameterization, etc.) applied to it. This behaves
like `cirq.decompose(self)`, but preserving moment structure.
"""
circuit = self.circuit.unfreeze()
if self.qubit_map:
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.repetitions < 0:
circuit = circuit ** -1
if self.measurement_key_map:
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
if self.param_resolver:
circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False)
if self.repetition_ids:
if not self.use_repetition_ids or not protocols.is_measurement(circuit):
circuit = circuit * abs(self.repetitions)
if not self.use_repetition_ids or not protocols.is_measurement(self.circuit):
circuit = self._mapped_single_loop() * abs(self.repetitions)
else:
circuit = circuits.Circuit(
protocols.with_rescoped_keys(circuit, (rep,)) for rep in self.repetition_ids
self._mapped_single_loop(rep) for rep in self.repetition_ids
)
circuit = protocols.with_rescoped_keys(
circuit, self.parent_path, bindable_keys=self.extern_keys
)
else:
circuit = self._mapped_single_loop()
if deep:
circuit = circuit.map_operations(
lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op
Expand All @@ -280,8 +310,16 @@ def _decompose_(self) -> Iterator['cirq.Operation']:
return self.mapped_circuit(deep=False).all_operations()

def _act_on_(self, args: 'cirq.OperationTarget') -> bool:
for op in self._decompose_():
protocols.act_on(op, args)
if self.repeat_until:
circuit = self._mapped_single_loop()
while True:
for op in circuit.all_operations():
protocols.act_on(op, args)
if self.repeat_until.resolve(args.classical_data):
break
else:
for op in self._decompose_():
protocols.act_on(op, args)
return True

# Methods for string representation of the operation.
Expand All @@ -303,6 +341,8 @@ def __repr__(self):
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
if not self.use_repetition_ids:
args += 'use_repetition_ids=False,\n'
if self.repeat_until:
args += f'repeat_until={self.repeat_until!r},\n'
indented_args = args.replace('\n', '\n ')
return f'cirq.CircuitOperation({indented_args[:-4]})'

Expand Down Expand Up @@ -335,6 +375,8 @@ def dict_str(d: Dict) -> str:
args.append(f'loops={self.repetitions}')
if not self.use_repetition_ids:
args.append('no_rep_ids')
if self.repeat_until:
args.append(f'until={self.repeat_until}')
if not args:
return circuit_msg
return f'{circuit_msg}({", ".join(args)})'
Expand Down Expand Up @@ -373,6 +415,8 @@ def _json_dict_(self):
}
if not self.use_repetition_ids:
resp['use_repetition_ids'] = False
if self.repeat_until:
resp['repeat_until'] = self.repeat_until
return resp

@classmethod
Expand All @@ -386,10 +430,11 @@ def _from_json_dict_(
repetition_ids,
parent_path=(),
use_repetition_ids=True,
repeat_until=None,
**kwargs,
):
return (
cls(circuit, use_repetition_ids=use_repetition_ids)
cls(circuit, use_repetition_ids=use_repetition_ids, repeat_until=repeat_until)
.with_qubit_mapping(dict(qubit_map))
.with_measurement_key_mapping(measurement_key_map)
.with_params(param_resolver)
Expand Down
78 changes: 78 additions & 0 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,24 @@ def test_string_format():
use_repetition_ids=False,
)"""
)
op7 = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(x, key='a')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)
assert (
repr(op7)
== """\
cirq.CircuitOperation(
circuit=cirq.FrozenCircuit([
cirq.Moment(
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
)"""
)


def test_json_dict():
Expand Down Expand Up @@ -928,4 +946,64 @@ def test_simulate_no_repetition_ids_inner(sim):
assert result.records['1:a'].shape == (1, 2, 1)


@pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()])
def test_repeat_until(sim):
q = cirq.LineQubit(0)
key = cirq.MeasurementKey('m')
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q) ** 0.2,
cirq.measure(q, key='m'),
),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
result = sim.run(c)
assert result.records['m'][0][-1] == (1,)
for i in range(len(result.records['m'][0]) - 1):
assert result.records['m'][0][i] == (0,)


def test_repeat_until_diagram():
q = cirq.LineQubit(0)
key = cirq.MeasurementKey('m')
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q) ** 0.2,
cirq.measure(q, key='m'),
),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
cirq.testing.assert_has_diagram(
c,
"""
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
m: ═══╩══════════════════════════════════════════════════
""",
use_unicode_characters=True,
)


def test_repeat_until_error():
q = cirq.LineQubit(0)
with pytest.raises(ValueError, match='Cannot use repetitions with repeat_until'):
cirq.CircuitOperation(
cirq.FrozenCircuit(),
use_repetition_ids=True,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)
with pytest.raises(ValueError, match='Infinite loop'):
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(q, key='m')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)


# TODO: Operation has a "gate" property. What is this for a CircuitOperation?
26 changes: 26 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/CircuitOperation.json
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,32 @@
"parent_path": [],
"repetition_ids": null,
"use_repetition_ids": false
},
{
"cirq_type": "CircuitOperation",
"circuit": {
"cirq_type": "_SerializedKey",
"key": 1
},
"repetitions": 1,
"qubit_map": [],
"measurement_key_map": {},
"param_resolver": {
"cirq_type": "ParamResolver",
"param_dict": []
},
"parent_path": [],
"repetition_ids": null,
"use_repetition_ids": false,
"repeat_until": {
"cirq_type": "KeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "0,1,2,3,4",
"path": []
},
"index": -1
}
}
]
]
Expand Down
17 changes: 16 additions & 1 deletion cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,19 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([
),
]),
param_resolver={sympy.Symbol('theta'): 1.5},
use_repetition_ids=False)]
use_repetition_ids=False),
cirq.CircuitOperation(circuit=cirq.FrozenCircuit([
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
cirq.H(cirq.LineQubit(1)),
cirq.H(cirq.LineQubit(2)),
cirq.H(cirq.LineQubit(3)),
cirq.H(cirq.LineQubit(4)),
),
cirq.Moment(
cirq.MeasurementGate(5, '0,1,2,3,4', ()).on(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), cirq.LineQubit(3), cirq.LineQubit(4)),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key=cirq.MeasurementKey('0,1,2,3,4')),
)]