Commit 930ffb7e authored by Brice Videau's avatar Brice Videau
Browse files

Multidimensional distrib

parent 3ba7aa30
import ctypes as ct
from .base import Object, Error, ccs_error, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_rng, ccs_distribution, ccs_expression, ccs_datum
from .context import Context
from .distribution import Distribution
from .hyperparameter import Hyperparameter
from .expression import Expression
from .expression_parser import ccs_parser
......@@ -12,6 +13,8 @@ ccs_configuration_space_set_rng = _ccs_get_function("ccs_configuration_space_set
ccs_configuration_space_get_rng = _ccs_get_function("ccs_configuration_space_get_rng", [ccs_configuration_space, ct.POINTER(ccs_rng)])
ccs_configuration_space_add_hyperparameter = _ccs_get_function("ccs_configuration_space_add_hyperparameter", [ccs_configuration_space, ccs_hyperparameter, ccs_distribution])
ccs_configuration_space_add_hyperparameters = _ccs_get_function("ccs_configuration_space_add_hyperparameters", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ccs_distribution)])
ccs_configuration_space_set_distribution = _ccs_get_function("ccs_configuration_space_set_distribution", [ccs_configuration_space, ccs_distribution, ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameter_distribution = _ccs_get_function("ccs_configuration_space_get_hyperparameter_distribution", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_distribution), ct.POINTER(ct.c_size_t)])
ccs_configuration_space_set_condition = _ccs_get_function("ccs_configuration_space_set_condition", [ccs_configuration_space, ct.c_size_t, ccs_expression])
ccs_configuration_space_get_condition = _ccs_get_function("ccs_configuration_space_get_condition", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_expression)])
ccs_configuration_space_get_conditions = _ccs_get_function("ccs_configuration_space_get_conditions", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_expression), ct.POINTER(ct.c_size_t)])
......@@ -71,6 +74,33 @@ class ConfigurationSpace(Context):
res = ccs_configuration_space_add_hyperparameters(self.handle, count, hypers, distribs)
Error.check(res)
def set_distribution(self, distribution, hyperparameters):
count = distribution.dimension
if count != len(hyperparameters):
raise Error(ccs_error(ccs_error.INVALID_VALUE))
hyps = []
for h in hyperparameters:
if isinstance(h, Hyperparameter):
hyps.append(self.hyperparameter_index(h))
elif isinstance(h, str):
hyps.append(self.hyperparameter_index_by_name(h))
else:
hyps.append(h)
v = (ct.c_size_t * count)(*hyps)
res = ccs_configuration_space_set_distribution(self.handle, distribution.handle, v)
Error.check(res)
def get_hyperparameter_distribution(self, hyperparameter):
if isinstance(hyperparameter, Hyperparameter):
hyperparameter = self.hyperparameter_index(hyperparameter)
elif isinstance(hyperparameter, str):
hyperparameter = self.hyperparameter_index_by_name(hyperparameter)
v1 = ccs_distribution()
v2 = ct.c_size_t()
res = ccs_configuration_space_get_hyperparameter_distribution(self.handle, hyperparameter, ct.byref(v1), ct.byref(v2))
Error.check(res)
return [Distribution.from_handle(v1), v2.value]
def set_condition(self, hyperparameter, expression):
if isinstance(expression, str):
expression = ccs_parser.parse(expression, context = PContext(extra=self))
......
......@@ -7,7 +7,9 @@ class ccs_distribution_type(CEnumeration):
_members_ = [
('UNIFORM', 0),
'NORMAL',
'ROULETTE' ]
'ROULETTE',
'MIXTURE',
'MULTIVARIATE' ]
class ccs_scale_type(CEnumeration):
_members_ = [
......@@ -15,7 +17,7 @@ class ccs_scale_type(CEnumeration):
'LOGARITHMIC' ]
ccs_distribution_get_type = _ccs_get_function("ccs_distribution_get_type", [ccs_distribution, ct.POINTER(ccs_distribution_type)])
ccs_distribution_get_data_type = _ccs_get_function("ccs_distribution_get_data_type", [ccs_distribution, ct.POINTER(ccs_numeric_type)])
ccs_distribution_get_data_types = _ccs_get_function("ccs_distribution_get_data_types", [ccs_distribution, ct.POINTER(ccs_numeric_type)])
ccs_distribution_get_dimension = _ccs_get_function("ccs_distribution_get_dimension", [ccs_distribution, ct.POINTER(ct.c_size_t)])
ccs_distribution_get_bounds = _ccs_get_function("ccs_distribution_get_bounds", [ccs_distribution, ct.POINTER(ccs_interval)])
ccs_distribution_check_oversampling = _ccs_get_function("ccs_distribution_check_oversampling", [ccs_distribution, ct.POINTER(ccs_interval), ct.POINTER(ccs_bool)])
......@@ -36,6 +38,10 @@ class Distribution(Object):
return NormalDistribution(handle = handle, retain = True)
elif v == ccs_distribution_type.ROULETTE:
return RouletteDistribution(handle = handle, retain = True)
elif v == ccs_distribution_type.MIXTURE:
return MixtureDistribution(handle = handle, retain = True)
elif v == ccs_distribution_type.MULTIVARIATE:
return MultivariateDistribution(handle = handle, retain = True)
else:
raise Error(ccs_error(ccs_error.INVALID_DISTRIBUTION))
......@@ -50,14 +56,14 @@ class Distribution(Object):
return self._type
@property
def data_type(self):
if hasattr(self, "_data_type"):
return self._data_type
v = ccs_numeric_type(0)
res = ccs_distribution_get_data_type(self.handle, ct.byref(v))
def data_types(self):
if hasattr(self, "_data_types"):
return self._data_types
v = (ccs_numeric_type*self.dimension)()
res = ccs_distribution_get_data_types(self.handle, v)
Error.check(res)
self._data_type = v.value
return self._data_type
self._data_types = [t.value for t in v]
return self._data_types
@property
def dimension(self):
......@@ -86,30 +92,41 @@ class Distribution(Object):
return False if v.value == ccs_false else True
def sample(self, rng):
v = ccs_numeric()
res = ccs_distribution_sample(self.handle, rng.handle, ct.byref(v))
dim = self.dimension
v = (ccs_numeric*dim)()
res = ccs_distribution_sample(self.handle, rng.handle, v)
Error.check(res)
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
return v.i
elif t == ccs_numeric_type.NUM_FLOAT:
return v.f
if dim == 1:
t = self.data_types[0]
if t == ccs_numeric_type.NUM_INTEGER:
return v[0].i
elif t == ccs_numeric_type.NUM_FLOAT:
return v[0].f
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
return [ v[i].i if self.data_types[i] == ccs_numeric_type.NUM_INTEGER else v[i].f for i in range(dim) ]
def samples(self, rng, count):
if count == 0:
return []
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
v = (ccs_int * count)()
elif t == ccs_numeric_type.NUM_FLOAT:
v = (ccs_float * count)()
dim = self.dimension
if dim == 1:
t = self.data_types[0]
if t == ccs_numeric_type.NUM_INTEGER:
v = (ccs_int * (count * dim))()
elif t == ccs_numeric_type.NUM_FLOAT:
v = (ccs_float * (count * dim))()
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
res = ccs_distribution_samples(self.handle, rng.handle, count, ct.cast(v, ct.POINTER(ccs_numeric)))
Error.check(res)
return list(v)
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
res = ccs_distribution_samples(self.handle, rng.handle, count, ct.cast(v, ct.POINTER(ccs_numeric)))
Error.check(res)
return list(v)
v = (ccs_numeric*dim*count)()
res = ccs_distribution_samples(self.handle, rng.handle, count, v)
Error.check(res)
return [ [v[j][i].i if self.data_types[i] == ccs_numeric_type.NUM_INTEGER else v[j][i].f for i in range(dim) ] for j in range(count) ]
ccs_create_uniform_distribution = _ccs_get_function("ccs_create_uniform_distribution", [ccs_numeric_type, ccs_int, ccs_int, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_uniform_int_distribution = _ccs_get_function("ccs_create_uniform_int_distribution", [ccs_int, ccs_int, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
......@@ -139,6 +156,13 @@ class UniformDistribution(Distribution):
def float(cls, lower, upper, scale = ccs_scale_type.LINEAR, quantization = 0.0):
return cls(data_type = NUM_FLOAT, lower = lower, upper = upper, scale = scale, quantization = quantization)
@property
def data_type(self):
if hasattr(self, "_data_type"):
return self._data_type
self._data_type = self.data_types[0]
return self._data_type
@property
def lower(self):
if hasattr(self, "_lower"):
......@@ -227,6 +251,13 @@ class NormalDistribution(Distribution):
def float(cls, mu, sigma, scale = ccs_scale_type.LINEAR, quantization = 0.0):
return cls(data_type = NUM_FLOAT, mu = mu, sigma = sigma, scale = scale, quantization = quantization)
@property
def data_type(self):
if hasattr(self, "_data_type"):
return self._data_type
self._data_type = self.data_types[0]
return self._data_type
@property
def mu(self):
if hasattr(self, "_mu"):
......@@ -290,6 +321,13 @@ class RouletteDistribution(Distribution):
else:
super().__init__(handle = handle, retain = retain)
@property
def data_type(self):
if hasattr(self, "_data_type"):
return self._data_type
self._data_type = self.data_types[0]
return self._data_type
@property
def num_areas(self):
if hasattr(self, "_num_areas"):
......@@ -310,4 +348,86 @@ class RouletteDistribution(Distribution):
self._areas = list(v)
return self._areas
ccs_create_mixture_distribution = _ccs_get_function("ccs_create_mixture_distribution", [ct.c_size_t, ct.POINTER(ccs_distribution), ct.POINTER(ccs_float), ct.POINTER(ccs_distribution)])
ccs_mixture_distribution_get_num_distributions = _ccs_get_function("ccs_mixture_distribution_get_num_distributions", [ccs_distribution, ct.POINTER(ct.c_size_t)])
ccs_mixture_distribution_get_distributions = _ccs_get_function("ccs_mixture_distribution_get_distributions", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_distribution), ct.POINTER(ct.c_size_t)])
ccs_mixture_distribution_get_weights = _ccs_get_function("ccs_mixture_distribution_get_weights", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ct.c_size_t)])
class MixtureDistribution(Distribution):
def __init__(self, handle = None, retain = False, distributions = [], weights = None):
if handle is None:
handle = ccs_distribution(0)
if weights is None:
weights = [1.0] * len(distributions)
ws = (ccs_float * len(distributions))(*weights)
ds = (ccs_distribution * len(distributions))(*[x.handle.value for x in distributions])
res = ccs_create_mixture_distribution(len(distributions), ds, ws, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def num_distributions(self):
if hasattr(self, "_num_distributions"):
return self._num_distributions
v = ct.c_size_t()
res = ccs_mixture_distribution_get_num_distributions(self.handle, ct.byref(v))
Error.check(res)
self._num_distributions = v.value
return self._num_distributions
@property
def weights(self):
if hasattr(self, "_weights"):
return self._weights
v = (ccs_float * self.num_distributions)()
res = ccs_mixture_distribution_get_weights(self.handle, self.num_distributions, v, None)
Error.check(res)
self._weights = list(v)
return self._weights
@property
def distributions(self):
if hasattr(self, "_distributions"):
return self._distributions
v = (ccs_distribution * self.num_distributions)()
res = ccs_mixture_distribution_get_distributions(self.handle, self.num_distributions, v, None)
Error.check(res)
self._distributions = [Distribution.from_handle(ccs_distribution(x)) for x in v]
return self._distributions
ccs_create_multivariate_distribution = _ccs_get_function("ccs_create_multivariate_distribution", [ct.c_size_t, ct.POINTER(ccs_distribution), ct.POINTER(ccs_distribution)])
ccs_multivariate_distribution_get_num_distributions = _ccs_get_function("ccs_multivariate_distribution_get_num_distributions", [ccs_distribution, ct.POINTER(ct.c_size_t)])
ccs_multivariate_distribution_get_distributions = _ccs_get_function("ccs_multivariate_distribution_get_distributions", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_distribution), ct.POINTER(ct.c_size_t)])
class MultivariateDistribution(Distribution):
def __init__(self, handle = None, retain = False, distributions = [], weights = None):
if handle is None:
handle = ccs_distribution(0)
ds = (ccs_distribution * len(distributions))(*[x.handle.value for x in distributions])
res = ccs_create_multivariate_distribution(len(distributions), ds, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def num_distributions(self):
if hasattr(self, "_num_distributions"):
return self._num_distributions
v = ct.c_size_t()
res = ccs_multivariate_distribution_get_num_distributions(self.handle, ct.byref(v))
Error.check(res)
self._num_distributions = v.value
return self._num_distributions
@property
def distributions(self):
if hasattr(self, "_distributions"):
return self._distributions
v = (ccs_distribution * self.num_distributions)()
res = ccs_multivariate_distribution_get_distributions(self.handle, self.num_distributions, v, None)
Error.check(res)
self._distributions = [Distribution.from_handle(ccs_distribution(x)) for x in v]
return self._distributions
......@@ -34,6 +34,31 @@ class TestConfigurationSpace(unittest.TestCase):
for c in cs.samples(100):
cs.check(c)
def test_set_distribution(self):
cs = ccs.ConfigurationSpace(name = "space")
h1 = ccs.NumericalHyperparameter()
h2 = ccs.NumericalHyperparameter()
h3 = ccs.NumericalHyperparameter()
cs.add_hyperparameters([h1, h2, h3])
distributions = [ ccs.UniformDistribution.float(lower = 0.1, upper = 0.3),
ccs.UniformDistribution.float(lower = 0.2, upper = 0.6) ]
d = ccs.MultivariateDistribution(distributions = distributions)
cs.set_distribution(d, [h1, h2])
(dist, indx) = cs.get_hyperparameter_distribution(h1)
self.assertEqual( d.handle.value, dist.handle.value )
self.assertEqual( 0, indx )
(dist, indx) = cs.get_hyperparameter_distribution(h2)
self.assertEqual( d.handle.value, dist.handle.value )
self.assertEqual( 1, indx )
cs.set_distribution(d, [h3, h1])
(dist, indx) = cs.get_hyperparameter_distribution(h1)
self.assertEqual( d.handle.value, dist.handle.value )
self.assertEqual( 1, indx )
(dist, indx) = cs.get_hyperparameter_distribution(h3)
self.assertEqual( d.handle.value, dist.handle.value )
self.assertEqual( 0, indx )
def test_conditions(self):
h1 = ccs.NumericalHyperparameter(lower = -1.0, upper = 1.0, default = 0.0)
h2 = ccs.NumericalHyperparameter(lower = -1.0, upper = 1.0)
......
......@@ -169,6 +169,39 @@ class TestDistribution(unittest.TestCase):
for v in a:
self.assertTrue( i.include(v) )
def test_create_mixture(self):
distributions = [ ccs.UniformDistribution.float(lower = -5.0, upper = 0.0),
ccs.UniformDistribution.float(lower = 0.0, upper = 2.0) ]
d = ccs.MixtureDistribution(distributions = distributions)
self.assertEqual( d.object_type, ccs.DISTRIBUTION )
self.assertEqual( d.type, ccs.MIXTURE )
self.assertEqual( d.data_types, [ccs.NUM_FLOAT] )
self.assertEqual( d.weights, [0.5, 0.5] )
self.assertEqual( [x.handle.value for x in d.distributions], [x.handle.value for x in distributions] )
d2 = ccs.Object.from_handle(d.handle)
self.assertEqual( d.__class__, d2.__class__ )
def test_create_multivariate(self):
distributions = [ ccs.UniformDistribution.float(lower = -5.0, upper = 0.0),
ccs.UniformDistribution.int(lower = 0, upper = 2) ]
d = ccs.MultivariateDistribution(distributions = distributions)
self.assertEqual( d.object_type, ccs.DISTRIBUTION )
self.assertEqual( d.type, ccs.MULTIVARIATE )
self.assertEqual( d.data_types, [ccs.NUM_FLOAT, ccs.NUM_INTEGER] )
self.assertEqual( [x.handle.value for x in d.distributions], [x.handle.value for x in distributions] )
d2 = ccs.Object.from_handle(d.handle)
self.assertEqual( d.__class__, d2.__class__ )
def test_mixture_multidim(self):
distributions = [ ccs.UniformDistribution.float(lower = -5.0, upper = 0.0),
ccs.UniformDistribution.int(lower = 0, upper = 2) ]
d = ccs.MultivariateDistribution(distributions = distributions)
d2 = ccs.MixtureDistribution(distributions = [d, d])
self.assertEqual( d2.object_type, ccs.DISTRIBUTION )
self.assertEqual( d2.type, ccs.MIXTURE )
self.assertEqual( d2.data_types, [ccs.NUM_FLOAT, ccs.NUM_INTEGER] )
self.assertEqual( d2.weights, [0.5, 0.5] )
if __name__ == '__main__':
unittest.main()
......@@ -35,9 +35,11 @@ module CCS
class MemoryPointer
alias read_ccs_float_t read_double
alias get_ccs_float_t get_double
alias read_array_of_ccs_float_t read_array_of_double
alias write_array_of_ccs_float_t write_array_of_double
alias read_ccs_int_t read_int64
alias get_ccs_int_t get_int64
alias read_array_of_ccs_int_t read_array_of_int64
alias write_array_of_ccs_int_t write_array_of_int64
alias read_ccs_bool_t read_int32
......@@ -45,8 +47,12 @@ module CCS
alias read_ccs_hash_t read_uint32
if FFI.find_type(:size_t).size == 8
alias read_size_t read_uint64
alias write_size_t write_uint64
alias write_array_of_size_t write_array_of_uint64
else
alias read_size_t read_uint32
alias write_size_t write_uint32
alias write_array_of_size_t write_array_of_uint32
end
end
......
......@@ -4,6 +4,8 @@ module CCS
attach_function :ccs_configuration_space_get_rng, [:ccs_configuration_space_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_add_hyperparameter, [:ccs_configuration_space_t, :ccs_hyperparameter_t, :ccs_distribution_t], :ccs_result_t
attach_function :ccs_configuration_space_add_hyperparameters, [:ccs_configuration_space_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_set_distribution, [:ccs_configuration_space_t, :ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_hyperparameter_distribution, [:ccs_configuration_space_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_set_condition, [:ccs_configuration_space_t, :size_t, :ccs_expression_t], :ccs_result_t
attach_function :ccs_configuration_space_get_condition, [:ccs_configuration_space_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_conditions, [:ccs_configuration_space_t, :size_t, :pointer, :pointer], :ccs_result_t
......@@ -53,6 +55,40 @@ module CCS
self
end
def set_distribution(distribution, hyperparameters )
count = distribution.dimension
raise CCSError, :CCS_INVALID_VALUE if count != hyperparameters.size
hyperparameters = hyperparameters.collect { |h|
case h
when Hyperparameter
hyperparameter_index(h)
when String
hyperparameter_index_by_name(hyperparameter)
else
h
end
}
p_hypers = MemoryPointer::new(:size_t, count)
p_hypers.write_array_of_size_t(hyperparameters)
res = CCS.ccs_configuration_space_set_distribution(@handle, distribution, p_hypers)
CCS.error_check(res)
self
end
def get_hyperparameter_distribution(hyperparameter)
case hyperparameter
when Hyperparameter
hyperparameter = hyperparameter_index(hyperparameter);
when String
hyperparameter = hyperparameter_index_by_name(hyperparameter);
end
p_distribution = MemoryPointer::new(:ccs_distribution_t)
p_indx = MemoryPointer::new(:size_t)
res = CCS.ccs_configuration_space_get_hyperparameter_distribution(@handle, hyperparameter, p_distribution, p_indx)
CCS.error_check(res)
[CCS::Distribution.from_handle(p_distribution.read_ccs_distribution_t), p_indx.read_size_t]
end
def add_hyperparameters(hyperparameters, distributions: nil)
count = hyperparameters.size
return self if count == 0
......
......@@ -3,7 +3,9 @@ module CCS
DistributionType = enum FFI::Type::INT32, :ccs_distribution_type_t, [
:CCS_UNIFORM,
:CCS_NORMAL,
:CCS_ROULETTE
:CCS_ROULETTE,
:CCS_MIXTURE,
:CCS_MULTIVARIATE
]
class MemoryPointer
def read_ccs_distribution_type_t
......@@ -22,7 +24,7 @@ module CCS
end
attach_function :ccs_distribution_get_type, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_data_type, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_data_types, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_dimension, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_get_bounds, [:ccs_distribution_t, :pointer], :ccs_result_t
attach_function :ccs_distribution_check_oversampling, [:ccs_distribution_t, Interval.by_ref, :pointer], :ccs_result_t
......@@ -31,7 +33,6 @@ module CCS
class Distribution < Object
add_property :type, :ccs_distribution_type_t, :ccs_distribution_get_type, memoize: true
add_property :data_type, :ccs_numeric_type_t, :ccs_distribution_get_data_type, memoize: true
add_property :dimension, :size_t, :ccs_distribution_get_dimension, memoize: true
def self.from_handle(handle)
......@@ -45,11 +46,24 @@ module CCS
NormalDistribution::new(handle, retain: true)
when :CCS_ROULETTE
RouletteDistribution::new(handle, retain: true)
when :CCS_MIXTURE
MixtureDistribution::new(handle, retain: true)
when :CCS_MULTIVARIATE
MultivariateDistribution::new(handle, retain: true)
else
raise CCSError, :CCS_INVALID_DISTRIBUTION
end
end
def data_types
@data_types ||= begin
ptr = MemoryPointer::new(:ccs_numeric_type_t, dimension)
res = CCS.ccs_distribution_get_data_types(@handle, ptr)
CCS.error_check(res)
ptr.read_array_of_int64(dimension).collect { |i| NumericType.from_native(i, nil) }
end
end
def bounds
@bounds ||= begin
interval = Interval::new(type: :CCS_NUM_FLOAT)
......@@ -67,25 +81,50 @@ module CCS
end
def sample(rng)
ptr = MemoryPointer::new(:ccs_numeric_t)
dim = dimension
ptr = MemoryPointer::new(:ccs_numeric_t, dim)
res = CCS.ccs_distribution_sample(@handle, rng, ptr)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
if dim == 1
if data_types.first == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
else
ptr.read_ccs_int_t
end
else
ptr.read_ccs_int_t
data_types.each_with_index.collect { |t, i|
if t == :CCS_NUM_FLOAT
ptr.get_ccs_float_t(i*8)
else
ptr.get_ccs_int_t(i*8)
end
}
end
end
def samples(rng, count)
return [] if count == 0
ptr = MemoryPointer::new(:ccs_numeric_t, count)
dim = dimension
ptr = MemoryPointer::new(:ccs_numeric_t, count*dim)
res = CCS.ccs_distribution_samples(@handle, rng, count, ptr)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_array_of_ccs_float_t(count)
if dim == 1
if data_types.first == :CCS_NUM_FLOAT
ptr.read_array_of_ccs_float_t(count)
else
ptr.read_array_of_ccs_int_t(count)
end
else
ptr.read_array_of_ccs_int_t(count)
sz = CCS.find_type(:ccs_numeric_t).size
count.times.collect { |j|
data_types.each_with_index.collect { |t, i|
if t == :CCS_NUM_FLOAT
ptr.get_ccs_float_t((j*dim + i)*sz)
else
ptr.get_ccs_int_t((j*dim + i)*sz)
end
}
}
end
end
......@@ -120,6 +159,10 @@ module CCS
self.new(nil, data_type: :CCS_NUM_FLOAT, lower: lower, upper: upper, scale: scale, quantization: quantization)
end
def data_type
@data_type ||= data_types.first
end
def lower
@lower ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
......@@ -197,6 +240,10 @@ module CCS
self::new(nil, retain: false, data_type: :CCS_NUM_FLOAT, mu: mu, sigma: sigma, scale: scale, quantization: quantization)
end
def data_type
@data_type ||= data_types.first
end
def mu
@mu ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
......@@ -256,6 +303,10 @@ module CCS
end
end