Commit ebcc01e0 authored by Brice Videau's avatar Brice Videau
Browse files

Added get_hyperparameters to expressions.

parent 51a09c9d
......@@ -76,7 +76,7 @@ ccs_configuration_space_get_hyperparameters(ccs_configuration_space_t configura
size_t *num_hyperparameters_ret);
extern ccs_error_t
ccs_configuration_space_add_condition(ccs_configuration_space_t configuration_space,
ccs_configuration_space_set_condition(ccs_configuration_space_t configuration_space,
size_t hyperparameter_index,
ccs_expression_t expression);
......@@ -102,13 +102,13 @@ ccs_configuration_space_add_forbidden_clauses(ccs_configuration_space_t configu
extern ccs_error_t
ccs_configuration_space_get_forbidden_clause(ccs_configuration_space_t configuration_space,
size_t hyperparameter_index,
size_t index,
ccs_expression_t *expression_ret);
extern ccs_error_t
ccs_configuration_space_get_forbidden_clauses(ccs_configuration_space_t configuration_space,
size_t num_expressions,
ccs_expression_t *expressionss,
ccs_expression_t *expressions,
size_t *num_expressions_ret);
// Configuration related functions
......@@ -137,6 +137,7 @@ ccs_configuration_space_samples(ccs_configuration_space_t configuration_space,
// Hyperparameter related functions
extern ccs_error_t
ccs_configuration_space_get_active_hyperparameters(ccs_configuration_space_t configuration_space,
ccs_configuration_t configuration,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,
size_t *num_hyperparameters_ret);
......
......@@ -89,6 +89,11 @@ extern ccs_error_t
ccs_expression_get_type(ccs_expression_t expression,
ccs_expression_type_t *type_ret);
extern ccs_error_t
ccs_expression_get_hyperparameters(ccs_expression_t expression,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,
size_t *num_hyperparameters_ret);
#ifdef __cplusplus
}
#endif
......
......@@ -14,10 +14,18 @@ _ccs_configuration_space_del(ccs_object_t object) {
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) ) {
ccs_release_object(wrapper->hyperparameter);
if (wrapper->condition)
ccs_release_object(wrapper->condition);
}
array = configuration_space->data->forbidden_clauses;
ccs_expression_t *expr = NULL;
while ( (expr = (ccs_expression_t *)utarray_next(array, expr)) ) {
ccs_release_object(*expr);
}
HASH_CLEAR(hh_name, configuration_space->data->name_hash);
HASH_CLEAR(hh_handle, configuration_space->data->handle_hash);
utarray_free(array);
utarray_free(configuration_space->data->hyperparameters);
utarray_free(configuration_space->data->forbidden_clauses);
_ccs_distribution_wrapper_t *dw;
_ccs_distribution_wrapper_t *tmp;
DL_FOREACH_SAFE(configuration_space->data->distribution_list, dw, tmp) {
......@@ -39,11 +47,18 @@ static const UT_icd _hyperparameter_wrapper_icd = {
NULL,
};
static const UT_icd _forbidden_clauses_icd = {
sizeof(ccs_expression_t),
NULL,
NULL,
NULL,
};
#undef utarray_oom
#define utarray_oom() { \
ccs_release_object(config_space->data->rng); \
free((void *)mem); \
return -CCS_ENOMEM; \
err = -CCS_ENOMEM; \
goto arrays; \
}
ccs_error_t
ccs_create_configuration_space(const char *name,
......@@ -67,12 +82,22 @@ ccs_create_configuration_space(const char *name,
config_space->data->name = (const char *)(mem + sizeof(struct _ccs_configuration_space_s) + sizeof(struct _ccs_configuration_space_data_s));
config_space->data->user_data = user_data;
config_space->data->rng = rng;
config_space->data->hyperparameters = NULL;
config_space->data->forbidden_clauses = NULL;
utarray_new(config_space->data->hyperparameters, &_hyperparameter_wrapper_icd);
utarray_new(config_space->data->forbidden_clauses, &_forbidden_clauses_icd);
config_space->data->name_hash = NULL;
config_space->data->distribution_list = NULL;
strcpy((char *)(config_space->data->name), name);
*configuration_space_ret = config_space;
return CCS_SUCCESS;
arrays:
if(config_space->data->hyperparameters)
utarray_free(config_space->data->hyperparameters);
if(config_space->data->forbidden_clauses)
utarray_free(config_space->data->forbidden_clauses);
free((void *)mem);
return err;
}
ccs_error_t
......@@ -196,6 +221,7 @@ ccs_configuration_space_add_hyperparameter(ccs_configuration_space_t configurati
hyper_wrapper.distribution_index = 0;
hyper_wrapper.distribution = distrib_wrapper;
hyper_wrapper.name = name;
hyper_wrapper.condition = NULL;
utarray_push_back(hyperparameters, &hyper_wrapper);
p_hyper_wrapper =
......@@ -437,7 +463,7 @@ ccs_configuration_space_check_configuration_values(ccs_configuration_space_t co
}
// This is temporary until I fugure out how correlated sampling should work
// This is temporary until I figure out how correlated sampling should work
ccs_error_t
ccs_configuration_space_sample(ccs_configuration_space_t configuration_space,
ccs_configuration_t *configuration_ret) {
......@@ -513,4 +539,147 @@ ccs_configuration_space_samples(ccs_configuration_space_t configuration_space,
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_set_condition(ccs_configuration_space_t configuration_space,
size_t hyperparameter_index,
ccs_expression_t expression) {
if (!configuration_space || !configuration_space->data || !expression)
return -CCS_INVALID_OBJECT;
_ccs_hyperparameter_wrapper_t *wrapper = (_ccs_hyperparameter_wrapper_t*)
utarray_eltptr(configuration_space->data->hyperparameters,
(unsigned int)hyperparameter_index);
if (!wrapper)
return -CCS_OUT_OF_BOUNDS;
if (wrapper->condition)
return -CCS_INVALID_HYPERPARAMETER;
ccs_error_t err = ccs_retain_object(expression);
if (err)
return err;
wrapper->condition = expression;
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_get_condition(ccs_configuration_space_t configuration_space,
size_t hyperparameter_index,
ccs_expression_t *expression_ret) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (!expression_ret)
return -CCS_INVALID_VALUE;
_ccs_hyperparameter_wrapper_t *wrapper = (_ccs_hyperparameter_wrapper_t*)
utarray_eltptr(configuration_space->data->hyperparameters,
(unsigned int)hyperparameter_index);
if (!wrapper)
return -CCS_OUT_OF_BOUNDS;
*expression_ret = wrapper->condition;
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_get_conditions(ccs_configuration_space_t configuration_space,
size_t num_expressions,
ccs_expression_t *expressions,
size_t *num_expressions_ret) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (num_expressions && !expressions)
return -CCS_INVALID_VALUE;
if (!expressions && !num_expressions_ret)
return -CCS_INVALID_VALUE;
UT_array *array = configuration_space->data->hyperparameters;
size_t size = utarray_len(array);
if (expressions) {
if (num_expressions < size)
return -CCS_INVALID_VALUE;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
size_t index = 0;
while ( (wrapper = (_ccs_hyperparameter_wrapper_t *)utarray_next(array, wrapper)) )
expressions[index++] = wrapper->condition;
for (size_t i = size; i < num_expressions; i++)
expressions[i] = NULL;
}
if (num_expressions_ret)
*num_expressions_ret = size;
return CCS_SUCCESS;
}
#undef utarray_oom
#define utarray_oom() { \
return -CCS_ENOMEM; \
}
ccs_error_t
ccs_configuration_space_add_forbidden_clause(ccs_configuration_space_t configuration_space,
ccs_expression_t expression) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
ccs_error_t err = ccs_retain_object(expression);
if (err)
return err;
utarray_push_back(configuration_space->data->forbidden_clauses, &expression);
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_add_forbidden_clauses(ccs_configuration_space_t configuration_space,
size_t num_expressions,
ccs_expression_t *expressions) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (num_expressions && !expressions)
return -CCS_INVALID_VALUE;
for (size_t i = 0; i < num_expressions; i++) {
ccs_error_t err = ccs_retain_object(expressions[i]);
if (err)
return err;
utarray_push_back(configuration_space->data->forbidden_clauses, expressions + i);
}
return CCS_SUCCESS;
}
#undef utarray_oom
#define utarray_oom() exit(-1)
ccs_error_t
ccs_configuration_space_get_forbidden_clause(ccs_configuration_space_t configuration_space,
size_t index,
ccs_expression_t *expression_ret) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (!expression_ret)
return -CCS_INVALID_VALUE;
ccs_expression_t *p_expr = (ccs_expression_t*)
utarray_eltptr(configuration_space->data->forbidden_clauses,
(unsigned int)index);
if (!p_expr)
return -CCS_OUT_OF_BOUNDS;
*expression_ret = *p_expr;
return CCS_SUCCESS;
}
ccs_error_t
ccs_configuration_space_get_forbidden_clauses(ccs_configuration_space_t configuration_space,
size_t num_expressions,
ccs_expression_t *expressions,
size_t *num_expressions_ret) {
if (!configuration_space || !configuration_space->data)
return -CCS_INVALID_OBJECT;
if (num_expressions && !expressions)
return -CCS_INVALID_VALUE;
if (!expressions && !num_expressions_ret)
return -CCS_INVALID_VALUE;
UT_array *array = configuration_space->data->forbidden_clauses;
size_t size = utarray_len(array);
if (expressions) {
if (num_expressions < size)
return -CCS_INVALID_VALUE;
ccs_expression_t *p_expr = NULL;
size_t index = 0;
while ( (p_expr = (ccs_expression_t *)utarray_next(array, p_expr)) )
expressions[index++] = *p_expr;
for (size_t i = size; i < num_expressions; i++)
expressions[i] = NULL;
}
if (num_expressions_ret)
*num_expressions_ret = size;
return CCS_SUCCESS;
}
......@@ -16,6 +16,7 @@ struct _ccs_hyperparameter_wrapper_s {
UT_hash_handle hh_handle;
size_t distribution_index;
_ccs_distribution_wrapper_t *distribution;
ccs_expression_t condition;
};
typedef struct _ccs_hyperparameter_wrapper_s _ccs_hyperparameter_wrapper_t;
......@@ -48,6 +49,7 @@ struct _ccs_configuration_space_data_s {
_ccs_hyperparameter_wrapper_t *name_hash;
_ccs_hyperparameter_wrapper_t *handle_hash;
_ccs_distribution_wrapper_t *distribution_list;
UT_array *forbidden_clauses;
};
#endif //_CONFIGURATION_SPACE_INTERNAL_H
......@@ -2,6 +2,7 @@
#include "expression_internal.h"
#include <math.h>
#include <string.h>
#include "utarray.h"
const int ccs_expression_precedence[] = {
0,
......@@ -1081,3 +1082,98 @@ ccs_expression_get_type(ccs_expression_t expression,
return CCS_SUCCESS;
}
#undef utarray_oom
#define utarray_oom() { \
return -CCS_ENOMEM; \
}
static ccs_error_t _get_hyperparameters(ccs_expression_t expression,
UT_array *array) {
if (!expression || !expression->data)
return CCS_INVALID_OBJECT;
ccs_error_t err;
for (size_t i = 0; i < expression->data->num_nodes; i++) {
ccs_datum_t *d = expression->data->nodes + i;
if (d->type == CCS_OBJECT) {
ccs_object_type_t t;
err = ccs_object_get_type(d->value.o, &t);
if (err)
return err;
if (t == CCS_HYPERPARAMETER)
utarray_push_back(array, &(d->value.o));
else if (t == CCS_EXPRESSION) {
err = _get_hyperparameters((ccs_expression_t)(d->value.o), array);
if (err)
return err;
}
}
}
return CCS_SUCCESS;
}
static const UT_icd _hyperparameter_icd = {
sizeof(ccs_hyperparameter_t),
NULL,
NULL,
NULL,
};
static int hyper_sort(const void *a, const void *b) {
ccs_hyperparameter_t ha = *(ccs_hyperparameter_t *)a;
ccs_hyperparameter_t hb = *(ccs_hyperparameter_t *)b;
return ha < hb ? -1 : ha > hb ? 1 : 0;
}
ccs_error_t
ccs_expression_get_hyperparameters(ccs_expression_t expression,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,
size_t *num_hyperparameters_ret) {
if (num_hyperparameters && !hyperparameters)
return -CCS_INVALID_VALUE;
if (!hyperparameters && !num_hyperparameters_ret)
return -CCS_INVALID_VALUE;
ccs_error_t err;
UT_array *array;
utarray_new(array, &_hyperparameter_icd);
err = _get_hyperparameters(expression, array);
if (err) {
utarray_free(array);
return err;
}
utarray_sort(array, &hyper_sort);
size_t count = 0;
if (utarray_len(array) > 0) {
ccs_hyperparameter_t previous = NULL;
ccs_hyperparameter_t *p_h = NULL;
while ( (p_h = (ccs_hyperparameter_t *)utarray_next(array, p_h)) ) {
if (*p_h != previous) {
count += 1;
previous = *p_h;
}
}
} else
count = 0;
if (hyperparameters) {
if (count > num_hyperparameters) {
utarray_free(array);
return -CCS_INVALID_VALUE;
}
ccs_hyperparameter_t previous = NULL;
ccs_hyperparameter_t *p_h = NULL;
size_t index = 0;
while ( (p_h = (ccs_hyperparameter_t *)utarray_next(array, p_h)) ) {
if (*p_h != previous) {
hyperparameters[index++] = *p_h;
previous = *p_h;
}
}
for (size_t i = count; i < num_hyperparameters; i++)
hyperparameters[i] = NULL;
}
if (num_hyperparameters_ret)
*num_hyperparameters_ret = count;
utarray_free(array);
return CCS_SUCCESS;
}
......@@ -738,6 +738,66 @@ test_compound() {
assert( err == CCS_SUCCESS );
}
void test_get_hyperparameters() {
ccs_expression_t expression1, expression2;
ccs_hyperparameter_t hyperparameter1, hyperparameter2;
ccs_hyperparameter_t hyperparameters[3];
ccs_error_t err;
size_t count;
hyperparameter1 = create_dummy_categorical("param1");
hyperparameter2 = create_dummy_numerical("param2");
err = ccs_create_binary_expression(CCS_ADD, ccs_float(3.0), ccs_object(hyperparameter2), &expression1);
assert( err == CCS_SUCCESS );
err = ccs_expression_get_hyperparameters(expression1, 0, NULL, &count);
assert( err == CCS_SUCCESS );
assert( count == 1 );
err = ccs_expression_get_hyperparameters(expression1, 3, hyperparameters, &count);
assert( err == CCS_SUCCESS );
assert( count == 1 );
assert( hyperparameters[0] == hyperparameter2 );
assert( hyperparameters[1] == NULL );
assert( hyperparameters[2] == NULL );
err = ccs_create_binary_expression(CCS_EQUAL,
ccs_object(hyperparameter1), ccs_object(expression1), &expression2);
assert( err == CCS_SUCCESS );
err = ccs_expression_get_hyperparameters(expression2, 0, NULL, &count);
assert( err == CCS_SUCCESS );
assert( count == 2 );
err = ccs_expression_get_hyperparameters(expression2, 3, hyperparameters, NULL);
assert( err == CCS_SUCCESS );
assert( hyperparameters[0] != hyperparameters[1] );
assert( hyperparameters[0] == hyperparameter1 || hyperparameters[0] == hyperparameter2 );
assert( hyperparameters[1] == hyperparameter1 || hyperparameters[1] == hyperparameter2 );
assert( hyperparameters[2] == NULL );
err = ccs_release_object(expression2);
assert( err == CCS_SUCCESS );
err = ccs_create_binary_expression(CCS_EQUAL,
ccs_object(hyperparameter2), ccs_object(expression1), &expression2);
assert( err == CCS_SUCCESS );
err = ccs_expression_get_hyperparameters(expression2, 3, hyperparameters, &count);
assert( err == CCS_SUCCESS );
assert( count == 1 );
assert( hyperparameters[0] == hyperparameter2 );
assert( hyperparameters[1] == NULL );
assert( hyperparameters[2] == NULL );
err = ccs_release_object(hyperparameter1);
assert( err == CCS_SUCCESS );
err = ccs_release_object(hyperparameter2);
assert( err == CCS_SUCCESS );
err = ccs_release_object(expression1);
assert( err == CCS_SUCCESS );
err = ccs_release_object(expression2);
assert( err == CCS_SUCCESS );
}
int main(int argc, char *argv[]) {
ccs_init();
test_equal_literal();
......@@ -760,4 +820,5 @@ int main(int argc, char *argv[]) {
test_arithmetic_greater_or_equal();
test_compound();
test_in();
test_get_hyperparameters();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment