Commit f74a8a11 authored by Brice Videau's avatar Brice Videau

Added strided sampling for distributions.

parent 888f6930
......@@ -136,6 +136,16 @@ ccs_distribution_samples(ccs_distribution_t distribution,
size_t num_values,
ccs_numeric_t *values);
// Stride between elements given in number of ccs_numeric_t.
// stride equal to the the distribution dimension is
// equivalent to ccs_distribution_samples
extern ccs_result_t
ccs_distribution_strided_samples(ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values);
#ifdef __cplusplus
}
#endif
......
......@@ -113,6 +113,20 @@ ccs_distribution_samples(ccs_distribution_t distribution,
return ops->samples(distribution->data, rng, num_values, values);
}
ccs_result_t
ccs_distribution_strided_samples(ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values) {
CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION);
if (!num_values)
return CCS_SUCCESS;
CCS_CHECK_ARY(num_values, values);
_ccs_distribution_ops_t *ops = ccs_distribution_get_ops(distribution);
return ops->strided_samples(distribution->data, rng, num_values, stride, values);
}
ccs_result_t
ccs_create_normal_float_distribution(ccs_float_t mu,
ccs_float_t sigma,
......
......@@ -17,6 +17,13 @@ struct _ccs_distribution_ops_s {
_ccs_distribution_data_t *distribution,
ccs_interval_t *interval_ret);
ccs_result_t (*strided_samples)(
_ccs_distribution_data_t *distribution,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values);
};
typedef struct _ccs_distribution_ops_s _ccs_distribution_ops_t;
......
......@@ -28,10 +28,18 @@ _ccs_distribution_normal_samples(_ccs_distribution_data_t *data,
size_t num_values,
ccs_numeric_t *values);
static ccs_result_t
_ccs_distribution_normal_strided_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values);
static _ccs_distribution_ops_t _ccs_distribution_normal_ops = {
{ &_ccs_distribution_del },
&_ccs_distribution_normal_samples,
&_ccs_distribution_normal_get_bounds
&_ccs_distribution_normal_get_bounds,
&_ccs_distribution_normal_strided_samples
};
static ccs_result_t
......@@ -207,6 +215,121 @@ _ccs_distribution_normal_samples(_ccs_distribution_data_t *data,
num_values, values);
}
static inline ccs_result_t
_ccs_distribution_normal_strided_samples_float(gsl_rng *grng,
const ccs_scale_type_t scale_type,
const ccs_float_t quantization,
const ccs_float_t mu,
const ccs_float_t sigma,
const int quantize,
size_t num_values,
size_t stride,
ccs_float_t *values) {
size_t i;
if (scale_type == CCS_LOGARITHMIC && quantize) {
ccs_float_t lq = log(quantization*0.5);
if (mu - lq >= 0.0)
//at least 50% chance to get a valid value
for (i = 0; i < num_values; i++)
do {
values[i*stride] = gsl_ran_gaussian(grng, sigma) + mu;
} while (values[i*stride] < lq);
else
//use tail distribution
for (i = 0; i < num_values; i++)
values[i*stride] = gsl_ran_gaussian_tail(grng, lq - mu, sigma) + mu;
} else
for (i = 0; i < num_values; i++)
values[i*stride] = gsl_ran_gaussian(grng, sigma) + mu;
if (scale_type == CCS_LOGARITHMIC)
for (i = 0; i < num_values; i++)
values[i*stride] = exp(values[i*stride]);
if (quantize) {
ccs_float_t rquantization = 1.0 / quantization;
for (i = 0; i < num_values; i++)
values[i*stride] = round(values[i*stride] * rquantization) * quantization;
}
return CCS_SUCCESS;
}
static inline ccs_result_t
_ccs_distribution_normal_strided_samples_int(gsl_rng *grng,
const ccs_scale_type_t scale_type,
const ccs_int_t quantization,
const ccs_float_t mu,
const ccs_float_t sigma,
const int quantize,
size_t num_values,
size_t stride,
ccs_numeric_t *values) {
size_t i;
ccs_float_t q;
if (quantize)
q = quantization*0.5;
else
q = 0.5;
if (scale_type == CCS_LOGARITHMIC) {
ccs_float_t lq = log(q);
if (mu - lq >= 0.0)
for (i = 0; i < num_values; i++)
do {
do {
values[i*stride].f = gsl_ran_gaussian(grng, sigma) + mu;
} while (values[i*stride].f < lq);
values[i*stride].f = exp(values[i*stride].f);
} while (unlikely(values[i*stride].f - q > CCS_INT_MAX));
else
for (i = 0; i < num_values; i++)
do {
values[i*stride].f = gsl_ran_gaussian_tail(grng, lq - mu, sigma) + mu;
values[i*stride].f = exp(values[i*stride].f);
} while (unlikely(values[i*stride].f - q > CCS_INT_MAX));
}
else
for (i = 0; i < num_values; i++)
do {
values[i*stride].f = gsl_ran_gaussian(grng, sigma) + mu;
} while (unlikely(values[i*stride].f - q > CCS_INT_MAX || values[i*stride].f + q < CCS_INT_MIN));
if (quantize) {
ccs_float_t rquantization = 1.0 / quantization;
for (i = 0; i < num_values; i++)
values[i*stride].i = (ccs_int_t)round(values[i*stride].f * rquantization) * quantization;
} else
for (i = 0; i < num_values; i++)
values[i*stride].i = round(values[i*stride].f);
return CCS_SUCCESS;
}
static ccs_result_t
_ccs_distribution_normal_strided_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values) {
_ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_numeric_t quantization = d->common_data.quantization;
const ccs_float_t mu = d->mu;
const ccs_float_t sigma = d->sigma;
const int quantize = d->quantize;
gsl_rng *grng;
ccs_result_t err = ccs_rng_get_gsl_rng(rng, &grng);
if (err)
return err;
if (data_type == CCS_NUM_FLOAT)
return _ccs_distribution_normal_strided_samples_float(grng, scale_type,
quantization.f, mu,
sigma, quantize,
num_values, stride,
(ccs_float_t*) values);
else
return _ccs_distribution_normal_strided_samples_int(grng, scale_type,
quantization.i, mu,
sigma, quantize,
num_values, stride, values);
}
extern ccs_result_t
ccs_create_normal_distribution(ccs_numeric_type_t data_type,
ccs_float_t mu,
......
......@@ -27,10 +27,18 @@ _ccs_distribution_roulette_samples(_ccs_distribution_data_t *data,
size_t num_values,
ccs_numeric_t *values);
static ccs_result_t
_ccs_distribution_roulette_strided_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values);
static _ccs_distribution_ops_t _ccs_distribution_roulette_ops = {
{ &_ccs_distribution_del },
&_ccs_distribution_roulette_samples,
&_ccs_distribution_roulette_get_bounds
&_ccs_distribution_roulette_get_bounds,
&_ccs_distribution_roulette_strided_samples
};
static ccs_result_t
......@@ -79,6 +87,40 @@ _ccs_distribution_roulette_samples(_ccs_distribution_data_t *data,
return CCS_SUCCESS;
}
static ccs_result_t
_ccs_distribution_roulette_strided_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values) {
_ccs_distribution_roulette_data_t *d = (_ccs_distribution_roulette_data_t *)data;
gsl_rng *grng;
ccs_result_t err = ccs_rng_get_gsl_rng(rng, &grng);
if (err)
return err;
for (size_t i = 0; i < num_values; i++) {
ccs_float_t rnd = gsl_rng_uniform(grng);
ccs_int_t upper = d->num_areas - 1;
ccs_int_t lower = 0;
ccs_int_t index = upper * rnd;
int found = 0;
while( !found ) {
if ( rnd < d->areas[index] ) {
upper = index - 1;
index = (lower+upper)/2;
} else if ( rnd >= d->areas[index+1] ) {
lower = index + 1;
index = (lower+upper)/2;
} else
found = 1;
}
values[i*stride].i = index;
}
return CCS_SUCCESS;
}
ccs_result_t
ccs_create_roulette_distribution(size_t num_areas,
ccs_float_t *areas,
......
......@@ -31,10 +31,18 @@ _ccs_distribution_uniform_samples(_ccs_distribution_data_t *data,
size_t num_values,
ccs_numeric_t *values);
static ccs_result_t
_ccs_distribution_uniform_strided_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values);
static _ccs_distribution_ops_t _ccs_distribution_uniform_ops = {
{ &_ccs_distribution_del },
&_ccs_distribution_uniform_samples,
&_ccs_distribution_uniform_get_bounds
&_ccs_distribution_uniform_get_bounds,
&_ccs_distribution_uniform_strided_samples
};
static ccs_result_t
......@@ -58,6 +66,67 @@ _ccs_distribution_uniform_get_bounds(_ccs_distribution_data_t *data,
return CCS_SUCCESS;
}
static ccs_result_t
_ccs_distribution_uniform_strided_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
size_t num_values,
size_t stride,
ccs_numeric_t *values) {
_ccs_distribution_uniform_data_t *d = (_ccs_distribution_uniform_data_t *)data;
size_t i;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_numeric_t quantization = d->common_data.quantization;
const ccs_numeric_t lower = d->lower;
const ccs_numeric_t internal_lower = d->internal_lower;
const ccs_numeric_t internal_upper = d->internal_upper;
const int quantize = d->quantize;
gsl_rng *grng;
ccs_result_t err = ccs_rng_get_gsl_rng(rng, &grng);
if (err)
return err;
if (data_type == CCS_NUM_FLOAT) {
for (i = 0; i < num_values; i++) {
values[i*stride].f = gsl_ran_flat(grng, internal_lower.f, internal_upper.f);
}
if (scale_type == CCS_LOGARITHMIC) {
for (i = 0; i < num_values; i++)
values[i*stride].f = exp(values[i*stride].f);
if (quantize)
for (i = 0; i < num_values; i++)
values[i*stride].f = floor((values[i*stride].f - lower.f)/quantization.f) * quantization.f + lower.f;
} else
if (quantize)
for (i = 0; i < num_values; i++)
values[i*stride].f = floor(values[i*stride].f) * quantization.f + lower.f;
else
for (i = 0; i < num_values; i++)
values[i*stride].f += lower.f;
} else {
if (scale_type == CCS_LOGARITHMIC) {
for (i = 0; i < num_values; i++) {
values[i*stride].i = floor(exp(gsl_ran_flat(grng, internal_lower.f, internal_upper.f)));
}
if (quantize)
for (i = 0; i < num_values; i++)
values[i*stride].i = ((values[i*stride].i - lower.i)/quantization.i) * quantization.i + lower.i;
} else {
for (i = 0; i < num_values; i++) {
values[i*stride].i = gsl_rng_uniform_int(grng, internal_upper.i);
}
if (quantize)
for (i = 0; i < num_values; i++)
values[i*stride].i = values[i*stride].i * quantization.i + lower.i;
else
for (i = 0; i < num_values; i++)
values[i*stride].i += lower.i;
}
}
return CCS_SUCCESS;
}
static ccs_result_t
_ccs_distribution_uniform_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
......
......@@ -47,7 +47,7 @@ static void test_create_normal_distribution() {
assert( err == CCS_SUCCESS );
assert( quantization.f == 0.0 );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_FLOAT );
assert( interval.lower.f == -CCS_INFINITY );
......@@ -149,7 +149,7 @@ static void test_normal_distribution_int() {
&distrib);
assert( err == CCS_SUCCESS );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_INTEGER );
assert( interval.lower.i == CCS_INT_MIN );
......@@ -196,7 +196,7 @@ static void test_normal_distribution_float() {
&distrib);
assert( err == CCS_SUCCESS );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_FLOAT );
assert( interval.lower.f == -CCS_INFINITY );
......@@ -243,7 +243,7 @@ static void test_normal_distribution_int_log() {
&distrib);
assert( err == CCS_SUCCESS );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_INTEGER );
assert( interval.lower.i == 1 );
......@@ -298,7 +298,7 @@ static void test_normal_distribution_float_log() {
&distrib);
assert( err == CCS_SUCCESS );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_FLOAT );
assert( interval.lower.f == 0.0 );
......@@ -346,7 +346,7 @@ static void test_normal_distribution_int_quantize() {
&distrib);
assert( err == CCS_SUCCESS );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_INTEGER );
assert( interval.lower.i == (CCS_INT_MIN/q)*q );
......@@ -393,7 +393,7 @@ static void test_normal_distribution_float_quantize() {
&distrib);
assert( err == CCS_SUCCESS );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_FLOAT );
assert( interval.lower.f == -CCS_INFINITY );
......@@ -441,7 +441,7 @@ static void test_normal_distribution_int_log_quantize() {
&distrib);
assert( err == CCS_SUCCESS );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_INTEGER );
assert( interval.lower.i == quantize );
......@@ -498,7 +498,7 @@ static void test_normal_distribution_float_log_quantize() {
&distrib);
assert( err == CCS_SUCCESS );
err = ccs_distribution_get_bounds(distrib, &interval);
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_FLOAT );
assert( interval.lower.f == quantization );
......@@ -530,6 +530,66 @@ static void test_normal_distribution_float_log_quantize() {
assert( err == CCS_SUCCESS );
}
static void test_normal_distribution_strided_samples() {
ccs_distribution_t distrib1 = NULL;
ccs_distribution_t distrib2 = NULL;
ccs_rng_t rng = NULL;
ccs_result_t err = CCS_SUCCESS;
const size_t num_samples = 10000;
const ccs_float_t mu1 = 1;
const ccs_float_t sigma1 = 2;
const ccs_float_t mu2 = 0;
const ccs_float_t sigma2 = 2;
ccs_numeric_t samples[num_samples*2];
double mean, sig;
err = ccs_rng_create(&rng);
assert( err == CCS_SUCCESS );
err = ccs_create_normal_distribution(
CCS_NUM_FLOAT,
mu1,
sigma1,
CCS_LINEAR,
CCSF(0.0),
&distrib1);
assert( err == CCS_SUCCESS );
err = ccs_create_normal_distribution(
CCS_NUM_FLOAT,
mu2,
sigma2,
CCS_LINEAR,
CCSF(0.0),
&distrib2);
assert( err == CCS_SUCCESS );
err = ccs_distribution_strided_samples(distrib1, rng, num_samples, 2, samples);
assert( err == CCS_SUCCESS );
err = ccs_distribution_strided_samples(distrib2, rng, num_samples, 2, &(samples[0]) + 1);
assert( err == CCS_SUCCESS );
mean = gsl_stats_mean((double*)samples, 2, num_samples);
assert( mean < mu1 + 0.1 );
assert( mean > mu1 - 0.1 );
sig = gsl_stats_sd_m((double*)samples, 2, num_samples, mu1);
assert( sig < sigma1 + 0.1 );
assert( sig > sigma1 - 0.1 );
mean = gsl_stats_mean((double*)samples + 1, 2, num_samples);
assert( mean < mu2 + 0.1 );
assert( mean > mu2 - 0.1 );
sig = gsl_stats_sd_m((double*)samples + 1, 2, num_samples, mu2);
assert( sig < sigma2 + 0.1 );
assert( sig > sigma2 - 0.1 );
err = ccs_release_object(distrib1);
assert( err == CCS_SUCCESS );
err = ccs_release_object(distrib2);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
assert( err == CCS_SUCCESS );
}
int main(int argc, char *argv[]) {
ccs_init();
test_create_normal_distribution();
......@@ -542,5 +602,6 @@ int main(int argc, char *argv[]) {
test_normal_distribution_float_log();
test_normal_distribution_float_quantize();
test_normal_distribution_float_log_quantize();
test_normal_distribution_strided_samples();
return 0;
}
......@@ -220,11 +220,90 @@ void test_roulette_distribution_zero() {
assert( err == CCS_SUCCESS );
}
void test_roulette_distribution_strided_sample() {
ccs_distribution_t distrib1 = NULL;
ccs_distribution_t distrib2 = NULL;
ccs_rng_t rng = NULL;
ccs_result_t err = CCS_SUCCESS;
const size_t num_samples = 10000;
ccs_numeric_t samples[num_samples*2];
const size_t num_areas = 4;
ccs_float_t areas1[num_areas];
ccs_float_t areas2[num_areas];
int counts1[num_areas];
int counts2[num_areas];
for(size_t i = 0; i < num_areas; i++) {
areas1[i] = (double)(i+1);
counts1[i] = 0;
areas2[i] = (double)(i+2);
counts2[i] = 0;
}
err = ccs_rng_create(&rng);
assert( err == CCS_SUCCESS );
err = ccs_create_roulette_distribution(
num_areas,
areas1,
&distrib1);
assert( err == CCS_SUCCESS );
err = ccs_create_roulette_distribution(
num_areas,
areas2,
&distrib2);
assert( err == CCS_SUCCESS );
err = ccs_distribution_strided_samples(distrib1, rng, num_samples, 2, samples);
assert( err == CCS_SUCCESS );
err = ccs_distribution_strided_samples(distrib2, rng, num_samples, 2, &(samples[0]) + 1);
assert( err == CCS_SUCCESS );
ccs_float_t sum = 0.0;
ccs_float_t inv_sum = 0.0;
for(size_t i = 0; i < num_areas; i++) {
sum += areas1[i];
}
inv_sum = 1.0 / sum;
for(size_t i = 0; i < num_samples; i++) {
assert( samples[2*i].i >=0 && samples[2*i].i < 4 );
counts1[samples[2*i].i]++;
}
for(size_t i = 0; i < num_areas; i++) {
ccs_float_t target = num_samples * areas1[i] * inv_sum;
assert( counts1[i] >= target * 0.95 && counts1[i] <= target * 1.05 );
}
sum = 0.0;
inv_sum = 0.0;
for(size_t i = 0; i < num_areas; i++) {
sum += areas2[i];
}
inv_sum = 1.0 / sum;
for(size_t i = 0; i < num_samples; i++) {
assert( samples[2*i+1].i >=0 && samples[2*i+1].i < 4 );
counts2[samples[2*i+1].i]++;
}
for(size_t i = 0; i < num_areas; i++) {
ccs_float_t target = num_samples * areas2[i] * inv_sum;
assert( counts2[i] >= target * 0.95 && counts2[i] <= target * 1.05 );
}
err = ccs_release_object(distrib1);
assert( err == CCS_SUCCESS );
err = ccs_release_object(distrib2);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
assert( err == CCS_SUCCESS );
}
int main(int argc, char *argv[]) {
ccs_init();
test_create_roulette_distribution();
test_create_roulette_distribution_errors();
test_roulette_distribution();
test_roulette_distribution_zero();
test_roulette_distribution_strided_sample();
return 0;
}
......@@ -401,6 +401,60 @@ static void test_uniform_distribution_float_quantize() {
assert( err == CCS_SUCCESS );
}
static void test_uniform_distribution_strided_samples() {
ccs_distribution_t distrib1 = NULL;
ccs_distribution_t distrib2 = NULL;
ccs_rng_t rng = NULL;
ccs_result_t err = CCS_SUCCESS;
const size_t num_samples = 100;
ccs_int_t lower1 = -10;
ccs_int_t upper1 = 11;
ccs_int_t lower2 = 12;
ccs_int_t upper2 = 20;
ccs_numeric_t samples[num_samples*2];
err = ccs_rng_create(&rng);
assert( err == CCS_SUCCESS );
err = ccs_create_uniform_distribution(
CCS_NUM_INTEGER,
CCSI(lower1),
CCSI(upper1),
CCS_LINEAR,
CCSI(0),
&distrib1);
assert( err == CCS_SUCCESS );
err = ccs_create_uniform_distribution(
CCS_NUM_INTEGER,
CCSI(lower2),
CCSI(upper2),
CCS_LINEAR,
CCSI(0),
&distrib2);
assert( err == CCS_SUCCESS );
err = ccs_distribution_strided_samples(distrib1, rng, num_samples, 2, samples);
err = ccs_distribution_strided_samples(distrib2, rng, num_samples, 2, &(samples[0])+1);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < num_samples; i++) {
assert(samples[i*2].i >= lower1);
assert(samples[i*2].i < upper1);
}
for (size_t i = 0; i < num_samples; i++) {
assert(samples[i*2+1].i >= lower2);
assert(samples[i*2+1].i < upper2);
}
err = ccs_release_object(distrib1);
assert( err == CCS_SUCCESS );
err = ccs_release_object(distrib2);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
assert( err == CCS_SUCCESS );
}
int main(int argc, char *argv[]) {
ccs_init();
test_create_uniform_distribution();
......@@ -413,5 +467,6 @@ int main(int argc, char *argv[]) {
test_uniform_distribution_float_log();
test_uniform_distribution_float_quantize();
test_uniform_distribution_float_log_quantize();
test_uniform_distribution_strided_samples();
return 0;
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment