Commit ee8e698b authored by Brice Videau's avatar Brice Videau
Browse files

Added topological sort.

parent a3833aba
......@@ -2,10 +2,12 @@ include_HEADERS=cconfigspace.h
include_ccsdir=$(includedir)/ccs
include_ccs_HEADERS = \
ccs/base.h \
ccs/rng.h \
ccs/distribution.h \
ccs/hyperparameter.h \
ccs/expression.h\
ccs/interval.h \
ccs/configuration_space.h \
ccs/hyperparameter.h \
ccs/expression.h \
ccs/configuration.h
......@@ -38,6 +38,8 @@ enum ccs_error_e {
CCS_INVALID_HYPERPARAMETER,
CCS_INVALID_CONFIGURATION,
CCS_INVALID_NAME,
CCS_INVALID_CONDITION,
CCS_INVALID_GRAPH,
CCS_TYPE_NOT_COMPARABLE,
CCS_INVALID_BOUNDS,
CCS_OUT_OF_BOUNDS,
......
......@@ -69,6 +69,13 @@ ccs_configuration_space_get_hyperparameter_index(
ccs_hyperparameter_t hyperparameter,
size_t *index_ret);
extern ccs_error_t
ccs_configuration_space_get_hyperparameter_indexes(
ccs_configuration_space_t configuration_space,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,
size_t *indexes);
extern ccs_error_t
ccs_configuration_space_get_hyperparameters(ccs_configuration_space_t configuration_space,
size_t num_hyperparameters,
......
......@@ -95,6 +95,7 @@ ccs_expression_list_eval_node(ccs_expression_t expression,
size_t index,
ccs_datum_t *result);
//uniq and sorted list of hyperparameters handle
extern ccs_error_t
ccs_expression_get_hyperparameters(ccs_expression_t expression,
size_t num_hyperparameters,
......
......@@ -16,6 +16,8 @@ _ccs_configuration_space_del(ccs_object_t object) {
ccs_release_object(wrapper->hyperparameter);
if (wrapper->condition)
ccs_release_object(wrapper->condition);
utarray_free(wrapper->parents);
utarray_free(wrapper->children);
}
array = configuration_space->data->forbidden_clauses;
ccs_expression_t *expr = NULL;
......@@ -26,6 +28,7 @@ _ccs_configuration_space_del(ccs_object_t object) {
HASH_CLEAR(hh_handle, configuration_space->data->handle_hash);
utarray_free(configuration_space->data->hyperparameters);
utarray_free(configuration_space->data->forbidden_clauses);
utarray_free(configuration_space->data->sorted_indexes);
_ccs_distribution_wrapper_t *dw;
_ccs_distribution_wrapper_t *tmp;
DL_FOREACH_SAFE(configuration_space->data->distribution_list, dw, tmp) {
......@@ -54,6 +57,13 @@ static const UT_icd _forbidden_clauses_icd = {
NULL,
};
static UT_icd _size_t_icd = {
sizeof(size_t),
NULL,
NULL,
NULL
};
#undef utarray_oom
#define utarray_oom() { \
ccs_release_object(config_space->data->rng); \
......@@ -84,18 +94,23 @@ ccs_create_configuration_space(const char *name,
config_space->data->rng = rng;
config_space->data->hyperparameters = NULL;
config_space->data->forbidden_clauses = NULL;
config_space->data->sorted_indexes = NULL;
utarray_new(config_space->data->hyperparameters, &_hyperparameter_wrapper_icd);
utarray_new(config_space->data->forbidden_clauses, &_forbidden_clauses_icd);
utarray_new(config_space->data->sorted_indexes, &_size_t_icd);
config_space->data->name_hash = NULL;
config_space->data->distribution_list = NULL;
config_space->data->graph_ok = CCS_FALSE;
strcpy((char *)(config_space->data->name), name);
*configuration_space_ret = config_space;
return CCS_SUCCESS;
arrays:
if(config_space->data->hyperparameters)
if (config_space->data->hyperparameters)
utarray_free(config_space->data->hyperparameters);
if(config_space->data->forbidden_clauses)
if (config_space->data->forbidden_clauses)
utarray_free(config_space->data->forbidden_clauses);
if (config_space->data->sorted_indexes)
utarray_free(config_space->data->sorted_indexes);
free((void *)mem);
return err;
}
......@@ -222,6 +237,10 @@ ccs_configuration_space_add_hyperparameter(ccs_configuration_space_t configurati
hyper_wrapper.distribution = distrib_wrapper;
hyper_wrapper.name = name;
hyper_wrapper.condition = NULL;
hyper_wrapper.parents = NULL;
hyper_wrapper.children = NULL;
utarray_new(hyper_wrapper.parents, &_size_t_icd);
utarray_new(hyper_wrapper.children, &_size_t_icd);
utarray_push_back(hyperparameters, &hyper_wrapper);
p_hyper_wrapper =
......@@ -236,6 +255,10 @@ ccs_configuration_space_add_hyperparameter(ccs_configuration_space_t configurati
errorutarray:
utarray_pop_back(hyperparameters);
errordistrib_wrapper:
if (hyper_wrapper.parents)
utarray_free(hyper_wrapper.parents);
if (hyper_wrapper.children)
utarray_free(hyper_wrapper.children);
free(distrib_wrapper);
errordistrib:
ccs_release_object(distribution);
......@@ -357,6 +380,27 @@ ccs_configuration_space_get_hyperparameter_index(
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_get_hyperparameter_indexes(
ccs_configuration_space_t configuration_space,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,
size_t *indexes) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (num_hyperparameters && (!hyperparameters || !indexes ))
return -CCS_INVALID_VALUE;
_ccs_hyperparameter_wrapper_t *wrapper;
for(size_t i = 0; i < num_hyperparameters; i++) {
HASH_FIND(hh_handle, configuration_space->data->handle_hash,
hyperparameters + i, sizeof(ccs_hyperparameter_t), wrapper);
if (!wrapper)
return -CCS_INVALID_HYPERPARAMETER;
indexes[i] = wrapper->index;
}
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_get_hyperparameters(ccs_configuration_space_t configuration_space,
size_t num_hyperparameters
......@@ -539,6 +583,155 @@ ccs_configuration_space_samples(ccs_configuration_space_t configuration_space,
return CCS_SUCCESS;
}
static int _size_t_sort(const void *a, const void *b) {
const size_t sa = *(const size_t *)a;
const size_t sb = *(const size_t *)b;
return sa < sb ? -1 : sa > sb ? 1 : 0;
}
static void _uniq_size_t_array(UT_array *array) {
size_t count = utarray_len(array);
if (count == 0)
return;
utarray_sort(array, &_size_t_sort);
size_t real_count = 0;
size_t *p = (size_t *)utarray_front(array);
size_t *p2 = p;
real_count++;
while ( (p = (size_t *)utarray_next(array, p)) ) {
if (*p != *p2) {
p2 = (size_t *)utarray_next(array, p2);
*p2 = *p;
real_count++;
}
}
utarray_resize(array, real_count);
}
struct _hyper_list_s;
struct _hyper_list_s {
size_t in_edges;
size_t index;
struct _hyper_list_s *next;
struct _hyper_list_s *prev;
};
#undef utarray_oom
#define utarray_oom() { \
free((void *)list); \
return -CCS_ENOMEM; \
}
static ccs_error_t
_topological_sort(ccs_configuration_space_t configuration_space) {
UT_array *array = configuration_space->data->hyperparameters;
size_t count = utarray_len(array);
struct _hyper_list_s *list = (struct _hyper_list_s *)malloc(
sizeof(struct _hyper_list_s *) * count);
if (!list)
return -CCS_ENOMEM;
struct _hyper_list_s *to_process = NULL;
struct _hyper_list_s *queue = NULL;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
size_t index = 0;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
size_t in_edges = utarray_len(wrapper->parents);
list[index].in_edges = in_edges;
list[index].index = index;
if (in_edges == 0)
DL_APPEND(queue, list + index);
else
DL_APPEND(list, list + index);
index++;
}
size_t processed = 0;
while(queue) {
struct _hyper_list_s *e = queue;
DL_DELETE(queue, queue);
wrapper = (_ccs_hyperparameter_wrapper_t *)
utarray_eltptr(array, e->index);
size_t *child = NULL;
while ( (child = (size_t *)utarray_next(wrapper->children, child)) ) {
list[*child].in_edges--;
if (list[*child].in_edges == 0) {
DL_DELETE(to_process, list + *child);
DL_APPEND(queue, list + *child);
}
}
utarray_push_back(configuration_space->data->sorted_indexes, &(e->index));
processed++;
};
free(list);
if (processed < count)
return -CCS_INVALID_GRAPH;
return CCS_SUCCESS;
}
#undef utarray_oom
#define utarray_oom() { \
free((void *)mem); \
return -CCS_ENOMEM; \
}
static ccs_error_t
_recompute_graph(ccs_configuration_space_t configuration_space) {
configuration_space->data->graph_ok = CCS_FALSE;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
UT_array *array = configuration_space->data->hyperparameters;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
utarray_clear(wrapper->parents);
utarray_clear(wrapper->children);
}
wrapper = NULL;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
if (!wrapper->condition)
continue;
size_t count;
ccs_error_t err;
err = ccs_expression_get_hyperparameters(wrapper->condition, 0, NULL, &count);
if (err)
return err;
if (count == 0)
continue;
ccs_hyperparameter_t *parents = NULL;
size_t *parents_index = NULL;
_ccs_hyperparameter_wrapper_t *parent_wrapper = NULL;
intptr_t mem = (intptr_t)malloc(count *
(sizeof(ccs_hyperparameter_t) + sizeof(size_t)));
if (!mem)
return -CCS_ENOMEM;
parents = (ccs_hyperparameter_t *)mem;
parents_index = (size_t *)(mem + count * sizeof(ccs_hyperparameter_t));
err = ccs_expression_get_hyperparameters(wrapper->condition, count, parents, NULL);
if (err) {
free((void *)mem);
return err;
}
err = ccs_configuration_space_get_hyperparameter_indexes(
configuration_space, count, parents, parents_index);
if (err) {
free((void *)mem);
return err;
}
for (size_t i = 0; i < count; i++) {
utarray_push_back(wrapper->parents, parents_index + i);
parent_wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_eltptr(array, parents_index[i]);
utarray_push_back(parent_wrapper->children, &(wrapper->index));
}
}
wrapper = NULL;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
_uniq_size_t_array(wrapper->parents);
_uniq_size_t_array(wrapper->children);
}
configuration_space->data->graph_ok = CCS_TRUE;
return CCS_SUCCESS;
}
#undef utarray_oom
#define utarray_oom() { \
exit(-1); \
}
ccs_error_t
ccs_configuration_space_set_condition(ccs_configuration_space_t configuration_space,
size_t hyperparameter_index,
......@@ -553,13 +746,24 @@ ccs_configuration_space_set_condition(ccs_configuration_space_t configuration_sp
if (wrapper->condition)
return -CCS_INVALID_HYPERPARAMETER;
ccs_error_t err;
err = ccs_expression_check_context(expression, configuration_space);
if (err)
return err;
err = ccs_retain_object(expression);
if (err)
return err;
wrapper->condition = expression;
// Recompute the whole graph for now
err = _recompute_graph(configuration_space);
if (err) {
ccs_release_object(expression);
wrapper->condition = NULL;
return err;
}
err = _topological_sort(configuration_space);
if (err) {
configuration_space->data->graph_ok = CCS_FALSE;
ccs_release_object(expression);
wrapper->condition = NULL;
return err;
}
return CCS_SUCCESS;
}
......
......@@ -17,6 +17,8 @@ struct _ccs_hyperparameter_wrapper_s {
size_t distribution_index;
_ccs_distribution_wrapper_t *distribution;
ccs_expression_t condition;
UT_array *parents;
UT_array *children;
};
typedef struct _ccs_hyperparameter_wrapper_s _ccs_hyperparameter_wrapper_t;
......@@ -50,6 +52,8 @@ struct _ccs_configuration_space_data_s {
_ccs_hyperparameter_wrapper_t *handle_hash;
_ccs_distribution_wrapper_t *distribution_list;
UT_array *forbidden_clauses;
ccs_bool_t graph_ok;
UT_array *sorted_indexes;
};
#endif //_CONFIGURATION_SPACE_INTERNAL_H
......@@ -1144,7 +1144,7 @@ static const UT_icd _hyperparameter_icd = {
NULL,
};
static int hyper_sort(const void *a, const void *b) {
static int _hyper_sort(const void *a, const void *b) {
ccs_hyperparameter_t ha = *(ccs_hyperparameter_t *)a;
ccs_hyperparameter_t hb = *(ccs_hyperparameter_t *)b;
return ha < hb ? -1 : ha > hb ? 1 : 0;
......@@ -1167,7 +1167,7 @@ ccs_expression_get_hyperparameters(ccs_expression_t expression,
utarray_free(array);
return err;
}
utarray_sort(array, &hyper_sort);
utarray_sort(array, &_hyper_sort);
size_t count = 0;
if (utarray_len(array) > 0) {
ccs_hyperparameter_t previous = NULL;
......@@ -1210,7 +1210,7 @@ ccs_expression_check_context(ccs_expression_t expression,
UT_array *array;
utarray_new(array, &_hyperparameter_icd);
err = _get_hyperparameters(expression, array);
utarray_sort(array, &hyper_sort);
utarray_sort(array, &_hyper_sort);
if (err) {
utarray_free(array);
return err;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment