Commit d11426eb authored by Brice Videau's avatar Brice Videau
Browse files

First test of conditions.

parent 2e952b38
......@@ -12,6 +12,7 @@ extern ccs_error_t
ccs_create_configuration(ccs_configuration_space_t configuration_space,
size_t num_values,
ccs_datum_t *values,
ccs_bool_t *actives,
void *user_data,
ccs_configuration_t *configuration_ret);
......@@ -40,6 +41,24 @@ ccs_configuration_get_values(ccs_configuration_t configuration,
ccs_datum_t *values,
size_t *num_values_ret);
extern ccs_error_t
ccs_configuration_get_active(ccs_configuration_t configuration,
size_t index,
ccs_bool_t *active_ret);
extern ccs_error_t
ccs_configuration_set_active(ccs_configuration_t configuration,
size_t index,
ccs_bool_t active);
extern ccs_error_t
ccs_configuration_get_actives(ccs_configuration_t configuration,
size_t num_actives,
ccs_bool_t *actives,
size_t *num_actives_ret);
extern ccs_error_t
ccs_configuration_get_value_by_name(ccs_configuration_t configuration,
const char *name,
......
......@@ -21,6 +21,7 @@ ccs_error_t
ccs_create_configuration(ccs_configuration_space_t configuration_space,
size_t num_values,
ccs_datum_t *values,
ccs_bool_t *actives,
void *user_data,
ccs_configuration_t *configuration_ret) {
if (!configuration_ret)
......@@ -36,7 +37,7 @@ ccs_create_configuration(ccs_configuration_space_t configuration_space,
return err;
if (values && num != num_values)
return -CCS_INVALID_VALUE;
uintptr_t mem = (uintptr_t)calloc(1, sizeof(struct _ccs_configuration_s) + sizeof(struct _ccs_configuration_data_s) + num * sizeof(ccs_datum_t));
uintptr_t mem = (uintptr_t)calloc(1, sizeof(struct _ccs_configuration_s) + sizeof(struct _ccs_configuration_data_s) + num * sizeof(ccs_datum_t) + num * sizeof(ccs_bool_t));
if (!mem)
return CCS_ENOMEM;
err = ccs_retain_object(configuration_space);
......@@ -51,8 +52,13 @@ ccs_create_configuration(ccs_configuration_space_t configuration_space,
config->data->num_values = num;
config->data->configuration_space = configuration_space;
config->data->values = (ccs_datum_t *)(mem + sizeof(struct _ccs_configuration_s) + sizeof(struct _ccs_configuration_data_s));
config->data->actives = (ccs_bool_t *)(config->data->values+num);
if (values)
memcpy(config->data->values, values, num*sizeof(ccs_datum_t));
if (actives)
memcpy(config->data->actives, actives, num*sizeof(ccs_bool_t));
else for (size_t i = 0; i < num; i++)
config->data->actives[i] = CCS_TRUE;
*configuration_ret = config;
return CCS_SUCCESS;
}
......
......@@ -19,6 +19,7 @@ struct _ccs_configuration_data_s {
ccs_configuration_space_t configuration_space;
size_t num_values;
ccs_datum_t *values;
ccs_bool_t *actives;
};
#endif //_CONFIGURATION_INTERNAL_H
......@@ -100,7 +100,7 @@ ccs_create_configuration_space(const char *name,
utarray_new(config_space->data->sorted_indexes, &_size_t_icd);
config_space->data->name_hash = NULL;
config_space->data->distribution_list = NULL;
config_space->data->graph_ok = CCS_FALSE;
config_space->data->graph_ok = CCS_TRUE;
strcpy((char *)(config_space->data->name), name);
*configuration_space_ret = config_space;
return CCS_SUCCESS;
......@@ -442,7 +442,7 @@ ccs_configuration_space_get_default_configuration(ccs_configuration_space_t con
return -CCS_INVALID_VALUE;
ccs_error_t err;
ccs_configuration_t config;
err = ccs_create_configuration(configuration_space, 0, NULL, NULL, &config);
err = ccs_create_configuration(configuration_space, 0, NULL, NULL, NULL, &config);
if (err)
return err;
UT_array *array = configuration_space->data->hyperparameters;
......@@ -508,6 +508,46 @@ ccs_configuration_space_check_configuration_values(ccs_configuration_space_t co
}
static ccs_error_t
_check_actives(ccs_configuration_space_t configuration_space,
ccs_configuration_t configuration) {
size_t *p_index = NULL;
UT_array *index_array = configuration_space->data->sorted_indexes;
UT_array *array = configuration_space->data->hyperparameters;
ccs_bool_t *actives = configuration->data->actives;
ccs_datum_t *values = configuration->data->values;
while ( (p_index = (size_t *)utarray_next(index_array, p_index)) ) {
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_eltptr(array, *p_index);
if (!wrapper->condition)
continue;
UT_array *parents = wrapper->parents;
size_t *p_parent = NULL;
while ( (p_parent = (size_t*)utarray_next(parents, p_parent)) ) {
if (!actives[*p_parent]) {
actives[*p_index] = CCS_FALSE;
values[*p_index] = ccs_none;
break;
}
}
if (!actives[*p_index])
continue;
ccs_datum_t result;
ccs_error_t err;
err = ccs_expression_eval(wrapper->condition, configuration_space,
values, &result);
if (err)
return err;
if (!(result.type == CCS_BOOLEAN && result.value.i == CCS_TRUE)) {
actives[*p_index] = CCS_FALSE;
values[*p_index] = ccs_none;
}
}
return CCS_SUCCESS;
}
static ccs_error_t
_generate_constraints(ccs_configuration_space_t configuration_space);
// This is temporary until I figure out how correlated sampling should work
ccs_error_t
ccs_configuration_space_sample(ccs_configuration_space_t configuration_space,
......@@ -518,7 +558,12 @@ ccs_configuration_space_sample(ccs_configuration_space_t configuration_space,
return -CCS_INVALID_VALUE;
ccs_error_t err;
ccs_configuration_t config;
err = ccs_create_configuration(configuration_space, 0, NULL, NULL, &config);
if (!configuration_space->data->graph_ok) {
err = _generate_constraints(configuration_space);
if (err)
return err;
}
err = ccs_create_configuration(configuration_space, 0, NULL, NULL, NULL, &config);
if (err)
return err;
ccs_rng_t rng = configuration_space->data->rng;
......@@ -534,6 +579,11 @@ ccs_configuration_space_sample(ccs_configuration_space_t configuration_space,
return err;
}
}
err = _check_actives(configuration_space, config);
if (err) {
ccs_release_object(config);
return err;
}
*configuration_ret = config;
return CCS_SUCCESS;
}
......@@ -567,19 +617,27 @@ ccs_configuration_space_samples(ccs_configuration_space_t configuration_space,
}
size_t i;
for(i = 0; i < num_configurations; i++) {
err = ccs_create_configuration(configuration_space, 0, NULL, NULL, configurations + i);
err = ccs_create_configuration(configuration_space, 0, NULL, NULL, NULL, configurations + i);
if (unlikely(err)) {
free(values);
for(size_t j = 0; j < i; j++) {
for(size_t j = 0; j < i; j++)
ccs_release_object(configurations + j);
}
return err;
}
}
for(i = 0; i < num_configurations; i++)
for(i = 0; i < num_configurations; i++) {
for(size_t j = 0; j < num_hyper; j++)
configurations[i]->data->values[j] =
values[j*num_configurations + i];
err = _check_actives(configuration_space, configurations[i]);
if (err) {
free(values);
for(size_t j = 0; j < num_configurations; j++)
ccs_release_object(configurations + j);
return err;
}
}
free(values);
return CCS_SUCCESS;
}
......@@ -628,8 +686,8 @@ _topological_sort(ccs_configuration_space_t configuration_space) {
UT_array *array = configuration_space->data->hyperparameters;
size_t count = utarray_len(array);
struct _hyper_list_s *list = (struct _hyper_list_s *)malloc(
sizeof(struct _hyper_list_s *) * count);
struct _hyper_list_s *list = (struct _hyper_list_s *)calloc(1,
sizeof(struct _hyper_list_s) * count);
if (!list)
return -CCS_ENOMEM;
struct _hyper_list_s *queue = NULL;
......@@ -673,7 +731,6 @@ _topological_sort(ccs_configuration_space_t configuration_space) {
}
static ccs_error_t
_recompute_graph(ccs_configuration_space_t configuration_space) {
configuration_space->data->graph_ok = CCS_FALSE;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
UT_array *array = configuration_space->data->hyperparameters;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
......@@ -716,6 +773,7 @@ _recompute_graph(ccs_configuration_space_t configuration_space) {
parent_wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_eltptr(array, parents_index[i]);
utarray_push_back(parent_wrapper->children, &(wrapper->index));
}
free((void *)mem);
}
wrapper = NULL;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
......@@ -729,6 +787,19 @@ _recompute_graph(ccs_configuration_space_t configuration_space) {
exit(-1); \
}
static ccs_error_t
_generate_constraints(ccs_configuration_space_t configuration_space) {
ccs_error_t err;
err = _recompute_graph(configuration_space);
if (err)
return err;
err = _topological_sort(configuration_space);
if (err)
return err;
configuration_space->data->graph_ok = CCS_TRUE;
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_set_condition(ccs_configuration_space_t configuration_space,
size_t hyperparameter_index,
......@@ -748,20 +819,13 @@ ccs_configuration_space_set_condition(ccs_configuration_space_t configuration_sp
return err;
wrapper->condition = expression;
// Recompute the whole graph for now
err = _recompute_graph(configuration_space);
if (err) {
ccs_release_object(expression);
wrapper->condition = NULL;
return err;
}
err = _topological_sort(configuration_space);
configuration_space->data->graph_ok = CCS_FALSE;
err = _generate_constraints(configuration_space);
if (err) {
configuration_space->data->graph_ok = CCS_FALSE;
ccs_release_object(expression);
wrapper->condition = NULL;
return err;
}
configuration_space->data->graph_ok = CCS_TRUE;
return CCS_SUCCESS;
}
......
......@@ -13,7 +13,8 @@ RNG_TESTS = \
test_categorical_hyperparameter \
test_ordinal_hyperparameter \
test_configuration_space \
test_expression
test_expression \
test_condition
# unit tests
UNIT_TESTS = \
......
#include <stdlib.h>
#include <assert.h>
#include <cconfigspace.h>
#include <string.h>
#include <math.h>
static inline ccs_datum_t
ccs_bool(ccs_bool_t v) {
ccs_datum_t d;
d.type = CCS_BOOLEAN;
d.value.i = v;
return d;
}
static inline ccs_datum_t
ccs_float(ccs_float_t v) {
ccs_datum_t d;
d.type = CCS_FLOAT;
d.value.f = v;
return d;
}
static inline ccs_datum_t
ccs_int(ccs_int_t v) {
ccs_datum_t d;
d.type = CCS_INTEGER;
d.value.i = v;
return d;
}
static inline ccs_datum_t
ccs_object(ccs_object_t v) {
ccs_datum_t d;
d.type = CCS_OBJECT;
d.value.o = v;
return d;
}
static inline ccs_datum_t
ccs_string(const char *v) {
ccs_datum_t d;
d.type = CCS_STRING;
d.value.s = v;
return d;
}
ccs_hyperparameter_t create_numerical(const char * name) {
ccs_hyperparameter_t hyperparameter;
ccs_error_t err;
err = ccs_create_numerical_hyperparameter(name, CCS_NUM_FLOAT,
CCSF(-1.0), CCSF(1.0),
CCSF(0.0), CCSF(0),
NULL, &hyperparameter);
assert( err == CCS_SUCCESS );
return hyperparameter;
}
void
test_simple() {
ccs_hyperparameter_t hyperparameter1, hyperparameter2;
ccs_configuration_space_t space;
ccs_expression_t expression;
ccs_configuration_t configuration;
ccs_datum_t values[2];
ccs_configuration_t configurations[100];
ccs_error_t err;
hyperparameter1 = create_numerical("param1");
hyperparameter2 = create_numerical("param2");
err = ccs_create_configuration_space("space", NULL, &space);
assert( err == CCS_SUCCESS );
err = ccs_configuration_space_add_hyperparameter(space, hyperparameter1, NULL);
assert( err == CCS_SUCCESS );
err = ccs_configuration_space_add_hyperparameter(space, hyperparameter2, NULL);
assert( err == CCS_SUCCESS );
err = ccs_create_binary_expression(CCS_LESS, ccs_object(hyperparameter1),
ccs_float(0.0), &expression);
assert( err == CCS_SUCCESS );
err = ccs_configuration_space_set_condition(space, 1, expression);
assert( err == CCS_SUCCESS );
for (int i = 0; i < 100; i ++) {
ccs_float_t f;
err = ccs_configuration_space_sample(space, &configuration);
assert( err == CCS_SUCCESS );
err = ccs_configuration_get_values(configuration, 2, values, NULL);
assert( err == CCS_SUCCESS );
assert( values[0].type == CCS_FLOAT );
f = values[0].value.f;
assert( f >= -1.0 && f < 1.0 );
if ( f < 0.0 )
assert( values[1].type == CCS_FLOAT );
else
assert( values[1].type == CCS_NONE );
err = ccs_release_object(configuration);
assert( err == CCS_SUCCESS );
}
err = ccs_configuration_space_samples(space, 100, configurations);
assert( err == CCS_SUCCESS );
for (int i = 0; i < 100; i ++) {
ccs_float_t f;
err = ccs_configuration_get_values(configurations[i], 2, values, NULL);
assert( err == CCS_SUCCESS );
assert( values[0].type == CCS_FLOAT );
f = values[0].value.f;
assert( f >= -1.0 && f < 1.0 );
if ( f < 0.0 )
assert( values[1].type == CCS_FLOAT );
else
assert( values[1].type == CCS_NONE );
err = ccs_release_object(configurations[i]);
assert( err == CCS_SUCCESS );
}
err = ccs_release_object(expression);
err = ccs_release_object(hyperparameter1);
err = ccs_release_object(hyperparameter2);
err = ccs_release_object(space);
}
int main() {
ccs_init();
test_simple();
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment