From 90e3d08a75d9bc4185703672954082b4520dbff0 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 6 Apr 2023 02:00:39 +0530 Subject: [PATCH 1/2] Update cirq.contrib.svg to escape < and > characters --- cirq-core/cirq/contrib/svg/svg.py | 4 ++++ cirq-core/cirq/contrib/svg/svg_test.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/cirq-core/cirq/contrib/svg/svg.py b/cirq-core/cirq/contrib/svg/svg.py index e79b3b438c3..b08c8a71edd 100644 --- a/cirq-core/cirq/contrib/svg/svg.py +++ b/cirq-core/cirq/contrib/svg/svg.py @@ -23,6 +23,10 @@ def fixup_text(text: str): if '[cirq.VirtualTag()]' in text: # https://github.com/quantumlib/Cirq/issues/2905 return text.replace('[cirq.VirtualTag()]', '') + if '<' in text: + return text.replace('<', '<') + if '>' in text: + return text.replace('>', '>') return text diff --git a/cirq-core/cirq/contrib/svg/svg_test.py b/cirq-core/cirq/contrib/svg/svg_test.py index 5f3d38cf859..1cdf3af64ac 100644 --- a/cirq-core/cirq/contrib/svg/svg_test.py +++ b/cirq-core/cirq/contrib/svg/svg_test.py @@ -57,3 +57,22 @@ def test_empty_moments(): svg_2 = circuit_to_svg(cirq.Circuit(cirq.Moment())) assert '' in svg_2 + + +def test_gate_with_less_greater_str(): + class CustomGate(cirq.Gate): + def _num_qubits_(self) -> int: + return 4 + + def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: + return cirq.CircuitDiagramInfo(wire_symbols=['c', '>=d']) + + circuit = cirq.Circuit(CustomGate().on(*cirq.LineQubit.range(4))) + svg = circuit_to_svg(circuit) + import IPython.display + + _ = IPython.display.SVG(svg) + assert '<a' in svg + assert '<=b' in svg + assert '>c' in svg + assert '>=d' in svg From 5b28795413fb3ec8e80867e3e374841ba9859218 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 6 Apr 2023 03:01:55 +0530 Subject: [PATCH 2/2] Add type: ignore for mypy --- cirq-core/cirq/contrib/svg/svg_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/contrib/svg/svg_test.py b/cirq-core/cirq/contrib/svg/svg_test.py index 1cdf3af64ac..ad9bfd472ca 100644 --- a/cirq-core/cirq/contrib/svg/svg_test.py +++ b/cirq-core/cirq/contrib/svg/svg_test.py @@ -69,7 +69,7 @@ def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: circuit = cirq.Circuit(CustomGate().on(*cirq.LineQubit.range(4))) svg = circuit_to_svg(circuit) - import IPython.display + import IPython.display # type: ignore _ = IPython.display.SVG(svg) assert '<a' in svg