From 3ba7aa3035231310e7144929bb7ac08a5c123c8f Mon Sep 17 00:00:00 2001 From: Brice Videau Date: Mon, 10 Aug 2020 14:13:20 -0500 Subject: [PATCH] Started refactoring distributions. --- bindings/python/cconfigspace/distribution.py | 98 ++++++++++++------- bindings/python/test/test_distribution.py | 1 - .../ruby/lib/cconfigspace/distribution.rb | 72 +++++++++----- bindings/ruby/test/test_distribution.rb | 1 - include/cconfigspace/distribution.h | 31 +++--- src/distribution.c | 21 +--- src/distribution_internal.h | 5 +- src/distribution_normal.c | 41 ++++---- src/distribution_roulette.c | 3 +- src/distribution_uniform.c | 31 +++--- tests/test_normal_distribution.c | 12 +-- tests/test_roulette_distribution.c | 10 -- tests/test_uniform_distribution.c | 12 +-- 13 files changed, 184 insertions(+), 154 deletions(-) diff --git a/bindings/python/cconfigspace/distribution.py b/bindings/python/cconfigspace/distribution.py index b24791e..5adfc54 100644 --- a/bindings/python/cconfigspace/distribution.py +++ b/bindings/python/cconfigspace/distribution.py @@ -17,8 +17,6 @@ class ccs_scale_type(CEnumeration): ccs_distribution_get_type = _ccs_get_function("ccs_distribution_get_type", [ccs_distribution, ct.POINTER(ccs_distribution_type)]) ccs_distribution_get_data_type = _ccs_get_function("ccs_distribution_get_data_type", [ccs_distribution, ct.POINTER(ccs_numeric_type)]) ccs_distribution_get_dimension = _ccs_get_function("ccs_distribution_get_dimension", [ccs_distribution, ct.POINTER(ct.c_size_t)]) -ccs_distribution_get_scale_type = _ccs_get_function("ccs_distribution_get_scale_type", [ccs_distribution, ct.POINTER(ccs_scale_type)]) -ccs_distribution_get_quantization = _ccs_get_function("ccs_distribution_get_quantization", [ccs_distribution, ct.POINTER(ccs_numeric)]) ccs_distribution_get_bounds = _ccs_get_function("ccs_distribution_get_bounds", [ccs_distribution, ct.POINTER(ccs_interval)]) ccs_distribution_check_oversampling = _ccs_get_function("ccs_distribution_check_oversampling", [ccs_distribution, ct.POINTER(ccs_interval), ct.POINTER(ccs_bool)]) ccs_distribution_sample = _ccs_get_function("ccs_distribution_sample", [ccs_distribution, ccs_rng, ct.POINTER(ccs_numeric)]) @@ -71,34 +69,6 @@ class Distribution(Object): self._dimension = v.value return v.value - @property - def scale_type(self): - if hasattr(self, "_scale_type"): - return self._scale_type - v = ccs_scale_type(0) - res = ccs_distribution_get_scale_type(self.handle, ct.byref(v)) - Error.check(res) - self._scale_type = v.value - return self._scale_type - - scale = scale_type - - @property - def quantization(self): - if hasattr(self, "_quantization"): - return self._quantization - v = ccs_numeric(0) - res = ccs_distribution_get_quantization(self.handle, ct.byref(v)) - Error.check(res) - t = self.data_type - if t == ccs_numeric_type.NUM_INTEGER: - self._quantization = v.i - elif t == ccs_numeric_type.NUM_FLOAT: - self._quantization = v.f - else: - raise Error(ccs_error(ccs_error.INVALID_VALUE)) - return self._quantization - @property def bounds(self): if hasattr(self, "_bounds"): @@ -144,7 +114,7 @@ class Distribution(Object): ccs_create_uniform_distribution = _ccs_get_function("ccs_create_uniform_distribution", [ccs_numeric_type, ccs_int, ccs_int, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)]) ccs_create_uniform_int_distribution = _ccs_get_function("ccs_create_uniform_int_distribution", [ccs_int, ccs_int, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)]) ccs_create_uniform_float_distribution = _ccs_get_function("ccs_create_uniform_float_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_float, ct.POINTER(ccs_distribution)]) -ccs_uniform_distribution_get_parameters = _ccs_get_function("ccs_uniform_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_numeric), ct.POINTER(ccs_numeric)]) +ccs_uniform_distribution_get_parameters = _ccs_get_function("ccs_uniform_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_numeric), ct.POINTER(ccs_numeric), ct.POINTER(ccs_scale_type), ct.POINTER(ccs_numeric)]) class UniformDistribution(Distribution): def __init__(self, handle = None, retain = False, data_type = NUM_FLOAT, lower = 0.0, upper = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0): @@ -174,7 +144,7 @@ class UniformDistribution(Distribution): if hasattr(self, "_lower"): return self._lower v = ccs_numeric() - res = ccs_uniform_distribution_get_parameters(self.handle, ct.byref(v), None) + res = ccs_uniform_distribution_get_parameters(self.handle, ct.byref(v), None, None, None) Error.check(res) t = self.data_type if t == ccs_numeric_type.NUM_INTEGER: @@ -190,7 +160,7 @@ class UniformDistribution(Distribution): if hasattr(self, "_upper"): return self._upper v = ccs_numeric() - res = ccs_uniform_distribution_get_parameters(self.handle, None, ct.byref(v)) + res = ccs_uniform_distribution_get_parameters(self.handle, None, ct.byref(v), None, None) Error.check(res) t = self.data_type if t == ccs_numeric_type.NUM_INTEGER: @@ -201,10 +171,38 @@ class UniformDistribution(Distribution): raise Error(ccs_error(ccs_error.INVALID_VALUE)) return self._upper + @property + def scale_type(self): + if hasattr(self, "_scale_type"): + return self._scale_type + v = ccs_scale_type(0) + res = ccs_uniform_distribution_get_parameters(self.handle, None, None, ct.byref(v), None) + Error.check(res) + self._scale_type = v.value + return self._scale_type + + scale = scale_type + + @property + def quantization(self): + if hasattr(self, "_quantization"): + return self._quantization + v = ccs_numeric(0) + res = ccs_uniform_distribution_get_parameters(self.handle, None, None, None, ct.byref(v)) + Error.check(res) + t = self.data_type + if t == ccs_numeric_type.NUM_INTEGER: + self._quantization = v.i + elif t == ccs_numeric_type.NUM_FLOAT: + self._quantization = v.f + else: + raise Error(ccs_error(ccs_error.INVALID_VALUE)) + return self._quantization + ccs_create_normal_distribution = _ccs_get_function("ccs_create_normal_distribution", [ccs_numeric_type, ccs_float, ccs_float, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)]) ccs_create_normal_int_distribution = _ccs_get_function("ccs_create_normal_int_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)]) ccs_create_normal_float_distribution = _ccs_get_function("ccs_create_normal_float_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_float, ct.POINTER(ccs_distribution)]) -ccs_normal_distribution_get_parameters = _ccs_get_function("ccs_normal_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_float), ct.POINTER(ccs_float)]) +ccs_normal_distribution_get_parameters = _ccs_get_function("ccs_normal_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_float), ct.POINTER(ccs_float), ct.POINTER(ccs_scale_type), ct.POINTER(ccs_numeric)]) class NormalDistribution(Distribution): def __init__(self, handle = None, retain = False, data_type = NUM_FLOAT, mu = 0.0, sigma = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0): @@ -234,7 +232,7 @@ class NormalDistribution(Distribution): if hasattr(self, "_mu"): return self._mu v = ccs_float() - res = ccs_normal_distribution_get_parameters(self.handle, ct.byref(v), None) + res = ccs_normal_distribution_get_parameters(self.handle, ct.byref(v), None, None, None) Error.check(res) self._mu = v.value return self._mu @@ -244,11 +242,39 @@ class NormalDistribution(Distribution): if hasattr(self, "_sigma"): return self._sigma v = ccs_float() - res = ccs_normal_distribution_get_parameters(self.handle, None, ct.byref(v)) + res = ccs_normal_distribution_get_parameters(self.handle, None, ct.byref(v), None, None) Error.check(res) self._sigma = v.value return self._sigma + @property + def scale_type(self): + if hasattr(self, "_scale_type"): + return self._scale_type + v = ccs_scale_type(0) + res = ccs_normal_distribution_get_parameters(self.handle, None, None, ct.byref(v), None) + Error.check(res) + self._scale_type = v.value + return self._scale_type + + scale = scale_type + + @property + def quantization(self): + if hasattr(self, "_quantization"): + return self._quantization + v = ccs_numeric(0) + res = ccs_normal_distribution_get_parameters(self.handle, None, None, None, ct.byref(v)) + Error.check(res) + t = self.data_type + if t == ccs_numeric_type.NUM_INTEGER: + self._quantization = v.i + elif t == ccs_numeric_type.NUM_FLOAT: + self._quantization = v.f + else: + raise Error(ccs_error(ccs_error.INVALID_VALUE)) + return self._quantization + ccs_create_roulette_distribution = _ccs_get_function("ccs_create_roulette_distribution", [ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ccs_distribution)]) ccs_roulette_distribution_get_num_areas = _ccs_get_function("ccs_roulette_distribution_get_num_areas", [ccs_distribution, ct.POINTER(ct.c_size_t)]) ccs_roulette_distribution_get_areas = _ccs_get_function("ccs_roulette_distribution_get_areas", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ct.c_size_t)]) diff --git a/bindings/python/test/test_distribution.py b/bindings/python/test/test_distribution.py index 877efaa..58ecc40 100644 --- a/bindings/python/test/test_distribution.py +++ b/bindings/python/test/test_distribution.py @@ -19,7 +19,6 @@ class TestDistribution(unittest.TestCase): self.assertEqual( ccs.DISTRIBUTION, d.object_type ) self.assertEqual( ccs.ROULETTE, d.type ) self.assertEqual( ccs.NUM_INTEGER, d.data_type ) - self.assertEqual( ccs.LINEAR, d.scale ) self.assertEqual( 1, d.dimension ) a = d.areas self.assertTrue( sum(a) > 0.999 ) diff --git a/bindings/ruby/lib/cconfigspace/distribution.rb b/bindings/ruby/lib/cconfigspace/distribution.rb index a81415a..afdb847 100644 --- a/bindings/ruby/lib/cconfigspace/distribution.rb +++ b/bindings/ruby/lib/cconfigspace/distribution.rb @@ -24,8 +24,6 @@ module CCS attach_function :ccs_distribution_get_type, [:ccs_distribution_t, :pointer], :ccs_result_t attach_function :ccs_distribution_get_data_type, [:ccs_distribution_t, :pointer], :ccs_result_t attach_function :ccs_distribution_get_dimension, [:ccs_distribution_t, :pointer], :ccs_result_t - attach_function :ccs_distribution_get_scale_type, [:ccs_distribution_t, :pointer], :ccs_result_t - attach_function :ccs_distribution_get_quantization, [:ccs_distribution_t, :pointer], :ccs_result_t attach_function :ccs_distribution_get_bounds, [:ccs_distribution_t, :pointer], :ccs_result_t attach_function :ccs_distribution_check_oversampling, [:ccs_distribution_t, Interval.by_ref, :pointer], :ccs_result_t attach_function :ccs_distribution_sample, [:ccs_distribution_t, :ccs_rng_t, :pointer], :ccs_result_t @@ -35,7 +33,6 @@ module CCS add_property :type, :ccs_distribution_type_t, :ccs_distribution_get_type, memoize: true add_property :data_type, :ccs_numeric_type_t, :ccs_distribution_get_data_type, memoize: true add_property :dimension, :size_t, :ccs_distribution_get_dimension, memoize: true - add_property :scale, :ccs_scale_type_t, :ccs_distribution_get_scale_type, memoize: true def self.from_handle(handle) ptr = MemoryPointer::new(:ccs_distribution_type_t) @@ -53,19 +50,6 @@ module CCS end end - def quantization - @quantization ||= begin - ptr = MemoryPointer::new(:ccs_numeric_t) - res = CCS.ccs_distribution_get_quantization(@handle, ptr) - CCS.error_check(res) - if data_type == :CCS_NUM_FLOAT - ptr.read_ccs_float_t - else - ptr.read_ccs_int_t - end - end - end - def bounds @bounds ||= begin interval = Interval::new(type: :CCS_NUM_FLOAT) @@ -110,7 +94,7 @@ module CCS attach_function :ccs_create_uniform_distribution, [:ccs_numeric_type_t, :ccs_numeric_t, :ccs_numeric_t, :ccs_scale_type_t, :ccs_numeric_t, :pointer], :ccs_result_t attach_function :ccs_create_uniform_int_distribution, [:ccs_int_t, :ccs_int_t, :ccs_scale_type_t, :ccs_int_t, :pointer], :ccs_result_t attach_function :ccs_create_uniform_float_distribution, [:ccs_float_t, :ccs_float_t, :ccs_scale_type_t, :ccs_float_t, :pointer], :ccs_result_t - attach_function :ccs_uniform_distribution_get_parameters, [:ccs_distribution_t, :pointer, :pointer], :ccs_result_t + attach_function :ccs_uniform_distribution_get_parameters, [:ccs_distribution_t, :pointer, :pointer, :pointer, :pointer], :ccs_result_t class UniformDistribution < Distribution def initialize(handle = nil, retain: false, data_type: :CCS_NUM_FLOAT, lower: 0.0, upper: 1.0, scale: :CCS_LINEAR, quantization: 0.0) @@ -139,7 +123,7 @@ module CCS def lower @lower ||= begin ptr = MemoryPointer::new(:ccs_numeric_t) - res = CCS.ccs_uniform_distribution_get_parameters(@handle, ptr, nil) + res = CCS.ccs_uniform_distribution_get_parameters(@handle, ptr, nil, nil, nil) CCS.error_check(res) if data_type == :CCS_NUM_FLOAT ptr.read_ccs_float_t @@ -152,7 +136,29 @@ module CCS def upper @upper ||= begin ptr = MemoryPointer::new(:ccs_numeric_t) - res = CCS.ccs_uniform_distribution_get_parameters(@handle, nil, ptr) + res = CCS.ccs_uniform_distribution_get_parameters(@handle, nil, ptr, nil, nil) + CCS.error_check(res) + if data_type == :CCS_NUM_FLOAT + ptr.read_ccs_float_t + else + ptr.read_ccs_int_t + end + end + end + + def scale + @scale ||= begin + ptr = MemoryPointer::new(:ccs_scale_type_t) + res = CCS.ccs_uniform_distribution_get_parameters(@handle, nil, nil, ptr, nil) + CCS.error_check(res) + ptr.read_ccs_scale_type_t + end + end + + def quantization + @quantization ||= begin + ptr = MemoryPointer::new(:ccs_numeric_t) + res = CCS.ccs_uniform_distribution_get_parameters(@handle, nil, nil, nil, ptr) CCS.error_check(res) if data_type == :CCS_NUM_FLOAT ptr.read_ccs_float_t @@ -166,7 +172,7 @@ module CCS attach_function :ccs_create_normal_distribution, [:ccs_numeric_type_t, :ccs_float_t, :ccs_float_t, :ccs_scale_type_t, :ccs_numeric_t, :pointer], :ccs_result_t attach_function :ccs_create_normal_int_distribution, [:ccs_float_t, :ccs_float_t, :ccs_scale_type_t, :ccs_int_t, :pointer], :ccs_result_t attach_function :ccs_create_normal_float_distribution, [:ccs_float_t, :ccs_float_t, :ccs_scale_type_t, :ccs_float_t, :pointer], :ccs_result_t - attach_function :ccs_normal_distribution_get_parameters, [:ccs_distribution_t, :pointer, :pointer], :ccs_result_t + attach_function :ccs_normal_distribution_get_parameters, [:ccs_distribution_t, :pointer, :pointer, :pointer, :pointer], :ccs_result_t class NormalDistribution < Distribution def initialize(handle = nil, retain: false, data_type: :CCS_NUM_FLOAT, mu: 0.0, sigma: 1.0, scale: :CCS_LINEAR, quantization: 0.0) if handle @@ -194,7 +200,7 @@ module CCS def mu @mu ||= begin ptr = MemoryPointer::new(:ccs_numeric_t) - res = CCS.ccs_normal_distribution_get_parameters(@handle, ptr, nil) + res = CCS.ccs_normal_distribution_get_parameters(@handle, ptr, nil, nil, nil) CCS.error_check(res) ptr.read_ccs_float_t end @@ -203,11 +209,33 @@ module CCS def sigma @sigma ||= begin ptr = MemoryPointer::new(:ccs_numeric_t) - res = CCS.ccs_normal_distribution_get_parameters(@handle, nil, ptr) + res = CCS.ccs_normal_distribution_get_parameters(@handle, nil, ptr, nil, nil) CCS.error_check(res) ptr.read_ccs_float_t end end + + def scale + @scale ||= begin + ptr = MemoryPointer::new(:ccs_scale_type_t) + res = CCS.ccs_normal_distribution_get_parameters(@handle, nil, nil, ptr, nil) + CCS.error_check(res) + ptr.read_ccs_scale_type_t + end + end + + def quantization + @quantization ||= begin + ptr = MemoryPointer::new(:ccs_numeric_t) + res = CCS.ccs_normal_distribution_get_parameters(@handle, nil, nil, nil, ptr) + CCS.error_check(res) + if data_type == :CCS_NUM_FLOAT + ptr.read_ccs_float_t + else + ptr.read_ccs_int_t + end + end + end end attach_function :ccs_create_roulette_distribution, [:size_t, :pointer, :pointer], :ccs_result_t diff --git a/bindings/ruby/test/test_distribution.rb b/bindings/ruby/test/test_distribution.rb index b96d02d..d539fdc 100644 --- a/bindings/ruby/test/test_distribution.rb +++ b/bindings/ruby/test/test_distribution.rb @@ -21,7 +21,6 @@ class CConfigSpaceTestDistribution < Minitest::Test assert_equal( :CCS_DISTRIBUTION, d.object_type ) assert_equal( :CCS_ROULETTE, d.type ) assert_equal( :CCS_NUM_INTEGER, d.data_type ) - assert_equal( :CCS_LINEAR, d.scale ) assert_equal( 1, d.dimension ) assert_equal( 4, d.num_areas ) assert( d.areas.reduce(:+) > 0.999 ) diff --git a/include/cconfigspace/distribution.h b/include/cconfigspace/distribution.h index 545f8bb..b9bc09f 100644 --- a/include/cconfigspace/distribution.h +++ b/include/cconfigspace/distribution.h @@ -74,26 +74,29 @@ ccs_create_roulette_distribution(size_t num_areas, ccs_float_t *areas, ccs_distribution_t *distribution_ret); +extern ccs_result_t +ccs_create_mixture_distribution(size_t num_distributions, + ccs_distribution_t *distributions, + ccs_float_t *weights, + ccs_distribution_t *distribution_ret); + +extern ccs_result_t +ccs_create_multivariate_distribution(size_t num_distributions, + ccs_distribution_t *distributions, + ccs_distribution_t *distribution_ret); + // Accessors extern ccs_result_t ccs_distribution_get_type(ccs_distribution_t distribution, ccs_distribution_type_t *type_ret); -extern ccs_result_t -ccs_distribution_get_data_type(ccs_distribution_t distribution, - ccs_numeric_type_t *data_type_ret); - extern ccs_result_t ccs_distribution_get_dimension(ccs_distribution_t distribution, size_t *dimension); extern ccs_result_t -ccs_distribution_get_scale_type(ccs_distribution_t distribution, - ccs_scale_type_t *scale_type_ret); - -extern ccs_result_t -ccs_distribution_get_quantization(ccs_distribution_t distribution, - ccs_numeric_t *quantization); +ccs_distribution_get_data_type(ccs_distribution_t distribution, + ccs_numeric_type_t *data_type_ret); extern ccs_result_t ccs_distribution_get_bounds(ccs_distribution_t distribution, @@ -107,12 +110,16 @@ ccs_distribution_check_oversampling(ccs_distribution_t distribution, extern ccs_result_t ccs_normal_distribution_get_parameters(ccs_distribution_t distribution, ccs_float_t *mu_ret, - ccs_float_t *sigma_ret); + ccs_float_t *sigma_ret, + ccs_scale_type_t *scale_ret, + ccs_numeric_t *quantization_ret); extern ccs_result_t ccs_uniform_distribution_get_parameters(ccs_distribution_t distribution, ccs_numeric_t *lower_ret, - ccs_numeric_t *upper_ret); + ccs_numeric_t *upper_ret, + ccs_scale_type_t *scale_ret, + ccs_numeric_t *quantization_ret); extern ccs_result_t ccs_roulette_distribution_get_num_areas(ccs_distribution_t distribution, diff --git a/src/distribution.c b/src/distribution.c index 440d115..c722fdb 100644 --- a/src/distribution.c +++ b/src/distribution.c @@ -30,26 +30,7 @@ ccs_distribution_get_dimension(ccs_distribution_t distribution, size_t *dimension_ret) { CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION); CCS_CHECK_PTR(dimension_ret); - *dimension_ret = 1; - return CCS_SUCCESS; -} - - -ccs_result_t -ccs_distribution_get_scale_type(ccs_distribution_t distribution, - ccs_scale_type_t *scale_type_ret) { - CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION); - CCS_CHECK_PTR(scale_type_ret); - *scale_type_ret = ((_ccs_distribution_common_data_t *)(distribution->data))->scale_type; - return CCS_SUCCESS; -} - -ccs_result_t -ccs_distribution_get_quantization(ccs_distribution_t distribution, - ccs_numeric_t *quantization_ret) { - CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION); - CCS_CHECK_PTR(quantization_ret); - *quantization_ret = ((_ccs_distribution_common_data_t *)(distribution->data))->quantization; + *dimension_ret = ((_ccs_distribution_common_data_t *)(distribution->data))->dimension; return CCS_SUCCESS; } diff --git a/src/distribution_internal.h b/src/distribution_internal.h index bd2f06f..57ae70c 100644 --- a/src/distribution_internal.h +++ b/src/distribution_internal.h @@ -33,10 +33,9 @@ struct _ccs_distribution_s { }; struct _ccs_distribution_common_data_s { - ccs_distribution_type_t type; + ccs_distribution_type_t type; + size_t dimension; ccs_numeric_type_t data_type; - ccs_scale_type_t scale_type; - ccs_numeric_t quantization; }; typedef struct _ccs_distribution_common_data_s _ccs_distribution_common_data_t; #endif //_DISTRIBUTION_INTERNAL_H diff --git a/src/distribution_normal.c b/src/distribution_normal.c index fe494ef..0dfbe8a 100644 --- a/src/distribution_normal.c +++ b/src/distribution_normal.c @@ -8,6 +8,8 @@ struct _ccs_distribution_normal_data_s { _ccs_distribution_common_data_t common_data; ccs_float_t mu; ccs_float_t sigma; + ccs_scale_type_t scale_type; + ccs_numeric_t quantization; int quantize; }; typedef struct _ccs_distribution_normal_data_s _ccs_distribution_normal_data_t; @@ -47,8 +49,8 @@ _ccs_distribution_normal_get_bounds(_ccs_distribution_data_t *data, ccs_interval_t *interval_ret) { _ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data; const ccs_numeric_type_t data_type = d->common_data.data_type; - const ccs_scale_type_t scale_type = d->common_data.scale_type; - const ccs_numeric_t quantization = d->common_data.quantization; + const ccs_scale_type_t scale_type = d->scale_type; + const ccs_numeric_t quantization = d->quantization; const int quantize = d->quantize; ccs_numeric_t l; ccs_bool_t li; @@ -193,8 +195,8 @@ _ccs_distribution_normal_samples(_ccs_distribution_data_t *data, ccs_numeric_t *values) { _ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data; const ccs_numeric_type_t data_type = d->common_data.data_type; - const ccs_scale_type_t scale_type = d->common_data.scale_type; - const ccs_numeric_t quantization = d->common_data.quantization; + const ccs_scale_type_t scale_type = d->scale_type; + const ccs_numeric_t quantization = d->quantization; const ccs_float_t mu = d->mu; const ccs_float_t sigma = d->sigma; const int quantize = d->quantize; @@ -308,8 +310,8 @@ _ccs_distribution_normal_strided_samples(_ccs_distribution_data_t *data, ccs_numeric_t *values) { _ccs_distribution_normal_data_t *d = (_ccs_distribution_normal_data_t *)data; const ccs_numeric_type_t data_type = d->common_data.data_type; - const ccs_scale_type_t scale_type = d->common_data.scale_type; - const ccs_numeric_t quantization = d->common_data.quantization; + const ccs_scale_type_t scale_type = d->scale_type; + const ccs_numeric_t quantization = d->quantization; const ccs_float_t mu = d->mu; const ccs_float_t sigma = d->sigma; const int quantize = d->quantize; @@ -354,9 +356,10 @@ ccs_create_normal_distribution(ccs_numeric_type_t data_type, _ccs_object_init(&(distrib->obj), CCS_DISTRIBUTION, (_ccs_object_ops_t *)&_ccs_distribution_normal_ops); _ccs_distribution_normal_data_t * distrib_data = (_ccs_distribution_normal_data_t *)(mem + sizeof(struct _ccs_distribution_s)); distrib_data->common_data.type = CCS_NORMAL; + distrib_data->common_data.dimension = 1; distrib_data->common_data.data_type = data_type; - distrib_data->common_data.scale_type = scale_type; - distrib_data->common_data.quantization = quantization; + distrib_data->scale_type = scale_type; + distrib_data->quantization = quantization; distrib_data->mu = mu; distrib_data->sigma = sigma; if (data_type == CCS_NUM_FLOAT) { @@ -373,21 +376,25 @@ ccs_create_normal_distribution(ccs_numeric_type_t data_type, extern ccs_result_t ccs_normal_distribution_get_parameters(ccs_distribution_t distribution, - ccs_float_t *mu, - ccs_float_t *sigma) { + ccs_float_t *mu_ret, + ccs_float_t *sigma_ret, + ccs_scale_type_t *scale_type_ret, + ccs_numeric_t *quantization_ret) { CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION); if (((_ccs_distribution_common_data_t*)distribution->data)->type != CCS_NORMAL) return -CCS_INVALID_OBJECT; - if (!mu && !sigma) + if (!mu_ret && !sigma_ret && !scale_type_ret && !quantization_ret) return -CCS_INVALID_VALUE; _ccs_distribution_normal_data_t * data = (_ccs_distribution_normal_data_t *)distribution->data; - if (mu) { - *mu = data->mu; - } - if (sigma) { - *sigma = data->sigma; - } + if (mu_ret) + *mu_ret = data->mu; + if (sigma_ret) + *sigma_ret = data->sigma; + if (scale_type_ret) + *scale_type_ret = data->scale_type; + if (quantization_ret) + *quantization_ret = data->quantization; return CCS_SUCCESS; } diff --git a/src/distribution_roulette.c b/src/distribution_roulette.c index d01bb8b..0bdb9ad 100644 --- a/src/distribution_roulette.c +++ b/src/distribution_roulette.c @@ -151,9 +151,8 @@ ccs_create_roulette_distribution(size_t num_areas, _ccs_object_init(&(distrib->obj), CCS_DISTRIBUTION, (_ccs_object_ops_t *)&_ccs_distribution_roulette_ops); _ccs_distribution_roulette_data_t * distrib_data = (_ccs_distribution_roulette_data_t *)(mem + sizeof(struct _ccs_distribution_s)); distrib_data->common_data.type = CCS_ROULETTE; + distrib_data->common_data.dimension = 1; distrib_data->common_data.data_type = CCS_NUM_INTEGER; - distrib_data->common_data.scale_type = CCS_LINEAR; - distrib_data->common_data.quantization = CCSI(0); distrib_data->num_areas = num_areas; distrib_data->areas = (ccs_float_t *)(mem + sizeof(struct _ccs_distribution_s) + sizeof(_ccs_distribution_roulette_data_t)); diff --git a/src/distribution_uniform.c b/src/distribution_uniform.c index 0482cf4..c0e4dcb 100644 --- a/src/distribution_uniform.c +++ b/src/distribution_uniform.c @@ -8,6 +8,8 @@ struct _ccs_distribution_uniform_data_s { _ccs_distribution_common_data_t common_data; ccs_numeric_t lower; ccs_numeric_t upper; + ccs_scale_type_t scale_type; + ccs_numeric_t quantization; ccs_numeric_type_t internal_type; ccs_numeric_t internal_lower; ccs_numeric_t internal_upper; @@ -76,8 +78,8 @@ _ccs_distribution_uniform_strided_samples(_ccs_distribution_data_t *data, _ccs_distribution_uniform_data_t *d = (_ccs_distribution_uniform_data_t *)data; size_t i; const ccs_numeric_type_t data_type = d->common_data.data_type; - const ccs_scale_type_t scale_type = d->common_data.scale_type; - const ccs_numeric_t quantization = d->common_data.quantization; + const ccs_scale_type_t scale_type = d->scale_type; + const ccs_numeric_t quantization = d->quantization; const ccs_numeric_t lower = d->lower; const ccs_numeric_t internal_lower = d->internal_lower; const ccs_numeric_t internal_upper = d->internal_upper; @@ -135,8 +137,8 @@ _ccs_distribution_uniform_samples(_ccs_distribution_data_t *data, _ccs_distribution_uniform_data_t *d = (_ccs_distribution_uniform_data_t *)data; size_t i; const ccs_numeric_type_t data_type = d->common_data.data_type; - const ccs_scale_type_t scale_type = d->common_data.scale_type; - const ccs_numeric_t quantization = d->common_data.quantization; + const ccs_scale_type_t scale_type = d->scale_type; + const ccs_numeric_t quantization = d->quantization; const ccs_numeric_t lower = d->lower; const ccs_numeric_t internal_lower = d->internal_lower; const ccs_numeric_t internal_upper = d->internal_upper; @@ -218,9 +220,10 @@ ccs_create_uniform_distribution(ccs_numeric_type_t data_type, _ccs_object_init(&(distrib->obj), CCS_DISTRIBUTION, (_ccs_object_ops_t *)&_ccs_distribution_uniform_ops); _ccs_distribution_uniform_data_t * distrib_data = (_ccs_distribution_uniform_data_t *)(mem + sizeof(struct _ccs_distribution_s)); distrib_data->common_data.type = CCS_UNIFORM; + distrib_data->common_data.dimension = 1; distrib_data->common_data.data_type = data_type; - distrib_data->common_data.scale_type = scale_type; - distrib_data->common_data.quantization = quantization; + distrib_data->scale_type = scale_type; + distrib_data->quantization = quantization; distrib_data->lower = lower; distrib_data->upper = upper; @@ -257,20 +260,24 @@ ccs_create_uniform_distribution(ccs_numeric_type_t data_type, ccs_result_t ccs_uniform_distribution_get_parameters(ccs_distribution_t distribution, ccs_numeric_t *lower_ret, - ccs_numeric_t *upper_ret) { + ccs_numeric_t *upper_ret, + ccs_scale_type_t *scale_type_ret, + ccs_numeric_t *quantization_ret) { CCS_CHECK_OBJ(distribution, CCS_DISTRIBUTION); if (((_ccs_distribution_common_data_t*)distribution->data)->type != CCS_UNIFORM) return -CCS_INVALID_OBJECT; - if (!lower_ret && !upper_ret) + if (!lower_ret && !upper_ret && !scale_type_ret && !quantization_ret) return -CCS_INVALID_VALUE; _ccs_distribution_uniform_data_t * data = (_ccs_distribution_uniform_data_t *)distribution->data; - if (lower_ret) { + if (lower_ret) *lower_ret = data->lower; - } - if (upper_ret) { + if (upper_ret) *upper_ret = data->upper; - } + if (scale_type_ret) + *scale_type_ret = data->scale_type; + if (quantization_ret) + *quantization_ret = data->quantization; return CCS_SUCCESS; } diff --git a/tests/test_normal_distribution.c b/tests/test_normal_distribution.c index 31ff4c4..06f0b86 100644 --- a/tests/test_normal_distribution.c +++ b/tests/test_normal_distribution.c @@ -39,14 +39,6 @@ static void test_create_normal_distribution() { assert( err == CCS_SUCCESS ); assert( data_type == CCS_NUM_FLOAT ); - err = ccs_distribution_get_scale_type(distrib, &stype); - assert( err == CCS_SUCCESS ); - assert( stype == CCS_LINEAR ); - - err = ccs_distribution_get_quantization(distrib, &quantization); - assert( err == CCS_SUCCESS ); - assert( quantization.f == 0.0 ); - err = ccs_distribution_get_bounds(distrib, &interval); assert( err == CCS_SUCCESS ); assert( interval.type == CCS_NUM_FLOAT ); @@ -55,10 +47,12 @@ static void test_create_normal_distribution() { assert( interval.upper.f == CCS_INFINITY ); assert( interval.upper_included == CCS_FALSE ); - err = ccs_normal_distribution_get_parameters(distrib, &mu, &sigma); + err = ccs_normal_distribution_get_parameters(distrib, &mu, &sigma, &stype, &quantization); assert( err == CCS_SUCCESS ); assert( mu == 1.0 ); assert( sigma == 2.0 ); + assert( stype == CCS_LINEAR ); + assert( quantization.f == 0.0 ); err = ccs_object_get_refcount(distrib, &refcount); assert( err == CCS_SUCCESS ); diff --git a/tests/test_roulette_distribution.c b/tests/test_roulette_distribution.c index 4abb6b5..c4e0581 100644 --- a/tests/test_roulette_distribution.c +++ b/tests/test_roulette_distribution.c @@ -12,9 +12,7 @@ void test_create_roulette_distribution() { int32_t refcount; ccs_object_type_t otype; ccs_distribution_type_t dtype; - ccs_scale_type_t stype; ccs_numeric_type_t data_type; - ccs_numeric_t quantization; ccs_interval_t interval; const size_t num_areas = 4; ccs_float_t areas[num_areas]; @@ -44,14 +42,6 @@ void test_create_roulette_distribution() { assert( err == CCS_SUCCESS ); assert( data_type == CCS_NUM_INTEGER ); - err = ccs_distribution_get_scale_type(distrib, &stype); - assert( err == CCS_SUCCESS ); - assert( stype == CCS_LINEAR ); - - err = ccs_distribution_get_quantization(distrib, &quantization); - assert( err == CCS_SUCCESS ); - assert( quantization.i == 0 ); - err = ccs_distribution_get_bounds(distrib, &interval); assert( err == CCS_SUCCESS ); assert( interval.type == CCS_NUM_INTEGER ); diff --git a/tests/test_uniform_distribution.c b/tests/test_uniform_distribution.c index 5b7f703..53388fe 100644 --- a/tests/test_uniform_distribution.c +++ b/tests/test_uniform_distribution.c @@ -37,14 +37,6 @@ static void test_create_uniform_distribution() { assert( err == CCS_SUCCESS ); assert( data_type == CCS_NUM_INTEGER ); - err = ccs_distribution_get_scale_type(distrib, &stype); - assert( err == CCS_SUCCESS ); - assert( stype == CCS_LINEAR ); - - err = ccs_distribution_get_quantization(distrib, &quantization); - assert( err == CCS_SUCCESS ); - assert( quantization.i == q ); - err = ccs_distribution_get_bounds(distrib, &interval); assert( err == CCS_SUCCESS ); assert( interval.type == CCS_NUM_INTEGER ); @@ -53,10 +45,12 @@ static void test_create_uniform_distribution() { assert( interval.upper.i == u ); assert( interval.upper_included == CCS_FALSE ); - err = ccs_uniform_distribution_get_parameters(distrib, &lower, &upper); + err = ccs_uniform_distribution_get_parameters(distrib, &lower, &upper, &stype, &quantization); assert( err == CCS_SUCCESS ); assert( lower.i == l ); assert( upper.i == u ); + assert( stype == CCS_LINEAR ); + assert( quantization.i == q ); err = ccs_object_get_refcount(distrib, &refcount); assert( err == CCS_SUCCESS ); -- 2.26.2