Commit c2fb328a authored by Brice Videau's avatar Brice Videau
Browse files

Further decoupling of hyperparameters and distributions.

parent 897caeb5
......@@ -23,7 +23,6 @@ ccs_create_numerical_hyperparameter(const char *name,
ccs_numeric_t upper,
ccs_numeric_t quantization,
ccs_numeric_t default_value,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret);
......@@ -32,7 +31,6 @@ ccs_create_categorical_hyperparameter(const char *name,
size_t num_possible_values,
ccs_datum_t *possible_values,
size_t default_value_index,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret);
......@@ -41,7 +39,6 @@ ccs_create_ordinal_hyperparameter(const char *name,
size_t num_possible_values,
ccs_datum_t *possible_values,
size_t default_value_index,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret);
......@@ -73,21 +70,16 @@ extern ccs_error_t
ccs_hyperparameter_get_default_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t *distribution);
extern ccs_error_t
ccs_hyperparameter_get_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t *distribution);
extern ccs_error_t
ccs_hyperparameter_set_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution);
// Sampling Interface
extern ccs_error_t
ccs_hyperparameter_sample(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution,
ccs_rng_t rng,
ccs_datum_t *value);
extern ccs_error_t
ccs_hyperparameter_samples(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values);
......
......@@ -62,63 +62,30 @@ ccs_hyperparameter_get_default_distribution(ccs_hyperparameter_t hyperparameter
return ops->get_default_distribution( hyperparameter->data, distribution);
}
ccs_error_t
ccs_hyperparameter_get_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t *distribution) {
if (!hyperparameter || !hyperparameter->data)
return -CCS_INVALID_OBJECT;
if (!distribution)
return -CCS_INVALID_VALUE;
*distribution = ((_ccs_hyperparameter_common_data_t *)(hyperparameter->data))->distribution;
return CCS_SUCCESS;
}
ccs_error_t
ccs_hyperparameter_set_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution) {
if (!hyperparameter || !hyperparameter->data)
return -CCS_INVALID_OBJECT;
if (!distribution)
return -CCS_INVALID_OBJECT;
_ccs_hyperparameter_common_data_t * d = ((_ccs_hyperparameter_common_data_t *)(hyperparameter->data));
ccs_error_t err;
ccs_bool_t oversampling;
err = ccs_distribution_check_oversampling(distribution, &(d->interval),
&oversampling);
if (err)
return err;
err = ccs_release_object(d->distribution);
if (err)
return err;
err = ccs_retain_object(distribution);
d->distribution = distribution;
d->oversampling = oversampling;
return CCS_SUCCESS;
}
ccs_error_t
ccs_hyperparameter_sample(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution,
ccs_rng_t rng,
ccs_datum_t *value) {
if (!hyperparameter || !hyperparameter->data)
if (!hyperparameter || distribution || !hyperparameter->data)
return -CCS_INVALID_OBJECT;
if (!value)
return -CCS_INVALID_VALUE;
_ccs_hyperparameter_ops_t *ops = ccs_hyperparameter_get_ops(hyperparameter);
return ops->samples(hyperparameter->data, rng, 1, value);
return ops->samples(hyperparameter->data, distribution, rng, 1, value);
}
ccs_error_t
ccs_hyperparameter_samples(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
if (!hyperparameter || !hyperparameter->data)
if (!hyperparameter || !distribution || !hyperparameter->data)
return -CCS_INVALID_OBJECT;
if (!num_values || !values)
return -CCS_INVALID_VALUE;
_ccs_hyperparameter_ops_t *ops = ccs_hyperparameter_get_ops(hyperparameter);
return ops->samples(hyperparameter->data, rng, num_values, values);
return ops->samples(hyperparameter->data, distribution, rng, num_values, values);
}
......@@ -11,13 +11,13 @@ typedef struct _ccs_hyperparameter_categorical_data_s _ccs_hyperparameter_catego
static ccs_error_t
_ccs_hyperparameter_categorical_del(ccs_object_t o) {
ccs_hyperparameter_t d = (ccs_hyperparameter_t)o;
_ccs_hyperparameter_categorical_data_t *data = (_ccs_hyperparameter_categorical_data_t *)(d->data);
return ccs_release_object(data->common_data.distribution);
(void)o;
return CCS_SUCCESS;
}
static ccs_error_t
_ccs_hyperparameter_categorical_samples(_ccs_hyperparameter_data_t *data,
ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
......@@ -25,11 +25,16 @@ _ccs_hyperparameter_categorical_samples(_ccs_hyperparameter_data_t *data,
(_ccs_hyperparameter_categorical_data_t *)data;
ccs_error_t err;
ccs_numeric_t *vs = (ccs_numeric_t *)values + num_values;
err = ccs_distribution_samples(d->common_data.distribution,
rng, num_values, vs);
ccs_bool_t oversampling;
err = ccs_distribution_check_oversampling(distribution,
&(d->common_data.interval),
&oversampling);
if (err)
return err;
err = ccs_distribution_samples(distribution, rng, num_values, vs);
if (err)
return err;
if (!d->common_data.oversampling) {
if (!oversampling) {
for(size_t i = 0; i < num_values; i++)
values[i] = d->possible_values[vs[i].i];
} else {
......@@ -44,8 +49,8 @@ _ccs_hyperparameter_categorical_samples(_ccs_hyperparameter_data_t *data,
vs = (ccs_numeric_t *)malloc(sizeof(ccs_numeric_t)*buff_sz);
if (!vs)
return -CCS_ENOMEM;
err = ccs_distribution_samples(d->common_data.distribution,
rng, buff_sz, vs);
err = ccs_distribution_samples(distribution, rng,
buff_sz, vs);
for(size_t i = 0; i < buff_sz && found < num_values; i++)
if (vs[i].i >= 0 && vs[i].i < d->num_possible_values)
values[found++] = d->possible_values[vs[i].i];
......@@ -81,7 +86,6 @@ ccs_create_categorical_hyperparameter(const char *name,
size_t num_possible_values,
ccs_datum_t *possible_values,
size_t default_value_index,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret) {
if (!hyperparameter_ret || !name)
......@@ -107,27 +111,6 @@ ccs_create_categorical_hyperparameter(const char *name,
interval.lower_included = CCS_TRUE;
interval.upper_included = CCS_FALSE;
ccs_error_t err;
ccs_bool_t oversampling;
if (!distribution) {
err = ccs_create_uniform_distribution(interval.type,
interval.lower, interval.upper,
CCS_LINEAR, CCSI(0),
&distribution);
if (err) {
free((void *)mem);
return err;
}
oversampling = CCS_FALSE;
} else {
err = ccs_distribution_check_oversampling(distribution, &interval,
&oversampling);
if (err) {
free((void *)mem);
return err;
}
ccs_retain_object(distribution);
}
ccs_hyperparameter_t hyperparam = (ccs_hyperparameter_t)mem;
_ccs_object_init(&(hyperparam->obj), CCS_HYPERPARAMETER, (_ccs_object_ops_t *)&_ccs_hyperparameter_categorical_ops);
_ccs_hyperparameter_categorical_data_t *hyperparam_data =
......@@ -140,10 +123,8 @@ ccs_create_categorical_hyperparameter(const char *name,
sizeof(ccs_datum_t) * num_possible_values);
strcpy((char *)hyperparam_data->common_data.name, name);
hyperparam_data->common_data.user_data = user_data;
hyperparam_data->common_data.distribution = distribution;
hyperparam_data->common_data.default_value = possible_values[default_value_index];
hyperparam_data->common_data.interval = interval;
hyperparam_data->common_data.oversampling = oversampling;
hyperparam_data->num_possible_values = num_possible_values;
hyperparam_data->possible_values = (ccs_datum_t *)(mem +
sizeof(struct _ccs_hyperparameter_s) +
......
......@@ -8,7 +8,8 @@ struct _ccs_hyperparameter_ops_s {
_ccs_object_ops_t obj_ops;
ccs_error_t (*samples)(
_ccs_hyperparameter_data_t *hyperparameter,
_ccs_hyperparameter_data_t *data,
ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values);
......@@ -28,10 +29,8 @@ struct _ccs_hyperparameter_common_data_s {
ccs_hyperparameter_type_t type;
const char *name;
void *user_data;
ccs_distribution_t distribution;
ccs_datum_t default_value;
ccs_interval_t interval;
ccs_bool_t oversampling;
};
typedef struct _ccs_hyperparameter_common_data_s _ccs_hyperparameter_common_data_t;
......
......@@ -10,26 +10,34 @@ typedef struct _ccs_hyperparameter_numerical_data_s _ccs_hyperparameter_numerica
static ccs_error_t
_ccs_hyperparameter_numerical_del(ccs_object_t o) {
ccs_hyperparameter_t d = (ccs_hyperparameter_t)o;
_ccs_hyperparameter_numerical_data_t *data = (_ccs_hyperparameter_numerical_data_t *)(d->data);
return ccs_release_object(data->common_data.distribution);
(void)o;
return CCS_SUCCESS;
}
static ccs_error_t
_ccs_hyperparameter_numerical_samples(_ccs_hyperparameter_data_t *data,
ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
_ccs_hyperparameter_numerical_data_t *d = (_ccs_hyperparameter_numerical_data_t *)data;
_ccs_hyperparameter_numerical_data_t *d =
(_ccs_hyperparameter_numerical_data_t *)data;
ccs_numeric_type_t type = d->common_data.interval.type;
ccs_interval_t *interval = &(d->common_data.interval);
ccs_error_t err;
ccs_numeric_t *vs = (ccs_numeric_t *)values + num_values;
err = ccs_distribution_samples(d->common_data.distribution,
ccs_bool_t oversampling;
err = ccs_distribution_check_oversampling(distribution,
interval,
&oversampling);
if (err)
return err;
err = ccs_distribution_samples(distribution,
rng, num_values, vs);
if (err)
return err;
if (!d->common_data.oversampling) {
if (!oversampling) {
if (type == CCS_NUM_FLOAT) {
for(size_t i = 0; i < num_values; i++)
values[i].value.f = vs[i].f;
......@@ -55,8 +63,8 @@ _ccs_hyperparameter_numerical_samples(_ccs_hyperparameter_data_t *data,
vs = (ccs_numeric_t *)malloc(sizeof(ccs_numeric_t)*buff_sz);
if (!vs)
return -CCS_ENOMEM;
err = ccs_distribution_samples(d->common_data.distribution,
rng, buff_sz, vs);
err = ccs_distribution_samples(distribution, rng,
buff_sz, vs);
if (type == CCS_NUM_FLOAT) {
for(size_t i = 0; i < buff_sz && found < num_values; i++)
if (ccs_interval_include(interval, vs[i]))
......@@ -102,7 +110,6 @@ ccs_create_numerical_hyperparameter(const char *name,
ccs_numeric_t upper,
ccs_numeric_t quantization,
ccs_numeric_t default_value,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret) {
if (!hyperparameter_ret || !name)
......@@ -134,27 +141,6 @@ ccs_create_numerical_hyperparameter(const char *name,
interval.lower_included = CCS_TRUE;
interval.upper_included = CCS_FALSE;
ccs_error_t err;
ccs_bool_t oversampling;
if (!distribution) {
err = ccs_create_uniform_distribution(data_type, lower, upper,
CCS_LINEAR, quantization,
&distribution);
if (err) {
free((void *)mem);
return err;
}
oversampling = CCS_FALSE;
} else {
err = ccs_distribution_check_oversampling(distribution, &interval,
&oversampling);
if (err) {
free((void *)mem);
return err;
}
ccs_retain_object(distribution);
}
ccs_hyperparameter_t hyperparam = (ccs_hyperparameter_t)mem;
_ccs_object_init(&(hyperparam->obj), CCS_HYPERPARAMETER, (_ccs_object_ops_t *)&_ccs_hyperparameter_numerical_ops);
_ccs_hyperparameter_numerical_data_t *hyperparam_data = (_ccs_hyperparameter_numerical_data_t *)(mem + sizeof(struct _ccs_hyperparameter_s));
......@@ -162,7 +148,6 @@ ccs_create_numerical_hyperparameter(const char *name,
hyperparam_data->common_data.name = (char *)(mem + sizeof(struct _ccs_hyperparameter_s) + sizeof(_ccs_hyperparameter_numerical_data_t));
strcpy((char *)hyperparam_data->common_data.name, name);
hyperparam_data->common_data.user_data = user_data;
hyperparam_data->common_data.distribution = distribution;
if (data_type == CCS_NUM_FLOAT) {
hyperparam_data->common_data.default_value.type = CCS_FLOAT;
hyperparam_data->common_data.default_value.value.f = default_value.f;
......@@ -171,7 +156,6 @@ ccs_create_numerical_hyperparameter(const char *name,
hyperparam_data->common_data.default_value.value.i = default_value.i;
}
hyperparam_data->common_data.interval = interval;
hyperparam_data->common_data.oversampling = oversampling;
hyperparam_data->quantization = quantization;
hyperparam->data = (_ccs_hyperparameter_data_t *)hyperparam_data;
*hyperparameter_ret = hyperparam;
......
......@@ -99,11 +99,12 @@ _ccs_hyperparameter_ordinal_del(ccs_object_t o) {
for (ccs_int_t i = 0; i < data->num_possible_values; i++) {
HASH_DELETE(hh, data->hash, data->possible_values + i);
}
return ccs_release_object(data->common_data.distribution);
return CCS_SUCCESS;
}
static ccs_error_t
_ccs_hyperparameter_ordinal_samples(_ccs_hyperparameter_data_t *data,
ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
......@@ -111,11 +112,17 @@ _ccs_hyperparameter_ordinal_samples(_ccs_hyperparameter_data_t *data,
(_ccs_hyperparameter_ordinal_data_t *)data;
ccs_error_t err;
ccs_int_t *vs = (ccs_int_t *)values + num_values;
err = ccs_distribution_samples(d->common_data.distribution,
rng, num_values, (ccs_numeric_t *)vs);
ccs_bool_t oversampling;
err = ccs_distribution_check_oversampling(distribution,
&(d->common_data.interval),
&oversampling);
if (err)
return err;
err = ccs_distribution_samples(distribution, rng,
num_values, (ccs_numeric_t *)vs);
if (err)
return err;
if (!d->common_data.oversampling) {
if (!oversampling) {
for(size_t i = 0; i < num_values; i++)
values[i] = d->possible_values[vs[i]].d;
} else {
......@@ -130,8 +137,8 @@ _ccs_hyperparameter_ordinal_samples(_ccs_hyperparameter_data_t *data,
vs = (ccs_int_t *)malloc(sizeof(ccs_numeric_t)*buff_sz);
if (!vs)
return -CCS_ENOMEM;
err = ccs_distribution_samples(d->common_data.distribution,
rng, buff_sz, (ccs_numeric_t *)vs);
err = ccs_distribution_samples(distribution, rng,
buff_sz, (ccs_numeric_t *)vs);
for(size_t i = 0; i < buff_sz && found < num_values; i++)
if (vs[i] >= 0 && vs[i] < d->num_possible_values)
values[found++] = d->possible_values[vs[i]].d;
......@@ -194,7 +201,6 @@ ccs_ordinal_hyperparameter_compare_values(ccs_hyperparameter_t hyperparameter,
while ( v > pvs) { \
HASH_DELETE(hh, hyperparam_data->hash, --v); \
} \
ccs_release_object(distribution); \
free((void*)mem); \
return -CCS_ENOMEM; \
}
......@@ -204,7 +210,6 @@ ccs_create_ordinal_hyperparameter(const char *name,
size_t num_possible_values,
ccs_datum_t *possible_values,
size_t default_value_index,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret) {
if (!hyperparameter_ret || !name)
......@@ -239,27 +244,6 @@ ccs_create_ordinal_hyperparameter(const char *name,
interval.lower_included = CCS_TRUE;
interval.upper_included = CCS_FALSE;
ccs_error_t err;
ccs_bool_t oversampling;
if (!distribution) {
err = ccs_create_uniform_distribution(interval.type,
interval.lower, interval.upper,
CCS_LINEAR, CCSI(0),
&distribution);
if (err) {
free((void *)mem);
return err;
}
oversampling = CCS_FALSE;
} else {
err = ccs_distribution_check_oversampling(distribution, &interval,
&oversampling);
if (err) {
free((void *)mem);
return err;
}
ccs_retain_object(distribution);
}
ccs_hyperparameter_t hyperparam = (ccs_hyperparameter_t)mem;
_ccs_object_init(&(hyperparam->obj), CCS_HYPERPARAMETER, (_ccs_object_ops_t *)&_ccs_hyperparameter_ordinal_ops);
_ccs_hyperparameter_ordinal_data_t *hyperparam_data =
......@@ -272,10 +256,8 @@ ccs_create_ordinal_hyperparameter(const char *name,
sizeof(_ccs_hash_datum_t) * num_possible_values);
strcpy((char *)hyperparam_data->common_data.name, name);
hyperparam_data->common_data.user_data = user_data;
hyperparam_data->common_data.distribution = distribution;
hyperparam_data->common_data.default_value = possible_values[default_value_index];
hyperparam_data->common_data.interval = interval;
hyperparam_data->common_data.oversampling = oversampling;
hyperparam_data->num_possible_values = num_possible_values;
_ccs_hash_datum_t *pvs = (_ccs_hash_datum_t *)(mem +
sizeof(struct _ccs_hyperparameter_s) +
......@@ -292,7 +274,6 @@ ccs_create_ordinal_hyperparameter(const char *name,
HASH_ITER(hh, hyperparam_data->hash, p, tmp) {
HASH_DELETE(hh, hyperparam_data->hash, p);
}
ccs_release_object(distribution);
free((void *)mem);
return -CCS_INVALID_VALUE;
}
......
......@@ -24,8 +24,7 @@ void test_create() {
err = ccs_create_categorical_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
NULL, (void *)0xdeadbeef,
&hyperparameter);
(void *)0xdeadbeef, &hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_get_type(hyperparameter, &type);
......@@ -45,7 +44,7 @@ void test_create() {
assert( err == CCS_SUCCESS );
assert( user_data == (void *)0xdeadbeef );
err = ccs_hyperparameter_get_distribution(hyperparameter, &distribution);
err = ccs_hyperparameter_get_default_distribution(hyperparameter, &distribution);
assert( err == CCS_SUCCESS );
assert( distribution );
......@@ -61,6 +60,8 @@ void test_create() {
assert( interval.upper.i == 4 );
assert( interval.upper_included == CCS_FALSE );
err = ccs_release_object(distribution);
assert( err == CCS_SUCCESS );
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
}
......@@ -68,6 +69,7 @@ void test_create() {
void test_samples() {
ccs_rng_t rng;
ccs_hyperparameter_t hyperparameter;
ccs_distribution_t distribution;
const size_t num_samples = 10000;
ccs_datum_t samples[num_samples];
ccs_error_t err;
......@@ -84,11 +86,15 @@ void test_samples() {
assert( err == CCS_SUCCESS );
err = ccs_create_categorical_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
NULL, NULL,
&hyperparameter);
NULL, &hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_samples(hyperparameter, rng, num_samples, samples);
err = ccs_hyperparameter_get_default_distribution(hyperparameter, &distribution);
assert( err == CCS_SUCCESS );
assert( distribution );
err = ccs_hyperparameter_samples(hyperparameter, distribution, rng,
num_samples, samples);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < num_samples; i++) {
......@@ -98,6 +104,8 @@ void test_samples() {
assert( samples[i].value.i <= (ccs_int_t)num_possible_values * 2);
}
err = ccs_release_object(distribution);
assert( err == CCS_SUCCESS );
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
......@@ -128,14 +136,11 @@ void test_oversampling() {
err = ccs_create_categorical_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
distribution, NULL,
&hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(distribution);
NULL, &hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_samples(hyperparameter, rng, num_samples, samples);
err = ccs_hyperparameter_samples(hyperparameter, distribution, rng,
num_samples, samples);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < num_samples; i++) {
......@@ -145,6 +150,8 @@ void test_oversampling() {
assert( samples[i].value.i <= (ccs_int_t)num_possible_values * 2);
}
err = ccs_release_object(distribution);
assert( err == CCS_SUCCESS );
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
......
......@@ -17,7 +17,7 @@ void test_create() {
err = ccs_create_numerical_hyperparameter("my_param", CCS_NUM_FLOAT,
CCSF(-5.0), CCSF(5.0),
CCSF(0.0), CCSF(1.0),
NULL, (void *)0xdeadbeef,
(void *)0xdeadbeef,
&hyperparameter);
assert( err == CCS_SUCCESS );
......@@ -38,7 +38,7 @@ void test_create() {
assert( err == CCS_SUCCESS );
assert( user_data == (void *)0xdeadbeef );
err = ccs_hyperparameter_get_distribution(hyperparameter, &distribution);
err = ccs_hyperparameter_get_default_distribution(hyperparameter, &distribution);
assert( err == CCS_SUCCESS );
assert( distribution );
......@@ -54,6 +54,8 @@ void test_create() {
assert( interval.upper.f == 5.0 );
assert( interval.upper_included == CCS_FALSE );
err = ccs_release_object(distribution);
assert( err == CCS_SUCCESS );
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
}
......@@ -61,6 +63,7 @@ void test_create() {
void test_samples() {
ccs_rng_t rng;
ccs_hyperparameter_t hyperparameter;
ccs_distribution_t distribution;
const size_t num_samples = 10000;
ccs_datum_t samples[num_samples];
ccs_error_t err;
......@@ -70,10 +73,15 @@ void test_samples() {
err = ccs_create_numerical_hyperparameter("my_param", CCS_NUM_FLOAT,
CCSF(-5.0), CCSF(5.0),
CCSF(0.0), CCSF(1.0),
NULL, NULL, &hyperparameter);
NULL, &hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_samples(hyperparameter, rng, num_samples, samples);
err = ccs_hyperparameter_get_default_distribution(hyperparameter, &distribution);
assert( err == CCS_SUCCESS );
assert( distribution );
err = ccs_hyperparameter_samples(hyperparameter, distribution, rng,
num_samples, samples);
assert( err == CCS_SUCCESS );
for( size_t i = 0; i < num_samples; i++) {
......@@ -81,6 +89,8 @@ void test_samples() {
assert( samples[i].value.f >= -5.0 && samples[i].value.f < 5.0 );
}
err = ccs_release_object(distribution);
assert( err == CCS_SUCCESS );
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
......@@ -105,13 +115,11 @@ void test_oversampling() {
err = ccs_create_numerical_hyperparameter("my_param", CCS_NUM_FLOAT,
CCSF(-1.0), CCSF(1.0),
CCSF(0.0), CCSF(0.0),
distribution, NULL, &hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(distribution);
NULL, &hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_samples(hyperparameter, rng, num_samples, samples);
err = ccs_hyperparameter_samples(hyperparameter, distribution, rng,
num_samples, samples);
assert( err == CCS_SUCCESS );
for( size_t i = 0; i < num_samples; i++) {
......@@ -119,6 +127,8 @@ void test_oversampling() {
assert( samples[i].value.f >= -1.0 && samples[i].value.f < 1.0 );