Skip to content

Commit eeeabef

Browse files
committed
Address NumPy 2 data type promotion warnings
One of the changes in NumPy 2 is to the [behavior of type promotion](https://numpy.org/devdocs/numpy_2_0_migration_guide.html#changes-to-numpy-data-type-promotion). A possible negative impact of the changes is that some operations involving scalar types can lead to lower precision, or even overflow. For example, `uint8(100) + 200` previously (in Numpy < 2.0) produced a `unit16` value, but now results in a `unit8` value and an overflow _warning_ (not error). This can have an impact on Cirq. For example, in Cirq, simulator measurement result values are `uint8`'s, and in some places, arrays of values are summed; this leads to overflows if the sum > 128. It would not be appropriate to change measurement values to be larger than `uint8`, so in cases like this, the proper solution is probably to make sure that where values are summed or otherwise numerically manipulated, `uint16` or larger values are ensured. NumPy 2 offers a new option (`np._set_promotion_state("weak_and_warn")`) to produce warnings where data types are changed. Commit 6cf50eb adds a new command-line to our pytest framework, such that running ```bash check/pytest --warn-numpy-data-promotion ``` will turn on this NumPy setting. Running `check/pytest` with this option enabled revealed quite a lot of warnings. The present commit changes code in places where those warnings were raised, in an effort to eliminate as many of them as possible. It is certainly the case that not all of the type promotion warnings are meaningful. Unfortunately, I found it sometimes difficult to be sure of which ones _are_ meaningful, in part because Cirq's code has many layers and uses ndarrays a lot, and understanding the impact of a type demotion (say, from `float64` to `float32`) was difficult for me to do. In view of this, I wanted to err on the side of caution and try to avoid losses of precision. The principles followed in the changes are roughly the following: * Don't worry about warnings about changes from `complex64` to `complex128`, as this obviously does not reduce precision. * If a warning involves an operation using an ndarray, change the code to try to get the actual data type of the data elements in the array rather than use a specific data type. This is the reason some of the changes look like the following, where it reaches into an ndarray to get the dtype of an element and then later uses the `.type()` method of that dtype to cast the value of something else: ```python dtype = args.target_tensor.flat[0].dtype ..... args.target_tensor[subspace] *= dtype.type(x) ``` * In cases where the above was not possible, or where it was obvious what the type must always be, the changes add type casts with explicit types like `complex(x)` or `np.float64(x)`. It is likely that this approach resulted in some unnecessary up-promotion of values and may have impacted run-time performance. Some simple overall timing of `check/pytest` did not reveal a glaring negative impact of the changes, but that doesn't mean real applications won't be impacted. Perhaps a future review can evaluate whether speedups are possible.
1 parent 37899df commit eeeabef

22 files changed

+80
-60
lines changed

cirq-core/cirq/contrib/quantum_volume/quantum_volume.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def compute_heavy_set(circuit: cirq.Circuit) -> List[int]:
8686
# The output wave function is a vector from the result value (big-endian) to
8787
# the probability of that bit-string. Return all of the bit-string
8888
# values that have a probability greater than the median.
89-
return [idx for idx, amp in enumerate(results.state_vector()) if np.abs(amp**2) > median]
89+
results_vector = results.state_vector()
90+
return [idx for idx, amp in enumerate(results_vector)
91+
if np.abs(np.square(amp)) > median]
9092

9193

9294
@dataclass

cirq-core/cirq/devices/grid_qubit.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def __add__(self, other: Union[Tuple[int, int], Self]) -> Self:
152152
'Can only add integer tuples of length 2 to '
153153
f'{type(self).__name__}. Instead was {other}'
154154
)
155-
return self._with_row_col(row=self._row + other[0], col=self._col + other[1])
155+
return self._with_row_col(row=np.int64(self._row) + other[0],
156+
col=np.int64(self._col) + other[1])
156157

157158
def __sub__(self, other: Union[Tuple[int, int], Self]) -> Self:
158159
if isinstance(other, _BaseGridQid):
@@ -171,7 +172,8 @@ def __sub__(self, other: Union[Tuple[int, int], Self]) -> Self:
171172
"Can only subtract integer tuples of length 2 to "
172173
f"{type(self).__name__}. Instead was {other}"
173174
)
174-
return self._with_row_col(row=self._row - other[0], col=self._col - other[1])
175+
return self._with_row_col(row=np.int64(self._row) - other[0],
176+
col=np.int64(self._col) - other[1])
175177

176178
def __radd__(self, other: Tuple[int, int]) -> Self:
177179
return self + other

cirq-core/cirq/experiments/fidelity_estimation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def xeb_fidelity(
192192
output_probabilities = state_vector_to_probabilities(output_state)
193193
bitstring_probabilities = output_probabilities[bitstrings].tolist()
194194
else:
195-
bitstring_probabilities = [abs(amplitudes[bitstring]) ** 2 for bitstring in bitstrings]
195+
bitstring_probabilities = [np.abs(amplitudes[bitstring], dtype=float) ** 2
196+
for bitstring in bitstrings]
196197
return estimator(dim, bitstring_probabilities)
197198

198199

cirq-core/cirq/linalg/transformations.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def dephase(v):
9292
if r == 0:
9393
return 1j if i < 0 else -1j
9494

95-
return np.exp(-1j * np.arctan2(i, r))
95+
return np.exp(-1j * complex(np.arctan2(i, r)), dtype=v.dtype)
9696

9797
# Zero the phase at this entry in both matrices.
9898
return a * dephase(a[k]), b * dephase(b[k])
@@ -237,13 +237,14 @@ def _build_from_slices(
237237
"""
238238
d = len(source.shape)
239239
out[...] = 0
240+
dtype = source.flat[0].dtype
240241
for arg in args:
241242
source_slice: List[Any] = [slice(None)] * d
242243
target_slice: List[Any] = [slice(None)] * d
243244
for sleis in arg.slices:
244245
source_slice[sleis.axis] = sleis.source_index
245246
target_slice[sleis.axis] = sleis.target_index
246-
out[tuple(target_slice)] += arg.scale * source[tuple(source_slice)]
247+
out[tuple(target_slice)] += dtype.type(arg.scale) * source[tuple(source_slice)]
247248
return out
248249

249250

@@ -564,7 +565,8 @@ def sub_state_vector(
564565
best_candidate = max(candidates, key=lambda c: float(np.linalg.norm(c, 2)))
565566
best_candidate = best_candidate / np.linalg.norm(best_candidate)
566567
left = np.conj(best_candidate.reshape((keep_dims,))).T
567-
coherence_measure = sum([abs(np.dot(left, c.reshape((keep_dims,)))) ** 2 for c in candidates])
568+
coherence_measure = sum([np.abs(np.dot(left, c.reshape((keep_dims,))), dtype=float) ** 2
569+
for c in candidates])
568570

569571
if protocols.approx_eq(coherence_measure, 1, atol=atol):
570572
return np.exp(2j * np.pi * np.random.random()) * best_candidate.reshape(ret_shape)

cirq-core/cirq/ops/clifford_gate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def pauli_tuple(self, pauli: Pauli) -> Tuple[Pauli, bool]:
673673
to = x_to * z_to # Y = iXZ
674674
to._coefficient *= 1j
675675
# pauli_mask returns a value between 0 and 4 for [I, X, Y, Z].
676-
to_gate = Pauli._XYZ[to.pauli_mask[0] - 1]
676+
to_gate = Pauli._XYZ[to.pauli_mask[0] - np.uint8(1)]
677677
return (to_gate, bool(to.coefficient != 1.0))
678678

679679
def dense_pauli_string(self, pauli: Pauli) -> 'cirq.DensePauliString':

cirq-core/cirq/ops/common_gates.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,8 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
426426
return NotImplemented
427427
zero = args.subspace_index(0)
428428
one = args.subspace_index(1)
429-
args.available_buffer[zero] = -1j * args.target_tensor[one]
430-
args.available_buffer[one] = 1j * args.target_tensor[zero]
429+
args.available_buffer[zero] = np.complex128(-1j) * args.target_tensor[one]
430+
args.available_buffer[one] = np.complex128(1j) * args.target_tensor[zero]
431431
p = 1j ** (2 * self._exponent * self._global_shift)
432432
if p != 1:
433433
args.available_buffer *= p
@@ -542,7 +542,7 @@ def __init__(self, *, rads: value.TParamVal):
542542
rads: Radians to rotate about the Y axis of the Bloch sphere.
543543
"""
544544
self._rads = rads
545-
super().__init__(exponent=rads / _pi(rads), global_shift=-0.5)
545+
super().__init__(exponent=rads / _pi(rads), global_shift=float(-0.5))
546546

547547
def _with_exponent(self, exponent: value.TParamVal) -> 'Ry':
548548
return Ry(rads=exponent * _pi(exponent))
@@ -638,10 +638,11 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
638638
if protocols.is_parameterized(self):
639639
return None
640640

641+
dtype = args.target_tensor.flat[0].dtype
641642
for i in range(1, self._dimension):
642643
subspace = args.subspace_index(i)
643644
c = 1j ** (self._exponent * 4 * i / self._dimension)
644-
args.target_tensor[subspace] *= c
645+
args.target_tensor[subspace] *= dtype.type(c)
645646
p = 1j ** (2 * self._exponent * self._global_shift)
646647
if p != 1:
647648
args.target_tensor *= p
@@ -991,11 +992,12 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
991992

992993
zero = args.subspace_index(0)
993994
one = args.subspace_index(1)
995+
dtype = args.target_tensor.flat[0].dtype
994996
args.target_tensor[one] -= args.target_tensor[zero]
995-
args.target_tensor[one] *= -0.5
997+
args.target_tensor[one] *= -dtype.type(0.5)
996998
args.target_tensor[zero] -= args.target_tensor[one]
997999
p = 1j ** (2 * self._exponent * self._global_shift)
998-
args.target_tensor *= np.sqrt(2) * p
1000+
args.target_tensor *= np.sqrt(2, dtype=dtype) * dtype.type(p)
9991001
return args.target_tensor
10001002

10011003
def _decompose_(self, qubits):
@@ -1005,7 +1007,6 @@ def _decompose_(self, qubits):
10051007
yield cirq.Y(q) ** 0.5
10061008
yield cirq.XPowGate(global_shift=-0.25 + self.global_shift).on(q)
10071009
return
1008-
10091010
yield YPowGate(exponent=0.25).on(q)
10101011
yield XPowGate(exponent=self._exponent, global_shift=self.global_shift).on(q)
10111012
yield YPowGate(exponent=-0.25).on(q)
@@ -1097,8 +1098,9 @@ def _apply_unitary_(
10971098
if protocols.is_parameterized(self):
10981099
return NotImplemented
10991100

1100-
c = 1j ** (2 * self._exponent)
11011101
one_one = args.subspace_index(0b11)
1102+
dtype = args.target_tensor[one_one].dtype
1103+
c = dtype.type(1j ** (2 * self._exponent))
11021104
args.target_tensor[one_one] *= c
11031105
p = 1j ** (2 * self._exponent * self._global_shift)
11041106
if p != 1:

cirq-core/cirq/ops/dense_pauli_string.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def __mul__(self, other):
277277
if split is not None:
278278
p, i = split
279279
mask = np.copy(self.pauli_mask)
280-
mask[i] ^= p
280+
mask[i] ^= np.int64(p)
281281
return concrete_class(
282282
pauli_mask=mask,
283283
coefficient=self.coefficient * _vectorized_pauli_mul_phase(self.pauli_mask[i], p),
@@ -293,7 +293,7 @@ def __rmul__(self, other):
293293
if split is not None:
294294
p, i = split
295295
mask = np.copy(self.pauli_mask)
296-
mask[i] ^= p
296+
mask[i] ^= np.int64(p)
297297
return type(self)(
298298
pauli_mask=mask,
299299
coefficient=self.coefficient * _vectorized_pauli_mul_phase(p, self.pauli_mask[i]),
@@ -552,7 +552,7 @@ def __imul__(self, other):
552552
if split is not None:
553553
p, i = split
554554
self._coefficient *= _vectorized_pauli_mul_phase(self.pauli_mask[i], p)
555-
self.pauli_mask[i] ^= p
555+
self.pauli_mask[i] ^= np.int64(p)
556556
return self
557557

558558
return NotImplemented

cirq-core/cirq/ops/fourier_transform.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,10 @@ def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'):
138138
return NotImplemented
139139

140140
n = int(np.prod([args.target_tensor.shape[k] for k in args.axes], dtype=np.int64))
141+
dtype = args.target_tensor.flat[0].dtype
141142
for i in range(n):
142143
p = 1j ** (4 * i / n * self.exponent)
143-
args.target_tensor[args.subspace_index(big_endian_bits_int=i)] *= p
144+
args.target_tensor[args.subspace_index(big_endian_bits_int=i)] *= dtype.type(p)
144145

145146
return args.target_tensor
146147

cirq-core/cirq/ops/parity_gates.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,8 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
335335
if global_phase != 1:
336336
args.target_tensor *= global_phase
337337

338-
relative_phase = 1j ** (2 * self.exponent)
338+
dtype = args.target_tensor.flat[0].dtype
339+
relative_phase = dtype.type(1j ** (2 * self.exponent))
339340
zo = args.subspace_index(0b01)
340341
oz = args.subspace_index(0b10)
341342
args.target_tensor[oz] *= relative_phase

cirq-core/cirq/ops/pauli_string.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def _expectation_from_density_matrix_no_validation(
740740
while any(result.shape):
741741
result = np.trace(result, axis1=0, axis2=len(result.shape) // 2)
742742

743-
return float(np.real(result * self.coefficient))
743+
return float(np.real(result * result.dtype.type(self.coefficient)))
744744

745745
def zip_items(
746746
self, other: 'cirq.PauliString[TKey]'

cirq-core/cirq/ops/pauli_string_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def test_expectation_from_state_vector_invalid_input():
980980
rho_or_wf = 0.5 * np.ones((2, 2), dtype=np.complex64)
981981
_ = ps.expectation_from_state_vector(rho_or_wf, q_map)
982982

983-
wf = np.arange(16, dtype=np.complex64) / np.linalg.norm(np.arange(16))
983+
wf = np.arange(16, dtype=np.complex64) / np.linalg.norm(np.arange(16, dtype=np.complex64))
984984
with pytest.raises(ValueError, match='shape'):
985985
ps.expectation_from_state_vector(wf.reshape((16, 1)), q_map_2)
986986
with pytest.raises(ValueError, match='shape'):

cirq-core/cirq/ops/swap_gates.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> Optional[np.nda
244244
args.available_buffer[zo] = args.target_tensor[zo]
245245
args.target_tensor[zo] = args.target_tensor[oz]
246246
args.target_tensor[oz] = args.available_buffer[zo]
247-
args.target_tensor[zo] *= 1j
248-
args.target_tensor[oz] *= 1j
247+
args.target_tensor[zo] *= args.target_tensor[zo].dtype.type(1j)
248+
args.target_tensor[oz] *= args.target_tensor[oz].dtype.type(1j)
249249
p = 1j ** (2 * self._exponent * self._global_shift)
250250
if p != 1:
251251
args.target_tensor *= p

cirq-core/cirq/qis/measures.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def _numpy_arrays_to_state_vectors_or_density_matrices(
234234
def _fidelity_state_vectors_or_density_matrices(state1: np.ndarray, state2: np.ndarray) -> float:
235235
if state1.ndim == 1 and state2.ndim == 1:
236236
# Both state vectors
237-
return np.abs(np.vdot(state1, state2)) ** 2
237+
return np.abs(np.vdot(state1, state2), dtype=float) ** 2
238238
elif state1.ndim == 1 and state2.ndim == 2:
239239
# state1 is a state vector and state2 is a density matrix
240240
return np.real(np.conjugate(state1) @ state2 @ state1)
@@ -245,7 +245,7 @@ def _fidelity_state_vectors_or_density_matrices(state1: np.ndarray, state2: np.n
245245
# Both density matrices
246246
state1_sqrt = _sqrt_positive_semidefinite_matrix(state1)
247247
eigs = linalg.eigvalsh(state1_sqrt @ state2 @ state1_sqrt)
248-
trace = np.sum(np.sqrt(np.abs(eigs)))
248+
trace = np.sum(np.sqrt(np.abs(eigs, dtype=float)))
249249
return trace**2
250250
raise ValueError(
251251
'The given arrays must be one- or two-dimensional. '

cirq-core/cirq/qis/states.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,8 @@ def bloch_vector_from_state_vector(
609609
"""
610610
rho = density_matrix_from_state_vector(state_vector, [index], qid_shape=qid_shape)
611611
v = np.zeros(3, dtype=np.float32)
612-
v[0] = 2 * np.real(rho[0][1])
613-
v[1] = 2 * np.imag(rho[1][0])
612+
v[0] = np.float32(2) * np.real(rho[0][1])
613+
v[1] = np.float32(2) * np.imag(rho[1][0])
614614
v[2] = np.real(rho[0][0] - rho[1][1])
615615

616616
return v
@@ -738,7 +738,8 @@ def dirac_notation(
738738
ket = "|{}⟩"
739739
for x in range(len(perm_list)):
740740
format_str = "({:." + str(decimals) + "g})"
741-
val = round(state_vector[x].real, decimals) + 1j * round(state_vector[x].imag, decimals)
741+
val = (round(state_vector[x].real, decimals)
742+
+ np.complex128(1j) * round(state_vector[x].imag, decimals))
742743

743744
if round(val.real, decimals) == 0 and round(val.imag, decimals) != 0:
744745
val = val.imag

cirq-core/cirq/sim/clifford/clifford_simulator_test.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,25 @@ def test_run_no_repetitions():
2323
simulator = cirq.CliffordSimulator()
2424
circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
2525
result = simulator.run(circuit, repetitions=0)
26-
assert sum(result.measurements['q(0)']) == 0
26+
assert sum(result.measurements['q(0)'].astype(np.uint16)) == 0
2727

2828

2929
def test_run_hadamard():
3030
q0 = cirq.LineQubit(0)
3131
simulator = cirq.CliffordSimulator()
3232
circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
3333
result = simulator.run(circuit, repetitions=100)
34-
assert sum(result.measurements['q(0)'])[0] < 80
35-
assert sum(result.measurements['q(0)'])[0] > 20
34+
assert sum(result.measurements['q(0)'].astype(np.uint16))[0] < 80
35+
assert sum(result.measurements['q(0)'].astype(np.uint16))[0] > 20
3636

3737

3838
def test_run_GHZ():
3939
(q0, q1) = (cirq.LineQubit(0), cirq.LineQubit(1))
4040
simulator = cirq.CliffordSimulator()
4141
circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.measure(q0))
4242
result = simulator.run(circuit, repetitions=100)
43-
assert sum(result.measurements['q(0)'])[0] < 80
44-
assert sum(result.measurements['q(0)'])[0] > 20
43+
assert sum(result.measurements['q(0)'].astype(np.uint16))[0] < 80
44+
assert sum(result.measurements['q(0)'].astype(np.uint16))[0] > 20
4545

4646

4747
def test_run_correlations():
@@ -392,8 +392,8 @@ def test_clifford_circuit_2(qubits, split):
392392
circuit.append(cirq.measure(qubits[0]))
393393
result = cirq.CliffordSimulator(split_untangled_states=split).run(circuit, repetitions=100)
394394

395-
assert sum(result.measurements['q(0)'])[0] < 80
396-
assert sum(result.measurements['q(0)'])[0] > 20
395+
assert sum(result.measurements['q(0)'].astype(np.uint16))[0] < 80
396+
assert sum(result.measurements['q(0)'].astype(np.uint16))[0] > 20
397397

398398

399399
@pytest.mark.parametrize('split', [True, False])

cirq-core/cirq/sim/sparse_simulator_test.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_run_repetitions_terminal_measurement_stochastic():
115115
q = cirq.LineQubit(0)
116116
c = cirq.Circuit(cirq.H(q), cirq.measure(q, key='q'))
117117
results = cirq.Simulator().run(c, repetitions=10000)
118-
assert 1000 <= sum(v[0] for v in results.measurements['q']) < 9000
118+
assert 1000 <= sum(np.int64(v[0]) for v in results.measurements['q']) < 9000
119119

120120

121121
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
@@ -255,7 +255,7 @@ def test_run_mixture(dtype: Type[np.complexfloating], split: bool):
255255
simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
256256
circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0))
257257
result = simulator.run(circuit, repetitions=100)
258-
assert 20 < sum(result.measurements['q(0)'])[0] < 80
258+
assert 20 < sum(result.measurements['q(0)'].astype(np.uint16))[0] < 80
259259

260260

261261
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
@@ -265,8 +265,8 @@ def test_run_mixture_with_gates(dtype: Type[np.complexfloating], split: bool):
265265
simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split, seed=23)
266266
circuit = cirq.Circuit(cirq.H(q0), cirq.phase_flip(0.5)(q0), cirq.H(q0), cirq.measure(q0))
267267
result = simulator.run(circuit, repetitions=100)
268-
assert sum(result.measurements['q(0)'])[0] < 80
269-
assert sum(result.measurements['q(0)'])[0] > 20
268+
assert sum(result.measurements['q(0)'].astype(np.uint16))[0] < 80
269+
assert sum(result.measurements['q(0)'].astype(np.uint16))[0] > 20
270270

271271

272272
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
@@ -1385,7 +1385,7 @@ def test_noise_model():
13851385
simulator = cirq.Simulator(noise=noise_model)
13861386
result = simulator.run(circuit, repetitions=100)
13871387

1388-
assert 20 <= sum(result.measurements['q(0)'])[0] < 80
1388+
assert 20 <= sum(result.measurements['q(0)'].astype(np.uint16))[0] < 80
13891389

13901390

13911391
def test_separated_states_str_does_not_merge():

cirq-core/cirq/sim/state_vector_simulation_state.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def prepare_into_buffer(k: int):
230230

231231
for index in range(len(kraus_tensors)):
232232
prepare_into_buffer(index)
233-
weight = float(np.linalg.norm(self._buffer) ** 2)
233+
weight = float(np.linalg.norm(self._buffer)) ** 2
234234

235235
if weight > fallback_weight:
236236
fallback_weight_index = index
@@ -248,7 +248,7 @@ def prepare_into_buffer(k: int):
248248
weight = fallback_weight
249249
index = fallback_weight_index
250250

251-
self._buffer /= np.sqrt(weight)
251+
self._buffer /= np.sqrt(weight, dtype=self._buffer.dtype)
252252
self._swap_target_tensor_for(self._buffer)
253253
return index
254254

0 commit comments

Comments
 (0)