Skip to content

Commit

Permalink
Use native implementation for adjoints in (control) operations (#1063)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [ ] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      [`tests`](../tests) directory!

- [ ] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [ ] Ensure that the test suite passes, by running `make test`.

- [ ] Add a new entry to the `.github/CHANGELOG.md` file, summarizing
the
      change, and including a link back to the PR.

- [ ] Ensure that code is properly formatted by running `make format`. 

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
Currently in `_apply_lightning`, we check for whether an operation is
`Adjoint`, then we apply the operation with an adjoint (`inv_param`)
flag. However, in cases where we have:
- adjoint(s) within control - e.g. `control(adjoint(gate))`
- control within adjoint - e.g. `adjoint(control(gate))`,  

these are all applied as matrices.

**Description of the Change:**
`_apply_lightning` and `_apply_lightning_controlled` checks for adjoint
in an operation, and if it's an adjoint it applies the base operation
with an adjoint flag, instead of treating everything as a matrix.

So in effect we have:
`control(adjoint(gate))` -> `control(gate with adjoint)`
`adjoint(control(gate))` -> `control(gate with adjoint)`

which are implemented natively in C++ (if the `gate` is supported),
yielding better performance

**Benefits:**
adjoint(ctrl()) will see the most speedup, especially with large number
of control wires, since we use native control operation which contains
less wires than the equivalent matrix, and needs to be operated on less
wires. adjoint(ctrl()) will see some speed-up, since we are now able to
use the native named gate implementation in C++.


Example timing improvement: 
4 ctrl wires

LQ:

| LQ, 25 qubits, 500 repeats          | master  | branch |
|-------------------------------------|--------|-------|
| ctrl(adjoint(IsingXX))              |   9.6s |   6.0s |
| ctrl(adjoint(DoubleExcitationPlus)) | 27.6s          | 9.2s   |

| LQ, 25 qubits, 100 repeats          | master | branch |
|-------------------------------------|------------------|--------|
| adjoint(ctrl(IsingXX))               |   267s | 2.9s|
| adjoint(ctrl(DoubleExcitationPlus)) | 1002s|  3.6s   |


Baseline:
| LQ, 25 qubits, 500 repeats          | master | branch |
|-------------------------------------|--------|--------|
| ctrl(IsingXX)       |  6.1s |6.1s  |
| ctrl(DoubleExcitationPlus)|  9.1s | 9.1s   |

LG:

| LG, 31 qubits, 1000 repeats          | master | branch |
|-------------------------------------|--------|--------|
| ctrl(adjoint(IsingXX))              | 4.9s |  4.8s |
| ctrl(adjoint(DoubleExcitationPlus)) |  5.0s |  4.9s    |

| LG, 31 qubits, 1000 repeats             | master | branch |
|-------------------------------------|-------------|--------|
| adjoint(ctrl(IsingXX))               |  119s | 4.8s|
| adjoint(ctrl(DoubleExcitationPlus)) |  208s |  4.9s  |


Baseline:
| LG, 31 qubits, 1000 repeats              | master | branch |
|-------------------------------------|-------------------|--------|
| ctrl(IsingXX)       |  4.8s | 4.8s   |
| ctrl(DoubleExcitationPlus)|  4.9s | 4.9s   |



LK:

| LK, 25 qubits, 500 repeats          | master | branch |
|-------------------------------------|--------|--------|
| ctrl(adjoint(IsingXX))              |  8.5s |5.7s  |
| ctrl(adjoint(DoubleExcitationPlus)) | 24.5s | 7.6s   |

| LK, 25 qubits, 100 repeats          | master | branch |
|-------------------------------------|-----|--------|
| adjoint(ctrl(IsingXX))               |    235s | 2.6s |
| adjoint(ctrl(DoubleExcitationPlus)) | 867s | 2.9s  |


Baseline:
| LK, 25 qubits, 500 repeats          | master |branch |
|-------------------------------------|-------------|--------|
| ctrl(IsingXX)       |  5.6s |5.8s  |
| ctrl(DoubleExcitationPlus)|  7.7s | 7.6 s  |

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-79430]

---------

Co-authored-by: ringo-but-quantum <github-ringo-but-quantum@xanadu.ai>
Co-authored-by: Christina Lee <chrissie.c.l@gmail.com>
Co-authored-by: Amintor Dusko <87949283+AmintorDusko@users.noreply.github.com>
  • Loading branch information
4 people authored Feb 26, 2025
1 parent aad3e59 commit 98a9292
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 70 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

### Improvements

* Use native C++ kernels for controlled-adjoint and adjoint-controlled of supported operations.
[(#1063)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1063)

* In Lightning-Tensor, allow `qml.MPSPrep` to accept an MPS with `len(MPS) = n_wires-1`.
[(#1064)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1064)

Expand Down
71 changes: 51 additions & 20 deletions pennylane_lightning/core/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,18 @@ def serialize_ops(self, tape: QuantumTape, wires_map: dict = None) -> Tuple[
uses_stateprep = False

def get_wires(operation, single_op):
if isinstance(operation, qml.ops.op_math.Controlled) and not isinstance(
operation,
# Serialize adjoint(op) and adjoint(ctrl(op))
if isinstance(operation, qml.ops.op_math.Adjoint):
inverse = True
op_base = operation.base
single_op_base = single_op.base
else:
inverse = False
op_base = operation
single_op_base = single_op

if isinstance(op_base, qml.ops.op_math.Controlled) and not isinstance(
op_base,
(
qml.CNOT,
qml.CY,
Expand All @@ -457,19 +467,41 @@ def get_wires(operation, single_op):
qml.CSWAP,
),
):
name = operation.base.name
wires_list = list(operation.target_wires)
controlled_wires_list = list(operation.control_wires)
control_values_list = operation.control_values
wires_list = list(op_base.target_wires)
controlled_wires_list = list(op_base.control_wires)
control_values_list = op_base.control_values
# Serialize ctrl(adjoint(op))
if isinstance(op_base.base, qml.ops.op_math.Adjoint):
ctrl_adjoint = True
name = op_base.base.base.name
else:
ctrl_adjoint = False
name = op_base.base.name

# Inside the controlled operation, if the base operation (of the adjoint)
# is supported natively, we apply the the base operation and invert the
# inverse flag; otherwise we apply the QubitUnitary of a matrix which
# contains the inverse and leave the inverse flag as is.
if not hasattr(self.sv_type, name):
single_op = QubitUnitary(matrix(single_op.base), single_op.base.wires)
name = single_op.name
single_op_base = QubitUnitary(
matrix(single_op_base.base), single_op_base.base.wires
)
name = single_op_base.name
else:
inverse ^= ctrl_adjoint
else:
name = single_op.name
wires_list = single_op.wires.tolist()
name = single_op_base.name
wires_list = single_op_base.wires.tolist()
controlled_wires_list = []
control_values_list = []
return single_op, name, list(wires_list), controlled_wires_list, control_values_list
return (
single_op_base,
name,
inverse,
list(wires_list),
controlled_wires_list,
control_values_list,
)

for operation in tape.operations:
if isinstance(operation, (BasisState, StatePrep)):
Expand All @@ -480,30 +512,29 @@ def get_wires(operation, single_op):
else:
op_list = [operation]

inverse = isinstance(operation, qml.ops.op_math.Adjoint)

for single_op in op_list:
(
single_op,
single_op_base,
name,
inverse,
wires_list,
controlled_wires_list,
controlled_values_list,
) = get_wires(operation, single_op)
inverses.append(inverse)
names.append(single_op.base.name if inverse else name)
names.append(name)
# QubitUnitary is a special case, it has a parameter which is not differentiable.
# We thus pass a dummy 0.0 parameter which will not be referenced
if isinstance(single_op, qml.QubitUnitary):
if isinstance(single_op_base, qml.QubitUnitary):
params.append([0.0])
mats.append(matrix(single_op))
mats.append(matrix(single_op_base))
else:
if hasattr(self.sv_type, single_op.base.name if inverse else name):
params.append(single_op.parameters)
if hasattr(self.sv_type, name):
params.append(single_op_base.parameters)
mats.append([])
else:
params.append([])
mats.append(matrix(single_op))
mats.append(matrix(single_op_base))

controlled_values.append(controlled_values_list)
controlled_wires.append(
Expand Down
4 changes: 3 additions & 1 deletion pennylane_lightning/core/_state_vector_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
from pennylane import BasisState, StatePrep
from pennylane.measurements import MidMeasureMP
from pennylane.ops import Controlled
from pennylane.tape import QuantumScript
from pennylane.wires import Wires

Expand Down Expand Up @@ -131,11 +132,12 @@ def _apply_basis_state(self, state, wires):
self._qubit_state.setBasisState(list(state), list(wires))

@abstractmethod
def _apply_lightning_controlled(self, operation):
def _apply_lightning_controlled(self, operation: Controlled, adjoint: bool):
"""Apply an arbitrary controlled operation to the state tensor.
Args:
operation (~pennylane.operation.Operation): controlled operation to apply
adjoint (bool): Apply the adjoint of the operation if True
Returns:
None
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.41.0-dev24"
__version__ = "0.41.0-dev25"
34 changes: 20 additions & 14 deletions pennylane_lightning/lightning_gpu/_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,34 +229,39 @@ def _apply_state_vector(self, state, device_wires, use_async: bool = False):
# set the state vector on GPU with provided state and their corresponding wires
self._qubit_state.setStateVector(state, list(device_wires), use_async)

def _apply_lightning_controlled(self, operation):
def _apply_lightning_controlled(self, operation, adjoint):
"""Apply an arbitrary controlled operation to the state tensor.
Args:
operation (~pennylane.operation.Operation): controlled operation to apply
adjoint (bool): Apply the adjoint of the operation if True
Returns:
None
"""
state = self.state_vector

basename = operation.base.name
method = getattr(state, f"{basename}", None)
if isinstance(operation.base, Adjoint):
base_operation = operation.base.base
adjoint = not adjoint
else:
base_operation = operation.base

method = getattr(state, f"{base_operation.name}", None)
control_wires = list(operation.control_wires)
control_values = operation.control_values
target_wires = list(operation.target_wires)
if method: # apply n-controlled specialized gate
inv = False
param = operation.parameters
method(control_wires, control_values, target_wires, inv, param)
method(control_wires, control_values, target_wires, adjoint, param)
else: # apply gate as an n-controlled matrix
method = getattr(state, "applyControlledMatrix")
method(
qml.matrix(operation.base),
qml.matrix(base_operation),
control_wires,
control_values,
target_wires,
False,
adjoint,
)

def _apply_lightning_midmeasure(
Expand Down Expand Up @@ -300,6 +305,7 @@ def _apply_lightning(
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
keep the same number of shots. Default is ``None``.
Returns:
None
"""
Expand All @@ -311,11 +317,12 @@ def _apply_lightning(
if isinstance(operation, qml.Identity):
continue
if isinstance(operation, Adjoint):
name = operation.base.name
op_adjoint_base = operation.base
invert_param = True
else:
name = operation.name
op_adjoint_base = operation
invert_param = False
name = op_adjoint_base.name
method = getattr(state, name, None)
wires = list(operation.wires)

Expand All @@ -330,13 +337,13 @@ def _apply_lightning(
param = operation.parameters
method(wires, invert_param, param)
elif (
isinstance(operation, qml.ops.Controlled) and not self._mpi_handler.use_mpi
isinstance(op_adjoint_base, qml.ops.Controlled) and not self._mpi_handler.use_mpi
): # MPI backend does not have native controlled gates support
self._apply_lightning_controlled(operation)
self._apply_lightning_controlled(op_adjoint_base, invert_param)
elif (
self._mpi_handler.use_mpi
and isinstance(operation, qml.ops.Controlled)
and isinstance(operation.base, qml.GlobalPhase)
and isinstance(op_adjoint_base, qml.ops.Controlled)
and isinstance(op_adjoint_base.base, qml.GlobalPhase)
):
# TODO: To move this line to the _apply_lightning_controlled method once the MPI backend supports controlled gates natively
raise DeviceError(
Expand All @@ -348,7 +355,6 @@ def _apply_lightning(
except AttributeError: # pragma: no cover
# To support older versions of PL
mat = operation.matrix

r_dtype = np.float32 if self.dtype == np.complex64 else np.float64
param = (
[[r_dtype(operation.hash)]]
Expand Down
33 changes: 19 additions & 14 deletions pennylane_lightning/lightning_kokkos/_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,34 +181,39 @@ def _apply_state_vector(self, state, device_wires: Wires):
# This operate on device
self._qubit_state.setStateVector(state, list(device_wires))

def _apply_lightning_controlled(self, operation):
def _apply_lightning_controlled(self, operation, adjoint):
"""Apply an arbitrary controlled operation to the state tensor.
Args:
operation (~pennylane.operation.Operation): controlled operation to apply
adjoint (bool): Apply the adjoint of the operation if True
Returns:
None
"""
state = self.state_vector

basename = operation.base.name
method = getattr(state, f"{basename}", None)
if isinstance(operation.base, Adjoint):
base_operation = operation.base.base
adjoint = not adjoint
else:
base_operation = operation.base

method = getattr(state, f"{base_operation.name}", None)
control_wires = list(operation.control_wires)
control_values = operation.control_values
target_wires = list(operation.target_wires)
inv = False # TODO: update to use recursive _apply_lightning to handle nested adjoint/ctrl
if method is not None: # apply n-controlled specialized gate
param = operation.parameters
method(control_wires, control_values, target_wires, inv, param)
method(control_wires, control_values, target_wires, adjoint, param)
else: # apply gate as an n-controlled matrix
method = getattr(state, "applyControlledMatrix")
method(
qml.matrix(operation.base),
qml.matrix(base_operation),
control_wires,
control_values,
target_wires,
inv,
adjoint,
)

def _apply_lightning_midmeasure(
Expand Down Expand Up @@ -262,11 +267,12 @@ def _apply_lightning(
if isinstance(operation, qml.Identity):
continue
if isinstance(operation, Adjoint):
name = operation.base.name
op_adjoint_base = operation.base
invert_param = True
else:
name = operation.name
op_adjoint_base = operation
invert_param = False
name = op_adjoint_base.name
method = getattr(state, name, None)
wires = list(operation.wires)

Expand All @@ -279,18 +285,17 @@ def _apply_lightning(
)
elif isinstance(operation, qml.PauliRot):
method = getattr(state, "applyPauliRot")
# pylint: disable=protected-access
paulis = operation._hyperparameters[
paulis = operation._hyperparameters[ # pylint: disable=protected-access
"pauli_word"
] # pylint: disable=protected-access
]
wires = [i for i, w in zip(wires, paulis) if w != "I"]
word = "".join(p for p in paulis if p != "I")
method(wires, invert_param, operation.parameters, word)
elif method is not None: # apply specialized gate
param = operation.parameters
method(wires, invert_param, param)
elif isinstance(operation, qml.ops.Controlled): # apply n-controlled gate
self._apply_lightning_controlled(operation)
elif isinstance(op_adjoint_base, qml.ops.Controlled): # apply n-controlled gate
self._apply_lightning_controlled(op_adjoint_base, invert_param)
else: # apply gate as a matrix
# Inverse can be set to False since qml.matrix(operation) is already in
# inverted form
Expand Down
Loading

0 comments on commit 98a9292

Please sign in to comment.