Commit a692f4b3 authored by Brice Videau's avatar Brice Videau
Browse files

Working equal expression.

parent 6ae4708d
......@@ -63,13 +63,13 @@ ccs_create_unary_expression(ccs_expression_type_t expression_type,
ccs_expression_t *expression_ret);
extern ccs_error_t
ccs_create_epression(ccs_expression_type_t expression_type,
size_t num_nodes,
ccs_datum_t *nodes,
ccs_expression_t *expression_ret);
ccs_create_expression(ccs_expression_type_t expression_type,
size_t num_nodes,
ccs_datum_t *nodes,
ccs_expression_t *expression_ret);
extern ccs_error_t
ccs_eval_expression(ccs_expression_t expression,
ccs_expression_eval(ccs_expression_t expression,
ccs_configuration_space_t context,
ccs_datum_t *values,
ccs_datum_t *result);
......
......@@ -58,7 +58,8 @@ static inline ccs_error_t
_ccs_expr_datum_eval(ccs_datum_t *d,
ccs_configuration_space_t context,
ccs_datum_t *values,
ccs_datum_t *result) {
ccs_datum_t *result,
ccs_hyperparameter_type_t *ht) {
ccs_object_type_t t;
size_t index;
ccs_error_t err;
......@@ -76,7 +77,7 @@ _ccs_expr_datum_eval(ccs_datum_t *d,
return err;
switch (t) {
case CCS_EXPRESSION:
return ccs_eval_expression((ccs_expression_t)(d->value.o),
return ccs_expression_eval((ccs_expression_t)(d->value.o),
context, values, result);
break;
case CCS_HYPERPARAMETER:
......@@ -85,6 +86,12 @@ _ccs_expr_datum_eval(ccs_datum_t *d,
if (err)
return err;
*result = values[index];
if (ht) {
err = ccs_hyperparameter_get_type(
(ccs_hyperparameter_t)(d->value.o), ht);
if (err)
return err;
}
break;
default:
return CCS_INVALID_OBJECT;
......@@ -96,15 +103,15 @@ _ccs_expr_datum_eval(ccs_datum_t *d,
return CCS_SUCCESS;
}
#define eval_left_right(data, context, values, left, right) { \
#define eval_left_right(data, context, values, left, right, htl, htr) do { \
ccs_error_t err; \
err = _ccs_expr_datum_eval(data->nodes, context, values, &left); \
err = _ccs_expr_datum_eval(data->nodes, context, values, &left, htl); \
if (err) \
return err; \
err = _ccs_expr_datum_eval(data->nodes + 1, context, values, &right); \
err = _ccs_expr_datum_eval(data->nodes + 1, context, values, &right, htr); \
if (err) \
return err; \
}
} while (0)
static ccs_error_t
_ccs_expr_or_eval(_ccs_expression_data_t *data,
......@@ -113,7 +120,7 @@ _ccs_expr_or_eval(_ccs_expression_data_t *data,
ccs_datum_t *result) {
ccs_datum_t left;
ccs_datum_t right;
eval_left_right(data, context, values, left, right);
eval_left_right(data, context, values, left, right, NULL, NULL);
if (left.type != CCS_BOOLEAN || right.type != CCS_BOOLEAN)
return -CCS_INVALID_VALUE;
result->type = CCS_BOOLEAN;
......@@ -133,7 +140,7 @@ _ccs_expr_and_eval(_ccs_expression_data_t *data,
ccs_datum_t *result) {
ccs_datum_t left;
ccs_datum_t right;
eval_left_right(data, context, values, left, right);
eval_left_right(data, context, values, left, right, NULL, NULL);
if (left.type != CCS_BOOLEAN || right.type != CCS_BOOLEAN)
return -CCS_INVALID_VALUE;
result->type = CCS_BOOLEAN;
......@@ -146,6 +153,94 @@ static _ccs_expression_ops_t _ccs_expr_and_ops = {
&_ccs_expr_and_eval
};
#define check_values(param, v) do { \
ccs_bool_t valid; \
ccs_error_t err; \
err = ccs_hyperparameter_check_value( \
(ccs_hyperparameter_t)(param), v, &valid); \
if (unlikely(err)) \
return err; \
if (!valid) \
return -CCS_INVALID_VALUE; \
} while(0)
#define check_hypers(param, v, t) do { \
if (t == CCS_ORDINAL || t == CCS_CATEGORICAL) { \
check_values(param.value.o, v); \
} else if (t == CCS_NUMERICAL) {\
if (v.type != CCS_INTEGER && v.type != CCS_FLOAT) \
return -CCS_INVALID_VALUE; \
} \
} while(0)
static inline ccs_int_t
_ccs_string_cmp(const char *a, const char *b) {
if (a == b)
return 0;
if (!a)
return -1;
if (!b)
return 1;
return strcmp(a, b);
}
static inline ccs_error_t
_ccs_datum_test_equal_generic(ccs_datum_t *a, ccs_datum_t *b, ccs_bool_t *equal) {
if (a->type == b->type) {
switch(a->type) {
case CCS_STRING:
*equal = _ccs_string_cmp(a->value.s, b->value.s) == 0 ? CCS_TRUE : CCS_FALSE;
break;
case CCS_NONE:
*equal = CCS_TRUE;
break;
default:
*equal = memcmp(&(a->value), &(b->value), sizeof(ccs_value_t)) == 0 ? CCS_TRUE : CCS_FALSE;
}
} else {
if (a->type == CCS_INTEGER && b->type == CCS_FLOAT) {
*equal = (a->value.i == b->value.f) ? CCS_TRUE : CCS_FALSE;
} else if (a->type == CCS_FLOAT && b->type == CCS_INTEGER) {
*equal = (a->value.f == b->value.i) ? CCS_TRUE : CCS_FALSE;
} else {
*equal = CCS_FALSE;
return -CCS_INVALID_VALUE;
}
}
return CCS_SUCCESS;
}
static ccs_error_t
_ccs_expr_equal_eval(_ccs_expression_data_t *data,
ccs_configuration_space_t context,
ccs_datum_t *values,
ccs_datum_t *result) {
ccs_datum_t left;
ccs_datum_t right;
ccs_hyperparameter_type_t htl = CCS_HYPERPARAMETER_TYPE_MAX;
ccs_hyperparameter_type_t htr = CCS_HYPERPARAMETER_TYPE_MAX;
eval_left_right(data, context, values, left, right, &htl, &htr);
check_hypers(data->nodes[0], right, htl);
check_hypers(data->nodes[1], left, htr);
ccs_bool_t equal;
ccs_error_t err = _ccs_datum_test_equal_generic(&left, &right, &equal);
if(htl != CCS_HYPERPARAMETER_TYPE_MAX || htr != CCS_HYPERPARAMETER_TYPE_MAX) {
result->value.i = equal;
} else {
if (err)
return err;
result->value.i = equal;
}
result->type = CCS_BOOLEAN;
return CCS_SUCCESS;
}
static _ccs_expression_ops_t _ccs_expr_equal_ops = {
{ &_ccs_expression_del },
&_ccs_expr_equal_eval
};
static inline _ccs_expression_ops_t *
_ccs_expression_ops_broker(ccs_expression_type_t expression_type) {
switch (expression_type) {
......@@ -155,16 +250,19 @@ _ccs_expression_ops_broker(ccs_expression_type_t expression_type) {
case CCS_AND:
return &_ccs_expr_and_ops;
break;
case CCS_EQUAL:
return &_ccs_expr_equal_ops;
break;
default:
return NULL;
}
}
ccs_error_t
ccs_create_epression(ccs_expression_type_t expression_type,
size_t num_nodes,
ccs_datum_t *nodes,
ccs_expression_t *expression_ret) {
ccs_create_expression(ccs_expression_type_t expression_type,
size_t num_nodes,
ccs_datum_t *nodes,
ccs_expression_t *expression_ret) {
if (num_nodes && !nodes)
return -CCS_INVALID_VALUE;
if (!expression_ret)
......@@ -229,7 +327,7 @@ ccs_create_epression(ccs_expression_type_t expression_type,
}
ccs_error_t
ccs_eval_expression(ccs_expression_t expression,
ccs_expression_eval(ccs_expression_t expression,
ccs_configuration_space_t context,
ccs_datum_t *values,
ccs_datum_t *result) {
......@@ -240,5 +338,3 @@ ccs_eval_expression(ccs_expression_t expression,
_ccs_expression_ops_t *ops = ccs_expression_get_ops(expression);
return ops->eval(expression->data, context, values, result);
}
......@@ -480,7 +480,7 @@ ccs_configuration_space_samples(ccs_configuration_space_t configuration_space,
ccs_error_t err;
UT_array *array = configuration_space->data->hyperparameters;
size_t num_hyper = utarray_len(array);
ccs_datum_t *values = (ccs_datum_t *)malloc(sizeof(ccs_datum_t)*num_configurations*num_hyper);
ccs_datum_t *values = (ccs_datum_t *)calloc(1, sizeof(ccs_datum_t)*num_configurations*num_hyper);
ccs_datum_t *p_values = values;
ccs_rng_t rng = configuration_space->data->rng;
_ccs_hyperparameter_wrapper_t *wrapper = NULL;
......
......@@ -12,7 +12,8 @@ RNG_TESTS = \
test_numerical_hyperparameter \
test_categorical_hyperparameter \
test_ordinal_hyperparameter \
test_configuration_space
test_configuration_space \
test_expression
# unit tests
UNIT_TESTS = \
......
#include <stdlib.h>
#include <assert.h>
#include <cconfigspace.h>
#include <string.h>
static inline ccs_datum_t
ccs_bool(ccs_bool_t v) {
ccs_datum_t d;
d.type = CCS_BOOLEAN;
d.value.i = v;
return d;
}
static inline ccs_datum_t
ccs_float(ccs_float_t v) {
ccs_datum_t d;
d.type = CCS_FLOAT;
d.value.f = v;
return d;
}
static inline ccs_datum_t
ccs_int(ccs_int_t v) {
ccs_datum_t d;
d.type = CCS_INTEGER;
d.value.i = v;
return d;
}
static inline ccs_datum_t
ccs_object(ccs_object_t v) {
ccs_datum_t d;
d.type = CCS_OBJECT;
d.value.o = v;
return d;
}
static inline ccs_datum_t
ccs_string(const char *v) {
ccs_datum_t d;
d.type = CCS_STRING;
d.value.s = v;
return d;
}
double d = -2.0;
ccs_hyperparameter_t create_dummy_numerical(const char * name) {
ccs_hyperparameter_t hyperparameter;
ccs_error_t err;
err = ccs_create_numerical_hyperparameter(name, CCS_NUM_FLOAT,
CCSF(-5.0), CCSF(5.0),
CCSF(0.0), CCSF(d),
NULL, &hyperparameter);
d += 1.0;
if (d >= 5.0)
d = -5.0;
assert( err == CCS_SUCCESS );
return hyperparameter;
}
ccs_hyperparameter_t create_dummy_categorical(const char * name) {
ccs_datum_t possible_values[4];
ccs_hyperparameter_t hyperparameter;
ccs_error_t err;
possible_values[0] = ccs_int(1);
possible_values[1] = ccs_float(2.0);
possible_values[2] = ccs_string("toto");
possible_values[3] = ccs_none;
err = ccs_create_categorical_hyperparameter(name, 4, possible_values, 0,
NULL, &hyperparameter);
assert( err == CCS_SUCCESS );
return hyperparameter;
}
ccs_hyperparameter_t create_dummy_ordinal(const char * name) {
ccs_datum_t possible_values[4];
ccs_hyperparameter_t hyperparameter;
ccs_error_t err;
possible_values[0] = ccs_int(1);
possible_values[1] = ccs_float(2.0);
possible_values[2] = ccs_string("toto");
possible_values[3] = ccs_none;
err = ccs_create_ordinal_hyperparameter(name, 4, possible_values, 0,
NULL, &hyperparameter);
assert( err == CCS_SUCCESS );
return hyperparameter;
}
void test_expression_wrapper(ccs_expression_type_t type,
size_t count,
ccs_datum_t *nodes,
ccs_configuration_space_t context,
ccs_datum_t *inputs,
ccs_datum_t eres,
ccs_error_t eerr) {
ccs_error_t err;
ccs_expression_t expression;
ccs_datum_t result;
err = ccs_create_expression(type, count, nodes, &expression);
assert( err == CCS_SUCCESS );
err = ccs_expression_eval(expression, context, inputs, &result);
assert( err == eerr );
if (eerr != CCS_SUCCESS) {
err = ccs_release_object(expression);
assert( err == CCS_SUCCESS );
return;
}
assert( result.type == eres.type );
switch (result.type) {
case CCS_INTEGER:
case CCS_BOOLEAN:
assert( result.value.i == eres.value.i );
break;
case CCS_FLOAT:
assert( result.value.f == eres.value.f );
break;
default:
assert( 0 );
}
err = ccs_release_object(expression);
assert( err == CCS_SUCCESS );
}
void test_equal_literal() {
ccs_datum_t nodes[2];
nodes[0] = ccs_float(1.0);
nodes[1] = ccs_float(1.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_TRUE), CCS_SUCCESS);
nodes[0] = ccs_float(0.0);
nodes[1] = ccs_float(1.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_FALSE), CCS_SUCCESS);
nodes[0] = ccs_int(1);
nodes[1] = ccs_float(1.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_TRUE), CCS_SUCCESS);
nodes[0] = ccs_float(0.0);
nodes[1] = ccs_int(1);
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_FALSE), CCS_SUCCESS);
nodes[0] = ccs_none;
nodes[1] = ccs_none;
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_TRUE), CCS_SUCCESS);
nodes[0] = ccs_none;
nodes[1] = ccs_int(1);
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_FALSE), -CCS_INVALID_VALUE);
nodes[0] = ccs_string("toto");
nodes[1] = ccs_string("toto");
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_TRUE), CCS_SUCCESS);
nodes[0] = ccs_string("tata");
nodes[1] = ccs_string("toto");
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_FALSE),CCS_SUCCESS);
nodes[0] = ccs_string("tata");
nodes[1] = ccs_int(1);
test_expression_wrapper(CCS_EQUAL, 2, nodes, NULL, NULL, ccs_bool(CCS_FALSE), -CCS_INVALID_VALUE);
}
void test_equal_numerical() {
ccs_configuration_space_t configuration_space;
ccs_hyperparameter_t hyperparameters[2];
ccs_datum_t nodes[2];
ccs_datum_t values[2];
ccs_error_t err;
err = ccs_create_configuration_space("my_config_space", NULL,
&configuration_space);
assert( err == CCS_SUCCESS );
hyperparameters[0] = create_dummy_numerical("param1");
hyperparameters[1] = create_dummy_numerical("param2");
err = ccs_configuration_space_add_hyperparameters(configuration_space, 2,
hyperparameters, NULL);
assert( err == CCS_SUCCESS );
nodes[0] = ccs_object(hyperparameters[0]);
nodes[1] = ccs_float(1.0);
values[0] = ccs_float(1.0);
values[1] = ccs_float(0.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_TRUE), CCS_SUCCESS);
values[0] = ccs_float(0.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_FALSE), CCS_SUCCESS);
nodes[0] = ccs_float(1.0);
nodes[1] = ccs_object(hyperparameters[1]);
values[0] = ccs_float(1.0);
values[1] = ccs_float(1.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_TRUE), CCS_SUCCESS);
values[1] = ccs_float(0.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_FALSE), CCS_SUCCESS);
nodes[0] = ccs_int(0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_TRUE), CCS_SUCCESS);
nodes[0] = ccs_bool(CCS_FALSE);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_TRUE), -CCS_INVALID_VALUE);
for (size_t i = 0; i < 2; i++) {
err = ccs_release_object(hyperparameters[i]);
assert( err == CCS_SUCCESS );
}
err = ccs_release_object(configuration_space);
assert( err == CCS_SUCCESS );
}
void test_equal_categorical() {
ccs_configuration_space_t configuration_space;
ccs_hyperparameter_t hyperparameters[2];
ccs_datum_t nodes[2];
ccs_datum_t values[2];
ccs_error_t err;
err = ccs_create_configuration_space("my_config_space", NULL,
&configuration_space);
assert( err == CCS_SUCCESS );
hyperparameters[0] = create_dummy_categorical("param1");
hyperparameters[1] = create_dummy_categorical("param2");
err = ccs_configuration_space_add_hyperparameters(configuration_space, 2,
hyperparameters, NULL);
assert( err == CCS_SUCCESS );
nodes[0] = ccs_object(hyperparameters[0]);
nodes[1] = ccs_float(2.0);
values[0] = ccs_float(2.0);
values[1] = ccs_int(1);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_TRUE), CCS_SUCCESS);
// Values tested must exist in the set
nodes[1] = ccs_float(3.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_TRUE), -CCS_INVALID_VALUE);
nodes[1] = ccs_int(1);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_FALSE), CCS_SUCCESS);
for (size_t i = 0; i < 2; i++) {
err = ccs_release_object(hyperparameters[i]);
assert( err == CCS_SUCCESS );
}
err = ccs_release_object(configuration_space);
assert( err == CCS_SUCCESS );
}
void test_equal_ordinal() {
ccs_configuration_space_t configuration_space;
ccs_hyperparameter_t hyperparameters[2];
ccs_datum_t nodes[2];
ccs_datum_t values[2];
ccs_error_t err;
err = ccs_create_configuration_space("my_config_space", NULL,
&configuration_space);
assert( err == CCS_SUCCESS );
hyperparameters[0] = create_dummy_ordinal("param1");
hyperparameters[1] = create_dummy_ordinal("param2");
err = ccs_configuration_space_add_hyperparameters(configuration_space, 2,
hyperparameters, NULL);
assert( err == CCS_SUCCESS );
nodes[0] = ccs_object(hyperparameters[0]);
nodes[1] = ccs_float(2.0);
values[0] = ccs_float(2.0);
values[1] = ccs_int(1);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_TRUE), CCS_SUCCESS);
// Values tested must exist in the set
nodes[1] = ccs_float(3.0);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_TRUE), -CCS_INVALID_VALUE);
nodes[1] = ccs_int(1);
test_expression_wrapper(CCS_EQUAL, 2, nodes, configuration_space, values, ccs_bool(CCS_FALSE), CCS_SUCCESS);
for (size_t i = 0; i < 2; i++) {
err = ccs_release_object(hyperparameters[i]);
assert( err == CCS_SUCCESS );
}
err = ccs_release_object(configuration_space);
assert( err == CCS_SUCCESS );
}
int main(int argc, char *argv[]) {
ccs_init();
test_equal_literal();
test_equal_numerical();
test_equal_categorical();
test_equal_ordinal();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment