diff --git a/src/modint.rs b/src/modint.rs index 1445f5b..ba72d48 100644 --- a/src/modint.rs +++ b/src/modint.rs @@ -891,8 +891,8 @@ impl_basic_traits! { macro_rules! impl_bin_ops { () => {}; - (for<$generic_param:ident : $generic_param_bound:tt> <$lhs_ty:ty> ~ <$rhs_ty:ty> -> $output:ty { { $lhs_body:expr } ~ { $rhs_body:expr } } $($rest:tt)*) => { - impl <$generic_param: $generic_param_bound> Add<$rhs_ty> for $lhs_ty { + (for<$($generic_param:ident : $generic_param_bound:tt),*> <$lhs_ty:ty> ~ <$rhs_ty:ty> -> $output:ty { { $lhs_body:expr } ~ { $rhs_body:expr } } $($rest:tt)*) => { + impl <$($generic_param: $generic_param_bound),*> Add<$rhs_ty> for $lhs_ty { type Output = $output; #[inline] @@ -901,7 +901,7 @@ macro_rules! impl_bin_ops { } } - impl <$generic_param: $generic_param_bound> Sub<$rhs_ty> for $lhs_ty { + impl <$($generic_param: $generic_param_bound),*> Sub<$rhs_ty> for $lhs_ty { type Output = $output; #[inline] @@ -910,7 +910,7 @@ macro_rules! impl_bin_ops { } } - impl <$generic_param: $generic_param_bound> Mul<$rhs_ty> for $lhs_ty { + impl <$($generic_param: $generic_param_bound),*> Mul<$rhs_ty> for $lhs_ty { type Output = $output; #[inline] @@ -919,7 +919,7 @@ macro_rules! impl_bin_ops { } } - impl <$generic_param: $generic_param_bound> Div<$rhs_ty> for $lhs_ty { + impl <$($generic_param: $generic_param_bound),*> Div<$rhs_ty> for $lhs_ty { type Output = $output; #[inline] @@ -934,29 +934,29 @@ macro_rules! impl_bin_ops { macro_rules! impl_assign_ops { () => {}; - (for<$generic_param:ident : $generic_param_bound:tt> <$lhs_ty:ty> ~= <$rhs_ty:ty> { _ ~= { $rhs_body:expr } } $($rest:tt)*) => { - impl <$generic_param: $generic_param_bound> AddAssign<$rhs_ty> for $lhs_ty { + (for<$($generic_param:ident : $generic_param_bound:tt),*> <$lhs_ty:ty> ~= <$rhs_ty:ty> { _ ~= { $rhs_body:expr } } $($rest:tt)*) => { + impl <$($generic_param: $generic_param_bound),*> AddAssign<$rhs_ty> for $lhs_ty { #[inline] fn add_assign(&mut self, rhs: $rhs_ty) { *self = *self + apply($rhs_body, rhs); } } - impl <$generic_param: $generic_param_bound> SubAssign<$rhs_ty> for $lhs_ty { + impl <$($generic_param: $generic_param_bound),*> SubAssign<$rhs_ty> for $lhs_ty { #[inline] fn sub_assign(&mut self, rhs: $rhs_ty) { *self = *self - apply($rhs_body, rhs); } } - impl <$generic_param: $generic_param_bound> MulAssign<$rhs_ty> for $lhs_ty { + impl <$($generic_param: $generic_param_bound),*> MulAssign<$rhs_ty> for $lhs_ty { #[inline] fn mul_assign(&mut self, rhs: $rhs_ty) { *self = *self * apply($rhs_body, rhs); } } - impl <$generic_param: $generic_param_bound> DivAssign<$rhs_ty> for $lhs_ty { + impl <$($generic_param: $generic_param_bound),*> DivAssign<$rhs_ty> for $lhs_ty { #[inline] fn div_assign(&mut self, rhs: $rhs_ty) { *self = *self / apply($rhs_body, rhs); @@ -981,6 +981,9 @@ impl_bin_ops! { for > ~ <&'_ DynamicModInt> -> DynamicModInt { { |x| x } ~ { |&x| x } } for <&'_ DynamicModInt> ~ > -> DynamicModInt { { |&x| x } ~ { |x| x } } for <&'_ DynamicModInt> ~ <&'_ DynamicModInt> -> DynamicModInt { { |&x| x } ~ { |&x| x } } + + for > ~ -> StaticModInt { { |x| x } ~ { StaticModInt::::new } } + for > ~ -> DynamicModInt { { |x| x } ~ { DynamicModInt::::new } } } impl_assign_ops! { @@ -988,6 +991,9 @@ impl_assign_ops! { for > ~= <&'_ StaticModInt > { _ ~= { |&x| x } } for > ~= > { _ ~= { |x| x } } for > ~= <&'_ DynamicModInt> { _ ~= { |&x| x } } + + for > ~= { _ ~= { StaticModInt::::new } } + for > ~= { _ ~= { DynamicModInt::::new } } } macro_rules! impl_folding { @@ -1108,4 +1114,29 @@ mod tests { assert_eq!(ModInt1000000007::new(-120), product(&[-1, 2, -3, 4, -5])); } + + #[test] + fn static_modint_binop_coercion() { + let f = ModInt1000000007::new; + let a = 10_293_812_usize; + let b = 9_083_240_982_usize; + assert_eq!(f(a) + f(b), f(a) + b); + assert_eq!(f(a) - f(b), f(a) - b); + assert_eq!(f(a) * f(b), f(a) * b); + assert_eq!(f(a) / f(b), f(a) / b); + } + + #[test] + fn static_modint_assign_coercion() { + let f = ModInt1000000007::new; + let a = f(10_293_812_usize); + let b = 9_083_240_982_usize; + let expected = (((a + b) * b) - b) / b; + let mut c = a; + c += b; + c *= b; + c -= b; + c /= b; + assert_eq!(expected, c); + } }