Skip to content

Commit f79849d

Browse files
authored
fix: Explicit target for identity gate (#38)
Qubit count validation was failing due to identity gate target not being counted.
1 parent 77a0e8b commit f79849d

File tree

6 files changed

+27
-8
lines changed

6 files changed

+27
-8
lines changed

src/braket/default_simulator/gate_operations.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,21 @@ def from_braket_instruction(instruction) -> GateOperation:
4949
class Identity(GateOperation):
5050
"""Identity gate"""
5151

52+
def __init__(self, targets):
53+
self._targets = tuple(targets)
54+
5255
@property
5356
def matrix(self) -> np.ndarray:
5457
return np.eye(2)
5558

5659
@property
5760
def targets(self) -> Tuple[int, ...]:
58-
return ()
61+
return self._targets
5962

6063

6164
@from_braket_instruction.register(braket_instruction.I)
6265
def _i(instruction) -> Identity:
63-
return Identity()
66+
return Identity([instruction.target])
6467

6568

6669
class Hadamard(GateOperation):

src/braket/default_simulator/operation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ class Operation(ABC):
2929
def targets(self) -> Tuple[int, ...]:
3030
""" Tuple[int, ...]: The indices of the qubits the operation applies to.
3131
32-
For an index to be a target, the operation must have a nontrivial (i.e. non-identity) action
33-
on that index. For example, a tensor product observable with a Z factor on qubit j acts
34-
trivially on j, so j would not be a target.
32+
Note: For an index to be a target of an observable, the observable must have a nontrivial
33+
(i.e. non-identity) action on that index. For example, a tensor product observable with a
34+
Z factor on qubit j acts trivially on j, so j would not be a target. This does not apply to
35+
gate operations.
3536
"""
3637

3738

test/unit_tests/braket/default_simulator/test_gate_operations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from braket.ir.jaqcd import shared_models
2020

2121
testdata = [
22-
(instruction.I(target=4), (), gate_operations.Identity),
22+
(instruction.I(target=4), (4,), gate_operations.Identity),
2323
(instruction.H(target=13), (13,), gate_operations.Hadamard),
2424
(instruction.X(target=11), (11,), gate_operations.PauliX),
2525
(instruction.Y(target=10), (10,), gate_operations.PauliY),

test/unit_tests/braket/default_simulator/test_operation_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
invalid_hermitian_matrices = [(np.array([[1, 0], [0, 1j]])), (np.array([[1, 2], [3, 4]]))]
4343

4444
gate_testdata = [
45-
gate_operations.Identity(),
45+
gate_operations.Identity([0]),
4646
gate_operations.Hadamard([0]),
4747
gate_operations.PauliX([0]),
4848
gate_operations.PauliY([0]),

test/unit_tests/braket/default_simulator/test_simulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
),
5252
([gate_operations.V([0])], 1, [0.5 + 0.5j, 0.5 - 0.5j], [0.5, 0.5],),
5353
([gate_operations.Vi([0])], 1, [0.5 - 0.5j, 0.5 + 0.5j], [0.5, 0.5],),
54-
([gate_operations.Identity()], 1, [1, 0], [1, 0]),
54+
([gate_operations.Identity([0])], 1, [1, 0], [1, 0]),
5555
([gate_operations.Unitary([0], [[0, 1], [1, 0]])], 1, [0, 1], [0, 1]),
5656
(
5757
[gate_operations.PauliX([0]), gate_operations.PhaseShift([0], 0.15)],

test/unit_tests/braket/default_simulator/test_simulator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ def test_simulator_run_bell_pair(bell_ir, batch_size):
101101
assert result.additionalMetadata == AdditionalMetadata(action=bell_ir)
102102

103103

104+
def test_simulator_identity():
105+
simulator = DefaultSimulator()
106+
shots_count = 1000
107+
result = simulator.run(
108+
Program.parse_raw(
109+
json.dumps({"instructions": [{"type": "i", "target": 0}, {"type": "i", "target": 1}]})
110+
),
111+
qubit_count=2,
112+
shots=shots_count,
113+
)
114+
counter = Counter(["".join(measurement) for measurement in result.measurements])
115+
assert counter.keys() == {"00"}
116+
assert counter["00"] == shots_count
117+
118+
104119
@pytest.mark.xfail(raises=ValueError)
105120
def test_simulator_run_no_results_no_shots(bell_ir):
106121
simulator = DefaultSimulator()

0 commit comments

Comments
 (0)