Commit 0385cf93 authored by Brice Videau's avatar Brice Videau
Browse files

WIP on numerical hyper-parameters.

parent 0e2df867
......@@ -5,6 +5,7 @@
#include <stdint.h>
#include <limits.h>
#include "ccs/base.h"
#include "ccs/interval.h"
#include "ccs/rng.h"
#include "ccs/distribution.h"
#include "ccs/hyperparameter.h"
......
......@@ -36,6 +36,9 @@ enum ccs_error_e {
CCS_INVALID_VALUE,
CCS_INVALID_TYPE,
CCS_INVALID_SCALE,
CCS_TYPE_NOT_COMPARABLE,
CCS_INVALID_BOUNDS,
CCS_SAMPLING_UNSUCCESSFUL,
CCS_ENOMEM,
CCS_ERROR_MAX,
CCS_ERROR_FORCE_32BIT = INT_MAX
......@@ -70,6 +73,13 @@ enum ccs_data_type_e {
typedef enum ccs_data_type_e ccs_data_type_t;
enum ccs_numeric_type_e {
CCS_NUM_INTEGER = CCS_INTEGER,
CCS_NUM_FLOAT = CCS_FLOAT
};
typedef enum ccs_numeric_type_e ccs_numeric_type_t;
union ccs_object_u {
void *ptr;
ccs_rng_t rng;
......@@ -95,6 +105,17 @@ union ccs_object_u {
typedef union ccs_object_u ccs_object_t;
union ccs_numeric_u {
ccs_float_t f;
ccs_int_t i;
#ifdef __cplusplus
ccs_numeric_u() : i(0L) {}
ccs_numeric_u(ccs_float_t v) : f(v) {}
ccs_numeric_u(ccs_int_t v) : i(v) {}
#endif
};
typedef union ccs_numeric_u ccs_numeric_t;
union ccs_value_u {
ccs_float_t f;
......@@ -112,6 +133,15 @@ union ccs_value_u {
typedef union ccs_value_u ccs_value_t;
struct ccs_interval_s {
ccs_numeric_type_t type;
ccs_numeric_t lower;
ccs_numeric_t upper;
ccs_bool_t lower_included;
ccs_bool_t upper_included;
};
typedef struct ccs_interval_s ccs_interval_t;
struct ccs_datum_u {
ccs_value_t value;
......
......@@ -26,22 +26,22 @@ typedef enum ccs_scale_type_e ccs_scale_type_t;
// Distribution
extern ccs_error_t
ccs_create_distribution(ccs_distribution_type_t distribution_type,
ccs_data_type_t data_type,
ccs_numeric_type_t data_type,
ccs_scale_type_t scale,
ccs_value_t quantization,
ccs_numeric_t quantization,
size_t num_parameters,
ccs_value_t *parameters,
ccs_numeric_t *parameters,
ccs_distribution_t *distribution_ret);
extern ccs_error_t
_ccs_create_normal_distribution(ccs_data_type_t data_type,
_ccs_create_normal_distribution(ccs_numeric_type_t data_type,
ccs_float_t mu,
ccs_float_t sigma,
ccs_scale_type_t scale,
ccs_value_t quantization,
ccs_numeric_t quantization,
ccs_distribution_t *distribution_ret);
#define ccs_create_normal_distribution(t, m, s, sc, q, d) \
_ccs_create_normal_distribution(t, m, s, sc, (ccs_value_t)(q), d)
_ccs_create_normal_distribution(t, m, s, sc, (ccs_numeric_t)(q), d)
extern ccs_error_t
ccs_create_normal_int_distribution(ccs_float_t mu,
......@@ -58,15 +58,15 @@ ccs_create_normal_float_distribution(ccs_float_t mu,
ccs_distribution_t *distribution_ret);
extern ccs_error_t
_ccs_create_uniform_distribution(ccs_data_type_t data_type,
ccs_value_t lower,
ccs_value_t upper,
_ccs_create_uniform_distribution(ccs_numeric_type_t data_type,
ccs_numeric_t lower,
ccs_numeric_t upper,
ccs_scale_type_t scale_type,
ccs_value_t quantization,
ccs_numeric_t quantization,
ccs_distribution_t *distribution_ret);
#define ccs_create_uniform_distribution(t, l, u, s, q, d) \
_ccs_create_uniform_distribution(t, (ccs_value_t)(l), (ccs_value_t)(u), s, (ccs_value_t)(q), d)
_ccs_create_uniform_distribution(t, (ccs_numeric_t)(l), (ccs_numeric_t)(u), s, (ccs_numeric_t)(q), d)
extern ccs_error_t
ccs_create_uniform_int_distribution(ccs_int_t lower,
......@@ -89,7 +89,7 @@ ccs_distribution_get_type(ccs_distribution_t distribution,
extern ccs_error_t
ccs_distribution_get_data_type(ccs_distribution_t distribution,
ccs_data_type_t *data_type_ret);
ccs_numeric_type_t *data_type_ret);
extern ccs_error_t
ccs_distribution_get_scale_type(ccs_distribution_t distribution,
......@@ -97,46 +97,38 @@ ccs_distribution_get_scale_type(ccs_distribution_t distribution,
extern ccs_error_t
ccs_distribution_get_quantization(ccs_distribution_t distribution,
ccs_datum_t *quantization);
ccs_numeric_t *quantization);
extern ccs_error_t
ccs_distribution_get_num_parameters(ccs_distribution_t distribution,
size_t *num_parameters_ret);
extern ccs_error_t
ccs_distribution_get_parameters(ccs_distribution_t distribution,
size_t num_parameters,
ccs_datum_t *parameters,
size_t *num_parameters_ret);
ccs_distribution_get_bounds(ccs_distribution_t distribution,
ccs_interval_t *interval_ret);
extern ccs_error_t
ccs_distribution_get_bounds(ccs_distribution_t distribution,
ccs_datum_t *lower,
ccs_bool_t *lower_included,
ccs_datum_t *upper,
ccs_bool_t *upper_included);
ccs_distribution_check_oversampling(ccs_distribution_t distribution,
ccs_interval_t *interval,
ccs_bool_t *oversampling_ret);
extern ccs_error_t
ccs_normal_distribution_get_parameters(ccs_distribution_t distribution,
ccs_datum_t *mu_ret,
ccs_datum_t *sigma_ret);
ccs_float_t *mu_ret,
ccs_float_t *sigma_ret);
extern ccs_error_t
ccs_uniform_distribution_get_parameters(ccs_distribution_t distribution,
ccs_datum_t *lower,
ccs_datum_t *upper);
ccs_numeric_t *lower,
ccs_numeric_t *upper);
// Sampling Interface
extern ccs_error_t
ccs_distribution_sample(ccs_distribution_t distribution,
ccs_rng_t rng,
ccs_value_t *value);
ccs_numeric_t *value);
extern ccs_error_t
ccs_distribution_samples(ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
ccs_value_t *values);
ccs_numeric_t *values);
#ifdef __cplusplus
}
......
......@@ -18,11 +18,11 @@ typedef enum ccs_hyperparameter_type_e ccs_hyperparameter_type_t;
// Hyperparameter Interface
extern ccs_error_t
_ccs_create_numerical_hyperparameter(const char *name,
ccs_data_type_t data_type,
ccs_value_t lower,
ccs_value_t upper,
ccs_value_t quantization,
ccs_value_t default_value,
ccs_numeric_type_t data_type,
ccs_numeric_t lower,
ccs_numeric_t upper,
ccs_numeric_t quantization,
ccs_numeric_t default_value,
ccs_distribution_t distribution,
void *user_data,
ccs_hyperparameter_t *hyperparameter_ret);
......
#ifndef _CCS_INTERVAL_H
#define _CCS_INTERVAL_H
#ifdef __cplusplus
extern "C" {
#endif
extern ccs_error_t
ccs_interval_empty(ccs_interval_t *interval, ccs_bool_t *empty_ret);
extern ccs_error_t
ccs_interval_intersect(ccs_interval_t *interval1,
ccs_interval_t *interval2,
ccs_interval_t *interval_res);
extern ccs_error_t
ccs_interval_equal(ccs_interval_t *interval1,
ccs_interval_t *interval2,
ccs_bool_t *equal_res);
static inline ccs_bool_t
ccs_interval_include(ccs_interval_t *interval, ccs_numeric_t value) {
if (interval->type == CCS_NUM_FLOAT) {
return ( interval->lower_included ?
interval->lower.f <= value.f :
interval->lower.f < value.f ) &&
( interval->upper_included ?
interval->upper.f >= value.f :
interval->upper.f > value.f );
} else {
return ( interval->lower_included ?
interval->lower.i <= value.i :
interval->lower.i < value.i ) &&
( interval->upper_included ?
interval->upper.i >= value.i :
interval->upper.i > value.i );
}
}
#ifdef __cplusplus
}
#endif
#endif //_CCS_INTERVAL_H
......@@ -7,6 +7,7 @@ lib_LTLIBRARIES = libcconfigspace.la
libcconfigspace_la_SOURCES = \
cconfigspace.c \
cconfigspace_internal.h \
interval.c \
rng.c \
rng_internal.h \
distribution.c \
......
......@@ -20,7 +20,7 @@ ccs_distribution_get_type(ccs_distribution_t distribution,
ccs_error_t
ccs_distribution_get_data_type(ccs_distribution_t distribution,
ccs_data_type_t *data_type_ret) {
ccs_numeric_type_t *data_type_ret) {
if (!distribution || !distribution->data)
return -CCS_INVALID_OBJECT;
if (!data_type_ret)
......@@ -42,58 +42,57 @@ ccs_distribution_get_scale_type(ccs_distribution_t distribution,
ccs_error_t
ccs_distribution_get_quantization(ccs_distribution_t distribution,
ccs_datum_t *quantization_ret) {
ccs_numeric_t *quantization_ret) {
if (!distribution || !distribution->data)
return -CCS_INVALID_OBJECT;
if (!quantization_ret)
return -CCS_INVALID_VALUE;
quantization_ret->value = ((_ccs_distribution_common_data_t *)(distribution->data))->quantization;
quantization_ret->type = ((_ccs_distribution_common_data_t *)(distribution->data))->data_type;
*quantization_ret = ((_ccs_distribution_common_data_t *)(distribution->data))->quantization;
return CCS_SUCCESS;
}
ccs_error_t
ccs_distribution_get_num_parameters(ccs_distribution_t distribution,
size_t *num_parameters_ret) {
ccs_distribution_get_bounds(ccs_distribution_t distribution,
ccs_interval_t *interval_ret) {
if (!distribution || !distribution->data)
return -CCS_INVALID_OBJECT;
if (!num_parameters_ret)
if (!interval_ret)
return -CCS_INVALID_VALUE;
_ccs_distribution_ops_t *ops = ccs_distribution_get_ops(distribution);
return ops->get_num_parameters(distribution->data, num_parameters_ret);
return ops->get_bounds(distribution->data, interval_ret);
}
ccs_error_t
ccs_distribution_get_parameters(ccs_distribution_t distribution,
size_t num_parameters,
ccs_datum_t *parameters,
size_t *num_parameters_ret) {
if (!distribution || !distribution->data)
return -CCS_INVALID_OBJECT;
if (num_parameters > 0 && !parameters)
ccs_distribution_check_oversampling(ccs_distribution_t distribution,
ccs_interval_t *interval,
ccs_bool_t *oversampling_ret) {
ccs_error_t err;
ccs_interval_t d_interval;
if (!interval || !oversampling_ret)
return -CCS_INVALID_VALUE;
_ccs_distribution_ops_t *ops = ccs_distribution_get_ops(distribution);
return ops->get_parameters(distribution->data, num_parameters, parameters, num_parameters_ret);
}
ccs_error_t
ccs_distribution_get_bounds(ccs_distribution_t distribution,
ccs_datum_t *lower,
ccs_bool_t *lower_included,
ccs_datum_t *upper,
ccs_bool_t *upper_included) {
if (!distribution || !distribution->data)
return -CCS_INVALID_OBJECT;
if (!lower && !lower_included && !upper && !upper_included)
return -CCS_INVALID_VALUE;
_ccs_distribution_ops_t *ops = ccs_distribution_get_ops(distribution);
return ops->get_bounds(distribution->data, lower, lower_included, upper, upper_included);
err = ccs_distribution_get_bounds(distribution, &d_interval);
if (err)
return err;
ccs_interval_t intersection;
err = ccs_interval_intersect(&d_interval, interval, &intersection);
if (err)
return err;
ccs_bool_t eql;
err = ccs_interval_equal(&d_interval, &intersection, &eql);
if (err)
return err;
*oversampling_ret = (eql ? CCS_FALSE : CCS_TRUE);
return CCS_SUCCESS;
}
ccs_error_t
ccs_distribution_sample(ccs_distribution_t distribution,
ccs_rng_t rng,
ccs_value_t *value) {
ccs_numeric_t *value) {
if (!distribution || !distribution->data)
return -CCS_INVALID_OBJECT;
if (!value)
......@@ -106,7 +105,7 @@ ccs_error_t
ccs_distribution_samples(ccs_distribution_t distribution,
ccs_rng_t rng,
size_t num_values,
ccs_value_t *values) {
ccs_numeric_t *values) {
if (!distribution || !distribution->data)
return -CCS_INVALID_OBJECT;
if (!num_values || !values)
......
......@@ -7,28 +7,15 @@ typedef struct _ccs_distribution_data_s _ccs_distribution_data_t;
struct _ccs_distribution_ops_s {
_ccs_object_ops_t obj_ops;
ccs_error_t (*get_num_parameters)(
_ccs_distribution_data_t *distribution,
size_t *num_parameters_ret);
ccs_error_t (*get_parameters)(
_ccs_distribution_data_t *distribution,
size_t num_parameters,
ccs_datum_t *parameters,
size_t *num_parameters_ret);
ccs_error_t (*samples)(
_ccs_distribution_data_t *distribution,
ccs_rng_t rng,
size_t num_values,
ccs_value_t *values);
ccs_numeric_t *values);
ccs_error_t (*get_bounds)(
_ccs_distribution_data_t *distribution,
ccs_datum_t *lower,
ccs_bool_t *lower_included,
ccs_datum_t *upper,
ccs_bool_t *upper_included);
ccs_interval_t *interval_ret);
};
typedef struct _ccs_distribution_ops_s _ccs_distribution_ops_t;
......@@ -40,9 +27,9 @@ struct _ccs_distribution_s {
struct _ccs_distribution_common_data_s {
ccs_distribution_type_t type;
ccs_data_type_t data_type;
ccs_numeric_type_t data_type;
ccs_scale_type_t scale_type;
ccs_value_t quantization;
ccs_numeric_t quantization;
};
typedef struct _ccs_distribution_common_data_s _ccs_distribution_common_data_t;
#endif //_DISTRIBUTION_INTERNAL_H
......@@ -18,131 +18,80 @@ _ccs_distribution_del(ccs_object_t o) {
return CCS_SUCCESS;
}
static ccs_error_t
_ccs_distribution_normal_get_num_parameters(_ccs_distribution_data_t *data,
size_t *num_parameters_ret);
static ccs_error_t
_ccs_distribution_normal_get_parameters(_ccs_distribution_data_t *data,
size_t num_parameters,
ccs_datum_t *parameters,
size_t *num_parameters_ret);
static ccs_error_t
_ccs_distribution_normal_get_bounds(_ccs_distribution_data_t *data,
ccs_datum_t *lower,
ccs_bool_t *lower_included,
ccs_datum_t *upper,
ccs_bool_t *upper_included);
ccs_interval_t *interval_ret);
static ccs_error_t
_ccs_distribution_normal_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
size_t num_values,
ccs_value_t *values);
ccs_numeric_t *values);
_ccs_distribution_ops_t _ccs_distribution_normal_ops = {
{ &_ccs_distribution_del },
&_ccs_distribution_normal_get_num_parameters,
&_ccs_distribution_normal_get_parameters,
&_ccs_distribution_normal_samples,
&_ccs_distribution_normal_get_bounds
};
static ccs_error_t
_ccs_distribution_normal_get_num_parameters(_ccs_distribution_data_t *data,
size_t *num_parameters_ret) {
*num_parameters_ret = 2;
return CCS_SUCCESS;
}
static ccs_error_t
_ccs_distribution_normal_get_parameters(_ccs_distribution_data_t *data,
size_t num_parameters,
ccs_datum_t *parameters,
size_t *num_parameters_ret) {
if (num_parameters > 0 && num_parameters < 2)
return -CCS_INVALID_VALUE;
_ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data;
if (num_parameters_ret)
*num_parameters_ret = 2;
if (parameters) {
parameters[0].type = CCS_FLOAT;
parameters[0].value.f = d->mu;
parameters[1].type = CCS_FLOAT;
parameters[1].value.f = d->sigma;
}
return CCS_SUCCESS;
}
static ccs_error_t
_ccs_distribution_normal_get_bounds(_ccs_distribution_data_t *data,
ccs_datum_t *lower,
ccs_bool_t *lower_included,
ccs_datum_t *upper,
ccs_bool_t *upper_included) {
ccs_interval_t *interval_ret) {
_ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data;
const ccs_data_type_t data_type = d->common_data.data_type;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_value_t quantization = d->common_data.quantization;
const ccs_numeric_t quantization = d->common_data.quantization;
const int quantize = d->quantize;
ccs_datum_t l;
ccs_numeric_t l;
ccs_bool_t li;
ccs_datum_t u;
ccs_numeric_t u;
ccs_bool_t ui;
l.type = data_type;
u.type = data_type;
if (scale_type == CCS_LOGARITHMIC) {
if (data_type == CCS_FLOAT) {
if (data_type == CCS_NUM_FLOAT) {
if (quantize) {
l.value.f = quantization.f;
l.f = quantization.f;
li = CCS_TRUE;
} else {
l.value.f = 0.0;
l.f = 0.0;
li = CCS_FALSE;
}
u.value.f = CCS_INFINITY;
u.f = CCS_INFINITY;
ui = CCS_FALSE;
} else {
if (quantize) {
l.value.i = quantization.i;
u.value.i = (CCS_INT_MAX/quantization.i)*quantization.i;
l.i = quantization.i;
u.i = (CCS_INT_MAX/quantization.i)*quantization.i;
} else {
l.value.i = 1;
u.value.i = CCS_INT_MAX;
l.i = 1;
u.i = CCS_INT_MAX;
}
li = CCS_TRUE;
ui = CCS_TRUE;
}
} else {
if (data_type == CCS_FLOAT) {
l.value.f = -CCS_INFINITY;
if (data_type == CCS_NUM_FLOAT) {
l.f = -CCS_INFINITY;
li = CCS_FALSE;
u.value.f = CCS_INFINITY;
u.f = CCS_INFINITY;
ui = CCS_FALSE;
} else {
if (quantize) {
l.value.i = (CCS_INT_MIN/quantization.i)*quantization.i;
u.value.i = (CCS_INT_MAX/quantization.i)*quantization.i;
l.i = (CCS_INT_MIN/quantization.i)*quantization.i;
u.i = (CCS_INT_MAX/quantization.i)*quantization.i;
} else {
l.value.i = CCS_INT_MIN;
u.value.i = CCS_INT_MAX;
l.i = CCS_INT_MIN;
u.i = CCS_INT_MAX;
}
li = CCS_TRUE;
ui = CCS_TRUE;
}
}
if (lower)
*lower = l;
if (lower_included)
*lower_included = li;
if (upper)
*upper = u;
if (upper_included)
*upper_included = ui;
interval_ret->type = data_type;
interval_ret->lower = l;
interval_ret->upper = u;
interval_ret->lower_included = li;
interval_ret->upper_included = ui;
return CCS_SUCCESS;
}
......@@ -190,7 +139,7 @@ _ccs_distribution_normal_samples_int(gsl_rng *grng,
const ccs_float_t sigma,
const int quantize,
size_t num_values,
ccs_value_t *values) {
ccs_numeric_t *values) {
size_t i;
ccs_float_t q;
if (quantize)
......@@ -233,11 +182,11 @@ static ccs_error_t
_ccs_distribution_normal_samples(_ccs_distribution_data_t *data,
ccs_rng_t rng,
size_t num_values,
ccs_value_t *values) {
ccs_numeric_t *values) {
_ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data;
const ccs_data_type_t data_type = d->common_data.data_type;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_value_t quantization = d->common_data.quantization;
const ccs_numeric_t quantization = d->common_data.quantization;
const ccs_float_t mu = d->mu;
const ccs_float_t sigma = d->sigma;
const int quantize = d->quantize;
......@@ -245,7 +194,7 @@ _ccs_distribution_normal_samples(_ccs_distribution_data_t *data,
ccs_error_t err = ccs_rng_get_gsl_rng(rng, &grng);
if (err)
return err;
if (data_type == CCS_FLOAT)
if (data_type == CCS_NUM_FLOAT)
return _ccs_distribution_normal_samples_float(grng, scale_type,
quantization.f, mu,
sigma, quantize,
......@@ -259,21 +208,21 @@ _ccs_distribution_normal_samples(_ccs_distribution_data_t *data,
}
extern ccs_error_t
_ccs_create_normal_distribution(ccs_data_type_t data_type,
_ccs_create_normal_distribution(ccs_numeric_type_t data_type,
ccs_float_t mu,
ccs_float_t sigma,
ccs_scale_type_t scale_type,
ccs_value_t quantization,
ccs_numeric_t quantization,
ccs_distribution_t *distribution_ret) {
if (!distribution_ret)
return -CCS_INVALID_VALUE;
if (data_type != CCS_FLOAT && data_type != CCS_INTEGER)
if (data_type != CCS_NUM_FLOAT && data_type != CCS_NUM_INTEGER)
return -CCS_INVALID_TYPE;
if (scale_type != CCS_LINEAR && scale_type != CCS_LOGARITHMIC)
return -CCS_INVALID_SCALE;
if (data_type == CCS_INTEGER && quantization.i < 0 )