From 8f4dfdc1aaf5cc66b0a572a72f8d6159c8a09830 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Sat, 29 Jan 2022 01:37:40 +0530 Subject: [PATCH] Support additional `**kwargs` with default arguments in Transformer API. (#4890) * Support additional **kwargs with default arguments in Transformer API. * Make context a keyword argument and use it's default value from signature, if provided * Make context an optional parameter with default value as None --- cirq/transformers/transformer_api.py | 68 ++++++++++--- cirq/transformers/transformer_api_test.py | 111 ++++++++++++++++++---- 2 files changed, 144 insertions(+), 35 deletions(-) diff --git a/cirq/transformers/transformer_api.py b/cirq/transformers/transformer_api.py index 484e731ef42..f7aa06b48e7 100644 --- a/cirq/transformers/transformer_api.py +++ b/cirq/transformers/transformer_api.py @@ -15,6 +15,7 @@ """Defines the API for circuit transformers in Cirq.""" import dataclasses +import inspect import enum import functools import textwrap @@ -24,6 +25,7 @@ Hashable, List, overload, + Optional, Type, TYPE_CHECKING, TypeVar, @@ -200,7 +202,7 @@ def show(self, level: LogLevel = LogLevel.INFO) -> None: pass -@dataclasses.dataclass() +@dataclasses.dataclass(frozen=True) class TransformerContext: """Stores common configurable options for transformers. @@ -220,7 +222,7 @@ class TransformerContext: class TRANSFORMER(Protocol): def __call__( - self, circuit: 'cirq.AbstractCircuit', context: TransformerContext + self, circuit: 'cirq.AbstractCircuit', *, context: Optional[TransformerContext] = None ) -> 'cirq.AbstractCircuit': ... @@ -248,7 +250,7 @@ def transformer(cls_or_func: Any) -> Any: >>> @cirq.transformer >>> def convert_to_cz( - >>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext + >>> circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None >>> ) -> cirq.Circuit: >>> ... @@ -259,10 +261,27 @@ def transformer(cls_or_func: Any) -> Any: >>> def __init__(self): >>> ... >>> def __call__( - >>> self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext + >>> self, + >>> circuit: cirq.AbstractCircuit, + >>> *, + >>> context: Optional[cirq.TransformerContext] = None, >>> ) -> cirq.Circuit: >>> ... + Note that transformers which take additional parameters as `**kwargs`, with default values + specified for each keyword argument, are also supported. For example: + + >>> @cirq.transformer + >>> def convert_to_sqrt_iswap( + >>> circuit: cirq.AbstractCircuit, + >>> *, + >>> context: Optional[cirq.TransformerContext] = None, + >>> atol: float = 1e-8, + >>> sqrt_iswap_gate: cirq.ISwapPowGate = cirq.SQRT_ISWAP_INV, + >>> cleanup_operations: bool = True, + >>> ) -> cirq.Circuit: + >>> pass + Args: cls_or_func: The callable class or function to be decorated. @@ -272,16 +291,18 @@ def transformer(cls_or_func: Any) -> Any: if isinstance(cls_or_func, type): cls = cls_or_func method = cls.__call__ + default_context = _get_default_context(method) @functools.wraps(method) def method_with_logging( - self, circuit: 'cirq.AbstractCircuit', context: TransformerContext + self, circuit: 'cirq.AbstractCircuit', **kwargs ) -> 'cirq.AbstractCircuit': return _transform_and_log( - lambda circuit, context: method(self, circuit, context), + lambda circuit, **kwargs: method(self, circuit, **kwargs), cls.__name__, circuit, - context, + kwargs.get('context', default_context), + **kwargs, ) setattr(cls, '__call__', method_with_logging) @@ -289,24 +310,41 @@ def method_with_logging( else: assert callable(cls_or_func) func = cls_or_func + default_context = _get_default_context(func) @functools.wraps(func) - def func_with_logging( - circuit: 'cirq.AbstractCircuit', context: TransformerContext - ) -> 'cirq.AbstractCircuit': - return _transform_and_log(func, func.__name__, circuit, context) + def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit': + return _transform_and_log( + func, + func.__name__, + circuit, + kwargs.get('context', default_context), + **kwargs, + ) return func_with_logging +def _get_default_context(func: TRANSFORMER): + sig = inspect.signature(func) + default_context = sig.parameters["context"].default + assert ( + default_context != inspect.Parameter.empty + ), "`context` argument must have a default value specified." + return default_context + + def _transform_and_log( func: TRANSFORMER, transformer_name: str, circuit: 'cirq.AbstractCircuit', - context: TransformerContext, + extracted_context: Optional[TransformerContext], + **kwargs, ) -> 'cirq.AbstractCircuit': """Helper to log initial and final circuits before and after calling the transformer.""" - context.logger.register_initial(circuit, transformer_name) - transformed_circuit = func(circuit, context) - context.logger.register_final(transformed_circuit, transformer_name) + if extracted_context: + extracted_context.logger.register_initial(circuit, transformer_name) + transformed_circuit = func(circuit, **kwargs) + if extracted_context: + extracted_context.logger.register_final(transformed_circuit, transformer_name) return transformed_circuit diff --git a/cirq/transformers/transformer_api_test.py b/cirq/transformers/transformer_api_test.py index 0434ef08b46..1b9428b3d67 100644 --- a/cirq/transformers/transformer_api_test.py +++ b/cirq/transformers/transformer_api_test.py @@ -13,6 +13,7 @@ # limitations under the License. from unittest import mock +from typing import Optional import cirq from cirq.transformers.transformer_api import LogLevel @@ -21,29 +22,75 @@ @cirq.transformer -class MockTransformerClassCircuit: +class MockTransformerClass: def __init__(self): self.mock = mock.Mock() def __call__( - self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext + self, circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None ) -> cirq.Circuit: self.mock(circuit, context) return circuit.unfreeze() -def make_circuit_transformer_func() -> cirq.TRANSFORMER: +@cirq.value_equality +class CustomArg: + def __init__(self, x: int = 0): + self._x = x + + def _value_equality_values_(self): + return (self._x,) + + +@cirq.transformer +class MockTransformerClassWithDefaults: + def __init__(self): + self.mock = mock.Mock() + + def __call__( + self, + circuit: cirq.AbstractCircuit, + *, + context: Optional[cirq.TransformerContext] = cirq.TransformerContext(), + atol: float = 1e-4, + custom_arg: CustomArg = CustomArg(), + ) -> cirq.AbstractCircuit: + self.mock(circuit, context, atol, custom_arg) + return circuit[::-1] + + +def make_transformer_func_with_defaults() -> cirq.TRANSFORMER: my_mock = mock.Mock() @cirq.transformer - def func(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.Circuit: - my_mock(circuit, context) - return circuit.unfreeze() + def func( + circuit: cirq.AbstractCircuit, + *, + context: Optional[cirq.TransformerContext] = cirq.TransformerContext(), + atol: float = 1e-4, + custom_arg: CustomArg = CustomArg(), + ) -> cirq.FrozenCircuit: + my_mock(circuit, context, atol, custom_arg) + return circuit.freeze() func.mock = my_mock # type: ignore return func +def make_transformer_func() -> cirq.TRANSFORMER: + my_mock = mock.Mock() + + @cirq.transformer + def mock_tranformer_func( + circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None + ) -> cirq.Circuit: + my_mock(circuit, context) + return circuit.unfreeze() + + mock_tranformer_func.mock = my_mock # type: ignore + return mock_tranformer_func + + @pytest.mark.parametrize( 'context', [ @@ -53,14 +100,11 @@ def func(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cir ) @pytest.mark.parametrize( 'transformer', - [ - MockTransformerClassCircuit(), - make_circuit_transformer_func(), - ], + [MockTransformerClass(), make_transformer_func()], ) def test_transformer_decorator(context, transformer): circuit = cirq.Circuit(cirq.X(cirq.NamedQubit("a"))) - transformer(circuit, context) + transformer(circuit, context=context) transformer.mock.assert_called_with(circuit, context) if not isinstance(context.logger, cirq.TransformerLogger): transformer_name = ( @@ -70,11 +114,32 @@ def test_transformer_decorator(context, transformer): context.logger.register_final.assert_called_with(circuit, transformer_name) +@pytest.mark.parametrize( + 'transformer', + [ + MockTransformerClassWithDefaults(), + make_transformer_func_with_defaults(), + ], +) +def test_transformer_decorator_with_defaults(transformer): + circuit = cirq.Circuit(cirq.X(cirq.NamedQubit("a"))) + context = cirq.TransformerContext(ignore_tags=("tags", "to", "ignore")) + transformer(circuit) + transformer.mock.assert_called_with(circuit, cirq.TransformerContext(), 1e-4, CustomArg()) + transformer(circuit, context=context, atol=1e-3) + transformer.mock.assert_called_with(circuit, context, 1e-3, CustomArg()) + transformer(circuit, context=context, custom_arg=CustomArg(10)) + transformer.mock.assert_called_with(circuit, context, 1e-4, CustomArg(10)) + transformer(circuit, context=context, atol=1e-2, custom_arg=CustomArg(12)) + transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12)) + + @cirq.transformer class T1: def __call__( - self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext + self, circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None ) -> cirq.AbstractCircuit: + assert context is not None context.logger.log("First Verbose Log", "of T1", level=LogLevel.DEBUG) context.logger.log("Second INFO Log", "of T1", level=LogLevel.INFO) context.logger.log("Third WARNING Log", "of T1", level=LogLevel.WARNING) @@ -85,19 +150,25 @@ def __call__( @cirq.transformer -def t2(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.FrozenCircuit: +def t2( + circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None +) -> cirq.FrozenCircuit: + assert context is not None context.logger.log("First INFO Log", "of T2 Start") - circuit = t1(circuit, context) + circuit = t1(circuit, context=context) context.logger.log("Second INFO Log", "of T2 End") return circuit[::2].freeze() @cirq.transformer -def t3(circuit: cirq.AbstractCircuit, context: cirq.TransformerContext) -> cirq.Circuit: +def t3( + circuit: cirq.AbstractCircuit, context: Optional[cirq.TransformerContext] = None +) -> cirq.Circuit: + assert context is not None context.logger.log("First INFO Log", "of T3 Start") - circuit = t1(circuit, context) + circuit = t1(circuit, context=context) context.logger.log("Second INFO Log", "of T3 Middle") - circuit = t2(circuit, context) + circuit = t2(circuit, context=context) context.logger.log("Third INFO Log", "of T3 End") return circuit.unfreeze() @@ -123,7 +194,7 @@ def test_transformer_stats_logger_show_levels(capfd): q = cirq.LineQubit.range(2) context = cirq.TransformerContext(logger=cirq.TransformerLogger()) initial_circuit = cirq.Circuit(cirq.H.on_each(*q), cirq.CNOT(*q)) - _ = t1(initial_circuit, context) + _ = t1(initial_circuit, context=context) info_line = 'LogLevel.INFO Second INFO Log of T1' debug_line = 'LogLevel.DEBUG First Verbose Log of T1' warning_line = 'LogLevel.WARNING Third WARNING Log of T1' @@ -150,8 +221,8 @@ def test_transformer_stats_logger_linear_and_nested(capfd): q = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H.on_each(*q), cirq.CNOT(*q)) context = cirq.TransformerContext(logger=cirq.TransformerLogger()) - circuit = t1(circuit, context) - circuit = t3(circuit, context) + circuit = t1(circuit, context=context) + circuit = t3(circuit, context=context) context.logger.show(LogLevel.ALL) out, _ = capfd.readouterr() assert (