Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type safety with generics on simulators #3818

Merged
merged 10 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ class MPSOptions:
sum_prob_atol: float = 1e-3


class MPSSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
class MPSSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateState[
'cirq.contrib.quimb.mps_simulator.MPSSimulatorStepResult',
'cirq.contrib.quimb.mps_simulator.MPSTrialResult',
'cirq.contrib.quimb.mps_simulator.MPSState',
],
):
"""An efficient simulator for MPS circuits."""

def __init__(
Expand Down Expand Up @@ -247,7 +254,7 @@ def __str__(self) -> str:
return f'measurements: {samples}\noutput state: {final}'


class MPSSimulatorStepResult(simulator.StepResult):
class MPSSimulatorStepResult(simulator.StepResult['MPSState']):
"""A `StepResult` that can perform measurements."""

def __init__(self, state, measurements):
Expand Down
4 changes: 1 addition & 3 deletions cirq/experiments/xeb_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
List,
Optional,
Sequence,
cast,
TYPE_CHECKING,
Dict,
Any,
Expand Down Expand Up @@ -73,8 +72,7 @@ def __call__(self, task: _Simulate2qXEBTask) -> List[Dict[str, Any]]:
if cycle_depth not in cycle_depths:
continue

psi = cast(sim.SparseSimulatorStep, step_result)
psi = psi.state_vector()
psi = step_result.state_vector()
pure_probs = np.abs(psi) ** 2

records += [
Expand Down
4 changes: 2 additions & 2 deletions cirq/experiments/xeb_simulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import multiprocessing
import time
from typing import Dict, Any, cast, Optional
from typing import Dict, Any, Optional
from typing import Sequence

import numpy as np
Expand Down Expand Up @@ -89,7 +89,7 @@ def _ref_simulate_2q_xeb_circuit(task: Dict[str, Any]):
tcircuit = cirq.resolve_parameters_once(tcircuit, param_resolver=param_resolver)

pure_sim = cirq.Simulator()
psi = cast(cirq.StateVectorTrialResult, pure_sim.simulate(tcircuit))
psi = pure_sim.simulate(tcircuit)
psi = psi.final_state_vector
pure_probs = np.abs(psi) ** 2

Expand Down
8 changes: 3 additions & 5 deletions cirq/google/calibration/engine_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Sequence,
Tuple,
Union,
cast,
)

import numpy as np
Expand All @@ -30,8 +29,7 @@
Simulator,
SimulatesSamples,
SimulatesIntermediateStateVector,
SparseSimulatorStep,
StepResult,
StateVectorStepResult,
)
from cirq.study import ParamResolver
from cirq.value import RANDOM_STATE_OR_SEED_LIKE, parse_random_state
Expand Down Expand Up @@ -319,7 +317,7 @@ def create_from_characterizations_sqrt_iswap(

def final_state_vector(self, program: Circuit) -> np.array:
result = self.simulate(program)
return cast(SparseSimulatorStep, result).state_vector()
return result.state_vector()

def get_calibrations(
self, requests: Sequence[PhasedFSimCalibrationRequest]
Expand Down Expand Up @@ -402,7 +400,7 @@ def _base_iterator(
circuit: Circuit,
qubit_order: QubitOrderOrList,
initial_state: Any,
) -> Iterator[StepResult]:
) -> Iterator[StateVectorStepResult]:
converted = _convert_to_circuit_with_drift(self, circuit)
return self._simulator._base_iterator(converted, qubit_order, initial_state)

Expand Down
9 changes: 7 additions & 2 deletions cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
from cirq.sim.simulator import check_all_resolved


class CliffordSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
class CliffordSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateState[
'CliffordSimulatorStepResult', 'CliffordTrialResult', 'CliffordState'
],
):
"""An efficient simulator for Clifford circuits."""

def __init__(self, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None):
Expand Down Expand Up @@ -168,7 +173,7 @@ def __str__(self) -> str:
return f'measurements: {samples}\noutput state: {final}'


class CliffordSimulatorStepResult(simulator.StepResult):
class CliffordSimulatorStepResult(simulator.StepResult['CliffordState']):
"""A `StepResult` that includes `StateVectorMixin` methods."""

def __init__(self, state, measurements):
Expand Down
9 changes: 7 additions & 2 deletions cirq/sim/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def __init__(self, num_qubits: int, tensor: np.ndarray):
self.buffers = [np.empty_like(tensor) for _ in range(3)]


class DensityMatrixSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
class DensityMatrixSimulator(
simulator.SimulatesSamples,
simulator.SimulatesIntermediateState[
'DensityMatrixStepResult', 'DensityMatrixTrialResult', 'DensityMatrixSimulatorState'
],
):
"""A simulator for density matrices and noisy quantum circuits.

This simulator can be applied on circuits that are made up of operations
Expand Down Expand Up @@ -320,7 +325,7 @@ def _create_simulator_trial_result(
)


class DensityMatrixStepResult(simulator.StepResult):
class DensityMatrixStepResult(simulator.StepResult['DensityMatrixSimulatorState']):
"""A single step in the simulation of the DensityMatrixSimulator.

Attributes:
Expand Down
12 changes: 6 additions & 6 deletions cirq/sim/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from cirq import circuits, protocols, study, devices, ops, value
from cirq._doc import document
from cirq.sim import sparse_simulator, density_matrix_simulator, state_vector_simulator
from cirq.sim import sparse_simulator, density_matrix_simulator
from cirq.sim.clifford import clifford_simulator
from cirq._compat import deprecated

Expand Down Expand Up @@ -155,7 +155,7 @@ def final_state_vector(
param_resolver=param_resolver,
)

return cast(sparse_simulator.SparseSimulatorStep, result).state_vector()
return result.state_vector()


@deprecated(
Expand Down Expand Up @@ -273,16 +273,16 @@ def final_density_matrix(

if can_do_unitary_simulation:
# pure case: use SparseSimulator
result = sparse_simulator.Simulator(dtype=dtype, seed=seed).simulate(
sparse_result = sparse_simulator.Simulator(dtype=dtype, seed=seed).simulate(
program=circuit_like,
initial_state=initial_state,
qubit_order=qubit_order,
param_resolver=param_resolver,
)
return cast(state_vector_simulator.StateVectorTrialResult, result).density_matrix_of()
return sparse_result.density_matrix_of()
else:
# noisy case: use DensityMatrixSimulator with dephasing
result = density_matrix_simulator.DensityMatrixSimulator(
density_result = density_matrix_simulator.DensityMatrixSimulator(
dtype=dtype,
noise=noise,
seed=seed,
Expand All @@ -293,4 +293,4 @@ def final_density_matrix(
qubit_order=qubit_order,
param_resolver=param_resolver,
)
return cast(density_matrix_simulator.DensityMatrixTrialResult, result).final_density_matrix
return density_result.final_density_matrix
42 changes: 26 additions & 16 deletions cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
TYPE_CHECKING,
Set,
cast,
TypeVar,
Generic,
)

import abc
Expand All @@ -53,6 +55,11 @@
import cirq


TStepResult = TypeVar('TStepResult', bound='StepResult')
TSimulationTrialResult = TypeVar('TSimulationTrialResult', bound='SimulationTrialResult')
TSimulatorState = TypeVar('TSimulatorState')


class SimulatesSamples(work.Sampler, metaclass=abc.ABCMeta):
"""Simulator that mimics running on quantum hardware.

Expand Down Expand Up @@ -288,7 +295,7 @@ def simulate_expectation_values_sweep(
"""


class SimulatesFinalState(metaclass=abc.ABCMeta):
class SimulatesFinalState(Generic[TSimulationTrialResult], metaclass=abc.ABCMeta):
"""Simulator that allows access to the simulator's final state.

Implementors of this interface should implement the simulate_sweep
Expand All @@ -305,7 +312,7 @@ def simulate(
param_resolver: 'study.ParamResolverOrSimilarType' = None,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
initial_state: Any = None,
) -> 'SimulationTrialResult':
) -> TSimulationTrialResult:
"""Simulates the supplied Circuit.

This method returns a result which allows access to the entire
Expand Down Expand Up @@ -335,7 +342,7 @@ def simulate_sweep(
params: study.Sweepable,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
initial_state: Any = None,
) -> List['SimulationTrialResult']:
) -> List[TSimulationTrialResult]:
"""Simulates the supplied Circuit.

This method returns a result which allows access to the entire final
Expand All @@ -359,7 +366,11 @@ def simulate_sweep(
raise NotImplementedError()


class SimulatesIntermediateState(SimulatesFinalState, metaclass=abc.ABCMeta):
class SimulatesIntermediateState(
Generic[TStepResult, TSimulationTrialResult, TSimulatorState],
SimulatesFinalState[TSimulationTrialResult],
metaclass=abc.ABCMeta,
):
"""A SimulatesFinalState that simulates a circuit by moments.

Whereas a general SimulatesFinalState may return the entire simulator
Expand All @@ -379,7 +390,7 @@ def simulate_sweep(
params: study.Sweepable,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
initial_state: Any = None,
) -> List['SimulationTrialResult']:
) -> List[TSimulationTrialResult]:
"""Simulates the supplied Circuit.

This method returns a result which allows access to the entire
Expand Down Expand Up @@ -425,7 +436,7 @@ def simulate_moment_steps(
param_resolver: 'study.ParamResolverOrSimilarType' = None,
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
initial_state: Any = None,
) -> Iterator:
) -> Iterator[TStepResult]:
"""Returns an iterator of StepResults for each moment simulated.

If the circuit being simulated is empty, a single step result should
Expand Down Expand Up @@ -456,7 +467,7 @@ def _simulator_iterator(
param_resolver: study.ParamResolver,
qubit_order: ops.QubitOrderOrList,
initial_state: Any,
) -> Iterator:
) -> Iterator[TStepResult]:
"""Iterator over StepResult from Moments of a Circuit.

If the initial state is an int, the state is set to the computational
Expand Down Expand Up @@ -493,7 +504,7 @@ def _base_iterator(
circuit: circuits.Circuit,
qubit_order: ops.QubitOrderOrList,
initial_state: Any,
) -> Iterator['StepResult']:
) -> Iterator[TStepResult]:
"""Iterator over StepResult from Moments of a Circuit.

Args:
Expand All @@ -512,13 +523,14 @@ def _base_iterator(
"""
raise NotImplementedError()

@abc.abstractmethod
def _create_simulator_trial_result(
self,
params: study.ParamResolver,
measurements: Dict[str, np.ndarray],
final_simulator_state: Any,
) -> 'SimulationTrialResult':
"""This method can be overridden to creation of a trial result.
final_simulator_state: TSimulatorState,
) -> TSimulationTrialResult:
"""This method can be implemented to create a trial result.

Args:
params: The ParamResolver for this trial.
Expand All @@ -529,12 +541,10 @@ def _create_simulator_trial_result(
Returns:
The SimulationTrialResult.
"""
return SimulationTrialResult(
params=params, measurements=measurements, final_simulator_state=final_simulator_state
)
raise NotImplementedError()


class StepResult(metaclass=abc.ABCMeta):
class StepResult(Generic[TSimulatorState], metaclass=abc.ABCMeta):
"""Results of a step of a SimulatesIntermediateState.

Attributes:
Expand All @@ -546,7 +556,7 @@ def __init__(self, measurements: Optional[Dict[str, List[int]]] = None) -> None:
self.measurements = measurements or collections.defaultdict(list)

@abc.abstractmethod
def _simulator_state(self) -> Any:
def _simulator_state(self) -> TSimulatorState:
"""Returns the simulator state of the simulator after this step.

This method starts with an underscore to indicate that it is private.
Expand Down
Loading