Skip to content

Commit c8e5270

Browse files
authored
Merge pull request #92 from atcoder/patch/issue33
#33: relax the constraints of floor_sum
2 parents 3974354 + 77e3a0b commit c8e5270

File tree

5 files changed

+58
-26
lines changed

5 files changed

+58
-26
lines changed

atcoder/internal_math.hpp

+29
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,35 @@ constexpr int primitive_root_constexpr(int m) {
177177
}
178178
template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
179179

180+
// @param n `n < 2^32`
181+
// @param m `1 <= m < 2^32`
182+
// @return sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)
183+
unsigned long long floor_sum_unsigned(unsigned long long n,
184+
unsigned long long m,
185+
unsigned long long a,
186+
unsigned long long b) {
187+
unsigned long long ans = 0;
188+
while (true) {
189+
if (a >= m) {
190+
ans += n * (n - 1) / 2 * (a / m);
191+
a %= m;
192+
}
193+
if (b >= m) {
194+
ans += n * (b / m);
195+
b %= m;
196+
}
197+
198+
unsigned long long y_max = a * n + b;
199+
if (y_max < m) break;
200+
// y_max < m * (n + 1)
201+
// floor(y_max / m) <= n
202+
n = (unsigned long long)(y_max / m);
203+
b = (unsigned long long)(y_max % m);
204+
std::swap(m, a);
205+
}
206+
return ans;
207+
}
208+
180209
} // namespace internal
181210

182211
} // namespace atcoder

atcoder/math.hpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,20 @@ std::pair<long long, long long> crt(const std::vector<long long>& r,
8080
}
8181

8282
long long floor_sum(long long n, long long m, long long a, long long b) {
83-
long long ans = 0;
84-
if (a >= m) {
85-
ans += (n - 1) * n / 2 * (a / m);
86-
a %= m;
83+
assert(0 <= n && n < (1LL << 32));
84+
assert(1 <= m && m < (1LL << 32));
85+
unsigned long long ans = 0;
86+
if (a < 0) {
87+
unsigned long long a2 = internal::safe_mod(a, m);
88+
ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m);
89+
a = a2;
8790
}
88-
if (b >= m) {
89-
ans += n * (b / m);
90-
b %= m;
91+
if (b < 0) {
92+
unsigned long long b2 = internal::safe_mod(b, m);
93+
ans -= 1ULL * n * ((b2 - b) / m);
94+
b = b2;
9195
}
92-
93-
long long y_max = a * n + b;
94-
if (y_max < m) return ans;
95-
ans += floor_sum(y_max / m, a, m, y_max % m);
96-
return ans;
96+
return ans + internal::floor_sum_unsigned(n, m, a, b);
9797
}
9898

9999
} // namespace atcoder

document_en/math.md

+8-5
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,20 @@ $y, z$ $(0 \leq y < z = \mathrm{lcm}(m[i]))$. It returns this $(y, z)$ as a pair
6565
ll floor_sum(ll n, ll m, ll a, ll b)
6666
```
6767

68-
It returns $\sum_{i = 0}^{n - 1} \mathrm{floor}(\frac{a \times i + b}{m})$.
68+
It returns
69+
70+
$$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor$$
71+
72+
It returns the answer in $\bmod 2^{\mathrm{64}}$, if overflowed.
6973

7074
**@{keyword.constraints}**
7175

72-
- $0 \leq n \leq 10^9$
73-
- $1 \leq m \leq 10^9$
74-
- $0 \leq a, b \lt m$
76+
- $0 \leq n \lt 2^{32}$
77+
- $1 \leq m \lt 2^{32}$
7578

7679
**@{keyword.complexity}**
7780

78-
- $O(\log{(n+m+a+b)})$
81+
- $O(\log{(m+a)})$
7982

8083
## @{keyword.examples}
8184

document_ja/math.md

+5-6
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,18 @@ $$x \equiv r[i] \pmod{m[i]}, \forall i \in \lbrace 0,1,\cdots, n - 1 \rbrace$$
6464
ll floor_sum(ll n, ll m, ll a, ll b)
6565
```
6666

67-
$\sum_{i = 0}^{n - 1} \mathrm{floor}(\frac{a \times i + b}{m})$
67+
$$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor$$
6868

69-
を返します。
69+
を返します。答えがオーバーフローしたならば $\bmod 2^{\mathrm{64}}$ で等しい値を返します。
7070

7171
**@{keyword.constraints}**
7272

73-
- $0 \leq n \leq 10^9$
74-
- $1 \leq m \leq 10^9$
75-
- $0 \leq a, b \lt m$
73+
- $0 \leq n \lt 2^{32}$
74+
- $1 \leq m \lt 2^{32}$
7675

7776
**@{keyword.complexity}**
7877

79-
- $O(\log{(n+m+a+b)})$
78+
- $O(\log{(m+a)})$
8079

8180
## @{keyword.examples}
8281

test/unittest/math_test.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ ll pow_mod_naive(ll x, ull n, uint mod) {
2727
ll floor_sum_naive(ll n, ll m, ll a, ll b) {
2828
ll sum = 0;
2929
for (ll i = 0; i < n; i++) {
30-
sum += (a * i + b) / m;
30+
ll z = a * i + b;
31+
sum += (z - internal::safe_mod(z, m)) / m;
3132
}
3233
return sum;
3334
}
@@ -93,8 +94,8 @@ TEST(MathTest, InvModZero) {
9394
TEST(MathTest, FloorSum) {
9495
for (int n = 0; n < 20; n++) {
9596
for (int m = 1; m < 20; m++) {
96-
for (int a = 0; a < 20; a++) {
97-
for (int b = 0; b < 20; b++) {
97+
for (int a = -20; a < 20; a++) {
98+
for (int b = -20; b < 20; b++) {
9899
ASSERT_EQ(floor_sum_naive(n, m, a, b),
99100
floor_sum(n, m, a, b));
100101
}

0 commit comments

Comments
 (0)