Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support additional **kwargs with default arguments in Transformer API. #4890

Merged
Merged
68 changes: 53 additions & 15 deletions cirq-core/cirq/transformers/transformer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Defines the API for circuit transformers in Cirq."""

import dataclasses
import inspect
import enum
import functools
import textwrap
Expand All @@ -24,6 +25,7 @@
Hashable,
List,
overload,
Optional,
Type,
TYPE_CHECKING,
TypeVar,
Expand Down Expand Up @@ -200,7 +202,7 @@ def show(self, level: LogLevel = LogLevel.INFO) -> None:
pass


@dataclasses.dataclass()
@dataclasses.dataclass(frozen=True)
class TransformerContext:
"""Stores common configurable options for transformers.

Expand All @@ -220,7 +222,7 @@ class TransformerContext:

class TRANSFORMER(Protocol):
def __call__(
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
self, circuit: 'cirq.AbstractCircuit', *, context: Optional[TransformerContext] = None
) -> 'cirq.AbstractCircuit':
...

Expand Down Expand Up @@ -248,7 +250,7 @@ def transformer(cls_or_func: Any) -> Any:

>>> @cirq.transformer
>>> def convert_to_cz(
>>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
>>> circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
>>> ) -> cirq.Circuit:
>>> ...

Expand All @@ -259,10 +261,27 @@ def transformer(cls_or_func: Any) -> Any:
>>> def __init__(self):
>>> ...
>>> def __call__(
>>> self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
>>> self,
>>> circuit: cirq.AbstractCircuit,
>>> *,
>>> context: Optional[cirq.TransformerContext] = None,
>>> ) -> cirq.Circuit:
>>> ...

Note that transformers which take additional parameters as `**kwargs`, with default values
specified for each keyword argument, are also supported. For example:

>>> @cirq.transformer
>>> def convert_to_sqrt_iswap(
>>> circuit: cirq.AbstractCircuit,
>>> *,
>>> context: Optional[cirq.TransformerContext] = None,
>>> atol: float = 1e-8,
>>> sqrt_iswap_gate: cirq.ISwapPowGate = cirq.SQRT_ISWAP_INV,
>>> cleanup_operations: bool = True,
>>> ) -> cirq.Circuit:
>>> pass

Args:
cls_or_func: The callable class or function to be decorated.

Expand All @@ -272,41 +291,60 @@ def transformer(cls_or_func: Any) -> Any:
if isinstance(cls_or_func, type):
cls = cls_or_func
method = cls.__call__
default_context = _get_default_context(method)

@functools.wraps(method)
def method_with_logging(
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
self, circuit: 'cirq.AbstractCircuit', **kwargs
) -> 'cirq.AbstractCircuit':
return _transform_and_log(
lambda circuit, context: method(self, circuit, context),
lambda circuit, **kwargs: method(self, circuit, **kwargs),
cls.__name__,
circuit,
context,
kwargs.get('context', default_context),
**kwargs,
)

setattr(cls, '__call__', method_with_logging)
return cls
else:
assert callable(cls_or_func)
func = cls_or_func
default_context = _get_default_context(func)

@functools.wraps(func)
def func_with_logging(
circuit: 'cirq.AbstractCircuit', context: TransformerContext
) -> 'cirq.AbstractCircuit':
return _transform_and_log(func, func.__name__, circuit, context)
def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit':
return _transform_and_log(
func,
func.__name__,
circuit,
kwargs.get('context', default_context),
**kwargs,
)

return func_with_logging


def _get_default_context(func: TRANSFORMER):
sig = inspect.signature(func)
default_context = sig.parameters["context"].default
assert (
default_context != inspect.Parameter.empty
), "`context` argument must have a default value specified."
return default_context


def _transform_and_log(
func: TRANSFORMER,
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
extracted_context: Optional[TransformerContext],
**kwargs,
) -> 'cirq.AbstractCircuit':
"""Helper to log initial and final circuits before and after calling the transformer."""
context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, context)
context.logger.register_final(transformed_circuit, transformer_name)
if extracted_context:
extracted_context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, **kwargs)
if extracted_context:
extracted_context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit
111 changes: 91 additions & 20 deletions cirq-core/cirq/transformers/transformer_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from unittest import mock
from typing import Optional

import cirq
from cirq.transformers.transformer_api import LogLevel
Expand All @@ -21,29 +22,75 @@


@cirq.transformer
class MockTransformerClassCircuit:
class MockTransformerClass:
def __init__(self):
self.mock = mock.Mock()

def __call__(
self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
self, circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
) -> cirq.Circuit:
self.mock(circuit, context)
return circuit.unfreeze()


def make_circuit_transformer_func() -> cirq.TRANSFORMER:
@cirq.value_equality
class CustomArg:
def __init__(self, x: int = 0):
self._x = x

def _value_equality_values_(self):
return (self._x,)


@cirq.transformer
class MockTransformerClassWithDefaults:
def __init__(self):
self.mock = mock.Mock()

def __call__(
self,
circuit: cirq.AbstractCircuit,
*,
context: Optional[cirq.TransformerContext] = cirq.TransformerContext(),
atol: float = 1e-4,
custom_arg: CustomArg = CustomArg(),
) -> cirq.AbstractCircuit:
self.mock(circuit, context, atol, custom_arg)
return circuit[::-1]


def make_transformer_func_with_defaults() -> cirq.TRANSFORMER:
my_mock = mock.Mock()

@cirq.transformer
def func(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.Circuit:
my_mock(circuit, context)
return circuit.unfreeze()
def func(
circuit: cirq.AbstractCircuit,
*,
context: Optional[cirq.TransformerContext] = cirq.TransformerContext(),
atol: float = 1e-4,
custom_arg: CustomArg = CustomArg(),
) -> cirq.FrozenCircuit:
my_mock(circuit, context, atol, custom_arg)
return circuit.freeze()

func.mock = my_mock # type: ignore
return func


def make_transformer_func() -> cirq.TRANSFORMER:
my_mock = mock.Mock()

@cirq.transformer
def mock_tranformer_func(
circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
) -> cirq.Circuit:
my_mock(circuit, context)
return circuit.unfreeze()

mock_tranformer_func.mock = my_mock # type: ignore
return mock_tranformer_func


@pytest.mark.parametrize(
'context',
[
Expand All @@ -53,14 +100,11 @@ def func(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cir
)
@pytest.mark.parametrize(
'transformer',
[
MockTransformerClassCircuit(),
make_circuit_transformer_func(),
],
[MockTransformerClass(), make_transformer_func()],
)
def test_transformer_decorator(context, transformer):
circuit = cirq.Circuit(cirq.X(cirq.NamedQubit("a")))
transformer(circuit, context)
transformer(circuit, context=context)
transformer.mock.assert_called_with(circuit, context)
if not isinstance(context.logger, cirq.TransformerLogger):
transformer_name = (
Expand All @@ -70,11 +114,32 @@ def test_transformer_decorator(context, transformer):
context.logger.register_final.assert_called_with(circuit, transformer_name)


@pytest.mark.parametrize(
'transformer',
[
MockTransformerClassWithDefaults(),
make_transformer_func_with_defaults(),
],
)
def test_transformer_decorator_with_defaults(transformer):
circuit = cirq.Circuit(cirq.X(cirq.NamedQubit("a")))
context = cirq.TransformerContext(ignore_tags=("tags", "to", "ignore"))
transformer(circuit)
transformer.mock.assert_called_with(circuit, cirq.TransformerContext(), 1e-4, CustomArg())
transformer(circuit, context=context, atol=1e-3)
transformer.mock.assert_called_with(circuit, context, 1e-3, CustomArg())
transformer(circuit, context=context, custom_arg=CustomArg(10))
transformer.mock.assert_called_with(circuit, context, 1e-4, CustomArg(10))
transformer(circuit, context=context, atol=1e-2, custom_arg=CustomArg(12))
transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12))


@cirq.transformer
class T1:
def __call__(
self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
self, circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None
) -> cirq.AbstractCircuit:
assert context is not None
context.logger.log("First Verbose Log", "of T1", level=LogLevel.DEBUG)
context.logger.log("Second INFO Log", "of T1", level=LogLevel.INFO)
context.logger.log("Third WARNING Log", "of T1", level=LogLevel.WARNING)
Expand All @@ -85,19 +150,25 @@ def __call__(


@cirq.transformer
def t2(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.FrozenCircuit:
def t2(
circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None
) -> cirq.FrozenCircuit:
assert context is not None
context.logger.log("First INFO Log", "of T2 Start")
circuit = t1(circuit, context)
circuit = t1(circuit, context=context)
context.logger.log("Second INFO Log", "of T2 End")
return circuit[::2].freeze()


@cirq.transformer
def t3(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.Circuit:
def t3(
circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None
) -> cirq.Circuit:
assert context is not None
context.logger.log("First INFO Log", "of T3 Start")
circuit = t1(circuit, context)
circuit = t1(circuit, context=context)
context.logger.log("Second INFO Log", "of T3 Middle")
circuit = t2(circuit, context)
circuit = t2(circuit, context=context)
context.logger.log("Third INFO Log", "of T3 End")
return circuit.unfreeze()

Expand All @@ -123,7 +194,7 @@ def test_transformer_stats_logger_show_levels(capfd):
q = cirq.LineQubit.range(2)
context = cirq.TransformerContext(logger=cirq.TransformerLogger())
initial_circuit = cirq.Circuit(cirq.H.on_each(*q), cirq.CNOT(*q))
_ = t1(initial_circuit, context)
_ = t1(initial_circuit, context=context)
info_line = 'LogLevel.INFO Second INFO Log of T1'
debug_line = 'LogLevel.DEBUG First Verbose Log of T1'
warning_line = 'LogLevel.WARNING Third WARNING Log of T1'
Expand All @@ -150,8 +221,8 @@ def test_transformer_stats_logger_linear_and_nested(capfd):
q = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.H.on_each(*q), cirq.CNOT(*q))
context = cirq.TransformerContext(logger=cirq.TransformerLogger())
circuit = t1(circuit, context)
circuit = t3(circuit, context)
circuit = t1(circuit, context=context)
circuit = t3(circuit, context=context)
context.logger.show(LogLevel.ALL)
out, _ = capfd.readouterr()
assert (
Expand Down