Skip to content

Commit 4f8f06a

Browse files
committed
Add Shr to u256
Float division requires some shift operations on big integers; implement right shift here.
1 parent 4797774 commit 4f8f06a

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

src/int/big.rs

+41-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ macro_rules! impl_common {
9393
type Output = Self;
9494

9595
fn shl(self, rhs: u32) -> Self::Output {
96-
todo!()
96+
unimplemented!("only used to meet trait bounds")
9797
}
9898
}
9999
};
@@ -102,6 +102,46 @@ macro_rules! impl_common {
102102
impl_common!(i256);
103103
impl_common!(u256);
104104

105+
impl ops::Shr<u32> for u256 {
106+
type Output = Self;
107+
108+
fn shr(self, rhs: u32) -> Self::Output {
109+
debug_assert!(rhs < Self::BITS, "attempted to shift right with overflow");
110+
111+
if rhs >= Self::BITS {
112+
// Only happens when not in debug mode
113+
return Self::ZERO;
114+
}
115+
116+
if rhs == 0 {
117+
return self;
118+
}
119+
120+
let mut ret = self;
121+
let byte_shift = rhs / 64;
122+
let bit_shift = rhs % 64;
123+
124+
for idx in 0..4 {
125+
let base_idx = idx + byte_shift as usize;
126+
127+
let Some(base) = ret.0.get(base_idx) else {
128+
ret.0[idx] = 0;
129+
continue;
130+
};
131+
132+
let mut new_val = base >> bit_shift;
133+
134+
if let Some(new) = ret.0.get(base_idx + 1) {
135+
new_val |= new.overflowing_shl(64 - bit_shift).0;
136+
}
137+
138+
ret.0[idx] = new_val;
139+
}
140+
141+
ret
142+
}
143+
}
144+
105145
macro_rules! word {
106146
(1, $val:expr) => {
107147
(($val >> (32 * 3)) & Self::from(WORD_LO_MASK)) as u64

testcrate/tests/big.rs

+73
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,76 @@ fn widen_mul_u128() {
5959
}
6060
assert!(errors.is_empty());
6161
}
62+
63+
#[test]
64+
fn not_u128() {
65+
assert_eq!(!u256::ZERO, u256::MAX);
66+
}
67+
68+
#[test]
69+
fn shr_u128() {
70+
let only_low = [
71+
1,
72+
u16::MAX.into(),
73+
u32::MAX.into(),
74+
u64::MAX.into(),
75+
u128::MAX,
76+
];
77+
78+
let mut errors = Vec::new();
79+
80+
for a in only_low {
81+
for perturb in 0..10 {
82+
let a = a.saturating_add(perturb);
83+
for shift in 0..128 {
84+
let res = a.widen() >> shift;
85+
let expected = (a >> shift).widen();
86+
if res != expected {
87+
errors.push((a.widen(), shift, res, expected));
88+
}
89+
}
90+
}
91+
}
92+
93+
let check = [
94+
(
95+
u256::MAX,
96+
1,
97+
u256([u64::MAX, u64::MAX, u64::MAX, u64::MAX >> 1]),
98+
),
99+
(
100+
u256::MAX,
101+
5,
102+
u256([u64::MAX, u64::MAX, u64::MAX, u64::MAX >> 5]),
103+
),
104+
(u256::MAX, 63, u256([u64::MAX, u64::MAX, u64::MAX, 1])),
105+
(u256::MAX, 64, u256([u64::MAX, u64::MAX, u64::MAX, 0])),
106+
(u256::MAX, 65, u256([u64::MAX, u64::MAX, u64::MAX >> 1, 0])),
107+
(u256::MAX, 127, u256([u64::MAX, u64::MAX, 1, 0])),
108+
(u256::MAX, 128, u256([u64::MAX, u64::MAX, 0, 0])),
109+
(u256::MAX, 129, u256([u64::MAX, u64::MAX >> 1, 0, 0])),
110+
(u256::MAX, 191, u256([u64::MAX, 1, 0, 0])),
111+
(u256::MAX, 192, u256([u64::MAX, 0, 0, 0])),
112+
(u256::MAX, 193, u256([u64::MAX >> 1, 0, 0, 0])),
113+
(u256::MAX, 191, u256([u64::MAX, 1, 0, 0])),
114+
(u256::MAX, 254, u256([0b11, 0, 0, 0])),
115+
(u256::MAX, 255, u256([1, 0, 0, 0])),
116+
];
117+
118+
for (input, shift, expected) in check {
119+
let res = input >> shift;
120+
if res != expected {
121+
errors.push((input, shift, res, expected));
122+
}
123+
}
124+
125+
for (a, b, res, expected) in &errors {
126+
eprintln!(
127+
"FAILURE: {} >> {b} = {} got {}",
128+
hexu(*a),
129+
hexu(*expected),
130+
hexu(*res),
131+
);
132+
}
133+
assert!(errors.is_empty());
134+
}

0 commit comments

Comments
 (0)