Skip to content

Commit 3a930a6

Browse files
committed
Address comments
1 parent 998464d commit 3a930a6

File tree

3 files changed

+137
-120
lines changed

3 files changed

+137
-120
lines changed

cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge.py

Lines changed: 86 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -19,69 +19,67 @@
1919
from typing import cast
2020

2121
import numpy as np
22+
from attrs import field, frozen
2223

2324
from cirq import circuits, ops
2425
from cirq.transformers.gauge_compiling.multi_moment_gauge_compiling import (
2526
MultiMomentGaugeTransformer,
2627
)
2728

29+
_PAULIS: np.ndarray = np.array((ops.I, ops.X, ops.Y, ops.Z), dtype=object)
30+
_COMMUTING_GATES = {ops.I, ops.Z} # I,Z Commute with ZPowGate and CZPowGate; X,Y anti-commute.
2831

32+
33+
def _merge_pauliandzpow(left: _PauliAndZPow, right: _PauliAndZPow) -> _PauliAndZPow:
34+
# 1. Commute left.zpow and right.pauli:
35+
# ─left.pauli─left.zpow─right.pauli─right.zpow─
36+
# ==> ─left.pauli─right.pauli─(+/-left.zpow─right.zpow)─
37+
if right.pauli in _COMMUTING_GATES:
38+
new_zpow_exp = left.zpow.exponent + right.zpow.exponent
39+
else:
40+
new_zpow_exp = -left.zpow.exponent + right.zpow.exponent
41+
42+
# 2. Merge left.pauli and right.pauli
43+
new_pauli = left.pauli
44+
if right.pauli is not ops.I:
45+
if new_pauli is ops.I:
46+
new_pauli = right.pauli
47+
else:
48+
# left.pauli * right.pauli
49+
new_pauli = cast(ops.Pauli, new_pauli).phased_pauli_product(
50+
cast(ops.Pauli, right.pauli)
51+
)[1]
52+
53+
return _PauliAndZPow(pauli=new_pauli, zpow=ops.ZPowGate(exponent=new_zpow_exp))
54+
55+
56+
@frozen
2957
class _PauliAndZPow:
30-
"""In pulling through, one qubit gate can be represented by a Pauli and an Rz gate.
31-
The order is --Pauli--ZPowGate--.
58+
"""A gate represented by a Pauli followed by a ZPowGate.
59+
60+
The order is ─Pauli──ZPowGate─.
61+
62+
Attributes:
63+
pauli: The Pauli gate.
64+
zpow: The ZPowGate.
3265
"""
3366

3467
pauli: ops.Pauli | ops.IdentityGate = ops.I
3568
zpow: ops.ZPowGate = ops.ZPowGate(exponent=0)
3669

37-
commuting_gates = {ops.I, ops.Z} # I,Z Commute with ZPowGate and CZPowGate; X,Y anti-commute.
38-
39-
def __init__(
40-
self,
41-
pauli: ops.Pauli | ops.IdentityGate = ops.I,
42-
zpow: ops.ZPowGate = ops.ZPowGate(exponent=0),
43-
) -> None:
44-
self.pauli = pauli
45-
self.zpow = zpow
46-
47-
def _merge_left_zpow(self, left: ops.ZPowGate):
48-
"""Merges ZPowGate from left."""
49-
if self.pauli in self.commuting_gates:
50-
self.zpow = ops.ZPowGate(exponent=left.exponent + self.zpow.exponent)
51-
else:
52-
self.zpow = ops.ZPowGate(exponent=-left.exponent + self.zpow.exponent)
70+
def merge_left(self, left: _PauliAndZPow) -> _PauliAndZPow:
71+
"""Merges another `_PauliAndZPow` from the left.
5372
54-
def _merge_right_zpow(self, right: ops.ZPowGate):
55-
"""Merges ZPowGate from right."""
56-
self.zpow = ops.ZPowGate(exponent=right.exponent + self.zpow.exponent)
73+
Calculates `─left─self─` and returns a new `_PauliAndZPow` instance.
74+
"""
75+
return _merge_pauliandzpow(left, self)
5776

58-
def _merge_left_pauli(self, left: ops.Pauli):
59-
"""Merges --left_pauli--self--."""
60-
if self.pauli == ops.I:
61-
self.pauli = left
62-
else:
63-
self.pauli = left.phased_pauli_product(self.pauli)[1]
77+
def merge_right(self, right: _PauliAndZPow) -> _PauliAndZPow:
78+
"""Merges another `_PauliAndZPow` from the right.
6479
65-
def _merge_right_pauli(self, right: ops.Pauli):
66-
"""Merges --self--right_pauli--."""
67-
if self.pauli == ops.I:
68-
self.pauli = right
69-
else:
70-
self.pauli = right.phased_pauli_product(self.pauli)[1]
71-
if right not in self.commuting_gates:
72-
self.zpow = ops.ZPowGate(exponent=-self.zpow.exponent)
73-
74-
def merge_left(self, left: _PauliAndZPow) -> None:
75-
"""Inplace merge other from left."""
76-
self._merge_left_zpow(left.zpow)
77-
if left.pauli != ops.I:
78-
self._merge_left_pauli(cast(ops.Pauli, left.pauli))
79-
80-
def merge_right(self, right: _PauliAndZPow) -> None:
81-
"""Inplace merge other from right."""
82-
if right.pauli != ops.I:
83-
self._merge_right_pauli(cast(ops.Pauli, right.pauli))
84-
self._merge_right_zpow(right.zpow)
80+
Calculates `─self─right─` and returns a new `_PauliAndZPow` instance.
81+
"""
82+
return _merge_pauliandzpow(self, right)
8583

8684
def after_cphase(
8785
self, cphase: ops.CZPowGate
@@ -92,8 +90,8 @@ def after_cphase(
9290
A tuple of
9391
(updated cphase gate, pull_through of this qubit, pull_through of the other qubit).
9492
"""
95-
if self.pauli in self.commuting_gates:
96-
return cphase, self, _PauliAndZPow()
93+
if self.pauli in _COMMUTING_GATES:
94+
return cphase, _PauliAndZPow(self.pauli, self.zpow), _PauliAndZPow()
9795
else:
9896
# Taking self.pauli==X gate as an example:
9997
# 0: ─X─Z^t──@────── 0: ─X──@─────Z^t─ 0: ─@──────X──Z^t──
@@ -103,29 +101,29 @@ def after_cphase(
103101
# add an extra Rz rotation on the other qubit.
104102
return (
105103
cast(ops.CZPowGate, cphase**-1),
106-
self,
104+
_PauliAndZPow(self.pauli, self.zpow),
107105
_PauliAndZPow(zpow=ops.ZPowGate(exponent=cphase.exponent)),
108106
)
109107

110108
def after_pauli(self, pauli: ops.Pauli | ops.IdentityGate) -> _PauliAndZPow:
111109
"""Calculates ─self─pauli─ ==> ─pauli─output─."""
112-
if pauli in self.commuting_gates:
110+
if pauli in _COMMUTING_GATES:
113111
return _PauliAndZPow(self.pauli, self.zpow)
114112
else:
115113
return _PauliAndZPow(self.pauli, ops.ZPowGate(exponent=-self.zpow.exponent))
116114

117115
def after_zpow(self, zpow: ops.ZPowGate) -> tuple[ops.ZPowGate, _PauliAndZPow]:
118-
"""Calculates ─self─zpow─ ==> ─zpow'─output─."""
119-
if self.pauli in self.commuting_gates:
120-
return zpow, self
116+
"""Calculates ─self─zpow─ ==> ─+/-zpow─output─."""
117+
if self.pauli in _COMMUTING_GATES:
118+
return zpow, _PauliAndZPow(self.pauli, self.zpow)
121119
else:
122120
return ops.ZPowGate(exponent=-zpow.exponent), self
123121

124122
def __str__(self) -> str:
125123
return f"─{self.pauli}──{self.zpow}─"
126124

127125
def to_single_qubit_gate(self) -> ops.PhasedXZGate | ops.ZPowGate | ops.IdentityGate:
128-
"""Converts the _PhasedXYAndRz to a single-qubit gate."""
126+
"""Converts the _PauliAndZPow to a single-qubit gate."""
129127
exp = self.zpow.exponent
130128
match self.pauli:
131129
case ops.I:
@@ -137,23 +135,17 @@ def to_single_qubit_gate(self) -> ops.PhasedXZGate | ops.ZPowGate | ops.Identity
137135
case ops.Y:
138136
return ops.PhasedXZGate(x_exponent=1, z_exponent=exp - 1, axis_phase_exponent=0)
139137
case _: # ops.Z
140-
if (exp + 1) % 2 == 0:
141-
return ops.I
142138
return ops.ZPowGate(exponent=1 + exp)
143139

144140

145141
def _pull_through_single_cphase(
146142
cphase: ops.CZPowGate, input0: _PauliAndZPow, input1: _PauliAndZPow
147143
) -> tuple[ops.CZPowGate, _PauliAndZPow, _PauliAndZPow]:
148144
"""Pulls input0 and input1 through a CZPowGate.
149-
Input:
150-
0: ─(input0)─@─────
151-
152-
1: ─(input1)─@^exp─
153-
Output:
154-
0: ─@────────(output0)─
155-
156-
1: ─@^+/-exp─(output1)─
145+
Input: Output:
146+
0: ─(input0)─@───── 0: ─@────────(output0)─
147+
│ ==> │
148+
1: ─(input1)─@^exp─ 1: ─@^+/-exp─(output1)─
157149
"""
158150

159151
# Step 1; pull input0 through CZPowGate.
@@ -167,37 +159,45 @@ def _pull_through_single_cphase(
167159
# ==> │ ==> │
168160
# 1: ─@^+/-exp───pulled1────output1─ 1: ─@^+/-exp─output1─
169161
output_cphase, pulled1, pulled0 = input1.after_cphase(output_cphase)
170-
output0.merge_left(pulled0)
171-
output1.merge_left(pulled1)
162+
output0 = output0.merge_left(pulled0)
163+
output1 = output1.merge_left(pulled1)
172164

173165
return output_cphase, output0, output1
174166

175167

176168
_TARGET_GATESET: ops.Gateset = ops.Gateset(ops.CZPowGate)
177-
_SUPPORTED_GATESET: ops.Gateset = ops.Gateset(ops.Pauli, ops.IdentityGate, ops.Rz, ops.ZPowGate)
169+
_SUPPORTED_GATESET: ops.Gateset = ops.Gateset(ops.Pauli, ops.IdentityGate, ops.ZPowGate)
178170

179171

172+
@frozen
180173
class CPhaseGaugeTransformerMM(MultiMomentGaugeTransformer):
174+
"""A gauge transformer for the cphase gate."""
181175

182-
def __init__(self, supported_gates=_SUPPORTED_GATESET):
183-
super().__init__(target=_TARGET_GATESET, supported_gates=supported_gates)
176+
target: ops.GateFamily | ops.Gateset = field(default=_TARGET_GATESET, init=False)
177+
supported_gates: ops.GateFamily | ops.Gateset = field(default=_SUPPORTED_GATESET)
184178

185179
def sample_left_moment(
186-
self, active_qubits: frozenset[ops.Qid], rng: np.random.Generator = np.random.default_rng()
180+
self, active_qubits: frozenset[ops.Qid], rng: np.random.Generator
187181
) -> circuits.Moment:
188-
return circuits.Moment(
189-
[
190-
rng.choice(
191-
np.array([ops.I, ops.X, ops.Y, ops.Z], dtype=ops.Gate),
192-
p=[0.25, 0.25, 0.25, 0.25],
193-
).on(q)
194-
for q in active_qubits
195-
]
196-
)
182+
"""Samples a random single-qubit moment to be inserted before the target block."""
183+
return circuits.Moment([cast(ops.Gate, rng.choice(_PAULIS)).on(q) for q in active_qubits])
197184

198-
def gauge_on_moments(self, moments_to_gauge) -> list[circuits.Moment]:
185+
def gauge_on_moments(
186+
self,
187+
moments_to_gauge: list[circuits.Moment],
188+
prng: np.random.Generator = np.random.default_rng(),
189+
) -> list[circuits.Moment]:
190+
"""Gauges a block of moments that contains at least a cphase gate in each of the moment.
191+
192+
Args:
193+
moments_to_gauge: A list of moments to be gauged.
194+
prng: A pseudorandom number generator.
195+
196+
Returns:
197+
A list of moments after gauging.
198+
"""
199199
active_qubits = circuits.Circuit.from_moments(*moments_to_gauge).all_qubits()
200-
left_moment = self.sample_left_moment(active_qubits)
200+
left_moment = self.sample_left_moment(active_qubits, prng)
201201
pulled: dict[ops.Qid, _PauliAndZPow] = {
202202
op.qubits[0]: _PauliAndZPow(pauli=cast(ops.Pauli | ops.IdentityGate, op.gate))
203203
for op in left_moment
@@ -233,10 +233,10 @@ def gauge_on_moments(self, moments_to_gauge) -> list[circuits.Moment]:
233233
ops_at_updated_moment.append(new_zpow.on(q))
234234
case _:
235235
raise ValueError(f"Gate type {type(op.gate)} is not supported.")
236-
# Keep the other ops of prev
237-
for q, gate in prev.items():
238-
if q not in pulled:
239-
pulled[q] = gate
236+
# Keep the other ops of prev
237+
for q, gate in prev.items():
238+
if q not in pulled:
239+
pulled[q] = gate
240240
ret.append(circuits.Moment(ops_at_updated_moment))
241241
last_moment = circuits.Moment(
242242
[gate.to_single_qubit_gate().on(q) for q, gate in pulled.items()]

cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge_test.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
from copy import deepcopy
18-
from unittest.mock import patch
1918

2019
import numpy as np
2120
import pytest
@@ -42,17 +41,27 @@ def test_gauge_on_single_cphase():
4241
q0, q1 = cirq.LineQubit.range(2)
4342

4443
input_circuit = cirq.Circuit(cirq.Moment(cirq.CZ(q0, q1) ** 0.2))
45-
cphase_transformer = CPhaseGaugeTransformerMM()
44+
45+
class _TestCPhaseGaugeTransformerMM(CPhaseGaugeTransformerMM):
46+
def sample_left_moment(
47+
self,
48+
active_qubits: frozenset[cirq.Qid],
49+
rng: np.random.Generator = np.random.default_rng(),
50+
) -> cirq.Moment:
51+
return cirq.Moment(g1(q0), g2(q1))
4652

4753
for g1 in [X, Y, Z, I]:
4854
for g2 in [X, Y, Z, I]: # Test with all possible samples of the left moment.
49-
with patch.object(
50-
cphase_transformer, "sample_left_moment", return_value=[g1(q0), g2(q1)]
51-
):
52-
output_circuit = cphase_transformer(input_circuit)
53-
cirq.testing.assert_circuits_have_same_unitary_given_final_permutation(
54-
input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()}
55-
)
55+
cphase_transformer = _TestCPhaseGaugeTransformerMM()
56+
output_circuit = cphase_transformer(input_circuit)
57+
import logging
58+
59+
logging.info(f"\n{input_circuit}")
60+
logging.info(f"g1: {g1}, g2: {g2}")
61+
logging.info(f"\n{output_circuit}")
62+
cirq.testing.assert_circuits_have_same_unitary_given_final_permutation(
63+
input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()}
64+
)
5665

5766

5867
def test_gauge_on_cz_moments():
@@ -90,6 +99,10 @@ def test_gauge_on_cz_moments():
9099
transformer = CPhaseGaugeTransformerMM()
91100

92101
output_circuit = transformer(input_circuit)
102+
import logging
103+
104+
logging.info(f"\n{input_circuit}")
105+
logging.info(f"\n{output_circuit}")
93106
cirq.testing.assert_circuits_have_same_unitary_given_final_permutation(
94107
input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()}
95108
)
@@ -203,10 +216,8 @@ def test_pauli_and_phxz_util_gate_merges():
203216
for right_pauli in [X, Y, Z, I]:
204217
left = _PauliAndZPow(pauli=left_pauli, zpow=ZPowGate(exponent=0.2))
205218
right = _PauliAndZPow(pauli=right_pauli, zpow=ZPowGate(exponent=0.6))
206-
merge1 = deepcopy(right)
207-
merge1.merge_left(left)
208-
merge2 = deepcopy(left)
209-
merge2.merge_right(right)
219+
merge1 = right.merge_left(left)
220+
merge2 = left.merge_right(right)
210221

211222
assert np.allclose(
212223
cirq.unitary(merge1.to_single_qubit_gate()),
@@ -240,3 +251,10 @@ def test_deep_not_supported():
240251
with pytest.raises(ValueError, match="GaugeTransformer cannot be used with deep=True"):
241252
t = CPhaseGaugeTransformerMM()
242253
t(cirq.Circuit(), context=cirq.TransformerContext(deep=True))
254+
255+
256+
def test_gate_type_not_supported():
257+
with pytest.raises(ValueError, match="Gate type .* is not supported."):
258+
t = CPhaseGaugeTransformerMM()
259+
q0, q1, q2 = cirq.LineQubit.range(3)
260+
t.gauge_on_moments([cirq.Moment(cirq.CZ(q0, q1), cirq.measure(q2))])

0 commit comments

Comments
 (0)