Commit 3ba7aa30 authored by Brice Videau's avatar Brice Videau

Started refactoring distributions.

parent 4b101bc7
......@@ -17,8 +17,6 @@ class ccs_scale_type(CEnumeration):
ccs_distribution_get_type = _ccs_get_function("ccs_distribution_get_type", [ccs_distribution, ct.POINTER(ccs_distribution_type)])
ccs_distribution_get_data_type = _ccs_get_function("ccs_distribution_get_data_type", [ccs_distribution, ct.POINTER(ccs_numeric_type)])
ccs_distribution_get_dimension = _ccs_get_function("ccs_distribution_get_dimension", [ccs_distribution, ct.POINTER(ct.c_size_t)])
ccs_distribution_get_scale_type = _ccs_get_function("ccs_distribution_get_scale_type", [ccs_distribution, ct.POINTER(ccs_scale_type)])
ccs_distribution_get_quantization = _ccs_get_function("ccs_distribution_get_quantization", [ccs_distribution, ct.POINTER(ccs_numeric)])
ccs_distribution_get_bounds = _ccs_get_function("ccs_distribution_get_bounds", [ccs_distribution, ct.POINTER(ccs_interval)])
ccs_distribution_check_oversampling = _ccs_get_function("ccs_distribution_check_oversampling", [ccs_distribution, ct.POINTER(ccs_interval), ct.POINTER(ccs_bool)])
ccs_distribution_sample = _ccs_get_function("ccs_distribution_sample", [ccs_distribution, ccs_rng, ct.POINTER(ccs_numeric)])
......@@ -71,34 +69,6 @@ class Distribution(Object):
self._dimension = v.value
return v.value
@property
def scale_type(self):
if hasattr(self, "_scale_type"):
return self._scale_type
v = ccs_scale_type(0)
res = ccs_distribution_get_scale_type(self.handle, ct.byref(v))
Error.check(res)
self._scale_type = v.value
return self._scale_type
scale = scale_type
@property
def quantization(self):
if hasattr(self, "_quantization"):
return self._quantization
v = ccs_numeric(0)
res = ccs_distribution_get_quantization(self.handle, ct.byref(v))
Error.check(res)
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._quantization = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
self._quantization = v.f
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
return self._quantization
@property
def bounds(self):
if hasattr(self, "_bounds"):
......@@ -144,7 +114,7 @@ class Distribution(Object):
ccs_create_uniform_distribution = _ccs_get_function("ccs_create_uniform_distribution", [ccs_numeric_type, ccs_int, ccs_int, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_uniform_int_distribution = _ccs_get_function("ccs_create_uniform_int_distribution", [ccs_int, ccs_int, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_uniform_float_distribution = _ccs_get_function("ccs_create_uniform_float_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_float, ct.POINTER(ccs_distribution)])
ccs_uniform_distribution_get_parameters = _ccs_get_function("ccs_uniform_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_numeric), ct.POINTER(ccs_numeric)])
ccs_uniform_distribution_get_parameters = _ccs_get_function("ccs_uniform_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_numeric), ct.POINTER(ccs_numeric), ct.POINTER(ccs_scale_type), ct.POINTER(ccs_numeric)])
class UniformDistribution(Distribution):
def __init__(self, handle = None, retain = False, data_type = NUM_FLOAT, lower = 0.0, upper = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0):
......@@ -174,7 +144,7 @@ class UniformDistribution(Distribution):
if hasattr(self, "_lower"):
return self._lower
v = ccs_numeric()
res = ccs_uniform_distribution_get_parameters(self.handle, ct.byref(v), None)
res = ccs_uniform_distribution_get_parameters(self.handle, ct.byref(v), None, None, None)
Error.check(res)
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
......@@ -190,7 +160,7 @@ class UniformDistribution(Distribution):
if hasattr(self, "_upper"):
return self._upper
v = ccs_numeric()
res = ccs_uniform_distribution_get_parameters(self.handle, None, ct.byref(v))
res = ccs_uniform_distribution_get_parameters(self.handle, None, ct.byref(v), None, None)
Error.check(res)
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
......@@ -201,10 +171,38 @@ class UniformDistribution(Distribution):
raise Error(ccs_error(ccs_error.INVALID_VALUE))
return self._upper
@property
def scale_type(self):
if hasattr(self, "_scale_type"):
return self._scale_type
v = ccs_scale_type(0)
res = ccs_uniform_distribution_get_parameters(self.handle, None, None, ct.byref(v), None)
Error.check(res)
self._scale_type = v.value
return self._scale_type
scale = scale_type
@property
def quantization(self):
if hasattr(self, "_quantization"):
return self._quantization
v = ccs_numeric(0)
res = ccs_uniform_distribution_get_parameters(self.handle, None, None, None, ct.byref(v))
Error.check(res)
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._quantization = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
self._quantization = v.f
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
return self._quantization
ccs_create_normal_distribution = _ccs_get_function("ccs_create_normal_distribution", [ccs_numeric_type, ccs_float, ccs_float, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_normal_int_distribution = _ccs_get_function("ccs_create_normal_int_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_normal_float_distribution = _ccs_get_function("ccs_create_normal_float_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_float, ct.POINTER(ccs_distribution)])
ccs_normal_distribution_get_parameters = _ccs_get_function("ccs_normal_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_float), ct.POINTER(ccs_float)])
ccs_normal_distribution_get_parameters = _ccs_get_function("ccs_normal_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_float), ct.POINTER(ccs_float), ct.POINTER(ccs_scale_type), ct.POINTER(ccs_numeric)])
class NormalDistribution(Distribution):
def __init__(self, handle = None, retain = False, data_type = NUM_FLOAT, mu = 0.0, sigma = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0):
......@@ -234,7 +232,7 @@ class NormalDistribution(Distribution):
if hasattr(self, "_mu"):
return self._mu
v = ccs_float()
res = ccs_normal_distribution_get_parameters(self.handle, ct.byref(v), None)
res = ccs_normal_distribution_get_parameters(self.handle, ct.byref(v), None, None, None)
Error.check(res)
self._mu = v.value
return self._mu
......@@ -244,11 +242,39 @@ class NormalDistribution(Distribution):
if hasattr(self, "_sigma"):
return self._sigma
v = ccs_float()
res = ccs_normal_distribution_get_parameters(self.handle, None, ct.byref(v))
res = ccs_normal_distribution_get_parameters(self.handle, None, ct.byref(v), None, None)
Error.check(res)
self._sigma = v.value
return self._sigma
@property
def scale_type(self):
if hasattr(self, "_scale_type"):
return self._scale_type
v = ccs_scale_type(0)
res = ccs_normal_distribution_get_parameters(self.handle, None, None, ct.byref(v), None)
Error.check(res)
self._scale_type = v.value
return self._scale_type
scale = scale_type
@property
def quantization(self):
if hasattr(self, "_quantization"):
return self._quantization
v = ccs_numeric(0)
res = ccs_normal_distribution_get_parameters(self.handle, None, None, None, ct.byref(v))
Error.check(res)
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._quantization = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
self._quantization = v.f
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
return self._quantization
ccs_create_roulette_distribution = _ccs_get_function("ccs_create_roulette_distribution", [ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ccs_distribution)])
ccs_roulette_distribution_get_num_areas = _ccs_get_function("ccs_roulette_distribution_get_num_areas", [ccs_distribution, ct.POINTER(ct.c_size_t)])
ccs_roulette_distribution_get_areas = _ccs_get_function("ccs_roulette_distribution_get_areas", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ct.c_size_t)])
......
......@@ -19,7 +19,6 @@ class TestDistribution(unittest.TestCase):
self.assertEqual( ccs.DISTRIBUTION, d.object_type )
self.assertEqual( ccs.ROULETTE, d.type )
self.assertEqual( ccs.NUM_INTEGER, d.data_type )
self.assertEqual( ccs.LINEAR, d.scale )
self.assertEqual( 1, d.dimension )
a = d.areas
self.assertTrue( sum(a) > 0.999 )
......
......@@ -24,8 +24,6 @@ module CCS
attach_function :ccs_distribution_get_type, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_data_type, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_dimension, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_scale_type, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_quantization, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_bounds, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_check_oversampling, [:ccs_distribution_t, Interval.by_ref, :pointer], :ccs_result_t
attach_function :ccs_distribution_sample, [:ccs_distribution_t, :ccs_rng_t, :pointer], :ccs_result_t
......@@ -35,7 +33,6 @@ module CCS
add_property :type, :ccs_distribution_type_t, :ccs_distribution_get_type, memoize: true
add_property :data_type, :ccs_numeric_type_t, :ccs_distribution_get_data_type, memoize: true
add_property :dimension, :size_t, :ccs_distribution_get_dimension, memoize: true
add_property :scale, :ccs_scale_type_t, :ccs_distribution_get_scale_type, memoize: true
def self.from_handle(handle)
ptr = MemoryPointer::new(:ccs_distribution_type_t)
......@@ -53,19 +50,6 @@ module CCS
end
end
def quantization
@quantization ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_distribution_get_quantization(@handle, ptr)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
else
ptr.read_ccs_int_t
end
end
end
def bounds
@bounds ||= begin
interval = Interval::new(type: :CCS_NUM_FLOAT)
......@@ -110,7 +94,7 @@ module CCS
attach_function :ccs_create_uniform_distribution, [:ccs_numeric_type_t, :ccs_numeric_t, :ccs_numeric_t, :ccs_scale_type_t, :ccs_numeric_t, :pointer], :ccs_result_t
attach_function :ccs_create_uniform_int_distribution, [:ccs_int_t, :ccs_int_t, :ccs_scale_type_t, :ccs_int_t, :pointer], :ccs_result_t
attach_function :ccs_create_uniform_float_distribution, [:ccs_float_t, :ccs_float_t, :ccs_scale_type_t, :ccs_float_t, :pointer], :ccs_result_t
attach_function :ccs_uniform_distribution_get_parameters, [:ccs_distribution_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_uniform_distribution_get_parameters, [:ccs_distribution_t, :pointer, :pointer, :pointer, :pointer], :ccs_result_t
class UniformDistribution < Distribution
def initialize(handle = nil, retain: false, data_type: :CCS_NUM_FLOAT, lower: 0.0, upper: 1.0, scale: :CCS_LINEAR, quantization: 0.0)
......@@ -139,7 +123,7 @@ module CCS
def lower
@lower ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_uniform_distribution_get_parameters(@handle, ptr, nil)
res = CCS.ccs_uniform_distribution_get_parameters(@handle, ptr, nil, nil, nil)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
......@@ -152,7 +136,29 @@ module CCS
def upper
@upper ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_uniform_distribution_get_parameters(@handle, nil, ptr)
res = CCS.ccs_uniform_distribution_get_parameters(@handle, nil, ptr, nil, nil)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
else
ptr.read_ccs_int_t
end
end
end
def scale
@scale ||= begin
ptr = MemoryPointer::new(:ccs_scale_type_t)
res = CCS.ccs_uniform_distribution_get_parameters(@handle, nil, nil, ptr, nil)
CCS.error_check(res)
ptr.read_ccs_scale_type_t
end
end
def quantization
@quantization ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_uniform_distribution_get_parameters(@handle, nil, nil, nil, ptr)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
......@@ -166,7 +172,7 @@ module CCS
attach_function :ccs_create_normal_distribution, [:ccs_numeric_type_t, :ccs_float_t, :ccs_float_t, :ccs_scale_type_t, :ccs_numeric_t, :pointer], :ccs_result_t
attach_function :ccs_create_normal_int_distribution, [:ccs_float_t, :ccs_float_t, :ccs_scale_type_t, :ccs_int_t, :pointer], :ccs_result_t
attach_function :ccs_create_normal_float_distribution, [:ccs_float_t, :ccs_float_t, :ccs_scale_type_t, :ccs_float_t, :pointer], :ccs_result_t
attach_function :ccs_normal_distribution_get_parameters, [:ccs_distribution_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_normal_distribution_get_parameters, [:ccs_distribution_t, :pointer, :pointer, :pointer, :pointer], :ccs_result_t
class NormalDistribution < Distribution
def initialize(handle = nil, retain: false, data_type: :CCS_NUM_FLOAT, mu: 0.0, sigma: 1.0, scale: :CCS_LINEAR, quantization: 0.0)
if handle
......@@ -194,7 +200,7 @@ module CCS
def mu
@mu ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_normal_distribution_get_parameters(@handle, ptr, nil)
res = CCS.ccs_normal_distribution_get_parameters(@handle, ptr, nil, nil, nil)
CCS.error_check(res)
ptr.read_ccs_float_t
end
......@@ -203,11 +209,33 @@ module CCS
def sigma
@sigma ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_normal_distribution_get_parameters(@handle, nil, ptr)
res = CCS.ccs_normal_distribution_get_parameters(@handle, nil, ptr, nil, nil)
CCS.error_check(res)
ptr.read_ccs_float_t
end
end
def scale
@scale ||= begin
ptr = MemoryPointer::new(:ccs_scale_type_t)
res = CCS.ccs_normal_distribution_get_parameters(@handle, nil, nil, ptr, nil)
CCS.error_check(res)
ptr.read_ccs_scale_type_t
end
end
def quantization
@quantization ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_normal_distribution_get_parameters(@handle, nil, nil, nil, ptr)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
else
ptr.read_ccs_int_t
end
end
end
end
attach_function :ccs_create_roulette_distribution, [:size_t, :pointer, :pointer], :ccs_result_t
......
......@@ -21,7 +21,6 @@ class CConfigSpaceTestDistribution < Minitest::Test
assert_equal( :CCS_DISTRIBUTION, d.object_type )
assert_equal( :CCS_ROULETTE, d.type )
assert_equal( :CCS_NUM_INTEGER, d.data_type )
assert_equal( :CCS_LINEAR, d.scale )
assert_equal( 1, d.dimension )
assert_equal( 4, d.num_areas )
assert( d.areas.reduce(:+) > 0.999 )
......
......@@ -74,26 +74,29 @@ ccs_create_roulette_distribution(size_t num_areas,
ccs_float_t *areas,
ccs_distribution_t *distribution_ret);
extern ccs_result_t
ccs_create_mixture_distribution(size_t num_distributions,
ccs_distribution_t *distributions,
ccs_float_t *weights,
ccs_distribution_t *distribution_ret);
extern ccs_result_t
ccs_create_multivariate_distribution(size_t num_distributions,
ccs_distribution_t *distributions,
ccs_distribution_t *distribution_ret);
// Accessors
extern ccs_result_t
ccs_distribution_get_type(ccs_distribution_t distribution,
ccs_distribution_type_t *type_ret);
extern ccs_result_t
ccs_distribution_get_data_type(ccs_distribution_t distribution,
ccs_numeric_type_t *data_type_ret);
extern ccs_result_t
ccs_distribution_get_dimension(ccs_distribution_t distribution,
size_t *dimension);
extern ccs_result_t
ccs_distribution_get_scale_type(ccs_distribution_t distribution,
ccs_scale_type_t *scale_type_ret);
extern ccs_result_t
ccs_distribution_get_quantization(ccs_distribution_t distribution,
ccs_numeric_t *quantization);
ccs_distribution_get_data_type(ccs_distribution_t distribution,
ccs_numeric_type_t *data_type_ret);
extern ccs_result_t
ccs_distribution_get_bounds(ccs_distribution_t distribution,
......@@ -107,12 +110,16 @@ ccs_distribution_check_oversampling(ccs_distribution_t distribution,
extern ccs_result_t
ccs_normal_distribution_get_parameters(ccs_distribution_t distribution,
ccs_float_t *mu_ret,
ccs_float_t *sigma_ret);
ccs_float_t *sigma_ret,
ccs_scale_type_t *scale_ret,
ccs_numeric_t *quantization_ret);
extern ccs_result_t
ccs_uniform_distribution_get_parameters(ccs_distribution_t distribution,
ccs_numeric_t *lower_ret,
ccs_numeric_t *upper_ret);
ccs_numeric_t *upper_ret,
ccs_scale_type_t *scale_ret,
ccs_numeric_t *quantization_ret);
extern ccs_result_t
ccs_roulette_distribution_get_num_areas(ccs_distribution_t distribution,
......
......@@ -30,26 +30,7 @@ ccs_distribution_get_dimension(ccs_distribution_t distribution,
size_t *dimension_ret) {
CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION);
CCS_CHECK_PTR(dimension_ret);
*dimension_ret = 1;
return CCS_SUCCESS;
}
ccs_result_t
ccs_distribution_get_scale_type(ccs_distribution_t distribution,
ccs_scale_type_t *scale_type_ret) {
CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION);
CCS_CHECK_PTR(scale_type_ret);
*scale_type_ret = ((_ccs_distribution_common_data_t *)(distribution->data))->scale_type;
return CCS_SUCCESS;
}
ccs_result_t
ccs_distribution_get_quantization(ccs_distribution_t distribution,
ccs_numeric_t *quantization_ret) {
CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION);
CCS_CHECK_PTR(quantization_ret);
*quantization_ret = ((_ccs_distribution_common_data_t *)(distribution->data))->quantization;
*dimension_ret = ((_ccs_distribution_common_data_t *)(distribution->data))->dimension;
return CCS_SUCCESS;
}
......
......@@ -33,10 +33,9 @@ struct _ccs_distribution_s {
};
struct _ccs_distribution_common_data_s {
ccs_distribution_type_t type;
ccs_distribution_type_t type;
size_t dimension;
ccs_numeric_type_t data_type;
ccs_scale_type_t scale_type;
ccs_numeric_t quantization;
};
typedef struct _ccs_distribution_common_data_s _ccs_distribution_common_data_t;
#endif //_DISTRIBUTION_INTERNAL_H
......@@ -8,6 +8,8 @@ struct _ccs_distribution_normal_data_s {
_ccs_distribution_common_data_t common_data;
ccs_float_t mu;
ccs_float_t sigma;
ccs_scale_type_t scale_type;
ccs_numeric_t quantization;
int quantize;
};
typedef struct _ccs_distribution_normal_data_s _ccs_distribution_normal_data_t;
......@@ -47,8 +49,8 @@ _ccs_distribution_normal_get_bounds(_ccs_distribution_data_t *data,
ccs_interval_t *interval_ret) {
_ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_numeric_t quantization = d->common_data.quantization;
const ccs_scale_type_t scale_type = d->scale_type;
const ccs_numeric_t quantization = d->quantization;
const int quantize = d->quantize;
ccs_numeric_t l;
ccs_bool_t li;
......@@ -193,8 +195,8 @@ _ccs_distribution_normal_samples(_ccs_distribution_data_t *data,
ccs_numeric_t *values) {
_ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_numeric_t quantization = d->common_data.quantization;
const ccs_scale_type_t scale_type = d->scale_type;
const ccs_numeric_t quantization = d->quantization;
const ccs_float_t mu = d->mu;
const ccs_float_t sigma = d->sigma;
const int quantize = d->quantize;
......@@ -308,8 +310,8 @@ _ccs_distribution_normal_strided_samples(_ccs_distribution_data_t *data,
ccs_numeric_t *values) {
_ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_numeric_t quantization = d->common_data.quantization;
const ccs_scale_type_t scale_type = d->scale_type;
const ccs_numeric_t quantization = d->quantization;
const ccs_float_t mu = d->mu;
const ccs_float_t sigma = d->sigma;
const int quantize = d->quantize;
......@@ -354,9 +356,10 @@ ccs_create_normal_distribution(ccs_numeric_type_t data_type,
_ccs_object_init(&(distrib->obj), CCS_DISTRIBUTION, (_ccs_object_ops_t *)&_ccs_distribution_normal_ops);
_ccs_distribution_normal_data_t * distrib_data = (_ccs_distribution_normal_data_t *)(mem + sizeof(struct _ccs_distribution_s));
distrib_data->common_data.type = CCS_NORMAL;
distrib_data->common_data.dimension = 1;
distrib_data->common_data.data_type = data_type;
distrib_data->common_data.scale_type = scale_type;
distrib_data->common_data.quantization = quantization;
distrib_data->scale_type = scale_type;
distrib_data->quantization = quantization;
distrib_data->mu = mu;
distrib_data->sigma = sigma;
if (data_type == CCS_NUM_FLOAT) {
......@@ -373,21 +376,25 @@ ccs_create_normal_distribution(ccs_numeric_type_t data_type,
extern ccs_result_t
ccs_normal_distribution_get_parameters(ccs_distribution_t distribution,
ccs_float_t *mu,
ccs_float_t *sigma) {
ccs_float_t *mu_ret,
ccs_float_t *sigma_ret,
ccs_scale_type_t *scale_type_ret,
ccs_numeric_t *quantization_ret) {
CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION);
if (((_ccs_distribution_common_data_t*)distribution->data)->type != CCS_NORMAL)
return -CCS_INVALID_OBJECT;
if (!mu && !sigma)
if (!mu_ret && !sigma_ret && !scale_type_ret && !quantization_ret)
return -CCS_INVALID_VALUE;
_ccs_distribution_normal_data_t * data = (_ccs_distribution_normal_data_t *)distribution->data;
if (mu) {
*mu = data->mu;
}
if (sigma) {
*sigma = data->sigma;
}
if (mu_ret)
*mu_ret = data->mu;
if (sigma_ret)
*sigma_ret = data->sigma;
if (scale_type_ret)
*scale_type_ret = data->scale_type;
if (quantization_ret)
*quantization_ret = data->quantization;
return CCS_SUCCESS;
}
......@@ -151,9 +151,8 @@ ccs_create_roulette_distribution(size_t num_areas,
_ccs_object_init(&(distrib->obj), CCS_DISTRIBUTION, (_ccs_object_ops_t *)&_ccs_distribution_roulette_ops);
_ccs_distribution_roulette_data_t * distrib_data = (_ccs_distribution_roulette_data_t *)(mem + sizeof(struct _ccs_distribution_s));
distrib_data->common_data.type = CCS_ROULETTE;
distrib_data->common_data.dimension = 1;
distrib_data->common_data.data_type = CCS_NUM_INTEGER;
distrib_data->common_data.scale_type = CCS_LINEAR;
distrib_data->common_data.quantization = CCSI(0);
distrib_data->num_areas = num_areas;
distrib_data->areas = (ccs_float_t *)(mem + sizeof(struct _ccs_distribution_s) + sizeof(_ccs_distribution_roulette_data_t));
......
......@@ -8,6 +8,8 @@ struct _ccs_distribution_uniform_data_s {
_ccs_distribution_common_data_t common_data;
ccs_numeric_t lower;
ccs_numeric_t upper;
ccs_scale_type_t scale_type;
ccs_numeric_t quantization;
ccs_numeric_type_t internal_type;
ccs_numeric_t internal_lower;
ccs_numeric_t internal_upper;
......@@ -76,8 +78,8 @@ _ccs_distribution_uniform_strided_samples(_ccs_distribution_data_t *data,
_ccs_distribution_uniform_data_t *d = (_ccs_distribution_uniform_data_t *)data;
size_t i;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_numeric_t quantization = d->common_data.quantization;
const ccs_scale_type_t scale_type = d->scale_type;
const ccs_numeric_t quantization = d->quantization;
const ccs_numeric_t lower = d->lower;
const ccs_numeric_t internal_lower = d->internal_lower;
const ccs_numeric_t internal_upper = d->internal_upper;
......@@ -135,8 +137,8 @@ _ccs_distribution_uniform_samples(_ccs_distribution_data_t *data,
_ccs_distribution_uniform_data_t *d = (_ccs_distribution_uniform_data_t *)data;
size_t i;
const ccs_numeric_type_t data_type = d->common_data.data_type;
const ccs_scale_type_t scale_type = d->common_data.scale_type;
const ccs_numeric_t quantization = d->common_data.quantization;
const ccs_scale_type_t scale_type = d->scale_type;
const ccs_numeric_t quantization = d->quantization;
const ccs_numeric_t lower = d->lower;
const ccs_numeric_t internal_lower = d->internal_lower;
const ccs_numeric_t internal_upper = d->internal_upper;
......@@ -218,9 +220,10 @@ ccs_create_uniform_distribution(ccs_numeric_type_t data_type,
_ccs_object_init(&(distrib->obj), CCS_DISTRIBUTION, (_ccs_object_ops_t *)&_ccs_distribution_uniform_ops);
_ccs_distribution_uniform_data_t * distrib_data = (_ccs_distribution_uniform_data_t *)(mem + sizeof(struct _ccs_distribution_s));
distrib_data->common_data.type = CCS_UNIFORM;
distrib_data->common_data.dimension = 1;
distrib_data->common_data.data_type = data_type;
distrib_data->common_data.scale_type = scale_type;
distrib_data->common_data.quantization = quantization;
distrib_data->scale_type = scale_type;
distrib_data->quantization = quantization;
distrib_data->lower = lower;
distrib_data->upper = upper;
......@@ -257,20 +260,24 @@ ccs_create_uniform_distribution(ccs_numeric_type_t data_type,
ccs_result_t
ccs_uniform_distribution_get_parameters(ccs_distribution_t distribution,
ccs_numeric_t *lower_ret,
ccs_numeric_t *upper_ret) {
ccs_numeric_t *upper_ret,
ccs_scale_type_t *scale_type_ret,
ccs_numeric_t *quantization_ret) {
CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION);
if (((_ccs_distribution_common_data_t*)distribution->data)->type != CCS_UNIFORM)
return -CCS_INVALID_OBJECT;
if (!lower_ret && !upper_ret)
if (!lower_ret && !upper_ret && !scale_type_ret && !quantization_ret)
return -CCS_INVALID_VALUE;
_ccs_distribution_uniform_data_t * data = (_ccs_distribution_uniform_data_t *)distribution->data;
if (lower_ret) {
if (lower_ret)
*lower_ret = data->lower;
}
if (upper_ret) {
if (upper_ret)
*upper_ret = data->upper;
}
if (scale_type_ret)
*scale_type_ret = data->scale_type;
if (quantization_ret)
*quantization_ret = data->quantization;
return CCS_SUCCESS;
}
......
......@@ -39,14 +39,6 @@ static void test_create_normal_distribution() {
assert( err == CCS_SUCCESS );
assert( data_type == CCS_NUM_FLOAT );
err = ccs_distribution_get_scale_type(distrib, &stype);
assert( err == CCS_SUCCESS );
assert( stype == CCS_LINEAR );
err = ccs_distribution_get_quantization(distrib, &quantization);
assert( err == CCS_SUCCESS );
assert( quantization.f == 0.0 );
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_FLOAT );
......@@ -55,10 +47,12 @@ static void test_create_normal_distribution() {
assert( interval.upper.f == CCS_INFINITY );
assert( interval.upper_included == CCS_FALSE );
err = ccs_normal_distribution_get_parameters(distrib, &mu, &sigma);
err = ccs_normal_distribution_get_parameters(distrib, &mu, &sigma, &stype, &quantization);
assert( err == CCS_SUCCESS );
assert( mu == 1.0 );
assert( sigma == 2.0 );
assert( stype == CCS_LINEAR );
assert( quantization.f == 0.0 );
err = ccs_object_get_refcount(distrib, &refcount);
assert( err == CCS_SUCCESS );
......
......@@ -12,9 +12,7 @@ void test_create_roulette_distribution() {
int32_t refcount;
ccs_object_type_t otype;
ccs_distribution_type_t dtype;
ccs_scale_type_t stype;
ccs_numeric_type_t data_type;
ccs_numeric_t quantization;
ccs_interval_t interval;
const size_t num_areas = 4;
ccs_float_t areas[num_areas];
......@@ -44,14 +42,6 @@ void test_create_roulette_distribution() {
assert( err == CCS_SUCCESS );
assert( data_type == CCS_NUM_INTEGER );
err = ccs_distribution_get_scale_type(distrib, &stype);
assert( err == CCS_SUCCESS );
assert( stype == CCS_LINEAR );
err = ccs_distribution_get_quantization(distrib, &quantization);
assert( err == CCS_SUCCESS );
assert( quantization.i == 0 );
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_INTEGER );
......
......@@ -37,14 +37,6 @@ static void test_create_uniform_distribution() {
assert( err == CCS_SUCCESS );
assert( data_type == CCS_NUM_INTEGER );
err = ccs_distribution_get_scale_type(distrib, &stype);
assert( err == CCS_SUCCESS );
assert( stype == CCS_LINEAR );
err = ccs_distribution_get_quantization(distrib, &quantization);
assert( err == CCS_SUCCESS );
assert( quantization.i == q );
err = ccs_distribution_get_bounds(distrib, &interval);
assert( err == CCS_SUCCESS );
assert( interval.type == CCS_NUM_INTEGER );
......@@ -53,10 +45,12 @@ static void test_create_uniform_distribution() {
assert( interval.upper.i == u );
assert( interval.upper_included == CCS_FALSE );
err = ccs_uniform_distribution_get_parameters(distrib, &lower, &upper);
err = ccs_uniform_distribution_get_parameters(distrib, &lower, &upper, &stype, &quantization);
assert( err == CCS_SUCCESS );
assert( lower.i == l );
assert( upper.i == u );
assert( stype == CCS_LINEAR );
assert( quantization.i == q );
err = ccs_object_get_refcount(distrib, &refcount);
assert( err == CCS_SUCCESS );
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment