Skip to content

Commit

Permalink
Add benchmarks for transformer primitives and json serialization (#5957)
Browse files Browse the repository at this point in the history
* Add benchmarks for transformer primitives and json serialization

* Track json size upto 3 decimal places
  • Loading branch information
tanujkhattar authored Dec 19, 2022
1 parent 30181d4 commit 7019adc
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
53 changes: 53 additions & 0 deletions benchmarks/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq


def _human_size(num_bytes: int, mod: int = 0, units=(' bytes', 'KB', 'MB', 'GB', 'TB', 'PB')):
"""Returns a human readable string representation of bytes"""
return (
f'{num_bytes}.{mod}{units[0]}'
if num_bytes < 1024
else _human_size(num_bytes >> 10, num_bytes % 1024, units[1:])
)


class SerializeLargeExpandedCircuits:
param_names = ["num_qubits", "num_moments"]
params = ([100, 500, 1000], [100, 1000, 4000])
timeout = 600 # Change timeout to 2 minutes instead of default 60 seconds.

def setup(self, num_qubits: int, num_moments: int):
qubits = cirq.LineQubit.range(num_qubits)
one_q_x_moment = cirq.Moment(cirq.X(q) for q in qubits[::2])
one_q_y_moment = cirq.Moment(cirq.Y(q) for q in qubits[1::2])
two_q_cx_moment = cirq.Moment(
cirq.CNOT(q1, q2) for q1, q2 in zip(qubits[::4], qubits[1::4])
)
two_q_cz_moment = cirq.Moment(cirq.CZ(q1, q2) for q1, q2 in zip(qubits[::4], qubits[1::4]))
measurement_moment = cirq.Moment(cirq.measure_each(*qubits))
self.circuit = cirq.Circuit(
[one_q_x_moment, two_q_cx_moment, one_q_y_moment, two_q_cz_moment, measurement_moment]
* (num_moments // 5)
)

def time_json_serialization(self, *_):
_ = cirq.to_json(self.circuit)

def time_json_serialization_gzip(self, *_):
_ = cirq.to_json_gzip(self.circuit)

def track_json_serialization_gzip_size(self, *_):
return _human_size(len(cirq.to_json_gzip(self.circuit)))
65 changes: 65 additions & 0 deletions benchmarks/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq


class MapLargeExpandedCircuit:
param_names = ["num_qubits", "num_moments"]
params = ([100, 500, 1000], [100, 1000, 4000])
timeout = 600 # Change timeout to 2 minutes instead of default 60 seconds.

def setup(self, num_qubits: int, num_moments: int):
qubits = cirq.LineQubit.range(num_qubits)
one_q_x_moment = cirq.Moment(cirq.X(q) for q in qubits[::2])
one_q_y_moment = cirq.Moment(cirq.Y(q) for q in qubits[1::2])
two_q_cx_moment = cirq.Moment(
cirq.CNOT(q1, q2) for q1, q2 in zip(qubits[::4], qubits[1::4])
)
two_q_cz_moment = cirq.Moment(cirq.CZ(q1, q2) for q1, q2 in zip(qubits[::4], qubits[1::4]))
self.circuit = cirq.Circuit(
[one_q_x_moment, two_q_cx_moment, one_q_y_moment, two_q_cz_moment] * (num_moments // 4)
)

def time_map_moments(self, num_qubits: int, _):
all_qubits = cirq.LineQubit.range(num_qubits)

def map_func(m: cirq.Moment, _) -> cirq.Moment:
new_ops = [op.with_tags("old op") for op in m.operations]
new_ops += [
cirq.Z(q).with_tags("new op")
for q in all_qubits
if not m.operates_on_single_qubit(q)
]
return cirq.Moment(new_ops)

_ = cirq.map_moments(circuit=self.circuit, map_func=map_func)

def time_map_operations_apply_tag(self, *_):
def map_func(op: cirq.Operation, _) -> cirq.Operation:
return op.with_tags("old op")

_ = cirq.map_operations(circuit=self.circuit, map_func=map_func)

def time_map_operations_to_optree(self, *_):
def map_func(op: cirq.Operation, _) -> cirq.OP_TREE:
return [op, op]

_ = cirq.map_operations(circuit=self.circuit, map_func=map_func)

def time_map_operations_to_optree_and_unroll(self, *_):
def map_func(op: cirq.Operation, _) -> cirq.OP_TREE:
return [op, op]

_ = cirq.map_operations_and_unroll(circuit=self.circuit, map_func=map_func)

0 comments on commit 7019adc

Please sign in to comment.