Skip to content

Commit 5501e08

Browse files
committed
#33: relax the constraints of floor_sum
1 parent c9838af commit 5501e08

File tree

5 files changed

+61
-27
lines changed

5 files changed

+61
-27
lines changed

atcoder/internal_math.hpp

+31-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ struct barrett {
4444
// -> im * m = 2^64 + r (0 <= r < m)
4545
// let z = a*b = c*m + d (0 <= c, d < m)
4646
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
47-
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
47+
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1)
48+
// < 2^64 * 2
4849
// ((ab * im) >> 64) == c or c + 1
4950
unsigned long long z = a;
5051
z *= b;
@@ -177,6 +178,35 @@ constexpr int primitive_root_constexpr(int m) {
177178
}
178179
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
179180

181+
// @param n `n < 2^32`
182+
// @param m `1 <= m < 2^32`
183+
// @return sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)
184+
unsigned long long floor_sum_unsigned(unsigned long long n,
185+
unsigned long long m,
186+
unsigned long long a,
187+
unsigned long long b) {
188+
unsigned long long ans = 0;
189+
while (true) {
190+
if (a >= m) {
191+
ans += n * (n - 1) / 2 * (a / m);
192+
a %= m;
193+
}
194+
if (b >= m) {
195+
ans += n * (b / m);
196+
b %= m;
197+
}
198+
199+
unsigned long long y_max = a * n + b;
200+
if (y_max < m) break;
201+
// y_max < m * (n + 1)
202+
// floor(y_max / m) <= n
203+
n = (unsigned long long)(y_max / m);
204+
b = (unsigned long long)(y_max % m);
205+
std::swap(m, a);
206+
}
207+
return ans;
208+
}
209+
180210
} // namespace internal
181211

182212
} // namespace atcoder

atcoder/math.hpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,20 @@ std::pair<long long, long long> crt(const std::vector<long long>& r,
8080
}
8181

8282
long long floor_sum(long long n, long long m, long long a, long long b) {
83-
long long ans = 0;
84-
if (a >= m) {
85-
ans += (n - 1) * n / 2 * (a / m);
86-
a %= m;
83+
assert(0 <= n && n < (1LL << 32));
84+
assert(1 <= m && m < (1LL << 32));
85+
unsigned long long ans = 0;
86+
if (a < 0) {
87+
unsigned long long a2 = internal::safe_mod(a, m);
88+
ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m);
89+
a = a2;
8790
}
88-
if (b >= m) {
89-
ans += n * (b / m);
90-
b %= m;
91+
if (b < 0) {
92+
unsigned long long b2 = internal::safe_mod(b, m);
93+
ans -= 1ULL * n * ((b2 - b) / m);
94+
b = b2;
9195
}
92-
93-
long long y_max = a * n + b;
94-
if (y_max < m) return ans;
95-
ans += floor_sum(y_max / m, a, m, y_max % m);
96-
return ans;
96+
return ans + internal::floor_sum_unsigned(n, m, a, b);
9797
}
9898

9999
} // namespace atcoder

document_en/math.md

+8-5
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,20 @@ $y, z$ $(0 \leq y < z = \mathrm{lcm}(m[i]))$. It returns this $(y, z)$ as a pair
6565
ll floor_sum(ll n, ll m, ll a, ll b)
6666
```
6767

68-
It returns $\sum_{i = 0}^{n - 1} \mathrm{floor}(\frac{a \times i + b}{m})$.
68+
It returns
69+
70+
$$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor$$
71+
72+
It returns the answer in $\bmod 2^{\mathrm{64}}$, if overflowed.
6973

7074
**@{keyword.constraints}**
7175

72-
- $0 \leq n \leq 10^9$
73-
- $1 \leq m \leq 10^9$
74-
- $0 \leq a, b \lt m$
76+
- $0 \leq n \leq 2^{32}$
77+
- $1 \leq m \leq 2^{32}$
7578

7679
**@{keyword.complexity}**
7780

78-
- $O(\log{(n+m+a+b)})$
81+
- $O(\log{(m+a)})$
7982

8083
## @{keyword.examples}
8184

document_ja/math.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,19 @@ $$x \equiv r[i] \pmod{m[i]}, \forall i \in \lbrace 0,1,\cdots, n - 1 \rbrace$$
6464
ll floor_sum(ll n, ll m, ll a, ll b)
6565
```
6666

67-
$\sum_{i = 0}^{n - 1} \mathrm{floor}(\frac{a \times i + b}{m})$
67+
$$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor$$
68+
69+
を返します。答えがオーバーフローしたならば $\bmod 2^{\mathrm{64}}$ で等しい値を返します。
6870

69-
を返します。
7071

7172
**@{keyword.constraints}**
7273

73-
- $0 \leq n \leq 10^9$
74-
- $1 \leq m \leq 10^9$
75-
- $0 \leq a, b \lt m$
74+
- $0 \leq n \leq 2^{32}$
75+
- $1 \leq m \leq 2^{32}$
7676

7777
**@{keyword.complexity}**
7878

79-
- $O(\log{(n+m+a+b)})$
79+
- $O(\log{(m+a)})$
8080

8181
## @{keyword.examples}
8282

test/unittest/math_test.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ ll pow_mod_naive(ll x, ull n, uint mod) {
2727
ll floor_sum_naive(ll n, ll m, ll a, ll b) {
2828
ll sum = 0;
2929
for (ll i = 0; i < n; i++) {
30-
sum += (a * i + b) / m;
30+
ll z = a * i + b;
31+
sum += (z - internal::safe_mod(z, m)) / m;
3132
}
3233
return sum;
3334
}
@@ -93,8 +94,8 @@ TEST(MathTest, InvModZero) {
9394
TEST(MathTest, FloorSum) {
9495
for (int n = 0; n < 20; n++) {
9596
for (int m = 1; m < 20; m++) {
96-
for (int a = 0; a < 20; a++) {
97-
for (int b = 0; b < 20; b++) {
97+
for (int a = -20; a < 20; a++) {
98+
for (int b = -20; b < 20; b++) {
9899
ASSERT_EQ(floor_sum_naive(n, m, a, b),
99100
floor_sum(n, m, a, b));
100101
}

0 commit comments

Comments
 (0)