diff --git a/atcoder/internal_math.hpp b/atcoder/internal_math.hpp index 2f45a70..fc71d5a 100644 --- a/atcoder/internal_math.hpp +++ b/atcoder/internal_math.hpp @@ -177,6 +177,35 @@ constexpr int primitive_root_constexpr(int m) { } template constexpr int primitive_root = primitive_root_constexpr(m); +// @param n `n < 2^32` +// @param m `1 <= m < 2^32` +// @return sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64) +unsigned long long floor_sum_unsigned(unsigned long long n, + unsigned long long m, + unsigned long long a, + unsigned long long b) { + unsigned long long ans = 0; + while (true) { + if (a >= m) { + ans += n * (n - 1) / 2 * (a / m); + a %= m; + } + if (b >= m) { + ans += n * (b / m); + b %= m; + } + + unsigned long long y_max = a * n + b; + if (y_max < m) break; + // y_max < m * (n + 1) + // floor(y_max / m) <= n + n = (unsigned long long)(y_max / m); + b = (unsigned long long)(y_max % m); + std::swap(m, a); + } + return ans; +} + } // namespace internal } // namespace atcoder diff --git a/atcoder/math.hpp b/atcoder/math.hpp index 0b6e366..5de8736 100644 --- a/atcoder/math.hpp +++ b/atcoder/math.hpp @@ -80,20 +80,20 @@ std::pair crt(const std::vector& r, } long long floor_sum(long long n, long long m, long long a, long long b) { - long long ans = 0; - if (a >= m) { - ans += (n - 1) * n / 2 * (a / m); - a %= m; + assert(0 <= n && n < (1LL << 32)); + assert(1 <= m && m < (1LL << 32)); + unsigned long long ans = 0; + if (a < 0) { + unsigned long long a2 = internal::safe_mod(a, m); + ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m); + a = a2; } - if (b >= m) { - ans += n * (b / m); - b %= m; + if (b < 0) { + unsigned long long b2 = internal::safe_mod(b, m); + ans -= 1ULL * n * ((b2 - b) / m); + b = b2; } - - long long y_max = a * n + b; - if (y_max < m) return ans; - ans += floor_sum(y_max / m, a, m, y_max % m); - return ans; + return ans + internal::floor_sum_unsigned(n, m, a, b); } } // namespace atcoder diff --git a/document_en/math.md b/document_en/math.md index 8a51537..6154802 100644 --- a/document_en/math.md +++ b/document_en/math.md @@ -65,17 +65,20 @@ $y, z$ $(0 \leq y < z = \mathrm{lcm}(m[i]))$. It returns this $(y, z)$ as a pair ll floor_sum(ll n, ll m, ll a, ll b) ``` -It returns $\sum_{i = 0}^{n - 1} \mathrm{floor}(\frac{a \times i + b}{m})$. +It returns + +$$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor$$ + +It returns the answer in $\bmod 2^{\mathrm{64}}$, if overflowed. **@{keyword.constraints}** -- $0 \leq n \leq 10^9$ -- $1 \leq m \leq 10^9$ -- $0 \leq a, b \lt m$ +- $0 \leq n \lt 2^{32}$ +- $1 \leq m \lt 2^{32}$ **@{keyword.complexity}** -- $O(\log{(n+m+a+b)})$ +- $O(\log{(m+a)})$ ## @{keyword.examples} diff --git a/document_ja/math.md b/document_ja/math.md index db116de..7fa4caa 100644 --- a/document_ja/math.md +++ b/document_ja/math.md @@ -64,19 +64,18 @@ $$x \equiv r[i] \pmod{m[i]}, \forall i \in \lbrace 0,1,\cdots, n - 1 \rbrace$$ ll floor_sum(ll n, ll m, ll a, ll b) ``` -$\sum_{i = 0}^{n - 1} \mathrm{floor}(\frac{a \times i + b}{m})$ +$$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor$$ -を返します。 +を返します。答えがオーバーフローしたならば $\bmod 2^{\mathrm{64}}$ で等しい値を返します。 **@{keyword.constraints}** -- $0 \leq n \leq 10^9$ -- $1 \leq m \leq 10^9$ -- $0 \leq a, b \lt m$ +- $0 \leq n \lt 2^{32}$ +- $1 \leq m \lt 2^{32}$ **@{keyword.complexity}** -- $O(\log{(n+m+a+b)})$ +- $O(\log{(m+a)})$ ## @{keyword.examples} diff --git a/test/unittest/math_test.cpp b/test/unittest/math_test.cpp index c2d80a9..8cb2ee7 100644 --- a/test/unittest/math_test.cpp +++ b/test/unittest/math_test.cpp @@ -27,7 +27,8 @@ ll pow_mod_naive(ll x, ull n, uint mod) { ll floor_sum_naive(ll n, ll m, ll a, ll b) { ll sum = 0; for (ll i = 0; i < n; i++) { - sum += (a * i + b) / m; + ll z = a * i + b; + sum += (z - internal::safe_mod(z, m)) / m; } return sum; } @@ -93,8 +94,8 @@ TEST(MathTest, InvModZero) { TEST(MathTest, FloorSum) { for (int n = 0; n < 20; n++) { for (int m = 1; m < 20; m++) { - for (int a = 0; a < 20; a++) { - for (int b = 0; b < 20; b++) { + for (int a = -20; a < 20; a++) { + for (int b = -20; b < 20; b++) { ASSERT_EQ(floor_sum_naive(n, m, a, b), floor_sum(n, m, a, b)); }