Commit 926a846a authored by Brice Videau's avatar Brice Videau
Browse files

Added support for ordinal parameters, beware of bit float equality.

parent c8550874
......@@ -62,20 +62,22 @@ enum ccs_object_type_e {
typedef enum ccs_object_type_e ccs_object_type_t;
enum ccs_data_type_e {
CCS_NONE,
CCS_INTEGER,
CCS_FLOAT,
CCS_STRING,
CCS_OBJECT,
CCS_NONE,
CCS_DATA_TYPE_MAX,
CCS_DATA_TYPE_FORCE_32BIT = INT_MAX
CCS_DATA_TYPE_FORCE_64BIT = INT64_MAX
};
typedef enum ccs_data_type_e ccs_data_type_t;
enum ccs_numeric_type_e {
CCS_NUM_INTEGER = CCS_INTEGER,
CCS_NUM_FLOAT = CCS_FLOAT
CCS_NUM_FLOAT = CCS_FLOAT,
CCS_NUM_TYPE_MAX,
CCS_NUM_TYPE_FORCE_64BIT = INT64_MAX
};
typedef enum ccs_numeric_type_e ccs_numeric_type_t;
......@@ -100,14 +102,14 @@ typedef union ccs_numeric_u ccs_numeric_t;
#define CCSF(v) v
#define CCSI(v) v
#else
#define CCSF(v) ( (ccs_numeric_t){ .f = v })
#define CCSI(v) ( (ccs_numeric_t){ .i = v })
#define CCSF(v) ((ccs_numeric_t){ .f = v })
#define CCSI(v) ((ccs_numeric_t){ .i = v })
#endif
union ccs_value_u {
ccs_float_t f;
ccs_int_t i;
char *s;
const char *s;
ccs_object_t o;
#ifdef __cplusplus
ccs_value_u() : i(0L) {}
......@@ -139,6 +141,10 @@ struct ccs_datum_u {
typedef struct ccs_datum_u ccs_datum_t;
extern const ccs_datum_t ccs_none;
#define CCS_NONE_VAL {{0}, CCS_NONE}
extern ccs_error_t
ccs_init();
......
......@@ -37,11 +37,19 @@ ccs_create_categorical_hyperparameter(const char *name,
ccs_hyperparameter_t *hyperparameter_ret);
extern ccs_error_t
ccs_create_ordinal_hyperparameters(const char *name,
void *user_data,
size_t num_possible_values,
ccs_datum_t *possible_values,
ccs_hyperparameter_t *hyperparameter_ret);
ccs_create_ordinal_hyperparameter(const char *name,
size_t num_possible_values,
ccs_datum_t *possible_values,
size_t default_value_index,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret);
extern ccs_error_t
ccs_ordinal_hyperparameter_compare_values(ccs_hyperparameter_t hyperparameter,
ccs_datum_t value1,
ccs_datum_t value2,
ccs_int_t *comp_ret);
// Accessors
......
......@@ -18,6 +18,8 @@ libcconfigspace_la_SOURCES = \
hyperparameter.c \
hyperparameter_internal.h \
hyperparameter_numerical.c \
hyperparameter_categorical.c
hyperparameter_categorical.c \
uthash.h \
hyperparameter_ordinal.c
@VALGRIND_CHECK_RULES@
......@@ -2,6 +2,8 @@
#include <stdlib.h>
#include <gsl/gsl_rng.h>
const ccs_datum_t ccs_none = CCS_NONE_VAL;
ccs_error_t
ccs_init() {
gsl_rng_env_setup();
......
#include "cconfigspace_internal.h"
#include "hyperparameter_internal.h"
#include <string.h>
#define HASH_NONFATAL_OOM 1
#define HASH_FUNCTION(s,len,hashv) (hashv) = _hash_datum((ccs_datum_t *)(s))
#define HASH_KEYCMP(a,b,len) (_datum_cmp((ccs_datum_t *)a, (ccs_datum_t *)b))
#include "uthash.h"
/* BEWARE: ccs_float_t are used as hash keys. In order to recall sucessfully,
* The *SAME* float must be used.
* Alternative is o(n) access for floating point values as they would all go
* in the same bucket. May be the wisest... Switch to find in the possiblie_values list?
* #define MAXULPDIFF 7 // To define
* // Could be doing type puning...
* static inline int _cmp_float(ccs_float_t a, ccs_float_t b) {
* int64_t t1, t2, cmp;
* memcpy(&t1, &a, sizeof(int64));
* memcpy(&t2, &b, sizeof(int64));
* if (a == b)
* return 0;
* if ((t1 < 0) != (t2 < 0)) {
* if (t1 < 0)
* return -1;
* if (t2 < 0)
* return 1;
* }
* cmp = labs(t1-t2);
* if (cmp <= MAXULPDIFF)
* return 0;
* else if (a < b)
* return -1;
* else
* return 1;
* }
*/
static inline unsigned _hash_datum(ccs_datum_t *d) {
unsigned h;
switch(d->type) {
case CCS_STRING:
if (d->value.s)
HASH_JEN(d->value.s, strlen(d->value.s), h);
else
HASH_JEN(d, sizeof(ccs_datum_t), h);
break;
case CCS_NONE:
HASH_JEN(&(d->type), sizeof(d->type), h);
break;
default:
HASH_JEN(d, sizeof(ccs_datum_t), h);
}
return h;
}
static inline int _datum_cmp(ccs_datum_t *a, ccs_datum_t *b) {
if (a->type < b->type) {
return -1;
} else if (a->type > b->type) {
return 1;
} else {
switch(a->type) {
case CCS_STRING:
if (a->value.s == b->value.s)
return 0;
else if (!a->value.s)
return -1;
else if (!b->value.s)
return 1;
else
return strcmp(a->value.s, b->value.s);
case CCS_NONE:
return 0;
break;
default:
return memcmp(&(a->value), &(b->value), sizeof(ccs_value_t));
}
}
}
struct _ccs_hash_datum_s {
ccs_datum_t d;
UT_hash_handle hh;
};
typedef struct _ccs_hash_datum_s _ccs_hash_datum_t;
struct _ccs_hyperparameter_ordinal_data_s {
_ccs_hyperparameter_common_data_t common_data;
ccs_int_t num_possible_values;
_ccs_hash_datum_t *possible_values;
_ccs_hash_datum_t *hash;
};
typedef struct _ccs_hyperparameter_ordinal_data_s _ccs_hyperparameter_ordinal_data_t;
static ccs_error_t
_ccs_hyperparameter_ordinal_del(ccs_object_t o) {
ccs_hyperparameter_t d = (ccs_hyperparameter_t)o;
_ccs_hyperparameter_ordinal_data_t *data = (_ccs_hyperparameter_ordinal_data_t *)(d->data);
for (ccs_int_t i = 0; i < data->num_possible_values; i++) {
HASH_DELETE(hh, data->hash, data->possible_values + i);
}
return ccs_release_object(data->common_data.distribution);
}
static ccs_error_t
_ccs_hyperparameter_ordinal_samples(_ccs_hyperparameter_data_t *data,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
_ccs_hyperparameter_ordinal_data_t *d =
(_ccs_hyperparameter_ordinal_data_t *)data;
ccs_error_t err;
ccs_int_t *vs = (ccs_int_t *)values + num_values;
err = ccs_distribution_samples(d->common_data.distribution,
rng, num_values, (ccs_numeric_t *)vs);
if (err)
return err;
if (!d->common_data.oversampling) {
for(size_t i = 0; i < num_values; i++)
values[i] = d->possible_values[vs[i]].d;
} else {
size_t found = 0;
for(size_t i = 0; i < num_values; i++)
if (vs[i] >= 0 && vs[i] < d->num_possible_values)
values[found++] = d->possible_values[vs[i]].d;
vs = NULL;
size_t coeff = 2;
while (found < num_values) {
size_t buff_sz = (num_values - found)*coeff;
vs = (ccs_int_t *)malloc(sizeof(ccs_numeric_t)*buff_sz);
if (!vs)
return -CCS_ENOMEM;
err = ccs_distribution_samples(d->common_data.distribution,
rng, buff_sz, (ccs_numeric_t *)vs);
for(size_t i = 0; i < buff_sz && found < num_values; i++)
if (vs[i] >= 0 && vs[i] < d->num_possible_values)
values[found++] = d->possible_values[vs[i]].d;
coeff <<= 1;
free(vs);
if (coeff > 32)
return -CCS_SAMPLING_UNSUCCESSFUL;
}
}
return CCS_SUCCESS;
}
static _ccs_hyperparameter_ops_t _ccs_hyperparameter_ordinal_ops = {
{ &_ccs_hyperparameter_ordinal_del },
&_ccs_hyperparameter_ordinal_samples
};
ccs_error_t
ccs_ordinal_hyperparameter_compare_values(ccs_hyperparameter_t hyperparameter,
ccs_datum_t value1,
ccs_datum_t value2,
ccs_int_t *comp_ret) {
if (unlikely(!hyperparameter || !hyperparameter->data))
return -CCS_INVALID_OBJECT;
if (unlikely(!comp_ret))
return -CCS_INVALID_VALUE;
_ccs_hyperparameter_ordinal_data_t *d = ((_ccs_hyperparameter_ordinal_data_t *)(hyperparameter->data));
_ccs_hash_datum_t *p1, *p2;
HASH_FIND(hh, d->hash, &value1, sizeof(ccs_datum_t), p1);
HASH_FIND(hh, d->hash, &value2, sizeof(ccs_datum_t), p2);
if (likely(p1 && p2)) {
if (p1 < p2)
*comp_ret = -1;
else if (p1 > p2)
*comp_ret = 1;
else
*comp_ret = 0;
return CCS_SUCCESS;
} else {
return -CCS_INVALID_VALUE;
}
}
#undef uthash_nonfatal_oom
#define uthash_nonfatal_oom(elt) { \
_ccs_hash_datum_t *v = (_ccs_hash_datum_t *)elt; \
while ( v > pvs) { \
HASH_DELETE(hh, hyperparam_data->hash, --v); \
} \
ccs_release_object(distribution); \
free((void*)mem); \
return -CCS_ENOMEM; \
}
ccs_error_t
ccs_create_ordinal_hyperparameter(const char *name,
size_t num_possible_values,
ccs_datum_t *possible_values,
size_t default_value_index,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret) {
if (!hyperparameter_ret || !name)
return -CCS_INVALID_VALUE;
if (!num_possible_values ||
num_possible_values > INT64_MAX ||
num_possible_values <= default_value_index)
return -CCS_INVALID_VALUE;
if (!possible_values)
return -CCS_INVALID_VALUE;
size_t size_strs = 0;
for(size_t i = 0; i < num_possible_values; i++)
if (possible_values[i].type == CCS_STRING) {
if (!possible_values[i].value.s)
return -CCS_INVALID_VALUE;
size_strs += strlen(possible_values[i].value.s) + 1;
}
uintptr_t mem = (uintptr_t)calloc(1,
sizeof(struct _ccs_hyperparameter_s) +
sizeof(_ccs_hyperparameter_ordinal_data_t) +
sizeof(_ccs_hash_datum_t) * num_possible_values +
strlen(name) + 1 +
size_strs);
if (!mem)
return -CCS_ENOMEM;
ccs_interval_t interval;
interval.type = CCS_NUM_INTEGER;
interval.lower.i = 0;
interval.upper.i = (ccs_int_t)num_possible_values;
interval.lower_included = CCS_TRUE;
interval.upper_included = CCS_FALSE;
ccs_error_t err;
ccs_bool_t oversampling;
if (!distribution) {
err = ccs_create_uniform_distribution(interval.type,
interval.lower, interval.upper,
CCS_LINEAR, CCSI(0),
&distribution);
if (err) {
free((void *)mem);
return err;
}
oversampling = CCS_FALSE;
} else {
err = ccs_distribution_check_oversampling(distribution, &interval,
&oversampling);
if (err) {
free((void *)mem);
return err;
}
ccs_retain_object(distribution);
}
ccs_hyperparameter_t hyperparam = (ccs_hyperparameter_t)mem;
_ccs_object_init(&(hyperparam->obj), CCS_HYPERPARAMETER, (_ccs_object_ops_t *)&_ccs_hyperparameter_ordinal_ops);
_ccs_hyperparameter_ordinal_data_t *hyperparam_data =
(_ccs_hyperparameter_ordinal_data_t *)(mem +
sizeof(struct _ccs_hyperparameter_s));
hyperparam_data->common_data.type = CCS_ORDINAL;
hyperparam_data->common_data.name = (char *)(mem +
sizeof(struct _ccs_hyperparameter_s) +
sizeof(_ccs_hyperparameter_ordinal_data_t) +
sizeof(_ccs_hash_datum_t) * num_possible_values);
strcpy((char *)hyperparam_data->common_data.name, name);
hyperparam_data->common_data.user_data = user_data;
hyperparam_data->common_data.distribution = distribution;
hyperparam_data->common_data.default_value = possible_values[default_value_index];
hyperparam_data->common_data.interval = interval;
hyperparam_data->common_data.oversampling = oversampling;
hyperparam_data->num_possible_values = num_possible_values;
_ccs_hash_datum_t *pvs = (_ccs_hash_datum_t *)(mem +
sizeof(struct _ccs_hyperparameter_s) +
sizeof(_ccs_hyperparameter_ordinal_data_t));
hyperparam_data->possible_values = pvs;
hyperparam_data->hash = NULL;
char *str_pool = (char *)(hyperparam_data->common_data.name) + strlen(name) + 1;
for (size_t i = 0; i < num_possible_values; i++) {
_ccs_hash_datum_t *p = NULL;
HASH_FIND(hh, hyperparam_data->hash, possible_values + i, sizeof(ccs_datum_t), p);
if (p) {
_ccs_hash_datum_t *tmp;
HASH_ITER(hh, hyperparam_data->hash, p, tmp) {
HASH_DELETE(hh, hyperparam_data->hash, p);
}
ccs_release_object(distribution);
free((void *)mem);
return -CCS_INVALID_VALUE;
}
if (possible_values[i].type == CCS_STRING) {
pvs[i].d.type = CCS_STRING;
pvs[i].d.value.s = str_pool;
strcpy(str_pool, possible_values[i].value.s);
str_pool += strlen(possible_values[i].value.s) + 1;
} else {
pvs[i].d = possible_values[i];
}
HASH_ADD(hh, hyperparam_data->hash, d, sizeof(ccs_datum_t), pvs + i);
}
hyperparam->data = (_ccs_hyperparameter_data_t *)hyperparam_data;
*hyperparameter_ret = hyperparam;
return CCS_SUCCESS;
}
This diff is collapsed.
......@@ -10,7 +10,8 @@ RNG_TESTS = \
test_normal_distribution \
test_roulette_distribution \
test_numerical_hyperparameter \
test_categorical_hyperparameter
test_categorical_hyperparameter \
test_ordinal_hyperparameter
# unit tests
UNIT_TESTS = \
......
#include <stdlib.h>
#include <assert.h>
#include <cconfigspace.h>
#include <string.h>
void test_create() {
ccs_hyperparameter_t hyperparameter;
ccs_hyperparameter_type_t type;
ccs_datum_t default_value;
ccs_error_t err;
const char *name;
void * user_data;
ccs_distribution_t distribution;
ccs_distribution_type_t dist_type;
ccs_interval_t interval;
const size_t num_possible_values = 4;
ccs_datum_t possible_values[num_possible_values];
const size_t default_value_index = 2;
for(size_t i = 0; i < num_possible_values; i++) {
possible_values[i].type = CCS_INTEGER;
possible_values[i].value.i = (i+1)*2;
}
err = ccs_create_ordinal_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
NULL, (void *)0xdeadbeef,
&hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_get_type(hyperparameter, &type);
assert( err == CCS_SUCCESS );
assert( type == CCS_ORDINAL );
err = ccs_hyperparameter_get_default_value(hyperparameter, &default_value);
assert( err == CCS_SUCCESS );
assert( default_value.type == CCS_INTEGER );
assert( default_value.value.i == possible_values[default_value_index].value.i );
err = ccs_hyperparameter_get_name(hyperparameter, &name);
assert( err == CCS_SUCCESS );
assert( strcmp(name, "my_param") == 0 );
err = ccs_hyperparameter_get_user_data(hyperparameter, &user_data);
assert( err == CCS_SUCCESS );
assert( user_data == (void *)0xdeadbeef );
err = ccs_hyperparameter_get_distribution(hyperparameter, &distribution);
assert( err == CCS_SUCCESS );
assert( distribution );
err = ccs_distribution_get_type(distribution, &dist_type);
assert( err == CCS_SUCCESS );
assert( dist_type == CCS_UNIFORM );
err = ccs_distribution_get_bounds(distribution, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_INTEGER );
assert( interval.lower.i == 0 );
assert( interval.lower_included == CCS_TRUE );
assert( interval.upper.i == 4 );
assert( interval.upper_included == CCS_FALSE );
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
}
void test_create_error() {
ccs_hyperparameter_t hyperparameter;
ccs_error_t err;
const size_t num_possible_values = 4;
ccs_datum_t possible_values[num_possible_values];
const size_t default_value_index = 2;
for(size_t i = 0; i < num_possible_values; i++) {
possible_values[i].type = CCS_INTEGER;
possible_values[i].value.i = (i+1)*2;
}
possible_values[0].value.i = possible_values[num_possible_values-1].value.i;
err = ccs_create_ordinal_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
NULL, NULL,
&hyperparameter);
assert( err == -CCS_INVALID_VALUE );
}
void test_samples() {
ccs_rng_t rng;
ccs_hyperparameter_t hyperparameter;
const size_t num_samples = 10000;
ccs_datum_t samples[num_samples];
ccs_error_t err;
const size_t num_possible_values = 4;
ccs_datum_t possible_values[num_possible_values];
const size_t default_value_index = 2;
for(size_t i = 0; i < num_possible_values; i++) {
possible_values[i].type = CCS_INTEGER;
possible_values[i].value.i = (i+1)*2;
}
err = ccs_rng_create(&rng);
assert( err == CCS_SUCCESS );
err = ccs_create_ordinal_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
NULL, NULL,
&hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_samples(hyperparameter, rng, num_samples, samples);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < num_samples; i++) {
assert( samples[i].type == CCS_INTEGER );
assert( samples[i].value.i %2 == 0 );
assert( samples[i].value.i >= 2 );
assert( samples[i].value.i <= (ccs_int_t)num_possible_values * 2);
}
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
assert( err == CCS_SUCCESS );
}
void test_oversampling() {
ccs_rng_t rng;
ccs_hyperparameter_t hyperparameter;
ccs_distribution_t distribution;
const size_t num_samples = 10000;
ccs_datum_t samples[num_samples];
ccs_error_t err;
const size_t num_possible_values = 4;
ccs_datum_t possible_values[num_possible_values];
const size_t default_value_index = 2;
for(size_t i = 0; i < num_possible_values; i++) {
possible_values[i].type = CCS_INTEGER;
possible_values[i].value.i = (i+1)*2;
}
err = ccs_rng_create(&rng);
assert( err == CCS_SUCCESS );
err = ccs_create_uniform_int_distribution(0, num_possible_values+1,
CCS_LINEAR, 0, &distribution);
assert( err == CCS_SUCCESS );
err = ccs_create_ordinal_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
distribution, NULL,
&hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(distribution);
assert( err == CCS_SUCCESS );
err = ccs_hyperparameter_samples(hyperparameter, rng, num_samples, samples);
assert( err == CCS_SUCCESS );
for (size_t i = 0; i < num_samples; i++) {
assert( samples[i].type == CCS_INTEGER );
assert( samples[i].value.i % 2 == 0 );
assert( samples[i].value.i >= 2 );
assert( samples[i].value.i <= (ccs_int_t)num_possible_values * 2);
}
err = ccs_release_object(hyperparameter);
assert( err == CCS_SUCCESS );
err = ccs_release_object(rng);
assert( err == CCS_SUCCESS );
}
void test_compare() {
ccs_hyperparameter_t hyperparameter;
ccs_error_t err;
const size_t num_possible_values = 4;
ccs_datum_t possible_values[num_possible_values];
const size_t default_value_index = 2;
ccs_int_t comp = 0;
ccs_datum_t invalid;
invalid.value.i = -1;
invalid.type = CCS_INTEGER;
for(size_t i = 0; i < num_possible_values; i++) {
possible_values[i].value.i = (i+1)*2;
possible_values[i].type = CCS_INTEGER;
}
err = ccs_create_ordinal_hyperparameter("my_param", num_possible_values,
possible_values, default_value_index,
NULL, NULL,