Skip to content

Commit 1420272

Browse files
committed
std::rand: full exponential & normal distributions
Complete the implementation of Exp and Normal started by Exp1 and StandardNormal by creating types implementing Sample & IndependentSample with the appropriate parameters.
1 parent 5aaef13 commit 1420272

File tree

1 file changed

+116
-20
lines changed

1 file changed

+116
-20
lines changed

src/libstd/rand/distributions.rs

Lines changed: 116 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,6 @@ fn ziggurat<R:Rng>(rng: &mut R,
9494
///
9595
/// Note that this has to be unwrapped before use as an `f64` (using either
9696
/// `*` or `cast::transmute` is safe).
97-
///
98-
/// # Example
99-
///
100-
/// ```
101-
/// use std::rand::distributions::StandardNormal;
102-
///
103-
/// fn main() {
104-
/// let normal = 2.0 + (*rand::random::<StandardNormal>()) * 3.0;
105-
/// println!("{} is from a N(2, 9) distribution", normal)
106-
/// }
107-
/// ```
10897
pub struct StandardNormal(f64);
10998

11099
impl Rand for StandardNormal {
@@ -142,23 +131,52 @@ impl Rand for StandardNormal {
142131
}
143132
}
144133

145-
/// A wrapper around an `f64` to generate Exp(1) random numbers. Dividing by
146-
/// the desired rate `lambda` will give Exp(lambda) distributed random
147-
/// numbers.
148-
///
149-
/// Note that this has to be unwrapped before use as an `f64` (using either
150-
/// `*` or `cast::transmute` is safe).
134+
/// The `N(mean, std_dev**2)` distribution, i.e. samples from a normal
135+
/// distribution with mean `mean` and standard deviation `std_dev`.
151136
///
152137
/// # Example
153138
///
154139
/// ```
155-
/// use std::rand::distributions::Exp1;
140+
/// use std::rand;
141+
/// use std::rand::distributions::{Normal, IndependentSample};
156142
///
157143
/// fn main() {
158-
/// let exp2 = (*rand::random::<Exp1>()) * 0.5;
159-
/// println!("{} is from a Exp(2) distribution", exp2);
144+
/// let normal = Normal::new(2.0, 3.0);
145+
/// let v = normal.ind_sample(rand::task_rng());
146+
/// println!("{} is from a N(2, 9) distribution", v)
160147
/// }
161148
/// ```
149+
pub struct Normal {
150+
priv mean: f64,
151+
priv std_dev: f64
152+
}
153+
154+
impl Normal {
155+
/// Construct a new `Normal` distribution with the given mean and
156+
/// standard deviation. Fails if `std_dev < 0`.
157+
pub fn new(mean: f64, std_dev: f64) -> Normal {
158+
assert!(std_dev >= 0.0, "Normal::new called with `std_dev` < 0");
159+
Normal {
160+
mean: mean,
161+
std_dev: std_dev
162+
}
163+
}
164+
}
165+
impl Sample<f64> for Normal {
166+
fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 { self.ind_sample(rng) }
167+
}
168+
impl IndependentSample<f64> for Normal {
169+
fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
170+
self.mean + self.std_dev * (*rng.gen::<StandardNormal>())
171+
}
172+
}
173+
174+
/// A wrapper around an `f64` to generate Exp(1) random numbers. Dividing by
175+
/// the desired rate `lambda` will give Exp(lambda) distributed random
176+
/// numbers.
177+
///
178+
/// Note that this has to be unwrapped before use as an `f64` (using either
179+
/// `*` or `cast::transmute` is safe).
162180
pub struct Exp1(f64);
163181

164182
// This could be done via `-rng.gen::<f64>().ln()` but that is slower.
@@ -181,10 +199,53 @@ impl Rand for Exp1 {
181199
}
182200
}
183201

202+
/// The `Exp(lambda)` distribution; i.e. samples from the exponential
203+
/// distribution with rate parameter `lambda`.
204+
///
205+
/// This distribution has density function: `f(x) = lambda *
206+
/// exp(-lambda * x)` for `x > 0`.
207+
///
208+
/// # Example
209+
///
210+
/// ```
211+
/// use std::rand;
212+
/// use std::rand::distributions::{Exp, IndependentSample};
213+
///
214+
/// fn main() {
215+
/// let exp = Exp::new(2.0);
216+
/// let v = exp.ind_sample(rand::task_rng());
217+
/// println!("{} is from a Exp(2) distribution", v);
218+
/// }
219+
/// ```
220+
pub struct Exp {
221+
/// `lambda` stored as `1/lambda`, since this is what we scale by.
222+
priv lambda_inverse: f64
223+
}
224+
225+
impl Exp {
226+
/// Construct a new `Exp` with the given shape parameter
227+
/// `lambda`. Fails if `lambda <= 0`.
228+
pub fn new(lambda: f64) -> Exp {
229+
assert!(lambda > 0.0, "Exp::new called with `lambda` <= 0");
230+
Exp { lambda_inverse: 1.0 / lambda }
231+
}
232+
}
233+
234+
impl Sample<f64> for Exp {
235+
fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 { self.ind_sample(rng) }
236+
}
237+
impl IndependentSample<f64> for Exp {
238+
fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
239+
(*rng.gen::<Exp1>()) * self.lambda_inverse
240+
}
241+
}
242+
184243
#[cfg(test)]
185244
mod tests {
186245
use rand::*;
187246
use super::*;
247+
use iter::range;
248+
use option::{Some, None};
188249

189250
struct ConstRand(uint);
190251
impl Rand for ConstRand {
@@ -200,4 +261,39 @@ mod tests {
200261
assert_eq!(*rand_sample.sample(task_rng()), 0);
201262
assert_eq!(*rand_sample.ind_sample(task_rng()), 0);
202263
}
264+
265+
#[test]
266+
fn test_normal() {
267+
let mut norm = Normal::new(10.0, 10.0);
268+
let rng = task_rng();
269+
for _ in range(0, 1000) {
270+
norm.sample(rng);
271+
norm.ind_sample(rng);
272+
}
273+
}
274+
#[test]
275+
#[should_fail]
276+
fn test_normal_invalid_sd() {
277+
Normal::new(10.0, -1.0);
278+
}
279+
280+
#[test]
281+
fn test_exp() {
282+
let mut exp = Exp::new(10.0);
283+
let rng = task_rng();
284+
for _ in range(0, 1000) {
285+
assert!(exp.sample(rng) >= 0.0);
286+
assert!(exp.ind_sample(rng) >= 0.0);
287+
}
288+
}
289+
#[test]
290+
#[should_fail]
291+
fn test_exp_invalid_lambda_zero() {
292+
Exp::new(0.0);
293+
}
294+
#[test]
295+
#[should_fail]
296+
fn test_exp_invalid_lambda_neg() {
297+
Exp::new(-10.0);
298+
}
203299
}

0 commit comments

Comments
 (0)