Commit c8550874 authored by Brice Videau's avatar Brice Videau
Browse files

Added categorical hyperparameter.

parent cdaec3f1
......@@ -29,11 +29,11 @@ ccs_create_numerical_hyperparameter(const char *name,
extern ccs_error_t
ccs_create_categorical_hyperparameter(const char *name,
void *user_data,
size_t num_possible_values,
ccs_datum_t *possible_values,
ccs_datum_t default_value,
ccs_datum_t *weights,
size_t default_value_index,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret);
extern ccs_error_t
......@@ -63,6 +63,10 @@ ccs_hyperparameter_get_user_data(ccs_hyperparameter_t hyperparameter,
extern ccs_error_t
ccs_hyperparameter_get_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t *distribution);
extern ccs_error_t
ccs_hyperparameter_set_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution);
// Sampling Interface
extern ccs_error_t
ccs_hyperparameter_sample(ccs_hyperparameter_t hyperparameter,
......
......@@ -17,6 +17,7 @@ libcconfigspace_la_SOURCES = \
distribution_roulette.c \
hyperparameter.c \
hyperparameter_internal.h \
hyperparameter_numerical.c
hyperparameter_numerical.c \
hyperparameter_categorical.c
@VALGRIND_CHECK_RULES@
......@@ -114,7 +114,7 @@ ccs_distribution_samples(ccs_distribution_t distribution,
return ops->samples(distribution->data, rng, num_values, values);
}
extern ccs_error_t
ccs_error_t
ccs_create_normal_float_distribution(ccs_float_t mu,
ccs_float_t sigma,
ccs_scale_type_t scale,
......@@ -123,3 +123,36 @@ ccs_create_normal_float_distribution(ccs_float_t mu,
return ccs_create_normal_distribution(CCS_NUM_FLOAT, mu, sigma, scale,
CCSF(quantization), distribution_ret);
}
ccs_error_t
ccs_create_normal_int_distribution(ccs_float_t mu,
ccs_float_t sigma,
ccs_scale_type_t scale,
ccs_int_t quantization,
ccs_distribution_t *distribution_ret) {
return ccs_create_normal_distribution(CCS_NUM_INTEGER, mu, sigma, scale,
CCSI(quantization), distribution_ret);
}
ccs_error_t
ccs_create_uniform_float_distribution(ccs_float_t lower,
ccs_float_t upper,
ccs_scale_type_t scale,
ccs_float_t quantization,
ccs_distribution_t *distribution_ret) {
return ccs_create_uniform_distribution(CCS_NUM_FLOAT, CCSF(lower), CCSF(upper),
scale, CCSF(quantization),
distribution_ret);
}
ccs_error_t
ccs_create_uniform_int_distribution(ccs_int_t lower,
ccs_int_t upper,
ccs_scale_type_t scale,
ccs_int_t quantization,
ccs_distribution_t *distribution_ret) {
return ccs_create_uniform_distribution(CCS_NUM_INTEGER, CCSI(lower), CCSI(upper),
scale, CCSI(quantization),
distribution_ret);
}
......@@ -127,7 +127,7 @@ ccs_create_roulette_distribution(size_t num_areas,
return CCS_SUCCESS;
}
extern ccs_error_t
ccs_error_t
ccs_roulette_distribution_get_num_areas(ccs_distribution_t distribution,
size_t *num_areas_ret) {
if (!distribution || distribution->obj.type != CCS_DISTRIBUTION)
......@@ -141,7 +141,7 @@ ccs_roulette_distribution_get_num_areas(ccs_distribution_t distribution,
return CCS_SUCCESS;
}
extern ccs_error_t
ccs_error_t
ccs_roulette_distribution_get_areas(ccs_distribution_t distribution,
size_t num_areas,
ccs_float_t *areas) {
......
......@@ -56,11 +56,35 @@ ccs_hyperparameter_get_distribution(ccs_hyperparameter_t hyperparameter,
if (!hyperparameter || !hyperparameter->data)
return -CCS_INVALID_OBJECT;
if (!distribution)
return -CCS_INVALID_OBJECT;
return -CCS_INVALID_VALUE;
*distribution = ((_ccs_hyperparameter_common_data_t *)(hyperparameter->data))->distribution;
return CCS_SUCCESS;
}
ccs_error_t
ccs_hyperparameter_set_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution) {
if (!hyperparameter || !hyperparameter->data)
return -CCS_INVALID_OBJECT;
if (!distribution)
return -CCS_INVALID_OBJECT;
_ccs_hyperparameter_common_data_t * d = ((_ccs_hyperparameter_common_data_t *)(hyperparameter->data));
ccs_error_t err;
ccs_bool_t oversampling;
err = ccs_distribution_check_oversampling(distribution, &(d->interval),
&oversampling);
if (err)
return err;
err = ccs_release_object(d->distribution);
if (err)
return err;
err = ccs_retain_object(distribution);
d->distribution = distribution;
d->oversampling = oversampling;
return CCS_SUCCESS;
}
ccs_error_t
ccs_hyperparameter_sample(ccs_hyperparameter_t hyperparameter,
ccs_rng_t rng,
......
#include "cconfigspace_internal.h"
#include "hyperparameter_internal.h"
#include <string.h>
struct _ccs_hyperparameter_categorical_data_s {
_ccs_hyperparameter_common_data_t common_data;
ccs_int_t num_possible_values;
ccs_datum_t *possible_values;
};
typedef struct _ccs_hyperparameter_categorical_data_s _ccs_hyperparameter_categorical_data_t;
static ccs_error_t
_ccs_hyperparameter_categorical_del(ccs_object_t o) {
ccs_hyperparameter_t d = (ccs_hyperparameter_t)o;
_ccs_hyperparameter_categorical_data_t *data = (_ccs_hyperparameter_categorical_data_t *)(d->data);
return ccs_release_object(data->common_data.distribution);
}
static ccs_error_t
_ccs_hyperparameter_categorical_samples(_ccs_hyperparameter_data_t *data,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
_ccs_hyperparameter_categorical_data_t *d =
(_ccs_hyperparameter_categorical_data_t *)data;
ccs_error_t err;
ccs_numeric_t *vs = (ccs_numeric_t *)values + num_values;
err = ccs_distribution_samples(d->common_data.distribution,
rng, num_values, vs);
if (err)
return err;
if (!d->common_data.oversampling) {
for(size_t i = 0; i < num_values; i++)
values[i] = d->possible_values[vs[i].i];
} else {
size_t found = 0;
for(size_t i = 0; i < num_values; i++)
if (vs[i].i >= 0 && vs[i].i < d->num_possible_values)
values[found++] = d->possible_values[vs[i].i];
vs = NULL;
size_t coeff = 2;
while (found < num_values) {
size_t buff_sz = (num_values - found)*coeff;
vs = (ccs_numeric_t *)malloc(sizeof(ccs_numeric_t)*buff_sz);
if (!vs)
return -CCS_ENOMEM;
err = ccs_distribution_samples(d->common_data.distribution,
rng, buff_sz, vs);
for(size_t i = 0; i < buff_sz && found < num_values; i++)
if (vs[i].i >= 0 && vs[i].i < d->num_possible_values)
values[found++] = d->possible_values[vs[i].i];
coeff <<= 1;
free(vs);
if (coeff > 32)
return -CCS_SAMPLING_UNSUCCESSFUL;
}
}
return CCS_SUCCESS;
}
static _ccs_hyperparameter_ops_t _ccs_hyperparameter_categorical_ops = {
{ &_ccs_hyperparameter_categorical_del },
&_ccs_hyperparameter_categorical_samples
};
ccs_error_t
ccs_create_categorical_hyperparameter(const char *name,
size_t num_possible_values,
ccs_datum_t *possible_values,
size_t default_value_index,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret) {
if (!hyperparameter_ret || !name)
return -CCS_INVALID_VALUE;
if (!num_possible_values ||
num_possible_values > INT64_MAX ||
num_possible_values <= default_value_index)
return -CCS_INVALID_VALUE;
if (!possible_values)
return -CCS_INVALID_VALUE;
uintptr_t mem = (uintptr_t)calloc(1,
sizeof(struct _ccs_hyperparameter_s) +
sizeof(_ccs_hyperparameter_categorical_data_t) +
sizeof(ccs_datum_t) * num_possible_values +
strlen(name) + 1);
if (!mem)
return -CCS_ENOMEM;
ccs_interval_t interval;
interval.type = CCS_NUM_INTEGER;
interval.lower.i = 0;
interval.upper.i = (ccs_int_t)num_possible_values;
interval.lower_included = CCS_TRUE;
interval.upper_included = CCS_FALSE;
ccs_error_t err;
ccs_bool_t oversampling;
if (!distribution) {
err = ccs_create_uniform_distribution(interval.type,
interval.lower, interval.upper,
CCS_LINEAR, CCSI(0),
&distribution);
if (err) {
free((void *)mem);
return err;
}
oversampling = CCS_FALSE;
} else {
err = ccs_distribution_check_oversampling(distribution, &interval,
&oversampling);
if (err) {
free((void *)mem);
return err;
}
ccs_retain_object(distribution);
}
ccs_hyperparameter_t hyperparam = (ccs_hyperparameter_t)mem;
_ccs_object_init(&(hyperparam->obj), CCS_HYPERPARAMETER, (_ccs_object_ops_t *)&_ccs_hyperparameter_categorical_ops);
_ccs_hyperparameter_categorical_data_t *hyperparam_data =
(_ccs_hyperparameter_categorical_data_t *)(mem +
sizeof(struct _ccs_hyperparameter_s));
hyperparam_data->common_data.type = CCS_CATEGORICAL;
hyperparam_data->common_data.name = (char *)(mem +
sizeof(struct _ccs_hyperparameter_s) +
sizeof(_ccs_hyperparameter_categorical_data_t) +
sizeof(ccs_datum_t) * num_possible_values);
strcpy((char *)hyperparam_data->common_data.name, name);
hyperparam_data->common_data.user_data = user_data;
hyperparam_data->common_data.distribution = distribution;
hyperparam_data->common_data.default_value = possible_values[default_value_index];
hyperparam_data->common_data.interval = interval;
hyperparam_data->common_data.oversampling = oversampling;
hyperparam_data->num_possible_values = num_possible_values;
hyperparam_data->possible_values = (ccs_datum_t *)(mem +
sizeof(struct _ccs_hyperparameter_s) +
sizeof(_ccs_hyperparameter_categorical_data_t));
memcpy(hyperparam_data->possible_values, possible_values,
sizeof(ccs_datum_t) * num_possible_values);
hyperparam->data = (_ccs_hyperparameter_data_t *)hyperparam_data;
*hyperparameter_ret = hyperparam;
return CCS_SUCCESS;
}
......@@ -26,6 +26,7 @@ struct _ccs_hyperparameter_common_data_s {
void *user_data;
ccs_distribution_t distribution;
ccs_datum_t default_value;
ccs_interval_t interval;
ccs_bool_t oversampling;
};
......
......@@ -4,7 +4,6 @@
struct _ccs_hyperparameter_numerical_data_s {
_ccs_hyperparameter_common_data_t common_data;
ccs_interval_t interval;
ccs_numeric_t quantization;
};
typedef struct _ccs_hyperparameter_numerical_data_s _ccs_hyperparameter_numerical_data_t;
......@@ -16,43 +15,37 @@ _ccs_hyperparameter_numerical_del(ccs_object_t o) {
return ccs_release_object(data->common_data.distribution);
}
static inline
ccs_bool_t _check_value(_ccs_hyperparameter_numerical_data_t *d,
ccs_numeric_t value) {
return ccs_interval_include(&(d->interval), value);
}
static ccs_error_t
_ccs_hyperparameter_numerical_samples(_ccs_hyperparameter_data_t *data,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
_ccs_hyperparameter_numerical_data_t *d = (_ccs_hyperparameter_numerical_data_t *)data;
ccs_numeric_type_t type = d->common_data.interval.type;
ccs_interval_t *interval = &(d->common_data.interval);
ccs_error_t err;
ccs_numeric_t *vs = (ccs_numeric_t *)values;
ccs_numeric_t *vs = (ccs_numeric_t *)values + num_values;
err = ccs_distribution_samples(d->common_data.distribution,
rng, num_values, vs);
if (err)
return err;
if (!d->common_data.oversampling) {
if (d->interval.type == CCS_NUM_FLOAT) {
if (type == CCS_NUM_FLOAT) {
for(size_t i = 0; i < num_values; i++)
values[num_values - 1 - i].value.f =
vs[num_values - 1 - i].f;
values[i].value.f = vs[i].f;
} else {
for(size_t i = 0; i < num_values; i++)
values[num_values - 1 - i].value.i =
vs[num_values - 1 - i].i;
values[i].value.i = vs[i].i;
}
} else {
size_t found = 0;
if (d->interval.type == CCS_NUM_FLOAT) {
if (type == CCS_NUM_FLOAT) {
for(size_t i = 0; i < num_values; i++)
if (_check_value(d, vs[i]))
if (ccs_interval_include(interval, vs[i]))
values[found++].value.f = vs[i].f;
} else {
for(size_t i = 0; i < num_values; i++)
if (_check_value(d, vs[i]))
if (ccs_interval_include(interval, vs[i]))
values[found++].value.i = vs[i].i;
}
vs = NULL;
......@@ -64,13 +57,13 @@ _ccs_hyperparameter_numerical_samples(_ccs_hyperparameter_data_t *data,
return -CCS_ENOMEM;
err = ccs_distribution_samples(d->common_data.distribution,
rng, buff_sz, vs);
if (d->interval.type == CCS_NUM_FLOAT) {
if (type == CCS_NUM_FLOAT) {
for(size_t i = 0; i < buff_sz && found < num_values; i++)
if (_check_value(d, vs[i]))
if (ccs_interval_include(interval, vs[i]))
values[found++].value.f = vs[i].f;
} else {
for(size_t i = 0; i < buff_sz && found < num_values; i++)
if (_check_value(d, vs[i]))
if (ccs_interval_include(interval, vs[i]))
values[found++].value.i = vs[i].i;
}
coeff <<= 1;
......@@ -80,7 +73,7 @@ _ccs_hyperparameter_numerical_samples(_ccs_hyperparameter_data_t *data,
}
}
for (size_t i = 0; i < num_values; i++)
values[i].type = (ccs_data_type_t)(d->interval.type);
values[i].type = (ccs_data_type_t)type;
return CCS_SUCCESS;
}
......@@ -164,8 +157,8 @@ ccs_create_numerical_hyperparameter(const char *name,
hyperparam_data->common_data.default_value.type = CCS_INTEGER;
hyperparam_data->common_data.default_value.value.i = default_value.i;
}
hyperparam_data->common_data.interval = interval;
hyperparam_data->common_data.oversampling = oversampling;
hyperparam_data->interval = interval;
hyperparam_data->quantization = quantization;
hyperparam->data = (_ccs_hyperparameter_data_t *)hyperparam_data;
*hyperparameter_ret = hyperparam;
......
......@@ -9,7 +9,8 @@ RNG_TESTS = \
test_uniform_distribution \
test_normal_distribution \
test_roulette_distribution \
test_numerical_hyperparameter
test_numerical_hyperparameter \
test_categorical_hyperparameter
# unit tests
UNIT_TESTS = \
......
#include <stdlib.h>
#include <assert.h>
#include <cconfigspace.h>
#include <string.h>
void test_create() {
ccs_hyperparameter_t hyperparameter;
ccs_hyperparameter_type_t type;
ccs_datum_t default_value;
ccs_error_t err;
const char *name;
void * user_data;
ccs_distribution_t distribution;
ccs_distribution_type_t dist_type;
ccs_interval_t interval;
const size_t num_possible_values = 4;
ccs_datum_t possible_values[num_possible_values];
const size_t default_value_index = 2;
for(size_t i = 0; i < num_possible_values; i++) {
possible_values[i].type = CCS_INTEGER;
possible_values[i].value.i = (i+1)*2;
}
err = ccs_create_categorical_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
NULL, (void *)0xdeadbeef,
&hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_get_type(hyperparameter, &type);
assert( err == CCS_SUCCESS );
assert( type == CCS_CATEGORICAL );
err = ccs_hyperparameter_get_default_value(hyperparameter, &default_value);
assert( err == CCS_SUCCESS );
assert( default_value.type == CCS_INTEGER );
assert( default_value.value.i == possible_values[default_value_index].value.i );
err = ccs_hyperparameter_get_name(hyperparameter, &name);
assert( err == CCS_SUCCESS );
assert( strcmp(name, "my_param") == 0 );
err = ccs_hyperparameter_get_user_data(hyperparameter, &user_data);
assert( err == CCS_SUCCESS );
assert( user_data == (void *)0xdeadbeef );
err = ccs_hyperparameter_get_distribution(hyperparameter, &distribution);
assert( err == CCS_SUCCESS );
assert( distribution );
err = ccs_distribution_get_type(distribution, &dist_type);
assert( err == CCS_SUCCESS );
assert( dist_type == CCS_UNIFORM );
err = ccs_distribution_get_bounds(distribution, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_INTEGER );
assert( interval.lower.i == 0 );
assert( interval.lower_included == CCS_TRUE );
assert( interval.upper.i == 4 );
assert( interval.upper_included == CCS_FALSE );
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
}
void test_samples() {
ccs_rng_t rng;
ccs_hyperparameter_t hyperparameter;
const size_t num_samples = 10000;
ccs_datum_t samples[num_samples];
ccs_error_t err;
const size_t num_possible_values = 4;
ccs_datum_t possible_values[num_possible_values];
const size_t default_value_index = 2;
for(size_t i = 0; i < num_possible_values; i++) {
possible_values[i].type = CCS_INTEGER;
possible_values[i].value.i = (i+1)*2;
}
err = ccs_rng_create(&rng);
assert( err == CCS_SUCCESS );
err = ccs_create_categorical_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
NULL, NULL,
&hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_samples(hyperparameter, rng, num_samples, samples);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < num_samples; i++) {
assert( samples[i].type == CCS_INTEGER );
assert( samples[i].value.i %2 == 0 );
assert( samples[i].value.i >= 2 );
assert( samples[i].value.i <= (ccs_int_t)num_possible_values * 2);
}
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
assert( err == CCS_SUCCESS );
}
void test_oversampling() {
ccs_rng_t rng;
ccs_hyperparameter_t hyperparameter;
ccs_distribution_t distribution;
const size_t num_samples = 10000;
ccs_datum_t samples[num_samples];
ccs_error_t err;
const size_t num_possible_values = 4;
ccs_datum_t possible_values[num_possible_values];
const size_t default_value_index = 2;
for(size_t i = 0; i < num_possible_values; i++) {
possible_values[i].type = CCS_INTEGER;
possible_values[i].value.i = (i+1)*2;
}
err = ccs_rng_create(&rng);
assert( err == CCS_SUCCESS );
err = ccs_create_uniform_int_distribution(0, num_possible_values+1,
CCS_LINEAR, 0, &distribution);
assert( err == CCS_SUCCESS );
err = ccs_create_categorical_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
distribution, NULL,
&hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(distribution);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_samples(hyperparameter, rng, num_samples, samples);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < num_samples; i++) {
assert( samples[i].type == CCS_INTEGER );
assert( samples[i].value.i %2 == 0 );
assert( samples[i].value.i >= 2 );
assert( samples[i].value.i <= (ccs_int_t)num_possible_values * 2);
}
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
assert( err == CCS_SUCCESS );
}
int main(int argc, char *argv[]) {
ccs_init();
test_create();
test_samples();
test_oversampling();
return 0;
}
......@@ -17,7 +17,8 @@ void test_create() {
err = ccs_create_numerical_hyperparameter("my_param", CCS_NUM_FLOAT,
CCSF(-5.0), CCSF(5.0),
CCSF(0.0), CCSF(1.0),
NULL, NULL, &hyperparameter);
NULL, (void *)0xdeadbeef,
&hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_get_type(hyperparameter, &type);
......@@ -35,7 +36,7 @@ void test_create() {
err = ccs_hyperparameter_get_user_data(hyperparameter, &user_data);
assert( err == CCS_SUCCESS );
assert( user_data == NULL );
assert( user_data == (void *)0xdeadbeef );
err = ccs_hyperparameter_get_distribution(hyperparameter, &distribution);
assert( err == CCS_SUCCESS );
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment