Skip to content

Commit

Permalink
1qb unitary gates
Browse files Browse the repository at this point in the history
  • Loading branch information
WrathfulSpatula committed Apr 23, 2022
1 parent 2b4c09a commit 146833c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 18 deletions.
23 changes: 20 additions & 3 deletions pennylane_qrack/qrack_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ def apply_operations(self, operations):
elif isinstance(op, BasisState):
self._apply_basis_state(op)
elif isinstance(op, QubitUnitary):
raise DeviceError(
"Operation {} is not supported on a {} device.".format(op.name, self.short_name)
)
if len(op.wires) > 1:
raise DeviceError(
"Operation {} is not supported on a {} device, except for single wires.".format(op.name, self.short_name)
)
self._apply_qubit_unitary(op)
else:
self._apply_gate(op)

Expand Down Expand Up @@ -235,6 +237,21 @@ def _apply_gate(self, op):
"Operation {} is not supported on a {} device.".format(op.name, self.short_name)
)

def _apply_qubit_unitary(self, op):
"""Apply unitary to state"""
# translate op wire labels to consecutive wire labels used by the device
device_wires = self.map_wires(op.wires)
par = op.parameters

if len(par[0]) != 2 ** len(device_wires):
raise ValueError("Unitary matrix must be of shape (2**wires, 2**wires).")

if op.inverse:
par[0] = par[0].conj().T

matrix = par[0].flatten().tolist()
self._state.mtrx(matrix, device_wires.labels[0])

def analytic_probability(self, wires=None):
"""Return the (marginal) analytic probability of each computational basis state."""
if self._state is None:
Expand Down
30 changes: 15 additions & 15 deletions tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,21 @@ def test_two_qubit_no_parameters(self, init_state, op, mat, tol):
expected = mat @ state
assert np.allclose(res, expected, tol)

# @pytest.mark.parametrize("mat", [U, U2])
# def test_qubit_unitary(self, init_state, mat, tol):
# """Test QubitUnitary application"""
#
# N = int(np.log2(len(mat)))
# dev = QrackDevice(N)
# state = init_state(N)
#
# op = qml.QubitUnitary(mat, wires=list(range(N)))
# dev.apply([qml.QubitStateVector(state, wires=list(range(N))), op])
# dev._obs_queue = []
#
# res = dev.state
# expected = mat @ state
# assert np.allclose(res, expected, tol)
@pytest.mark.parametrize("mat", [U])
def test_qubit_unitary(self, init_state, mat, tol):
"""Test QubitUnitary application"""

N = int(np.log2(len(mat)))
dev = QrackDevice(N)
state = init_state(N)

op = qml.QubitUnitary(mat, wires=list(range(N)))
dev.apply([qml.QubitStateVector(state, wires=list(range(N))), op])
dev._obs_queue = []

res = dev.state
expected = mat @ state
assert np.allclose(res, expected, tol)

def test_invalid_qubit_state_unitary(self):
"""Test that an exception is raised if the
Expand Down

0 comments on commit 146833c

Please sign in to comment.