-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathtransformer_primitives.py
720 lines (624 loc) · 30.8 KB
/
transformer_primitives.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines primitives for common transformer patterns."""
from collections import defaultdict
import bisect
import dataclasses
from typing import (
cast,
Callable,
Dict,
Hashable,
List,
Optional,
Sequence,
Set,
Union,
Tuple,
TYPE_CHECKING,
)
from cirq import circuits, ops, protocols
from cirq.circuits.circuit import CIRCUIT_TYPE
if TYPE_CHECKING:
import cirq
MAPPED_CIRCUIT_OP_TAG = '<mapped_circuit_op>'
def _to_target_circuit_type(
circuit: circuits.AbstractCircuit, target_circuit: CIRCUIT_TYPE
) -> CIRCUIT_TYPE:
return cast(
CIRCUIT_TYPE,
circuit.unfreeze(copy=False)
if isinstance(target_circuit, circuits.Circuit)
else circuit.freeze(),
)
def _create_target_circuit_type(ops: ops.OP_TREE, target_circuit: CIRCUIT_TYPE) -> CIRCUIT_TYPE:
return cast(
CIRCUIT_TYPE,
circuits.Circuit(ops)
if isinstance(target_circuit, circuits.Circuit)
else circuits.FrozenCircuit(ops),
)
def map_moments(
circuit: CIRCUIT_TYPE,
map_func: Callable[[circuits.Moment, int], Union[circuits.Moment, Sequence[circuits.Moment]]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Applies local transformation on moments, by calling `map_func(moment)` for each moment.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
map_func: Mapping function from (cirq.Moment, moment_index) to a sequence of moments.
tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
ignored when recursively applying the transformer primitive to sub-circuits, given
deep=True.
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
any circuit operations contained within `circuit`.
Returns:
Copy of input circuit with mapped moments.
"""
mutable_circuit = circuit.unfreeze(copy=False)
if deep:
batch_replace = []
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
if set(op.tags).intersection(tags_to_ignore):
continue
op_untagged = cast(circuits.CircuitOperation, op.untagged)
mapped_op = op_untagged.replace(
circuit=map_moments(
op_untagged.circuit, map_func, tags_to_ignore=tags_to_ignore, deep=deep
)
).with_tags(*op.tags)
batch_replace.append((i, op, mapped_op))
mutable_circuit = circuit.unfreeze(copy=True)
mutable_circuit.batch_replace(batch_replace)
return _create_target_circuit_type(
(map_func(mutable_circuit[i], i) for i in range(len(mutable_circuit))), circuit
)
def map_operations(
circuit: CIRCUIT_TYPE,
map_func: Callable[[ops.Operation, int], ops.OP_TREE],
*,
deep: bool = False,
raise_if_add_qubits=True,
tags_to_ignore: Sequence[Hashable] = (),
) -> CIRCUIT_TYPE:
"""Applies local transformations, by calling `map_func(op, moment_index)` for each operation.
By default, the function assumes `issubset(qubit_set(map_func(op, moment_index)), op.qubits)` is
True.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. If the
resulting optree spans more than 1 moment, it's inserted in-place in the same moment as
`cirq.CircuitOperation(cirq.FrozenCircuit(op_tree)).with_tags(MAPPED_CIRCUIT_OP_TAG)`
to preserve moment structure. Utility methods like `cirq.unroll_circuit_op` can
subsequently be used to unroll the mapped circuit operation.
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
any circuit operations contained within `circuit`.
raise_if_add_qubits: Set to True by default. If True, raises ValueError if
`map_func(op, idx)` adds operations on qubits outside of `op.qubits`.
tags_to_ignore: Sequence of tags which should be ignored while applying `map_func` on
tagged operations -- i.e. `map_func(op, idx)` will be called only for operations that
satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
Raises:
ValueError if `issubset(qubit_set(map_func(op, idx)), op.qubits) is False` and
`raise_if_add_qubits is True`.
Returns:
Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation).
"""
def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE:
if not set(op.tags).isdisjoint(tags_to_ignore):
return op
c = circuits.FrozenCircuit(map_func(op, idx))
if raise_if_add_qubits and not c.all_qubits().issubset(op.qubits):
raise ValueError(
f"Mapped operations {c.all_operations()} should act on a subset "
f"of qubits of the original operation {op}"
)
if len(c) <= 1:
# Either empty circuit or all operations act in the same moment;
# So, we don't need to wrap them in a circuit_op.
return c[0].operations if c else []
circuit_op = circuits.CircuitOperation(c).with_tags(MAPPED_CIRCUIT_OP_TAG)
return circuit_op
return map_moments(
circuit,
lambda m, i: circuits.Circuit(apply_map(op, i) for op in m.operations).moments
or [circuits.Moment()],
deep=deep,
tags_to_ignore=tags_to_ignore,
)
def map_operations_and_unroll(
circuit: CIRCUIT_TYPE,
map_func: Callable[[ops.Operation, int], ops.OP_TREE],
*,
deep: bool = False,
raise_if_add_qubits=True,
tags_to_ignore: Sequence[Hashable] = (),
) -> CIRCUIT_TYPE:
"""Applies local transformations via `cirq.map_operations` & unrolls intermediate circuit ops.
See `cirq.map_operations` and `cirq.unroll_circuit_op` for more details.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE.
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
any circuit operations contained within `circuit`.
raise_if_add_qubits: Set to True by default. If True, raises ValueError if
`map_func(op, idx)` adds operations on qubits outside `op.qubits`.
tags_to_ignore: Sequence of tags which should be ignored while applying `map_func` on
tagged operations -- i.e. `map_func(op, idx)` will be called only for operations that
satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
Returns:
Copy of input circuit with mapped operations, unrolled in a moment preserving way.
"""
return unroll_circuit_op(
map_operations(
circuit,
map_func,
deep=deep,
raise_if_add_qubits=raise_if_add_qubits,
tags_to_ignore=tags_to_ignore,
),
deep=deep,
)
@dataclasses.dataclass
class _MergedCircuit:
"""An optimized internal representation of a circuit, tailored for `cirq.merge_operations`
Attributes:
qubit_indexes: Mapping from qubits to (sorted) list of moment indexes containing operations
acting on the qubit.
mkey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing
measurement operations with the same key.
ckey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing
classically controlled operations controlled on the same key.
ops_by_index: List of circuit moments containing operations. We use a dictionary instead
of a set to store operations to preserve insertion order.
"""
qubit_indexes: Dict['cirq.Qid', List[int]] = dataclasses.field(
default_factory=lambda: defaultdict(lambda: [-1])
)
mkey_indexes: Dict['cirq.MeasurementKey', List[int]] = dataclasses.field(
default_factory=lambda: defaultdict(lambda: [-1])
)
ckey_indexes: Dict['cirq.MeasurementKey', List[int]] = dataclasses.field(
default_factory=lambda: defaultdict(lambda: [-1])
)
ops_by_index: List[Dict['cirq.Operation', int]] = dataclasses.field(default_factory=list)
def append_empty_moment(self) -> None:
self.ops_by_index.append({})
def add_op_to_moment(self, moment_index: int, op: 'cirq.Operation') -> None:
self.ops_by_index[moment_index][op] = 0
for q in op.qubits:
if moment_index > self.qubit_indexes[q][-1]:
self.qubit_indexes[q].append(moment_index)
else:
bisect.insort(self.qubit_indexes[q], moment_index)
for mkey in protocols.measurement_key_objs(op):
bisect.insort(self.mkey_indexes[mkey], moment_index)
for ckey in protocols.control_keys(op):
bisect.insort(self.ckey_indexes[ckey], moment_index)
def remove_op_from_moment(self, moment_index: int, op: 'cirq.Operation') -> None:
self.ops_by_index[moment_index].pop(op)
for q in op.qubits:
if self.qubit_indexes[q][-1] == moment_index:
self.qubit_indexes[q].pop()
else:
self.qubit_indexes[q].remove(moment_index)
for mkey in protocols.measurement_key_objs(op):
self.mkey_indexes[mkey].remove(moment_index)
for ckey in protocols.control_keys(op):
self.ckey_indexes[ckey].remove(moment_index)
def get_mergeable_ops(
self, op: 'cirq.Operation', op_qs: Set['cirq.Qid']
) -> Tuple[int, List['cirq.Operation']]:
# Find the index of previous moment which can be merged with `op`.
idx = max([self.qubit_indexes[q][-1] for q in op_qs], default=-1)
idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in protocols.control_keys(op)])
idx = max(
[idx] + [self.ckey_indexes[mkey][-1] for mkey in protocols.measurement_key_objs(op)]
)
# Return the set of overlapping ops in moment with index `idx`.
if idx == -1:
return idx, []
return idx, [
left_op for left_op in self.ops_by_index[idx] if not op_qs.isdisjoint(left_op.qubits)
]
def get_cirq_circuit(self) -> 'cirq.Circuit':
return circuits.Circuit(circuits.Moment(m.keys()) for m in self.ops_by_index)
def merge_operations(
circuit: CIRCUIT_TYPE,
merge_func: Callable[[ops.Operation, ops.Operation], Optional[ops.Operation]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges operations in a circuit by calling `merge_func` iteratively on operations.
Two operations op1 and op2 are merge-able if
- There is no other operations between op1 and op2 in the circuit
- is_subset(op1.qubits, op2.qubits) or is_subset(op2.qubits, op1.qubits)
The `merge_func` is a callable which, given two merge-able operations
op1 and op2, decides whether they should be merged into a single operation
or not. If not, it should return None, else it should return the single merged
operations `op`.
The method iterates on the input circuit moment-by-moment from left to right and attempts
to repeatedly merge each operation in the latest moment with all the corresponding merge-able
operations to its left.
If op1 and op2 are merged, both op1 and op2 are deleted from the circuit and
the resulting `merged_op` is inserted at the index corresponding to the larger
of op1/op2. If both op1 and op2 act on the same number of qubits, `merged_op` is
inserted in the smaller moment index to minimize circuit depth.
The number of calls to `merge_func` is O(N), where N = Total no. of operations, because:
- Every time the `merge_func` returns a new operation, the number of operations in the
circuit reduce by 1 and hence this can happen at most O(N) times
- Every time the `merge_func` returns None, the current operation is inserted into the
frontier and we go on to process the next operation, which can also happen at-most
O(N) times.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
merge_func: Callable to determine whether two merge-able operations in the circuit should
be merged. If the operations can be merged, the callable should return the merged
operation, else None.
tags_to_ignore: Sequence of tags which should be ignored while applying `merge_func` on
tagged operations -- i.e. `merge_func(op1, op2)` will be called only if both `op1` and
`op2` satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
Returns:
Copy of input circuit with merged operations.
Raises:
ValueError if the merged operation acts on new qubits outside the set of qubits
corresponding to the original operations to be merged.
"""
_circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit"
tags_to_ignore_set = set(tags_to_ignore) | {_circuit_op_tag}
def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Operation]:
if not all(tags_to_ignore_set.isdisjoint(op.tags) for op in [op1, op2]):
return None
new_op = merge_func(op1, op2)
qubit_set = frozenset(op1.qubits + op2.qubits)
if new_op is not None and not qubit_set.issuperset(new_op.qubits):
raise ValueError(
f"Merged operation {new_op} must act on a subset of qubits of "
f"original operations {op1} and {op2}"
)
return new_op
merged_circuit = _MergedCircuit()
for moment_idx, current_moment in enumerate(cast(List['cirq.Moment'], circuit)):
merged_circuit.append_empty_moment()
for op in sorted(current_moment.operations, key=lambda op: op.qubits):
if (
deep
and isinstance(op.untagged, circuits.CircuitOperation)
and tags_to_ignore_set.isdisjoint(op.tags)
):
op_untagged = op.untagged
merged_circuit.add_op_to_moment(
moment_idx,
op_untagged.replace(
circuit=merge_operations(
op_untagged.circuit,
merge_func,
tags_to_ignore=tags_to_ignore,
deep=True,
)
).with_tags(*op.tags, _circuit_op_tag),
)
continue
op_qs = set(op.qubits)
left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs)
if len(left_ops) == 1 and op_qs.issubset(left_ops[0].qubits):
# Case-1: Try to merge op with the larger operation on the left.
new_op = apply_merge_func(left_ops[0], op)
if new_op is not None:
merged_circuit.remove_op_from_moment(left_idx, left_ops[0])
merged_circuit.add_op_to_moment(left_idx, new_op)
else:
merged_circuit.add_op_to_moment(moment_idx, op)
continue
while left_ops and op_qs:
# Case-2: left_ops will merge right into `op` whenever possible.
for left_op in left_ops:
is_merged = False
if op_qs.issuperset(left_op.qubits):
# Try to merge left_op into op
new_op = apply_merge_func(left_op, op)
if new_op is not None:
merged_circuit.remove_op_from_moment(left_idx, left_op)
op, is_merged = new_op, True
if not is_merged:
op_qs -= frozenset(left_op.qubits)
left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs)
merged_circuit.add_op_to_moment(moment_idx, op)
ret_circuit = merged_circuit.get_cirq_circuit()
if deep:
ret_circuit = map_operations(
ret_circuit,
lambda o, _: o.untagged.with_tags(*(set(o.tags) - {_circuit_op_tag})),
deep=True,
)
return _to_target_circuit_type(ret_circuit, circuit)
def merge_operations_to_circuit_op(
circuit: CIRCUIT_TYPE,
can_merge: Callable[[Sequence['cirq.Operation'], Sequence['cirq.Operation']], bool],
*,
tags_to_ignore: Sequence[Hashable] = (),
merged_circuit_op_tag: str = "Merged connected component",
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges connected components of operations and wraps each component into a circuit operation.
Uses `cirq.merge_operations` to identify connected components of operations. Moment structure
is preserved for operations that do not participate in merging. For merged operations, the
newly created circuit operations are constructed by inserting operations using EARLIEST
strategy.
If you need more control on moment structure of newly created circuit operations, consider
using `cirq.merge_operations` directly with a custom `merge_func`.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
can_merge: Callable to determine whether a new operation `right_op` can be merged into an
existing connected component of operations `left_ops` based on boolen returned by
`can_merge(left_ops, right_op)`.
tags_to_ignore: Tagged operations marked any of `tags_to_ignore` will not be considered as
potential candidates for any connected component.
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
components.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
Returns:
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
"""
def merge_func(op1: 'cirq.Operation', op2: 'cirq.Operation') -> Optional['cirq.Operation']:
def get_ops(op: 'cirq.Operation'):
op_untagged = op.untagged
return (
[*op_untagged.circuit.all_operations()]
if isinstance(op_untagged, circuits.CircuitOperation)
and merged_circuit_op_tag in op.tags
else [op]
)
left_ops, right_ops = get_ops(op1), get_ops(op2)
if not can_merge(left_ops, right_ops):
return None
return circuits.CircuitOperation(circuits.FrozenCircuit(left_ops, right_ops)).with_tags(
merged_circuit_op_tag
)
return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore, deep=deep)
def merge_k_qubit_unitaries_to_circuit_op(
circuit: CIRCUIT_TYPE,
k: int,
*,
tags_to_ignore: Sequence[Hashable] = (),
merged_circuit_op_tag: Optional[str] = None,
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges connected components of operations, acting on <= k qubits, into circuit operations.
Uses `cirq.merge_operations_to_circuit_op` to identify and merge connected components of
unitary operations acting on at-most k-qubits. Moment structure is preserved for operations
that do not participate in merging. For merged operations, the newly created circuit operations
are constructed by inserting operations using EARLIEST strategy.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
k: Merge-able operations acting on <= k qubits are merged into a connected component.
tags_to_ignore: Tagged operations marked any of `tags_to_ignore` will not be considered as
potential candidates for any connected component.
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
components. A default tag is applied if left None.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
Returns:
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
"""
def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']) -> bool:
return all(
protocols.num_qubits(op) <= k and protocols.has_unitary(op)
for op_list in [ops1, ops2]
for op in op_list
)
return merge_operations_to_circuit_op(
circuit,
can_merge,
tags_to_ignore=tags_to_ignore,
merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.",
deep=deep,
)
def merge_moments(
circuit: CIRCUIT_TYPE,
merge_func: Callable[[circuits.Moment, circuits.Moment], Optional[circuits.Moment]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
merge_func: Callable to determine whether two adjacent moments in the circuit should be
merged. If the moments can be merged, the callable should return the merged moment,
else None.
tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
ignored when recursively applying the transformer primitive to sub-circuits, given
deep=True.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
Returns:
Copy of input circuit with merged moments.
"""
if not circuit:
return circuit
if deep:
circuit = map_operations(
circuit,
lambda op, _: op.untagged.replace(
circuit=merge_moments(op.untagged.circuit, merge_func, deep=deep)
).with_tags(*op.tags)
if isinstance(op.untagged, circuits.CircuitOperation)
else op,
tags_to_ignore=tags_to_ignore,
)
merged_moments: List[circuits.Moment] = [circuit[0]]
for current_moment in circuit[1:]:
merged_moment = merge_func(merged_moments[-1], current_moment)
if merged_moment is None:
merged_moments.append(current_moment)
else:
merged_moments[-1] = merged_moment
return _create_target_circuit_type(merged_moments, circuit)
def unroll_circuit_op(
circuit: CIRCUIT_TYPE,
*,
deep: bool = False,
tags_to_check: Optional[Sequence[Hashable]] = (MAPPED_CIRCUIT_OP_TAG,),
) -> CIRCUIT_TYPE:
"""Unrolls (tagged) `cirq.CircuitOperation`s while preserving the moment structure.
Each moment containing a matching circuit operation is expanded into a list of moments with the
unrolled operations, hence preserving the original moment structure.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.
Returns:
Copy of input circuit with (Tagged) CircuitOperation's expanded in a moment preserving way.
"""
def map_func(m: circuits.Moment, _: int):
to_zip: List['cirq.AbstractCircuit'] = []
for op in m:
op_untagged = op.untagged
if isinstance(op_untagged, circuits.CircuitOperation):
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
to_zip.append(
op_untagged.mapped_circuit()
if (tags_to_check is None or set(tags_to_check).intersection(op.tags))
else circuits.Circuit(op_untagged.with_tags(*op.tags))
)
else:
to_zip.append(circuits.Circuit(op))
return circuits.Circuit.zip(*to_zip).moments
return map_moments(circuit, map_func)
def unroll_circuit_op_greedy_earliest(
circuit: CIRCUIT_TYPE,
*,
deep: bool = False,
tags_to_check: Optional[Sequence[Hashable]] = (MAPPED_CIRCUIT_OP_TAG,),
) -> CIRCUIT_TYPE:
"""Unrolls (tagged) `cirq.CircuitOperation`s by inserting operations using EARLIEST strategy.
Each matching `cirq.CircuitOperation` is replaced by inserting underlying operations using the
`cirq.InsertStrategy.EARLIEST` strategy. The greedy approach attempts to minimize circuit depth
of the resulting circuit.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.
Returns:
Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy.
"""
batch_replace = []
batch_remove = []
batch_insert = []
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op_greedy_earliest(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
batch_remove.append((i, op))
batch_insert.append((i, op_untagged.mapped_circuit().all_operations()))
elif deep:
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
unrolled_circuit = circuit.unfreeze(copy=True)
unrolled_circuit.batch_replace(batch_replace)
unrolled_circuit.batch_remove(batch_remove)
unrolled_circuit.batch_insert(batch_insert)
return _to_target_circuit_type(unrolled_circuit, circuit)
def unroll_circuit_op_greedy_frontier(
circuit: CIRCUIT_TYPE,
*,
deep: bool = False,
tags_to_check: Optional[Sequence[Hashable]] = (MAPPED_CIRCUIT_OP_TAG,),
) -> CIRCUIT_TYPE:
"""Unrolls (tagged) `cirq.CircuitOperation`s by inserting operations inline at qubit frontier.
Each matching `cirq.CircuitOperation` is replaced by inserting underlying operations using the
`circuit.insert_at_frontier` method. The greedy approach attempts to reuse any available space
in existing moments on the right of circuit_op before inserting new moments.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check`
are unrolled.
Returns:
Copy of input circuit with (Tagged) CircuitOperation's expanded inline at qubit frontier.
"""
unrolled_circuit = circuit.unfreeze(copy=True)
frontier: Dict['cirq.Qid', int] = defaultdict(lambda: 0)
idx = 0
while idx < len(unrolled_circuit):
for op in unrolled_circuit[idx].operations:
# Don't touch stuff inserted by unrolling previous circuit ops.
if not isinstance(op.untagged, circuits.CircuitOperation):
continue
if any(frontier[q] > idx for q in op.qubits):
continue
op_untagged = op.untagged
if deep:
op_untagged = op_untagged.replace(
circuit=unroll_circuit_op_greedy_frontier(
op_untagged.circuit, deep=deep, tags_to_check=tags_to_check
)
)
if tags_to_check is None or set(tags_to_check).intersection(op.tags):
unrolled_circuit.clear_operations_touching(op.qubits, [idx])
frontier = unrolled_circuit.insert_at_frontier(
op_untagged.mapped_circuit().all_operations(), idx, frontier
)
elif deep:
unrolled_circuit.batch_replace([(idx, op, op_untagged.with_tags(*op.tags))])
idx += 1
return _to_target_circuit_type(unrolled_circuit, circuit)
def toggle_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool = False):
"""Toggles tags applied on each operation in the circuit, via `op.tags ^= tags`
For every operations `op` in the input circuit, the tags on `op` are replaced by a symmetric
difference of `op.tags` and `tags` -- this is useful in scenarios where you mark a small subset
of operations with a specific tag and then toggle the set of marked operations s.t. every
marked operation is now unmarked and vice versa.
Often used in transformer workflows to apply a transformer on a small subset of operations.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
tags: Sequence of tags s.t. `op.tags ^= tags` is done for every operation `op` in circuit.
deep: If true, tags will be recursively toggled for operations in circuits wrapped inside
any circuit operations contained within `circuit`.
Returns:
Copy of transformed input circuit with operation sets marked with `tags` toggled.
"""
tags_to_xor = set(tags)
def map_func(op: 'cirq.Operation', _) -> 'cirq.Operation':
return (
op
if deep and isinstance(op, circuits.CircuitOperation)
else op.untagged.with_tags(*(set(op.tags) ^ tags_to_xor))
)
return map_operations(circuit, map_func, deep=deep)