Commit c1b77455 authored by Brice Videau's avatar Brice Videau
Browse files

Refactored expressions.

parent 83ae78ea
......@@ -72,11 +72,11 @@ typedef enum ccs_object_type_e ccs_object_type_t;
enum ccs_data_type_e {
CCS_NONE,
CCS_INACTIVE,
CCS_INTEGER,
CCS_FLOAT,
CCS_BOOLEAN,
CCS_STRING,
CCS_INACTIVE,
CCS_OBJECT,
CCS_DATA_TYPE_MAX,
CCS_DATA_TYPE_FORCE_64BIT = INT64_MAX
......
......@@ -24,6 +24,8 @@ enum ccs_expression_type_e {
CCS_NEGATIVE,
CCS_NOT,
CCS_LIST,
CCS_LITERAL,
CCS_VARIABLE,
CCS_EXPRESSION_TYPE_MAX,
CCS_EXPRESSION_FORCE_32BIT = INT_MAX
};
......@@ -41,7 +43,7 @@ typedef enum ccs_expression_type_e ccs_expression_type_t;
// 6 : MULTIPLY, DIVIDE, MODULO
// 7 : POSITIVE, NEGATIVE, NOT
// max - 1: LIST
// max : LITERAL, VARIABLE, HYPERPARAMETER
// max : LITERAL, VARIABLE
// One for each expression type:
extern const int ccs_expression_precedence[];
......@@ -68,6 +70,14 @@ ccs_create_expression(ccs_expression_type_t type,
ccs_datum_t *nodes,
ccs_expression_t *expression_ret);
extern ccs_error_t
ccs_create_literal(ccs_datum_t value,
ccs_expression_t *expression_ret);
extern ccs_error_t
ccs_create_variable(ccs_hyperparameter_t hyperparameter,
ccs_expression_t *expression_ret);
extern ccs_error_t
ccs_expression_get_type(ccs_expression_t expression,
ccs_expression_type_t *type_ret);
......@@ -79,9 +89,17 @@ ccs_expression_get_num_nodes(ccs_expression_t expression,
extern ccs_error_t
ccs_expression_get_nodes(ccs_expression_t expression,
size_t num_nodes,
ccs_datum_t *nodes,
ccs_expression_t *nodes,
size_t *num_nodes_ret);
extern ccs_error_t
ccs_literal_get_value(ccs_expression_t expression,
ccs_datum_t *value_ret);
extern ccs_error_t
ccs_variable_get_hyperparameter(ccs_expression_t expression,
ccs_hyperparameter_t *hyperparameter_ret);
extern ccs_error_t
ccs_expression_eval(ccs_expression_t expression,
ccs_context_t context,
......
......@@ -33,9 +33,14 @@ ccs_objective_space_add_variable(ccs_objective_space_t objective_space,
extern ccs_error_t
ccs_objective_space_add_objective(ccs_objective_space_t objective_space,
ccs_object_t objective,
ccs_expression_t objective,
ccs_objective_type_t type);
extern ccs_error_t
ccs_objective_space_get_objective(ccs_objective_space_t objective_space,
ccs_expression_t *objective_ret,
ccs_objective_type_t *type_ret);
#ifdef __cplusplus
}
#endif
......
......@@ -13,7 +13,8 @@ const int ccs_expression_precedence[] = {
5, 5,
6, 6, 6,
7, 7, 7,
8
8,
9, 9
};
const char *ccs_expression_symbols[] = {
......@@ -25,7 +26,8 @@ const char *ccs_expression_symbols[] = {
"+", "-",
"*", "/", "%",
"+", "-", "!",
NULL
NULL,
NULL, NULL
};
const int ccs_expression_arity[] = {
......@@ -37,7 +39,8 @@ const int ccs_expression_arity[] = {
2, 2,
2, 2, 2,
1, 1, 1,
-1
-1,
0, 0
};
static inline _ccs_expression_ops_t *
......@@ -50,76 +53,41 @@ _ccs_expression_del(ccs_object_t o) {
ccs_expression_t d = (ccs_expression_t)o;
_ccs_expression_data_t *data = d->data;
for (size_t i = 0; i < data->num_nodes; i++)
if (data->nodes[i].type == CCS_OBJECT)
ccs_release_object(data->nodes[i].value.o);
ccs_release_object(data->nodes[i]);
return CCS_SUCCESS;
}
static inline ccs_error_t
_ccs_expr_datum_eval(ccs_datum_t *d,
ccs_context_t context,
ccs_datum_t *values,
ccs_datum_t *result,
ccs_hyperparameter_type_t *ht) {
ccs_object_type_t t;
size_t index;
_ccs_expr_node_eval(ccs_expression_t n,
ccs_context_t context,
ccs_datum_t *values,
ccs_datum_t *result,
ccs_hyperparameter_type_t *ht) {
ccs_error_t err;
switch (d->type) {
case CCS_NONE:
case CCS_INTEGER:
case CCS_FLOAT:
case CCS_BOOLEAN:
case CCS_STRING:
*result = *d;
break;
case CCS_OBJECT:
err = ccs_object_get_type(d->value.o, &t);
if (ht && n->data->type == CCS_VARIABLE) {
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)n->data;
err = ccs_hyperparameter_get_type(
(ccs_hyperparameter_t)(d->hyperparameter), ht);
if (err)
return err;
switch (t) {
case CCS_EXPRESSION:
return ccs_expression_eval((ccs_expression_t)(d->value.o),
context, values, result);
break;
case CCS_HYPERPARAMETER:
err = ccs_context_get_hyperparameter_index(
context, (ccs_hyperparameter_t)(d->value.o), &index);
if (err)
return err;
*result = values[index];
if (result->type == CCS_INACTIVE)
return -CCS_INACTIVE_HYPERPARAMETER;
if (ht) {
err = ccs_hyperparameter_get_type(
(ccs_hyperparameter_t)(d->value.o), ht);
if (err)
return err;
}
break;
default:
return -CCS_INVALID_OBJECT;
}
break;
default:
return -CCS_INVALID_VALUE;
}
return CCS_SUCCESS;
return ccs_expression_eval(n, context, values, result);
}
#define eval_node(data, context, values, node, ht) do { \
ccs_error_t err; \
err = _ccs_expr_datum_eval(data->nodes, context, values, &node, ht); \
err = _ccs_expr_node_eval(data->nodes[0], context, values, &node, ht); \
if (err) \
return err; \
} while(0)
#define eval_left_right(data, context, values, left, right, htl, htr) do { \
ccs_error_t err; \
err = _ccs_expr_datum_eval(data->nodes, context, values, &left, htl); \
err = _ccs_expr_node_eval(data->nodes[0], context, values, &left, htl); \
if (err) \
return err; \
err = _ccs_expr_datum_eval(data->nodes + 1, context, values, &right, htr); \
err = _ccs_expr_node_eval(data->nodes[1], context, values, &right, htr); \
if (err) \
return err; \
} while (0)
......@@ -167,8 +135,7 @@ static _ccs_expression_ops_t _ccs_expr_and_ops = {
#define check_values(param, v) do { \
ccs_bool_t valid; \
ccs_error_t err; \
err = ccs_hyperparameter_check_value( \
(ccs_hyperparameter_t)(param), v, &valid); \
err = ccs_hyperparameter_check_value(param, v, &valid); \
if (unlikely(err)) \
return err; \
if (!valid) \
......@@ -177,7 +144,9 @@ static _ccs_expression_ops_t _ccs_expr_and_ops = {
#define check_hypers(param, v, t) do { \
if (t == CCS_ORDINAL || t == CCS_CATEGORICAL) { \
check_values(param.value.o, v); \
_ccs_expression_variable_data_t *d = \
(_ccs_expression_variable_data_t *)param->data; \
check_values(d->hyperparameter, v); \
} else if (t == CCS_NUMERICAL) {\
if (v.type != CCS_INTEGER && v.type != CCS_FLOAT) \
return -CCS_INVALID_VALUE; \
......@@ -333,8 +302,10 @@ _ccs_expr_less_eval(_ccs_expression_data_t *data,
ccs_error_t err;
if (htl == CCS_ORDINAL) {
ccs_int_t cmp;
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data->nodes[0]->data;
err = ccs_ordinal_hyperparameter_compare_values(
(ccs_hyperparameter_t)(data->nodes[0].value.o),
d->hyperparameter,
left, right, &cmp);
if (err)
return err;
......@@ -344,8 +315,10 @@ _ccs_expr_less_eval(_ccs_expression_data_t *data,
}
if (htr == CCS_ORDINAL) {
ccs_int_t cmp;
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data->nodes[1]->data;
err = ccs_ordinal_hyperparameter_compare_values(
(ccs_hyperparameter_t)(data->nodes[1].value.o),
d->hyperparameter,
left, right, &cmp);
if (err)
return err;
......@@ -385,8 +358,10 @@ _ccs_expr_greater_eval(_ccs_expression_data_t *data,
ccs_error_t err;
if (htl == CCS_ORDINAL) {
ccs_int_t cmp;
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data->nodes[0]->data;
err = ccs_ordinal_hyperparameter_compare_values(
(ccs_hyperparameter_t)(data->nodes[0].value.o),
d->hyperparameter,
left, right, &cmp);
if (err)
return err;
......@@ -396,8 +371,10 @@ _ccs_expr_greater_eval(_ccs_expression_data_t *data,
}
if (htr == CCS_ORDINAL) {
ccs_int_t cmp;
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data->nodes[1]->data;
err = ccs_ordinal_hyperparameter_compare_values(
(ccs_hyperparameter_t)(data->nodes[1].value.o),
d->hyperparameter,
left, right, &cmp);
if (err)
return err;
......@@ -437,8 +414,10 @@ _ccs_expr_less_or_equal_eval(_ccs_expression_data_t *data,
ccs_error_t err;
if (htl == CCS_ORDINAL) {
ccs_int_t cmp;
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data->nodes[0]->data;
err = ccs_ordinal_hyperparameter_compare_values(
(ccs_hyperparameter_t)(data->nodes[0].value.o),
d->hyperparameter,
left, right, &cmp);
if (err)
return err;
......@@ -448,8 +427,10 @@ _ccs_expr_less_or_equal_eval(_ccs_expression_data_t *data,
}
if (htr == CCS_ORDINAL) {
ccs_int_t cmp;
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data->nodes[1]->data;
err = ccs_ordinal_hyperparameter_compare_values(
(ccs_hyperparameter_t)(data->nodes[1].value.o),
d->hyperparameter,
left, right, &cmp);
if (err)
return err;
......@@ -489,8 +470,10 @@ _ccs_expr_greater_or_equal_eval(_ccs_expression_data_t *data,
ccs_error_t err;
if (htl == CCS_ORDINAL) {
ccs_int_t cmp;
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data->nodes[0]->data;
err = ccs_ordinal_hyperparameter_compare_values(
(ccs_hyperparameter_t)(data->nodes[0].value.o),
d->hyperparameter,
left, right, &cmp);
if (err)
return err;
......@@ -500,8 +483,10 @@ _ccs_expr_greater_or_equal_eval(_ccs_expression_data_t *data,
}
if (htr == CCS_ORDINAL) {
ccs_int_t cmp;
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data->nodes[1]->data;
err = ccs_ordinal_hyperparameter_compare_values(
(ccs_hyperparameter_t)(data->nodes[1].value.o),
d->hyperparameter,
left, right, &cmp);
if (err)
return err;
......@@ -528,23 +513,15 @@ _ccs_expr_in_eval(_ccs_expression_data_t *data,
ccs_context_t context,
ccs_datum_t *values,
ccs_datum_t *result) {
if (data->nodes[1].type != CCS_OBJECT)
return -CCS_INVALID_VALUE;
ccs_object_type_t type;
ccs_error_t err;
err = ccs_object_get_type(data->nodes[1].value.o, &type);
if (err)
return err;
if (type != CCS_EXPRESSION)
return -CCS_INVALID_VALUE;
ccs_expression_type_t etype;
err = ccs_expression_get_type((ccs_expression_t)(data->nodes[1].value.o), &etype);
err = ccs_expression_get_type(data->nodes[1], &etype);
if (err)
return err;
if (etype != CCS_LIST)
return -CCS_INVALID_VALUE;
size_t num_nodes;
err = ccs_expression_get_num_nodes((ccs_expression_t)(data->nodes[1].value.o), &num_nodes);
err = ccs_expression_get_num_nodes(data->nodes[1], &num_nodes);
if (err)
return err;
if (num_nodes == 0) {
......@@ -558,7 +535,7 @@ _ccs_expr_in_eval(_ccs_expression_data_t *data,
eval_node(data, context, values, left, &htl);
for (size_t i = 0; i < num_nodes; i++) {
ccs_datum_t right;
err = ccs_expression_list_eval_node((ccs_expression_t)(data->nodes[1].value.o), context, values, i, &right);
err = ccs_expression_list_eval_node(data->nodes[1], context, values, i, &right);
if (err)
return err;
check_hypers(data->nodes[0], right, htl);
......@@ -871,9 +848,9 @@ static _ccs_expression_ops_t _ccs_expr_not_ops = {
static ccs_error_t
_ccs_expr_list_eval(_ccs_expression_data_t *data,
ccs_context_t context,
ccs_datum_t *values,
ccs_datum_t *result) {
ccs_context_t context,
ccs_datum_t *values,
ccs_datum_t *result) {
return -CCS_UNSUPPORTED_OPERATION;
}
......@@ -882,6 +859,53 @@ static _ccs_expression_ops_t _ccs_expr_list_ops = {
&_ccs_expr_list_eval
};
static ccs_error_t
_ccs_expr_literal_eval(_ccs_expression_data_t *data,
ccs_context_t context,
ccs_datum_t *values,
ccs_datum_t *result) {
_ccs_expression_literal_data_t *d =
(_ccs_expression_literal_data_t *)data;
*result = d->value;
return CCS_SUCCESS;
}
static _ccs_expression_ops_t _ccs_expr_literal_ops = {
{ &_ccs_expression_del },
&_ccs_expr_literal_eval
};
static ccs_error_t
_ccs_expr_variable_del(ccs_object_t o) {
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)((ccs_expression_t)o)->data;
return ccs_release_object(d->hyperparameter);
}
static ccs_error_t
_ccs_expr_variable_eval(_ccs_expression_data_t *data,
ccs_context_t context,
ccs_datum_t *values,
ccs_datum_t *result) {
_ccs_expression_variable_data_t *d =
(_ccs_expression_variable_data_t *)data;
size_t index;
ccs_error_t err;
err = ccs_context_get_hyperparameter_index(context,
(ccs_hyperparameter_t)(d->hyperparameter), &index);
if (err)
return err;
*result = values[index];
if (result->type == CCS_INACTIVE)
return -CCS_INACTIVE_HYPERPARAMETER;
return CCS_SUCCESS;
}
static _ccs_expression_ops_t _ccs_expr_variable_ops = {
{ &_ccs_expr_variable_del },
&_ccs_expr_variable_eval
};
static inline _ccs_expression_ops_t *
_ccs_expression_ops_broker(ccs_expression_type_t expression_type) {
switch (expression_type) {
......@@ -939,16 +963,98 @@ _ccs_expression_ops_broker(ccs_expression_type_t expression_type) {
case CCS_LIST:
return &_ccs_expr_list_ops;
break;
case CCS_LITERAL:
return &_ccs_expr_literal_ops;
break;
case CCS_VARIABLE:
return &_ccs_expr_variable_ops;
break;
default:
return NULL;
}
}
ccs_error_t
ccs_create_literal(ccs_datum_t value,
ccs_expression_t *expression_ret) {
if (value.type < CCS_NONE || value.type > CCS_STRING)
return -CCS_INVALID_VALUE;
if (!expression_ret)
return -CCS_INVALID_VALUE;
size_t size_str = 0;
if (value.type == CCS_STRING && value.value.s) {
size_str = strlen(value.value.s) + 1;
}
uintptr_t mem = (uintptr_t)calloc(1,
sizeof(struct _ccs_expression_s) +
sizeof(struct _ccs_expression_literal_data_s) +
size_str);
if(!mem)
return -CCS_ENOMEM;
ccs_expression_t expression = (ccs_expression_t)mem;
_ccs_object_init(&(expression->obj), CCS_EXPRESSION,
(_ccs_object_ops_t*)_ccs_expression_ops_broker(CCS_LITERAL));
_ccs_expression_literal_data_t *expression_data =
(_ccs_expression_literal_data_t *)
(mem + sizeof(struct _ccs_expression_s));
expression_data->expr.type = CCS_LITERAL;
expression_data->expr.num_nodes = 0;
expression_data->expr.nodes = NULL;
if (size_str) {
char *str_pool = (char *)(mem +
sizeof(struct _ccs_expression_s) +
sizeof(struct _ccs_expression_literal_data_s));
expression_data->value.type = CCS_STRING;
expression_data->value.value.s = str_pool;
strcpy(str_pool, value.value.s);
} else {
expression_data->value = value;
}
expression->data = (_ccs_expression_data_t *)expression_data;
*expression_ret = expression;
return CCS_SUCCESS;
}
ccs_error_t
ccs_create_variable(ccs_hyperparameter_t hyperparameter,
ccs_expression_t *expression_ret) {
if (!hyperparameter)
return -CCS_INVALID_OBJECT;
if (!expression_ret)
return -CCS_INVALID_VALUE;
ccs_error_t err;
uintptr_t mem = (uintptr_t)calloc(1,
sizeof(struct _ccs_expression_s) +
sizeof(struct _ccs_expression_variable_data_s));
if (!mem)
return -CCS_ENOMEM;
err = ccs_retain_object(hyperparameter);
if (err) {
free((void *)mem);
return err;
}
ccs_expression_t expression = (ccs_expression_t)mem;
_ccs_object_init(&(expression->obj), CCS_EXPRESSION,
(_ccs_object_ops_t*)_ccs_expression_ops_broker(CCS_VARIABLE));
_ccs_expression_variable_data_t *expression_data =
(_ccs_expression_variable_data_t *)
(mem + sizeof(struct _ccs_expression_s));
expression_data->expr.type = CCS_VARIABLE;
expression_data->expr.num_nodes = 0;
expression_data->expr.nodes = NULL;
expression_data->hyperparameter = hyperparameter;
expression->data = (_ccs_expression_data_t *)expression_data;
*expression_ret = expression;
return CCS_SUCCESS;
}
ccs_error_t
ccs_create_expression(ccs_expression_type_t type,
size_t num_nodes,
ccs_datum_t *nodes,
ccs_expression_t *expression_ret) {
if (type < CCS_OR || type > CCS_LIST)
return -CCS_INVALID_VALUE;
if (num_nodes && !nodes)
return -CCS_INVALID_VALUE;
if (!expression_ret)
......@@ -956,33 +1062,26 @@ ccs_create_expression(ccs_expression_type_t type,
int arity = ccs_expression_arity[type];
if (arity >= 0 && num_nodes != (size_t)arity)
return -CCS_INVALID_VALUE;
ccs_error_t err;
size_t size_strs = 0;
for(size_t i = 0; i < num_nodes; i++)
if (nodes[i].type == CCS_STRING) {
if (!nodes[i].value.s)
for(size_t i = 0; i < num_nodes; i++){
if (nodes[i].type == CCS_OBJECT) {
ccs_object_type_t t;
err = ccs_object_get_type(nodes[i].value.o, &t);
if (err)
return err;
if (t != CCS_HYPERPARAMETER && t != CCS_EXPRESSION)
return -CCS_INVALID_VALUE;
size_strs += strlen(nodes[i].value.s) + 1;
}
} else if (nodes[i].type < CCS_NONE || nodes[i].type > CCS_STRING)
return -CCS_INVALID_VALUE;
}
uintptr_t mem = (uintptr_t)calloc(1,
sizeof(struct _ccs_expression_s) +
sizeof(struct _ccs_expression_data_s) +
num_nodes*sizeof(ccs_datum_t) +
size_strs);
num_nodes*sizeof(ccs_expression_t));
if (!mem)
return -CCS_ENOMEM;
for(size_t i = 0; i < num_nodes; i++)
if (nodes[i].type == CCS_OBJECT) {
err = ccs_retain_object(nodes[i].value.o);
if (err) {
free((void *)mem);
return err;
}
}
ccs_expression_t expression = (ccs_expression_t)mem;
_ccs_object_init(&(expression->obj), CCS_EXPRESSION,
(_ccs_object_ops_t*)_ccs_expression_ops_broker(type));
......@@ -990,26 +1089,42 @@ ccs_create_expression(ccs_expression_type_t type,
(_ccs_expression_data_t *)(mem + sizeof(struct _ccs_expression_s));
expression_data->type = type;
expression_data->num_nodes = num_nodes;
expression_data->nodes = (ccs_datum_t *)(mem +
expression_data->nodes = (ccs_expression_t *)(mem +
sizeof(struct _ccs_expression_s) +
sizeof(struct _ccs_expression_data_s));
char *str_pool = (char *)(mem +
sizeof(struct _ccs_expression_s) +
sizeof(struct _ccs_expression_data_s) +
num_nodes*sizeof(ccs_datum_t));
for (size_t i = 0; i < num_nodes; i++) {
if (nodes[i].type == CCS_STRING) {
expression_data->nodes[i].type = CCS_STRING;
expression_data->nodes[i].value.s = str_pool;
strcpy(str_pool, nodes[i].value.s);
str_pool += strlen(nodes[i].value.s) + 1;
if (nodes[i].type == CCS_OBJECT) {
ccs_object_type_t t;
ccs_object_get_type(nodes[i].value.o, &t);
if (t == CCS_EXPRESSION) {
err = ccs_retain_object(nodes[i].value.o);
if (err)
goto cleanup;
expression_data->nodes[i] =
(ccs_expression_t)nodes[i].value.o;
} else {
err = ccs_create_variable(
(ccs_hyperparameter_t)nodes[i].value.o,
expression_data->nodes + i);
if (err)
goto cleanup;
}
} else {
expression_data->nodes[i] = nodes[i];
err = ccs_create_literal(nodes[i], expression_data->nodes + i);
if (err)
goto cleanup;
}
}
expression->data = expression_data;
*expression_ret = expression;
return CCS_SUCCESS;
cleanup:
for (size_t i = 0; i < num_nodes; i++) {
if (expression_data->nodes[i])
ccs_release_object(expression_data->nodes[i]);
}
free((void*)mem);
return err;
}
ccs_error_t
......@@ -1057,7 +1172,7 @@ ccs_expression_get_num_nodes(ccs_expression_t expression,
ccs_error_t
ccs_expression_get_nodes(ccs_expression_t expression,
size_t num_nodes,
ccs_datum_t *nodes,
ccs_expression_t *nodes,
size_t *num_nodes_ret) {
if (!expression || !expression->data)
return -CCS_INVALID_OBJECT;
......@@ -1067,13 +1182,13 @@ ccs_expression_get_nodes(ccs_expression_t expression,
return -CCS_INVALID_VALUE;
size_t count = expression->data->num_nodes;
if (nodes) {
ccs_datum_t *p_nodes =