Skip to content

Commit 38624c9

Browse files
committed
fix: relax the constraints of floor_sum
1 parent d2b35ac commit 38624c9

File tree

2 files changed

+64
-15
lines changed

2 files changed

+64
-15
lines changed

src/internal_math.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,41 @@ pub(crate) fn primitive_root(m: i32) -> i32 {
235235
// omitted
236236
// template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
237237

238+
/// # Arguments
239+
/// * `n` `n < 2^32`
240+
/// * `m` `1 <= m < 2^32`
241+
///
242+
/// # Returns
243+
/// `sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)`
244+
/* const */
245+
#[allow(clippy::many_single_char_names)]
246+
pub(crate) fn floor_sum_unsigned(mut n: u64, mut m: u64, mut a: u64, mut b: u64) -> u64 {
247+
let mut ans = 0;
248+
loop {
249+
if a >= m {
250+
if n > 0 {
251+
ans += n * (n - 1) / 2 * (a / m);
252+
}
253+
a %= m;
254+
}
255+
if b >= m {
256+
ans += n * (b / m);
257+
b %= m;
258+
}
259+
260+
let y_max = a * n + b;
261+
if y_max < m {
262+
break;
263+
}
264+
// y_max < m * (n + 1)
265+
// floor(y_max / m) <= n
266+
n = y_max / m;
267+
b = y_max % m;
268+
std::mem::swap(&mut m, &mut a);
269+
}
270+
return ans;
271+
}
272+
238273
#[cfg(test)]
239274
mod tests {
240275
#![allow(clippy::unreadable_literal)]

src/math.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -186,24 +186,20 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
186186
/// assert_eq!(math::floor_sum(6, 5, 4, 3), 13);
187187
/// ```
188188
pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 {
189+
assert!(0 <= n && n < 1i64 << 32);
190+
assert!(1 <= m && m < 1i64 << 32);
189191
let mut ans = 0;
190-
if a >= m {
191-
ans += (n - 1) * n * (a / m) / 2;
192-
a %= m;
192+
if a < 0 {
193+
let a2 = internal_math::safe_mod(a, m);
194+
ans -= n * (n - 1) / 2 * ((a2 - a) / m);
195+
a = a2;
193196
}
194-
if b >= m {
195-
ans += n * (b / m);
196-
b %= m;
197+
if b < 0 {
198+
let b2 = internal_math::safe_mod(b, m);
199+
ans -= n * ((b2 - b) / m);
200+
b = b2;
197201
}
198-
199-
let y_max = (a * n + b) / m;
200-
let x_max = y_max * m - b;
201-
if y_max == 0 {
202-
return ans;
203-
}
204-
ans += (n - (x_max + a - 1) / a) * y_max;
205-
ans += floor_sum(y_max, a, m, (a - x_max % a) % a);
206-
ans
202+
ans + internal_math::floor_sum_unsigned(n as u64, m as u64, a as u64, b as u64) as i64
207203
}
208204

209205
#[cfg(test)]
@@ -306,5 +302,23 @@ mod tests {
306302
499_999_999_500_000_000
307303
);
308304
assert_eq!(floor_sum(332955, 5590132, 2231, 999423), 22014575);
305+
for n in 0..20 {
306+
for m in 1..20 {
307+
for a in -20..20 {
308+
for b in -20..20 {
309+
assert_eq!(floor_sum(n, m, a, b), floor_sum_naive(n, m, a, b));
310+
}
311+
}
312+
}
313+
}
314+
}
315+
316+
fn floor_sum_naive(n: i64, m: i64, a: i64, b: i64) -> i64 {
317+
let mut ans = 0;
318+
for i in 0..n {
319+
let z = a * i + b;
320+
ans += (z - internal_math::safe_mod(z, m)) / m;
321+
}
322+
ans
309323
}
310324
}

0 commit comments

Comments
 (0)