Skip to content

Commit 6188948

Browse files
committed
Factor out half/full iterations
1 parent c4219eb commit 6188948

File tree

1 file changed

+88
-58
lines changed

1 file changed

+88
-58
lines changed

src/float/div.rs

+88-58
Original file line numberDiff line numberDiff line change
@@ -88,66 +88,47 @@ trait FloatDivision: Float
8888
where
8989
Self::Int: DInt,
9090
{
91-
/// Iterations that are done at half of the float's width, done for optimization.
92-
const HALF_ITERATIONS: usize;
91+
// /// Iterations that are done at half of the float's width, done for optimization.
92+
// const HALF_ITERATIONS: usize;
9393

94-
/// Iterations that are done at the full float's width. Must be at least one.
95-
const FULL_ITERATIONS: usize = 1;
94+
// /// Iterations that are done at the full float's width. Must be at least one.
95+
// const FULL_ITERATIONS: usize = 1;
9696

97-
const USE_NATIVE_FULL_ITERATIONS: bool = size_of::<Self>() < size_of::<*const ()>();
97+
// const USE_NATIVE_FULL_ITERATIONS: bool = size_of::<Self>() < size_of::<*const ()>();
9898

9999
/// C is (3/4 + 1/sqrt(2)) - 1 truncated to W0 fractional bits as UQ0.HW
100100
/// with W0 being either 16 or 32 and W0 <= HW.
101101
/// That is, C is the aforementioned 3/4 + 1/sqrt(2) constant (from which
102102
/// b/2 is subtracted to obtain x0) wrapped to [0, 1) range.
103103
const C_HW: HalfRep<Self>;
104104

105-
/// u_n for different precisions (with N-1 half-width iterations):
106-
/// W0 is the precision of C
107-
/// u_0 = (3/4 - 1/sqrt(2) + 2^-W0) * 2^HW
108-
///
109-
/// Estimated with bc:
110-
/// define half1(un) { return 2.0 * (un + un^2) / 2.0^hw + 1.0; }
111-
/// define half2(un) { return 2.0 * un / 2.0^hw + 2.0; }
112-
/// define full1(un) { return 4.0 * (un + 3.01) / 2.0^hw + 2.0 * (un + 3.01)^2 + 4.0; }
113-
/// define full2(un) { return 4.0 * (un + 3.01) / 2.0^hw + 8.0; }
114-
///
115-
/// | f32 (0 + 3) | f32 (2 + 1) | f64 (3 + 1) | f128 (4 + 1)
116-
/// u_0 | < 184224974 | < 2812.1 | < 184224974 | < 791240234244348797
117-
/// u_1 | < 15804007 | < 242.7 | < 15804007 | < 67877681371350440
118-
/// u_2 | < 116308 | < 2.81 | < 116308 | < 499533100252317
119-
/// u_3 | < 7.31 | | < 7.31 | < 27054456580
120-
/// u_4 | | | | < 80.4
121-
/// Final (U_N) | same as u_3 | < 72 | < 218 | < 13920
122-
///
123-
/// Add 2 to U_N due to final decrement.
124-
const RECIPROCAL_PRECISION: u16 = {
125-
// Do some related configuration validation
126-
if !Self::USE_NATIVE_FULL_ITERATIONS {
127-
if Self::FULL_ITERATIONS != 1 {
128-
panic!("Only a single emulated full iteration is supported");
129-
}
130-
if !(Self::HALF_ITERATIONS > 0) {
131-
panic!("Invalid number of half iterations");
132-
}
133-
}
134-
135-
if Self::FULL_ITERATIONS < 1 {
136-
panic!("Must have at least one full iteration");
137-
}
138-
139-
if Self::BITS == 32 && Self::HALF_ITERATIONS == 2 && Self::FULL_ITERATIONS == 1 {
140-
74u16
141-
} else if Self::BITS == 32 && Self::HALF_ITERATIONS == 0 && Self::FULL_ITERATIONS == 3 {
142-
10
143-
} else if Self::BITS == 64 && Self::HALF_ITERATIONS == 3 && Self::FULL_ITERATIONS == 1 {
144-
220
145-
} else if Self::BITS == 128 && Self::HALF_ITERATIONS == 4 && Self::FULL_ITERATIONS == 1 {
146-
13922
147-
} else {
148-
panic!("Invalid number of iterations")
149-
}
150-
};
105+
// const RECIPROCAL_PRECISION: u16 = {
106+
// // Do some related configuration validation
107+
// if !Self::USE_NATIVE_FULL_ITERATIONS {
108+
// if Self::FULL_ITERATIONS != 1 {
109+
// panic!("Only a single emulated full iteration is supported");
110+
// }
111+
// if !(Self::HALF_ITERATIONS > 0) {
112+
// panic!("Invalid number of half iterations");
113+
// }
114+
// }
115+
116+
// if Self::FULL_ITERATIONS < 1 {
117+
// panic!("Must have at least one full iteration");
118+
// }
119+
120+
// if Self::BITS == 32 && Self::HALF_ITERATIONS == 2 && Self::FULL_ITERATIONS == 1 {
121+
// 74u16
122+
// } else if Self::BITS == 32 && Self::HALF_ITERATIONS == 0 && Self::FULL_ITERATIONS == 3 {
123+
// 10
124+
// } else if Self::BITS == 64 && Self::HALF_ITERATIONS == 3 && Self::FULL_ITERATIONS == 1 {
125+
// 220
126+
// } else if Self::BITS == 128 && Self::HALF_ITERATIONS == 4 && Self::FULL_ITERATIONS == 1 {
127+
// 13922
128+
// } else {
129+
// panic!("Invalid number of iterations")
130+
// }
131+
// };
151132
}
152133

153134
/// Calculate the number of iterations required to get needed precision of a float type.
@@ -171,9 +152,57 @@ const fn calc_iterations<F: Float>() -> (usize, usize) {
171152
}
172153
}
173154

155+
/// u_n for different precisions (with N-1 half-width iterations):
156+
/// W0 is the precision of C
157+
/// u_0 = (3/4 - 1/sqrt(2) + 2^-W0) * 2^HW
158+
///
159+
/// Estimated with bc:
160+
/// define half1(un) { return 2.0 * (un + un^2) / 2.0^hw + 1.0; }
161+
/// define half2(un) { return 2.0 * un / 2.0^hw + 2.0; }
162+
/// define full1(un) { return 4.0 * (un + 3.01) / 2.0^hw + 2.0 * (un + 3.01)^2 + 4.0; }
163+
/// define full2(un) { return 4.0 * (un + 3.01) / 2.0^hw + 8.0; }
164+
///
165+
/// | f32 (0 + 3) | f32 (2 + 1) | f64 (3 + 1) | f128 (4 + 1)
166+
/// u_0 | < 184224974 | < 2812.1 | < 184224974 | < 791240234244348797
167+
/// u_1 | < 15804007 | < 242.7 | < 15804007 | < 67877681371350440
168+
/// u_2 | < 116308 | < 2.81 | < 116308 | < 499533100252317
169+
/// u_3 | < 7.31 | | < 7.31 | < 27054456580
170+
/// u_4 | | | | < 80.4
171+
/// Final (U_N) | same as u_3 | < 72 | < 218 | < 13920
172+
///
173+
/// Add 2 to U_N due to final decrement.
174+
const fn reciprocal_precision<F: Float>() -> u16 {
175+
let (half_iterations, full_iterations) = calc_iterations::<F>();
176+
177+
// if !Self::USE_NATIVE_FULL_ITERATIONS {
178+
// if Self::FULL_ITERATIONS != 1 {
179+
// panic!("Only a single emulated full iteration is supported");
180+
// }
181+
// if !(Self::HALF_ITERATIONS > 0) {
182+
// panic!("Invalid number of half iterations");
183+
// }
184+
// }
185+
186+
if full_iterations < 1 {
187+
panic!("Must have at least one full iteration");
188+
}
189+
190+
if F::BITS == 32 && half_iterations == 2 && full_iterations == 1 {
191+
74u16
192+
} else if F::BITS == 32 && half_iterations == 0 && full_iterations == 3 {
193+
10
194+
} else if F::BITS == 64 && half_iterations == 3 && full_iterations == 1 {
195+
220
196+
} else if F::BITS == 128 && half_iterations == 4 && full_iterations == 1 {
197+
13922
198+
} else {
199+
panic!("Invalid number of iterations")
200+
}
201+
}
202+
174203
impl FloatDivision for f32 {
175-
const HALF_ITERATIONS: usize = 0;
176-
const FULL_ITERATIONS: usize = 3;
204+
// const HALF_ITERATIONS: usize = 0;
205+
// const FULL_ITERATIONS: usize = 3;
177206

178207
/// Use 16-bit initial estimation in case we are using half-width iterations
179208
/// for float32 division. This is expected to be useful for some 16-bit
@@ -185,15 +214,15 @@ impl FloatDivision for f32 {
185214
}
186215

187216
impl FloatDivision for f64 {
188-
const HALF_ITERATIONS: usize = 3;
217+
// const HALF_ITERATIONS: usize = 3;
189218

190219
/// HW is at least 32. Shifting into the highest bits if needed.
191220
const C_HW: HalfRep<Self> = 0x7504F333 << (HalfRep::<Self>::BITS - 32);
192221
}
193222

194223
#[cfg(not(feature = "no-f16-f128"))]
195224
impl FloatDivision for f128 {
196-
const HALF_ITERATIONS: usize = 4;
225+
// const HALF_ITERATIONS: usize = 4;
197226

198227
// const C_HW: HalfRep<Self> = 0x7504F333 << (HalfRep::<Self>::BITS - 32);
199228
const C_HW: HalfRep<Self> = 0x7504f333f9de6108;
@@ -258,6 +287,7 @@ where
258287
let quiet_bit = implicit_bit >> 1;
259288
let qnan_rep = exponent_mask | quiet_bit;
260289
let (half_iterations, full_iterations) = calc_iterations::<F>();
290+
let recip_precision = reciprocal_precision::<F>();
261291

262292
let a_rep = a.repr();
263293
let b_rep = b.repr();
@@ -546,7 +576,7 @@ where
546576
if full_iterations > 1 {
547577
// Need to use concrete types since `F::Int::D` might not support math. So, restrict to
548578
// one type.
549-
assert!(F::BITS == 32, "native full iterations only supports f32");
579+
// assert!(F::BITS == 32, "native full iterations only supports f32");
550580

551581
for _ in 0..full_iterations {
552582
let corr_uq1: F::Int = zero.wrapping_sub(x_uq0.widen_mul(b_uq1).hi());
@@ -558,7 +588,7 @@ where
558588
x_uq0 = x_uq0.wrapping_sub(2.cast());
559589

560590
// Suppose 1/b - P * 2^-W < x < 1/b + P * 2^-W
561-
x_uq0 -= F::RECIPROCAL_PRECISION.cast();
591+
x_uq0 -= recip_precision.cast();
562592

563593
// Now 1/b - (2*P) * 2^-W < x < 1/b
564594
// FIXME Is x_UQ0 still >= 0.5?
@@ -640,7 +670,7 @@ where
640670
// conditionally turns the below LT comparison into LTE
641671
abs_result += u8::from(residual_lo > b_significand).into();
642672

643-
if F::BITS == 128 || (F::BITS == 32 && F::HALF_ITERATIONS > 0) {
673+
if F::BITS == 128 || (F::BITS == 32 && half_iterations > 0) {
644674
// Do not round Infinity to NaN
645675
abs_result +=
646676
u8::from(abs_result < inf_rep && residual_lo > (2 + 1).cast() * b_significand).into();

0 commit comments

Comments
 (0)