Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pretty repr to devices and result types #4649

Merged
merged 7 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ def __str__(self) -> str:
final = self._final_simulator_state
return f'measurements: {samples}\noutput state: {final}'

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
if cycle:
# There should never be a cycle. This is just in case.
p.text('cirq.MPSTrialResult(...)')
else:
p.text(str(self))


class MPSSimulatorStepResult(simulator_base.StepResultBase['MPSState', 'MPSState']):
"""A `StepResult` that can perform measurements."""
Expand Down Expand Up @@ -201,6 +209,10 @@ def bitstring(vals):

return f'{measurements}{final}'

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("cirq.MPSSimulatorStepResult(...)" if cycle else self.__str__())

def _simulator_state(self):
return self.state

Expand Down
38 changes: 38 additions & 0 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cirq
import cirq.contrib.quimb as ccq
import cirq.experiments.google_v2_supremacy_circuit as supremacy_v2
import cirq.testing
from cirq import value


Expand Down Expand Up @@ -275,6 +276,29 @@ def test_trial_result_str():
)


def test_trial_result_repr_pretty():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.StepResult)
final_step_result._simulator_state.return_value = ccq.mps_simulator.MPSState(
qubits=(q0,),
prng=value.parse_random_state(0),
simulation_options=ccq.mps_simulator.MPSOptions(),
)
result = ccq.mps_simulator.MPSTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
)
cirq.testing.assert_repr_pretty(
result,
"""measurements: m=1
output state: TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=set()),
])""",
)
cirq.testing.assert_repr_pretty(result, "cirq.MPSTrialResult(...)", cycle=True)


def test_empty_step_result():
q0 = cirq.LineQubit(0)
sim = ccq.mps_simulator.MPSSimulator()
Expand All @@ -288,6 +312,20 @@ def test_empty_step_result():
)


def test_step_result_repr_pretty():
q0 = cirq.LineQubit(0)
sim = ccq.mps_simulator.MPSSimulator()
step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0))))
cirq.testing.assert_repr_pretty(
step_result,
"""0=0
TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=set()),
])""",
)
cirq.testing.assert_repr_pretty(step_result, "cirq.MPSSimulatorStepResult(...)", cycle=True)


def test_state_equal():
q0, q1 = cirq.LineQubit.range(2)
state0 = ccq.mps_simulator.MPSState(
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/ion/ion_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def __str__(self) -> str:

return diagram.render(horizontal_spacing=3, vertical_spacing=2, use_unicode_characters=True)

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("IonDevice(...)" if cycle else self.__str__())

def _value_equality_values_(self) -> Any:
return (
self._measurement_duration,
Expand Down
13 changes: 7 additions & 6 deletions cirq-core/cirq/ion/ion_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import cirq
import cirq.ion as ci
import cirq.testing


def ion_device(chain_length: int, use_timedelta=False) -> ci.IonDevice:
Expand Down Expand Up @@ -183,12 +184,12 @@ def test_validate_circuit_repeat_measurement_keys():


def test_ion_device_str():
assert (
str(ion_device(3)).strip()
== """
0───1───2
""".strip()
)
assert str(ion_device(3)) == "0───1───2"


def test_ion_device_pretty_repr():
cirq.testing.assert_repr_pretty(ion_device(3), "0───1───2")
cirq.testing.assert_repr_pretty(ion_device(3), "IonDevice(...)", cycle=True)


def test_at():
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,7 @@ def __str__(self) -> str:
diagram.grid_line(q.col, q.row, q2.col, q2.row)

return diagram.render(horizontal_spacing=3, vertical_spacing=2, use_unicode_characters=True)

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("cirq.NeutralAtomDevice(...)" if cycle else self.__str__())
14 changes: 14 additions & 0 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import cirq
import cirq.neutral_atoms as neutral_atoms
import cirq.testing


def square_device(
Expand Down Expand Up @@ -266,5 +267,18 @@ def test_str():
)


def test_repr_pretty():
cirq.testing.assert_repr_pretty(
square_device(2, 2),
"""
(0, 0)───(0, 1)
│ │
│ │
(1, 0)───(1, 1)
""".strip(),
)
cirq.testing.assert_repr_pretty(square_device(2, 2), "cirq.NeutralAtomDevice(...)", cycle=True)


def test_qubit_set():
assert square_device(2, 2).qubit_set() == frozenset(cirq.GridQubit.square(2, 0, 0))
8 changes: 8 additions & 0 deletions cirq-core/cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def __str__(self) -> str:
final = self._final_simulator_state
return f'measurements: {samples}\noutput state: {final}'

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("cirq.CliffordTrialResult(...)" if cycle else self.__str__())


class CliffordSimulatorStepResult(
simulator_base.StepResultBase['clifford.CliffordState', 'clifford.ActOnStabilizerCHFormArgs']
Expand Down Expand Up @@ -168,6 +172,10 @@ def bitstring(vals):

return f'{measurements}{final}'

def _repr_pretty_(self, p, cycle):
"""iPython (Jupyter) pretty print."""
p.text("cirq.CliffordSimulatorStateResult(...)" if cycle else self.__str__())

@property
def state(self):
if self._clifford_state is None:
Expand Down
23 changes: 23 additions & 0 deletions cirq-core/cirq/sim/clifford/clifford_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,20 @@ def test_clifford_trial_result_str():
)


def test_clifford_trial_result_repr_pretty():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.CliffordSimulatorStepResult)
final_step_result._simulator_state.return_value = cirq.CliffordState(qubit_map={q0: 0})
result = cirq.CliffordTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_step_result=final_step_result,
)

cirq.testing.assert_repr_pretty(result, "measurements: m=1\n" "output state: |0⟩")
cirq.testing.assert_repr_pretty(result, "cirq.CliffordTrialResult(...)", cycle=True)


def test_clifford_step_result_str():
q0 = cirq.LineQubit(0)
result = next(
Expand All @@ -252,6 +266,15 @@ def test_clifford_step_result_str():
assert str(result) == "m=0\n" "|0⟩"


def test_clifford_step_result_repr_pretty():
q0 = cirq.LineQubit(0)
result = next(
cirq.CliffordSimulator().simulate_moment_steps(cirq.Circuit(cirq.measure(q0, key='m')))
)
cirq.testing.assert_repr_pretty(result, "m=0\n" "|0⟩")
cirq.testing.assert_repr_pretty(result, "cirq.CliffordSimulatorStateResult(...)", cycle=True)


def test_clifford_step_result_no_measurements_str():
q0 = cirq.LineQubit(0)
result = next(cirq.CliffordSimulator().simulate_moment_steps(cirq.Circuit(cirq.I(q0))))
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,7 @@ def __repr__(self) -> str:
f'params={self.params!r}, measurements={self.measurements!r}, '
f'final_simulator_state={self._final_simulator_state!r})'
)

def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
p.text("cirq.DensityMatrixTrialResult(...)" if cycle else self.__str__())
23 changes: 23 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sympy

import cirq
import cirq.testing


class PlusGate(cirq.Gate):
Expand Down Expand Up @@ -1188,6 +1189,28 @@ def test_density_matrix_trial_result_str():
)


def test_density_matrix_trial_result_repr_pretty():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.StepResult)
final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState(
density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}
)
result = cirq.DensityMatrixTrialResult(
params=cirq.ParamResolver({}), measurements={}, final_step_result=final_step_result
)

fake_printer = cirq.testing.FakePrinter()
result._repr_pretty_(fake_printer, cycle=False)
# numpy varies whitespace in its representation for different versions
# Eliminate whitespace to harden tests against this variation
result_no_whitespace = fake_printer.text_pretty.replace('\n', '').replace(' ', '')
assert result_no_whitespace == (
'measurements:(nomeasurements)finaldensitymatrix:[[0.50.5][0.50.5]]'
)

cirq.testing.assert_repr_pretty(result, "cirq.DensityMatrixTrialResult(...)", cycle=True)


def test_run_sweep_parameters_not_resolved():
a = cirq.LineQubit(0)
simulator = cirq.DensityMatrixSimulator()
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def __str__(self) -> str:
state_vector = str(final)
return f'measurements: {samples}\noutput vector: {state_vector}'

def _repr_pretty_(self, p: Any, cycle: bool) -> None:
"""Text output in Jupyter."""
def _repr_pretty_(self, p: Any, cycle: bool):
"""iPython (Jupyter) pretty print."""
if cycle:
# There should never be a cycle. This is just in case.
p.text('StateVectorTrialResult(...)')
Expand Down
41 changes: 25 additions & 16 deletions cirq-core/cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
assert_all_implemented_act_on_effects_match_unitary,
)

from cirq.testing.consistent_phase_by import (
assert_phase_by_is_consistent_with_unitary,
)

from cirq.testing.consistent_controlled_gate_op import (
assert_controlled_and_controlled_by_identical,
)
Expand All @@ -44,6 +40,10 @@
assert_pauli_expansion_is_consistent_with_unitary,
)

from cirq.testing.consistent_phase_by import (
assert_phase_by_is_consistent_with_unitary,
)

from cirq.testing.consistent_protocols import (
assert_eigengate_implements_consistent_protocols,
assert_has_consistent_trace_distance_bound,
Expand All @@ -63,6 +63,10 @@
assert_specifies_has_unitary_if_unitary,
)

from cirq.testing.deprecation import (
assert_deprecated,
)

from cirq.testing.devices import (
ValidatingTestDevice,
)
Expand All @@ -71,11 +75,18 @@
EqualsTester,
)

from cirq.testing.equivalent_basis_map import (
assert_equivalent_computational_basis_map,
)

from cirq.testing.equivalent_repr_eval import (
assert_equivalent_repr,
)

from cirq.testing.equivalent_basis_map import assert_equivalent_computational_basis_map
from cirq.testing.gate_features import (
TwoQubitGate,
ThreeQubitGate,
)

from cirq.testing.json import (
assert_json_roundtrip_works,
Expand All @@ -95,15 +106,14 @@
assert_logs,
)

from cirq.testing.gate_features import (
TwoQubitGate,
ThreeQubitGate,
)

from cirq.testing.no_identifier_qubit import (
NoIdentifierQubit,
)

from cirq.testing.op_tree import (
assert_equivalent_op_tree,
)

from cirq.testing.order_tester import (
OrderTester,
)
Expand All @@ -114,12 +124,11 @@
random_two_qubit_circuit_with_czs,
)

from cirq.testing.sample_circuits import (
nonoptimal_toffoli_circuit,
from cirq.testing.repr_pretty_tester import (
assert_repr_pretty,
FakePrinter,
)

from cirq.testing.deprecation import (
assert_deprecated,
from cirq.testing.sample_circuits import (
nonoptimal_toffoli_circuit,
)

from cirq.testing.op_tree import assert_equivalent_op_tree
5 changes: 3 additions & 2 deletions cirq-core/cirq/testing/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import os
from typing import Optional

from cirq.testing import assert_logs

ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'


Expand All @@ -42,6 +40,9 @@ def __enter__(self):
os.environ.get(ALLOW_DEPRECATION_IN_TEST, None),
)
os.environ[ALLOW_DEPRECATION_IN_TEST] = 'True'
# Avoid circular import.
from cirq.testing import assert_logs

self.assert_logs = assert_logs(
*(msgs + (deadline,)),
min_level=logging.WARNING,
Expand Down
Loading