Commit 897caeb5 authored by Brice Videau's avatar Brice Videau
Browse files

Added a get default distribution to hyperparameters.

parent 926a846a
......@@ -68,6 +68,11 @@ ccs_hyperparameter_get_name(ccs_hyperparameter_t hyperparameter,
extern ccs_error_t
ccs_hyperparameter_get_user_data(ccs_hyperparameter_t hyperparameter,
void **user_data_ret);
extern ccs_error_t
ccs_hyperparameter_get_default_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t *distribution);
extern ccs_error_t
ccs_hyperparameter_get_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t *distribution);
......
......@@ -50,6 +50,18 @@ ccs_hyperparameter_get_user_data(ccs_hyperparameter_t hyperparameter,
return CCS_SUCCESS;
}
ccs_error_t
ccs_hyperparameter_get_default_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t *distribution) {
if (!hyperparameter || !hyperparameter->data)
return -CCS_INVALID_OBJECT;
if (!distribution)
return -CCS_INVALID_VALUE;
_ccs_hyperparameter_ops_t *ops = ccs_hyperparameter_get_ops(hyperparameter);
return ops->get_default_distribution( hyperparameter->data, distribution);
}
ccs_error_t
ccs_hyperparameter_get_distribution(ccs_hyperparameter_t hyperparameter,
ccs_distribution_t *distribution) {
......
......@@ -58,9 +58,22 @@ _ccs_hyperparameter_categorical_samples(_ccs_hyperparameter_data_t *data,
return CCS_SUCCESS;
}
ccs_error_t
_ccs_hyperparameter_categorical_get_default_distribution(
_ccs_hyperparameter_data_t *data,
ccs_distribution_t *distribution) {
_ccs_hyperparameter_categorical_data_t *d = (_ccs_hyperparameter_categorical_data_t *)data;
ccs_interval_t *interval = &(d->common_data.interval);
return ccs_create_uniform_distribution(interval->type,
interval->lower, interval->upper,
CCS_LINEAR, CCSI(0),
distribution);
}
static _ccs_hyperparameter_ops_t _ccs_hyperparameter_categorical_ops = {
{ &_ccs_hyperparameter_categorical_del },
&_ccs_hyperparameter_categorical_samples
&_ccs_hyperparameter_categorical_samples,
&_ccs_hyperparameter_categorical_get_default_distribution
};
ccs_error_t
......
......@@ -12,6 +12,10 @@ struct _ccs_hyperparameter_ops_s {
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values);
ccs_error_t (*get_default_distribution)(
_ccs_hyperparameter_data_t *data,
ccs_distribution_t *distribution);
};
typedef struct _ccs_hyperparameter_ops_s _ccs_hyperparameter_ops_t;
......
......@@ -17,9 +17,9 @@ _ccs_hyperparameter_numerical_del(ccs_object_t o) {
static ccs_error_t
_ccs_hyperparameter_numerical_samples(_ccs_hyperparameter_data_t *data,
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
ccs_rng_t rng,
size_t num_values,
ccs_datum_t *values) {
_ccs_hyperparameter_numerical_data_t *d = (_ccs_hyperparameter_numerical_data_t *)data;
ccs_numeric_type_t type = d->common_data.interval.type;
ccs_interval_t *interval = &(d->common_data.interval);
......@@ -77,9 +77,22 @@ _ccs_hyperparameter_numerical_samples(_ccs_hyperparameter_data_t *data,
return CCS_SUCCESS;
}
ccs_error_t
_ccs_hyperparameter_numerical_get_default_distribution(
_ccs_hyperparameter_data_t *data,
ccs_distribution_t *distribution) {
_ccs_hyperparameter_numerical_data_t *d = (_ccs_hyperparameter_numerical_data_t *)data;
ccs_interval_t *interval = &(d->common_data.interval);
return ccs_create_uniform_distribution(interval->type,
interval->lower, interval->upper,
CCS_LINEAR, d->quantization,
distribution);
}
static _ccs_hyperparameter_ops_t _ccs_hyperparameter_numerical_ops = {
{ &_ccs_hyperparameter_numerical_del },
&_ccs_hyperparameter_numerical_samples
&_ccs_hyperparameter_numerical_samples,
&_ccs_hyperparameter_numerical_get_default_distribution
};
ccs_error_t
......
......@@ -144,9 +144,22 @@ _ccs_hyperparameter_ordinal_samples(_ccs_hyperparameter_data_t *data,
return CCS_SUCCESS;
}
ccs_error_t
_ccs_hyperparameter_ordinal_get_default_distribution(
_ccs_hyperparameter_data_t *data,
ccs_distribution_t *distribution) {
_ccs_hyperparameter_ordinal_data_t *d = (_ccs_hyperparameter_ordinal_data_t *)data;
ccs_interval_t *interval = &(d->common_data.interval);
return ccs_create_uniform_distribution(interval->type,
interval->lower, interval->upper,
CCS_LINEAR, CCSI(0),
distribution);
}
static _ccs_hyperparameter_ops_t _ccs_hyperparameter_ordinal_ops = {
{ &_ccs_hyperparameter_ordinal_del },
&_ccs_hyperparameter_ordinal_samples
&_ccs_hyperparameter_ordinal_samples,
&_ccs_hyperparameter_ordinal_get_default_distribution
};
ccs_error_t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment