Commit f41c5245 authored by Brice Videau's avatar Brice Videau
Browse files

Added sampling interface to configuration_space.

parent fb59273d
......@@ -132,13 +132,13 @@ ccs_configuration_space_get_default_configuration(ccs_configuration_space_t con
ccs_configuration_t *configuration_ret);
extern ccs_error_t
ccs_configuration_space_sample_configuration(ccs_configuration_space_t configuration_space,
ccs_configuration_t *configuration);
ccs_configuration_space_sample(ccs_configuration_space_t configuration_space,
ccs_configuration_t *configuration_ret);
extern ccs_error_t
ccs_configuration_space_sample_configurations(ccs_configuration_space_t configuration_space,
size_t num_configurations,
ccs_configuration_t *configurations);
ccs_configuration_space_samples(ccs_configuration_space_t configuration_space,
size_t num_configurations,
ccs_configuration_t *configurations);
// Hyperparameter related functions
extern ccs_error_t
......
......@@ -123,8 +123,10 @@ ccs_configuration_get_values(ccs_configuration_t configuration,
if (num_values < num)
return -CCS_INVALID_VALUE;
memcpy(values, configuration->data->values, num*sizeof(ccs_datum_t));
if (num < num_values)
memset(values + num, 0, (num_values - num)*sizeof(ccs_datum_t));
for (size_t i = num; i < num_values; i++) {
values[i].type = CCS_NONE;
values[i].value.i = 0;
}
}
if (num_values_ret)
*num_values_ret = num;
......
......@@ -188,9 +188,9 @@ ccs_configuration_space_add_hyperparameter(ccs_configuration_space_t configurati
distrib_wrapper->distribution = distribution;
distrib_wrapper->dimension = dimension;
distrib_wrapper->hyperparameter_indexes = (size_t *)(pmem + sizeof(_ccs_distribution_wrapper_t));
distrib_wrapper->hyperparameter_indexes[0] = 0;
hyperparameters = configuration_space->data->hyperparameters;
index = utarray_len(hyperparameters);
distrib_wrapper->hyperparameter_indexes[0] = index;
hyper_wrapper.index = index;
hyper_wrapper.distribution_index = 0;
hyper_wrapper.distribution = distrib_wrapper;
......@@ -352,22 +352,39 @@ ccs_configuration_space_get_default_configuration(ccs_configuration_space_t con
if (err)
return err;
UT_array *array = configuration_space->data->hyperparameters;
size_t index = 0;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
ccs_datum_t d;
ccs_datum_t *values = config->data->values;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
err = ccs_hyperparameter_get_default_value(wrapper->hyperparameter, &d);
if (unlikely(err))
goto error;
err = ccs_configuration_set_value(config, index++, d);
if (unlikely(err))
goto error;
err = ccs_hyperparameter_get_default_value(wrapper->hyperparameter,
values++);
if (unlikely(err)) {
ccs_release_object(config);
return err;
}
}
*configuration_ret = config;
return CCS_SUCCESS;
error:
ccs_release_object(config);
return err;
}
static inline ccs_error_t
_check_configuration(UT_array *array,
size_t num_values,
ccs_datum_t *values) {
if (num_values != utarray_len(array))
return -CCS_INVALID_CONFIGURATION;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
size_t index = 0;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
ccs_bool_t res;
ccs_error_t err;
err = ccs_hyperparameter_check_value(wrapper->hyperparameter,
values[index++], &res);
if (unlikely(err))
return err;
if (res == CCS_FALSE)
return -CCS_INVALID_CONFIGURATION;
}
return CCS_SUCCESS;
}
ccs_error_t
......@@ -379,22 +396,97 @@ ccs_configuration_space_check_configuration(ccs_configuration_space_t configurat
return -CCS_INVALID_OBJECT;
if (configuration->data->configuration_space != configuration_space)
return -CCS_INVALID_CONFIGURATION;
size_t index = 0;
return _check_configuration(configuration_space->data->hyperparameters,
configuration->data->num_values,
configuration->data->values);
}
ccs_error_t
ccs_configuration_space_check_configuration_values(ccs_configuration_space_t configuration_space,
size_t num_values,
ccs_datum_t *values) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (!values)
return -CCS_INVALID_VALUE;
return _check_configuration(configuration_space->data->hyperparameters,
num_values, values);
}
// This is temporary until I fugure out how correlated sampling should work
ccs_error_t
ccs_configuration_space_sample(ccs_configuration_space_t configuration_space,
ccs_configuration_t *configuration_ret) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (!configuration_ret)
return -CCS_INVALID_VALUE;
ccs_error_t err;
ccs_configuration_t config;
err = ccs_create_configuration(configuration_space, 0, NULL, NULL, &config);
if (err)
return err;
ccs_rng_t rng = configuration_space->data->rng;
UT_array *array = configuration_space->data->hyperparameters;
if (configuration->data->num_values != utarray_len(array))
return -CCS_INVALID_CONFIGURATION;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
ccs_datum_t *values = configuration->data->values;
ccs_datum_t *values = config->data->values;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
ccs_bool_t res;
ccs_error_t err;
err = ccs_hyperparameter_check_value(wrapper->hyperparameter,
values[index++], &res);
if (unlikely(err))
err = ccs_hyperparameter_sample(wrapper->hyperparameter,
wrapper->distribution->distribution,
rng, values++);
if (unlikely(err)) {
ccs_release_object(config);
return err;
if (res == CCS_FALSE)
return -CCS_INVALID_CONFIGURATION;
}
}
*configuration_ret = config;
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_samples(ccs_configuration_space_t configuration_space,
size_t num_configurations,
ccs_configuration_t *configurations) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (num_configurations && !configurations)
return -CCS_INVALID_VALUE;
if (!num_configurations)
return CCS_SUCCESS;
ccs_error_t err;
UT_array *array = configuration_space->data->hyperparameters;
size_t num_hyper = utarray_len(array);
ccs_datum_t *values = (ccs_datum_t *)malloc(sizeof(ccs_datum_t)*num_configurations*num_hyper);
ccs_datum_t *p_values = values;
ccs_rng_t rng = configuration_space->data->rng;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
err = ccs_hyperparameter_samples(wrapper->hyperparameter,
wrapper->distribution->distribution,
rng, num_configurations, p_values);
if (unlikely(err)) {
free(values);
return err;
}
p_values += num_configurations;
}
size_t i;
for(i = 0; i < num_configurations; i++) {
err = ccs_create_configuration(configuration_space, 0, NULL, NULL, configurations + i);
if (unlikely(err)) {
free(values);
for(size_t j = 0; j < i; j++) {
ccs_release_object(configurations + j);
}
return err;
}
}
for(i = 0; i < num_configurations; i++)
for(size_t j = 0; j < num_hyper; j++)
configurations[i]->data->values[j] =
values[j*num_configurations + i];
free(values);
return CCS_SUCCESS;
}
......
......@@ -92,7 +92,7 @@ ccs_hyperparameter_sample(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t distribution,
ccs_rng_t rng,
ccs_datum_t *value) {
if (!hyperparameter || distribution || !hyperparameter->data)
if (!hyperparameter || !distribution || !hyperparameter->data)
return -CCS_INVALID_OBJECT;
if (!value)
return -CCS_INVALID_VALUE;
......
......@@ -163,7 +163,7 @@ void test_add_list() {
hyperparameters[2] = create_dummy_hyperparameter("param3");
err = ccs_configuration_space_add_hyperparameters(configuration_space, 3,
hyperparameters, NULL);
hyperparameters, NULL);
assert( err == CCS_SUCCESS );
check_configuration(configuration_space, 3, hyperparameters);
......@@ -176,10 +176,62 @@ void test_add_list() {
assert( err == CCS_SUCCESS );
}
void test_sample() {
ccs_hyperparameter_t hyperparameters[4];
ccs_configuration_t configuration;
ccs_configuration_t configurations[100];
ccs_configuration_space_t configuration_space;
ccs_error_t err;
err = ccs_create_configuration_space("my_config_space", NULL,
&configuration_space);
assert( err == CCS_SUCCESS );
hyperparameters[0] = create_dummy_hyperparameter("param1");
hyperparameters[1] = create_dummy_hyperparameter("param2");
hyperparameters[2] = create_dummy_hyperparameter("param3");
err = ccs_create_numerical_hyperparameter("param4", CCS_NUM_INTEGER,
CCSI(-5), CCSI(5),
CCSI(0), CCSI(0),
NULL, hyperparameters+3);
assert( err == CCS_SUCCESS );
err = ccs_configuration_space_add_hyperparameters(configuration_space, 4,
hyperparameters, NULL);
assert( err == CCS_SUCCESS );
err = ccs_configuration_space_sample(configuration_space, &configuration);
assert( err == CCS_SUCCESS );
err = ccs_configuration_check(configuration);
assert( err == CCS_SUCCESS );
err = ccs_configuration_space_samples(configuration_space,
100, configurations);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < 100; i++) {
err = ccs_configuration_check(configurations[i]);
assert( err == CCS_SUCCESS );
err = ccs_release_object(configurations[i]);
assert( err == CCS_SUCCESS );
}
err = ccs_release_object(configuration);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < 4; i++) {
err = ccs_release_object(hyperparameters[i]);
assert( err == CCS_SUCCESS );
}
err = ccs_release_object(configuration_space);
assert( err == CCS_SUCCESS );
}
int main(int argc, char *argv[]) {
ccs_init();
test_create();
test_add();
test_add_list();
test_sample();
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment