Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat-until functionality to subcircuits #5018

Merged
merged 21 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 68 additions & 22 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
component operations in order, including any nested CircuitOperations.
"""
from typing import (
TYPE_CHECKING,
AbstractSet,
Callable,
cast,
Dict,
FrozenSet,
Iterator,
List,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -94,6 +95,12 @@ class CircuitOperation(ops.Operation):
will have its path prepended with the repetition id for each
repetition. When False, this will not happen and the measurement
key will be repeated.
repeat_until: A condition that will be tested after each iteration of
the subcircuit. The subcircuit will repeat until condition returns
True, but will always run at least once, and the measurement key
need not be defined prior to the subcircuit (but must be defined in
a measurement within the subcircuit). This field is incompatible
with repetitions or repetition_ids.
"""

_hash: Optional[int] = dataclasses.field(default=None, init=False)
Expand All @@ -103,6 +110,9 @@ class CircuitOperation(ops.Operation):
_cached_control_keys: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field(
default=None, init=False
)
_cached_mapped_single_loop: Optional['cirq.Circuit'] = dataclasses.field(
default=None, init=False
)

circuit: 'cirq.FrozenCircuit'
repetitions: int = 1
Expand All @@ -113,6 +123,7 @@ class CircuitOperation(ops.Operation):
parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset)
use_repetition_ids: bool = True
repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None)

def __post_init__(self):
if not isinstance(self.circuit, circuits.FrozenCircuit):
Expand Down Expand Up @@ -148,6 +159,14 @@ def __post_init__(self):
if q_new.dimension != q.dimension:
raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}')

if self.repeat_until:
if self.use_repetition_ids or self.repetitions != 1:
raise ValueError('Cannot use repetitions with repeat_until')
if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint(
self.repeat_until.keys
):
raise ValueError('Infinite loop: condition is not modified in subcircuit.')

# Ensure that param_resolver is converted to an actual ParamResolver.
object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver))

Expand All @@ -174,6 +193,7 @@ def __eq__(self, other) -> bool:
and self.repetition_ids == other.repetition_ids
and self.parent_path == other.parent_path
and self.use_repetition_ids == other.use_repetition_ids
and self.repeat_until == other.repeat_until
)

# Methods for getting post-mapping properties of the contained circuit.
Expand Down Expand Up @@ -223,6 +243,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
if not protocols.control_keys(self.circuit)
else protocols.control_keys(self.mapped_circuit())
)
if self.repeat_until is not None:
keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_()
object.__setattr__(self, '_cached_control_keys', keys)
return self._cached_control_keys # type: ignore

Expand All @@ -235,6 +257,27 @@ def _parameter_names_(self) -> AbstractSet[str]:
)
}

def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit':
if self._cached_mapped_single_loop is None:
circuit = self.circuit.unfreeze()
if self.qubit_map:
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.repetitions < 0:
circuit = circuit ** -1
if self.measurement_key_map:
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
if self.param_resolver:
circuit = protocols.resolve_parameters(
circuit, self.param_resolver, recursive=False
)
object.__setattr__(self, '_cached_mapped_single_loop', circuit)
circuit = cast(circuits.Circuit, self._cached_mapped_single_loop)
if repetition_id:
circuit = protocols.with_rescoped_keys(circuit, (repetition_id,))
return protocols.with_rescoped_keys(
circuit, self.parent_path, bindable_keys=self.extern_keys
)

def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
"""Applies all maps to the contained circuit and returns the result.

Expand All @@ -249,24 +292,12 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
"""
if self.repetitions == 0:
return circuits.Circuit()
circuit = self.circuit.unfreeze()
if self.qubit_map:
circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q))
if self.repetitions < 0:
circuit = circuit ** -1
if self.measurement_key_map:
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
if self.param_resolver:
circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False)
if self.repetition_ids is not None:
if not self.use_repetition_ids or not protocols.is_measurement(circuit):
circuit = circuit * abs(self.repetitions)
else:
circuit = circuits.Circuit(
protocols.with_rescoped_keys(circuit, (rep,)) for rep in self.repetition_ids
)
circuit = protocols.with_rescoped_keys(
circuit, self.parent_path, bindable_keys=self.extern_keys
circuit = (
circuits.Circuit(self._mapped_single_loop(rep) for rep in self.repetition_ids)
if self.repetition_ids is not None
and self.use_repetition_ids
and protocols.is_measurement(self.circuit)
else self._mapped_single_loop() * abs(self.repetitions)
)
if deep:
circuit = circuit.map_operations(
Expand All @@ -282,8 +313,16 @@ def _decompose_(self) -> Iterator['cirq.Operation']:
return self.mapped_circuit(deep=False).all_operations()

def _act_on_(self, args: 'cirq.OperationTarget') -> bool:
for op in self._decompose_():
protocols.act_on(op, args)
if self.repeat_until:
circuit = self._mapped_single_loop()
while True:
for op in circuit.all_operations():
protocols.act_on(op, args)
if self.repeat_until.resolve(args.classical_data):
break
else:
for op in self._decompose_():
protocols.act_on(op, args)
return True

# Methods for string representation of the operation.
Expand All @@ -305,6 +344,8 @@ def __repr__(self):
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
if not self.use_repetition_ids:
args += 'use_repetition_ids=False,\n'
if self.repeat_until:
args += f'repeat_until={self.repeat_until!r},\n'
indented_args = args.replace('\n', '\n ')
return f'cirq.CircuitOperation({indented_args[:-4]})'

Expand Down Expand Up @@ -337,6 +378,8 @@ def dict_str(d: Dict) -> str:
args.append(f'loops={self.repetitions}')
if not self.use_repetition_ids:
args.append('no_rep_ids')
if self.repeat_until:
args.append(f'until={self.repeat_until}')
if not args:
return circuit_msg
return f'{circuit_msg}({", ".join(args)})'
Expand Down Expand Up @@ -375,6 +418,8 @@ def _json_dict_(self):
}
if not self.use_repetition_ids:
resp['use_repetition_ids'] = False
if self.repeat_until:
resp['repeat_until'] = self.repeat_until
return resp

@classmethod
Expand All @@ -388,10 +433,11 @@ def _from_json_dict_(
repetition_ids,
parent_path=(),
use_repetition_ids=True,
repeat_until=None,
**kwargs,
):
return (
cls(circuit, use_repetition_ids=use_repetition_ids)
cls(circuit, use_repetition_ids=use_repetition_ids, repeat_until=repeat_until)
.with_qubit_mapping(dict(qubit_map))
.with_measurement_key_mapping(measurement_key_map)
.with_params(param_resolver)
Expand Down
121 changes: 121 additions & 0 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,24 @@ def test_string_format():
use_repetition_ids=False,
)"""
)
op7 = cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(x, key='a')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)
assert (
repr(op7)
== """\
cirq.CircuitOperation(
circuit=cirq.FrozenCircuit([
cirq.Moment(
cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')),
)"""
)


def test_json_dict():
Expand Down Expand Up @@ -977,4 +995,107 @@ def test_simulate_no_repetition_ids_inner(sim):
assert result.records['1:a'].shape == (1, 2, 1)


@pytest.mark.parametrize('sim', ALL_SIMULATORS)
def test_repeat_until(sim):
q = cirq.LineQubit(0)
key = cirq.MeasurementKey('m')
c = cirq.Circuit(
cirq.X(q),
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q),
cirq.measure(q, key=key),
),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
measurements = sim.run(c).records['m'][0]
assert len(measurements) == 2
assert measurements[0] == (0,)
assert measurements[1] == (1,)


@pytest.mark.parametrize('sim', ALL_SIMULATORS)
def test_repeat_until_sympy(sim):
q1, q2 = cirq.LineQubit.range(2)
circuitop = cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q2),
cirq.measure(q2, key='b'),
),
use_repetition_ids=False,
repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))),
)
c = cirq.Circuit(
cirq.measure(q1, key='a'),
circuitop,
)
# Validate commutation
assert len(c) == 2
assert cirq.control_keys(circuitop) == {cirq.MeasurementKey('a')}
measurements = sim.run(c).records['b'][0]
assert len(measurements) == 2
assert measurements[0] == (1,)
assert measurements[1] == (0,)


@pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()])
def test_post_selection(sim):
q = cirq.LineQubit(0)
key = cirq.MeasurementKey('m')
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q) ** 0.2,
cirq.measure(q, key=key),
),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
result = sim.run(c)
assert result.records['m'][0][-1] == (1,)
for i in range(len(result.records['m'][0]) - 1):
assert result.records['m'][0][i] == (0,)


def test_repeat_until_diagram():
q = cirq.LineQubit(0)
key = cirq.MeasurementKey('m')
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(q) ** 0.2,
cirq.measure(q, key=key),
),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key),
),
)
cirq.testing.assert_has_diagram(
c,
"""
0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)───
""",
use_unicode_characters=True,
)


def test_repeat_until_error():
q = cirq.LineQubit(0)
with pytest.raises(ValueError, match='Cannot use repetitions with repeat_until'):
cirq.CircuitOperation(
cirq.FrozenCircuit(),
use_repetition_ids=True,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)
with pytest.raises(ValueError, match='Infinite loop'):
cirq.CircuitOperation(
cirq.FrozenCircuit(cirq.measure(q, key='m')),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')),
)


# TODO: Operation has a "gate" property. What is this for a CircuitOperation?
26 changes: 26 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/CircuitOperation.json
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,32 @@
"parent_path": [],
"repetition_ids": null,
"use_repetition_ids": false
},
{
"cirq_type": "CircuitOperation",
"circuit": {
"cirq_type": "_SerializedKey",
"key": 1
},
"repetitions": 1,
"qubit_map": [],
"measurement_key_map": {},
"param_resolver": {
"cirq_type": "ParamResolver",
"param_dict": []
},
"parent_path": [],
"repetition_ids": null,
"use_repetition_ids": false,
"repeat_until": {
"cirq_type": "KeyCondition",
"key": {
"cirq_type": "MeasurementKey",
"name": "0,1,2,3,4",
"path": []
},
"index": -1
}
}
]
]
Expand Down
17 changes: 16 additions & 1 deletion cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,19 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([
),
]),
param_resolver={sympy.Symbol('theta'): 1.5},
use_repetition_ids=False)]
use_repetition_ids=False),
cirq.CircuitOperation(circuit=cirq.FrozenCircuit([
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
cirq.H(cirq.LineQubit(1)),
cirq.H(cirq.LineQubit(2)),
cirq.H(cirq.LineQubit(3)),
cirq.H(cirq.LineQubit(4)),
),
cirq.Moment(
cirq.MeasurementGate(5, '0,1,2,3,4', ()).on(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), cirq.LineQubit(3), cirq.LineQubit(4)),
),
]),
use_repetition_ids=False,
repeat_until=cirq.KeyCondition(key=cirq.MeasurementKey('0,1,2,3,4')),
)]