Skip to content

Commit

Permalink
Support serialization of GateFamilies (quantumlib#4715)
Browse files Browse the repository at this point in the history
Exactly what it says: this PR uses the new type-serialization behavior to allow GateFamilies to be serialized.

Additional work is still required for Gateset serialization - that behavior is not included in this PR.
  • Loading branch information
95-martin-orion authored Dec 1, 2021
1 parent 196318f commit 39ffd16
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 4 deletions.
4 changes: 4 additions & 0 deletions cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _parallel_gate_op(gate, qubits):
import sympy

return {
'AnyIntegerPowerGateFamily': cirq.AnyIntegerPowerGateFamily,
'AmplitudeDampingChannel': cirq.AmplitudeDampingChannel,
'AnyUnitaryGateFamily': cirq.AnyUnitaryGateFamily,
'AsymmetricDepolarizingChannel': cirq.AsymmetricDepolarizingChannel,
'BitFlipChannel': cirq.BitFlipChannel,
'BitstringAccumulator': cirq.work.BitstringAccumulator,
Expand Down Expand Up @@ -81,6 +83,7 @@ def _parallel_gate_op(gate, qubits):
'MutableDensePauliString': cirq.MutableDensePauliString,
'MutablePauliString': cirq.MutablePauliString,
'ObservableMeasuredResult': cirq.work.ObservableMeasuredResult,
'GateFamily': cirq.GateFamily,
'GateOperation': cirq.GateOperation,
'GeneralizedAmplitudeDampingChannel': cirq.GeneralizedAmplitudeDampingChannel,
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
Expand Down Expand Up @@ -115,6 +118,7 @@ def _parallel_gate_op(gate, qubits):
'_PauliZ': cirq.ops.pauli_gates._PauliZ,
'ParamResolver': cirq.ParamResolver,
'ParallelGate': cirq.ParallelGate,
'ParallelGateFamily': cirq.ParallelGateFamily,
'PauliMeasurementGate': cirq.PauliMeasurementGate,
'PauliString': cirq.PauliString,
'PhaseDampingChannel': cirq.PhaseDampingChannel,
Expand Down
32 changes: 32 additions & 0 deletions cirq/ops/common_gate_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
return self._num_qubits

def _json_dict_(self):
return {'num_qubits': self._num_qubits}

@classmethod
def _from_json_dict_(cls, num_qubits, **kwargs):
return cls(num_qubits)


class AnyIntegerPowerGateFamily(gateset.GateFamily):
"""GateFamily which accepts instances of a given `cirq.EigenGate`, raised to integer power."""
Expand Down Expand Up @@ -87,6 +94,15 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.gate

def _json_dict_(self):
return {'gate': self._gate_json()}

@classmethod
def _from_json_dict_(cls, gate, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(gate)


class ParallelGateFamily(gateset.GateFamily):
"""GateFamily which accepts instances of `cirq.ParallelGate` and it's sub_gate.
Expand Down Expand Up @@ -175,3 +191,19 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
# `isinstance` is used to ensure the a gate type and gate instance is not compared.
return super()._value_equality_values_() + (self._max_parallel_allowed,)

def _json_dict_(self):
return {
'gate': self._gate_json(),
'name': self.name,
'description': self.description,
'max_parallel_allowed': self._max_parallel_allowed,
}

@classmethod
def _from_json_dict_(cls, gate, name, description, max_parallel_allowed, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(
gate, name=name, description=description, max_parallel_allowed=max_parallel_allowed
)
19 changes: 19 additions & 0 deletions cirq/ops/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __init__(
def _gate_str(self, gettr: Callable[[Any], str] = str) -> str:
return _gate_str(self.gate, gettr)

def _gate_json(self) -> Union[raw_types.Gate, str]:
return self.gate if not isinstance(self.gate, type) else protocols.json_cirq_type(self.gate)

def _default_name(self) -> str:
family_type = 'Instance' if isinstance(self.gate, raw_types.Gate) else 'Type'
return f'{family_type} GateFamily: {self._gate_str()}'
Expand Down Expand Up @@ -167,6 +170,22 @@ def _value_equality_values_(self) -> Any:
self._ignore_global_phase,
)

def _json_dict_(self):
return {
'gate': self._gate_json(),
'name': self.name,
'description': self.description,
'ignore_global_phase': self._ignore_global_phase,
}

@classmethod
def _from_json_dict_(cls, gate, name, description, ignore_global_phase, **kwargs):
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(
gate, name=name, description=description, ignore_global_phase=ignore_global_phase
)


@value.value_equality()
class Gateset:
Expand Down
8 changes: 8 additions & 0 deletions cirq/ops/gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ def test_gate_family_repr_and_str(gate, name, description):
assert g.description in str(g)


@pytest.mark.parametrize('gate', [cirq.X, cirq.XPowGate(), cirq.XPowGate])
@pytest.mark.parametrize('name, description', [(None, None), ('custom_name', 'custom_description')])
def test_gate_family_json(gate, name, description):
g = cirq.GateFamily(gate, name=name, description=description)
g_json = cirq.to_json(g)
assert cirq.read_json(json_text=g_json) == g


def test_gate_family_eq():
eq = cirq.testing.EqualsTester()
eq.add_equality_group(cirq.GateFamily(CustomX))
Expand Down
4 changes: 4 additions & 0 deletions cirq/protocols/json_test_data/AnyIntegerPowerGateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"cirq_type": "AnyIntegerPowerGateFamily",
"gate": "XPowGate"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.AnyIntegerPowerGateFamily(cirq.ops.common_gates.XPowGate)
4 changes: 4 additions & 0 deletions cirq/protocols/json_test_data/AnyUnitaryGateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"cirq_type": "AnyUnitaryGateFamily",
"num_qubits": 2
}
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/AnyUnitaryGateFamily.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.AnyUnitaryGateFamily(num_qubits = 2)
20 changes: 20 additions & 0 deletions cirq/protocols/json_test_data/GateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"cirq_type": "GateFamily",
"gate": "XPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
"ignore_global_phase": true
},
{
"cirq_type": "GateFamily",
"gate": {
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
"name": "XFamily",
"description": "Just the X gate.",
"ignore_global_phase": false
}
]
4 changes: 4 additions & 0 deletions cirq/protocols/json_test_data/GateFamily.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True),
cirq.GateFamily(gate=cirq.X, name="XFamily", description="Just the X gate.", ignore_global_phase=False)
]
20 changes: 20 additions & 0 deletions cirq/protocols/json_test_data/ParallelGateFamily.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"cirq_type": "ParallelGateFamily",
"gate": "XPowGate",
"name": "INF Parallel Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts\n1. `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)` OR\n2. `cirq.ParallelGate` instance `g` s.t. `g.sub_gate` satisfies 1. and `cirq.num_qubits(g) <= INF` OR\n3. `cirq.Operation` instance `op` s.t. `op.gate` satisfies 1. or 2.",
"max_parallel_allowed": null
},
{
"cirq_type": "ParallelGateFamily",
"gate": {
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
"name": "ParallelXFamily",
"description": "Up to 4 parallel X gates",
"max_parallel_allowed": 4
}
]
9 changes: 9 additions & 0 deletions cirq/protocols/json_test_data/ParallelGateFamily.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[
cirq.ParallelGateFamily(gate=cirq.ops.common_gates.XPowGate, max_parallel_allowed=None),
cirq.ParallelGateFamily(
gate=cirq.X,
name="ParallelXFamily",
description=r'''Up to 4 parallel X gates''',
max_parallel_allowed=4
)
]
4 changes: 0 additions & 4 deletions cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
resolver_cache=_class_resolver_dictionary(),
not_yet_serializable=[
'Alignment',
'AnyIntegerPowerGateFamily',
'AnyUnitaryGateFamily',
'AxisAngleDecomposition',
'CircuitDag',
'CircuitDiagramInfo',
Expand All @@ -39,7 +37,6 @@
'DensityMatrixStepResult',
'DensityMatrixTrialResult',
'ExpressionMap',
'GateFamily',
'Gateset',
'InsertStrategy',
'IonDevice',
Expand All @@ -50,7 +47,6 @@
'ListSweep',
'DiagonalGate',
'NeutralAtomDevice',
'ParallelGateFamily',
'PauliInteractionGate',
'PauliStringPhasor',
'PauliSum',
Expand Down

0 comments on commit 39ffd16

Please sign in to comment.