Skip to content

Commit

Permalink
Support additional **kwargs with default arguments in Transformer A…
Browse files Browse the repository at this point in the history
…PI. (quantumlib#4890)

* Support additional **kwargs with default arguments in Transformer API.

* Make context a keyword argument and use it's default value from signature, if provided

* Make context an optional parameter with default value as None
  • Loading branch information
tanujkhattar authored Jan 28, 2022
1 parent 556ef42 commit 8f4dfdc
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 35 deletions.
68 changes: 53 additions & 15 deletions cirq/transformers/transformer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Defines the API for circuit transformers in Cirq."""

import dataclasses
import inspect
import enum
import functools
import textwrap
Expand All @@ -24,6 +25,7 @@
Hashable,
List,
overload,
Optional,
Type,
TYPE_CHECKING,
TypeVar,
Expand Down Expand Up @@ -200,7 +202,7 @@ def show(self, level: LogLevel = LogLevel.INFO) -> None:
pass


@dataclasses.dataclass()
@dataclasses.dataclass(frozen=True)
class TransformerContext:
"""Stores common configurable options for transformers.
Expand All @@ -220,7 +222,7 @@ class TransformerContext:

class TRANSFORMER(Protocol):
def __call__(
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
self, circuit: 'cirq.AbstractCircuit', *, context: Optional[TransformerContext] = None
) -> 'cirq.AbstractCircuit':
...

Expand Down Expand Up @@ -248,7 +250,7 @@ def transformer(cls_or_func: Any) -> Any:
>>> @cirq.transformer
>>> def convert_to_cz(
>>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
>>> circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
>>> ) -> cirq.Circuit:
>>> ...
Expand All @@ -259,10 +261,27 @@ def transformer(cls_or_func: Any) -> Any:
>>> def __init__(self):
>>> ...
>>> def __call__(
>>> self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
>>> self,
>>> circuit: cirq.AbstractCircuit,
>>> *,
>>> context: Optional[cirq.TransformerContext] = None,
>>> ) -> cirq.Circuit:
>>> ...
Note that transformers which take additional parameters as `**kwargs`, with default values
specified for each keyword argument, are also supported. For example:
>>> @cirq.transformer
>>> def convert_to_sqrt_iswap(
>>> circuit: cirq.AbstractCircuit,
>>> *,
>>> context: Optional[cirq.TransformerContext] = None,
>>> atol: float = 1e-8,
>>> sqrt_iswap_gate: cirq.ISwapPowGate = cirq.SQRT_ISWAP_INV,
>>> cleanup_operations: bool = True,
>>> ) -> cirq.Circuit:
>>> pass
Args:
cls_or_func: The callable class or function to be decorated.
Expand All @@ -272,41 +291,60 @@ def transformer(cls_or_func: Any) -> Any:
if isinstance(cls_or_func, type):
cls = cls_or_func
method = cls.__call__
default_context = _get_default_context(method)

@functools.wraps(method)
def method_with_logging(
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
self, circuit: 'cirq.AbstractCircuit', **kwargs
) -> 'cirq.AbstractCircuit':
return _transform_and_log(
lambda circuit, context: method(self, circuit, context),
lambda circuit, **kwargs: method(self, circuit, **kwargs),
cls.__name__,
circuit,
context,
kwargs.get('context', default_context),
**kwargs,
)

setattr(cls, '__call__', method_with_logging)
return cls
else:
assert callable(cls_or_func)
func = cls_or_func
default_context = _get_default_context(func)

@functools.wraps(func)
def func_with_logging(
circuit: 'cirq.AbstractCircuit', context: TransformerContext
) -> 'cirq.AbstractCircuit':
return _transform_and_log(func, func.__name__, circuit, context)
def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit':
return _transform_and_log(
func,
func.__name__,
circuit,
kwargs.get('context', default_context),
**kwargs,
)

return func_with_logging


def _get_default_context(func: TRANSFORMER):
sig = inspect.signature(func)
default_context = sig.parameters["context"].default
assert (
default_context != inspect.Parameter.empty
), "`context` argument must have a default value specified."
return default_context


def _transform_and_log(
func: TRANSFORMER,
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
extracted_context: Optional[TransformerContext],
**kwargs,
) -> 'cirq.AbstractCircuit':
"""Helper to log initial and final circuits before and after calling the transformer."""
context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, context)
context.logger.register_final(transformed_circuit, transformer_name)
if extracted_context:
extracted_context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, **kwargs)
if extracted_context:
extracted_context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit
111 changes: 91 additions & 20 deletions cirq/transformers/transformer_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from unittest import mock
from typing import Optional

import cirq
from cirq.transformers.transformer_api import LogLevel
Expand All @@ -21,29 +22,75 @@


@cirq.transformer
class MockTransformerClassCircuit:
class MockTransformerClass:
def __init__(self):
self.mock = mock.Mock()

def __call__(
self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
self, circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
) -> cirq.Circuit:
self.mock(circuit, context)
return circuit.unfreeze()


def make_circuit_transformer_func() -> cirq.TRANSFORMER:
@cirq.value_equality
class CustomArg:
def __init__(self, x: int = 0):
self._x = x

def _value_equality_values_(self):
return (self._x,)


@cirq.transformer
class MockTransformerClassWithDefaults:
def __init__(self):
self.mock = mock.Mock()

def __call__(
self,
circuit: cirq.AbstractCircuit,
*,
context: Optional[cirq.TransformerContext] = cirq.TransformerContext(),
atol: float = 1e-4,
custom_arg: CustomArg = CustomArg(),
) -> cirq.AbstractCircuit:
self.mock(circuit, context, atol, custom_arg)
return circuit[::-1]


def make_transformer_func_with_defaults() -> cirq.TRANSFORMER:
my_mock = mock.Mock()

@cirq.transformer
def func(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.Circuit:
my_mock(circuit, context)
return circuit.unfreeze()
def func(
circuit: cirq.AbstractCircuit,
*,
context: Optional[cirq.TransformerContext] = cirq.TransformerContext(),
atol: float = 1e-4,
custom_arg: CustomArg = CustomArg(),
) -> cirq.FrozenCircuit:
my_mock(circuit, context, atol, custom_arg)
return circuit.freeze()

func.mock = my_mock # type: ignore
return func


def make_transformer_func() -> cirq.TRANSFORMER:
my_mock = mock.Mock()

@cirq.transformer
def mock_tranformer_func(
circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
) -> cirq.Circuit:
my_mock(circuit, context)
return circuit.unfreeze()

mock_tranformer_func.mock = my_mock # type: ignore
return mock_tranformer_func


@pytest.mark.parametrize(
'context',
[
Expand All @@ -53,14 +100,11 @@ def func(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cir
)
@pytest.mark.parametrize(
'transformer',
[
MockTransformerClassCircuit(),
make_circuit_transformer_func(),
],
[MockTransformerClass(), make_transformer_func()],
)
def test_transformer_decorator(context, transformer):
circuit = cirq.Circuit(cirq.X(cirq.NamedQubit("a")))
transformer(circuit, context)
transformer(circuit, context=context)
transformer.mock.assert_called_with(circuit, context)
if not isinstance(context.logger, cirq.TransformerLogger):
transformer_name = (
Expand All @@ -70,11 +114,32 @@ def test_transformer_decorator(context, transformer):
context.logger.register_final.assert_called_with(circuit, transformer_name)


@pytest.mark.parametrize(
'transformer',
[
MockTransformerClassWithDefaults(),
make_transformer_func_with_defaults(),
],
)
def test_transformer_decorator_with_defaults(transformer):
circuit = cirq.Circuit(cirq.X(cirq.NamedQubit("a")))
context = cirq.TransformerContext(ignore_tags=("tags", "to", "ignore"))
transformer(circuit)
transformer.mock.assert_called_with(circuit, cirq.TransformerContext(), 1e-4, CustomArg())
transformer(circuit, context=context, atol=1e-3)
transformer.mock.assert_called_with(circuit, context, 1e-3, CustomArg())
transformer(circuit, context=context, custom_arg=CustomArg(10))
transformer.mock.assert_called_with(circuit, context, 1e-4, CustomArg(10))
transformer(circuit, context=context, atol=1e-2, custom_arg=CustomArg(12))
transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12))


@cirq.transformer
class T1:
def __call__(
self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
self, circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None
) -> cirq.AbstractCircuit:
assert context is not None
context.logger.log("First Verbose Log", "of T1", level=LogLevel.DEBUG)
context.logger.log("Second INFO Log", "of T1", level=LogLevel.INFO)
context.logger.log("Third WARNING Log", "of T1", level=LogLevel.WARNING)
Expand All @@ -85,19 +150,25 @@ def __call__(


@cirq.transformer
def t2(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.FrozenCircuit:
def t2(
circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None
) -> cirq.FrozenCircuit:
assert context is not None
context.logger.log("First INFO Log", "of T2 Start")
circuit = t1(circuit, context)
circuit = t1(circuit, context=context)
context.logger.log("Second INFO Log", "of T2 End")
return circuit[::2].freeze()


@cirq.transformer
def t3(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.Circuit:
def t3(
circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None
) -> cirq.Circuit:
assert context is not None
context.logger.log("First INFO Log", "of T3 Start")
circuit = t1(circuit, context)
circuit = t1(circuit, context=context)
context.logger.log("Second INFO Log", "of T3 Middle")
circuit = t2(circuit, context)
circuit = t2(circuit, context=context)
context.logger.log("Third INFO Log", "of T3 End")
return circuit.unfreeze()

Expand All @@ -123,7 +194,7 @@ def test_transformer_stats_logger_show_levels(capfd):
q = cirq.LineQubit.range(2)
context = cirq.TransformerContext(logger=cirq.TransformerLogger())
initial_circuit = cirq.Circuit(cirq.H.on_each(*q), cirq.CNOT(*q))
_ = t1(initial_circuit, context)
_ = t1(initial_circuit, context=context)
info_line = 'LogLevel.INFO Second INFO Log of T1'
debug_line = 'LogLevel.DEBUG First Verbose Log of T1'
warning_line = 'LogLevel.WARNING Third WARNING Log of T1'
Expand All @@ -150,8 +221,8 @@ def test_transformer_stats_logger_linear_and_nested(capfd):
q = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.H.on_each(*q), cirq.CNOT(*q))
context = cirq.TransformerContext(logger=cirq.TransformerLogger())
circuit = t1(circuit, context)
circuit = t3(circuit, context)
circuit = t1(circuit, context=context)
circuit = t3(circuit, context=context)
context.logger.show(LogLevel.ALL)
out, _ = capfd.readouterr()
assert (
Expand Down

0 comments on commit 8f4dfdc

Please sign in to comment.