Commit 643abc3b authored by Brice Videau's avatar Brice Videau

Added support for (1 for now) destructor callback on CCS objects.

parent 5444ebd9
......@@ -345,11 +345,14 @@ ccs_retain_object = _ccs_get_function("ccs_retain_object", [ccs_object])
ccs_release_object = _ccs_get_function("ccs_release_object", [ccs_object])
ccs_object_get_type = _ccs_get_function("ccs_object_get_type", [ccs_object, ct.POINTER(ccs_object_type)])
ccs_object_get_refcount = _ccs_get_function("ccs_object_get_refcount", [ccs_object, ct.POINTER(ct.c_int)])
ccs_object_destroy_callback_type = ct.CFUNCTYPE(None, ccs_object, ct.c_void_p)
ccs_object_set_destroy_callback = _ccs_get_function("ccs_object_set_destroy_callback", [ccs_object, ccs_object_destroy_callback_type, ct.c_void_p])
_res = ccs_init()
Error.check(_res)
class Object:
def __init__(self, handle, retain = False, auto_release = True):
if handle is None:
raise Error(ccs_error(ccs_error.INVALID_OBJECT))
......@@ -360,8 +363,9 @@ class Object:
Error.check(res)
def __del__(self):
res = ccs_release_object(self._handle)
Error.check(res)
if self.auto_release:
res = ccs_release_object(self._handle)
Error.check(res)
@property
def handle(self):
......@@ -389,28 +393,59 @@ class Object:
t = ccs_object_type(0)
res = ccs_object_get_type(h, ct.byref(t))
Error.check(res)
r = ct.c_int(0)
res = ccs_object_get_refcount(h, ct.byref(r))
Error.check(res)
r = r.value
if r == 0:
retain = False
auto_release = False
else:
retain = True
auto_release = True
v = t.value
if v == ccs_object_type.RNG:
return Rng.from_handle(h)
return Rng.from_handle(h, retain = retain, auto_release = auto_release)
elif v == ccs_object_type.DISTRIBUTION:
return Distribution.from_handle(h)
return Distribution.from_handle(h, retain = retain, auto_release = auto_release)
elif v == ccs_object_type.HYPERPARAMETER:
return Hyperparameter.from_handle(h)
return Hyperparameter.from_handle(h, retain = retain, auto_release = auto_release)
elif v == ccs_object_type.EXPRESSION:
return Expression.from_handle(h)
return Expression.from_handle(h, retain = retain, auto_release = auto_release)
elif v == ccs_object_type.CONFIGURATION_SPACE:
return ConfigurationSpace.from_handle(h)
return ConfigurationSpace.from_handle(h, retain = retain, auto_release = auto_release)
elif v == ccs_object_type.CONFIGURATION:
return Configuration.from_handle(h)
return Configuration.from_handle(h, retain = retain, auto_release = auto_release)
elif v == ccs_object_type.OBJECTIVE_SPACE:
return ObjectiveSpace.from_handle(h)
return ObjectiveSpace.from_handle(h, retain = retain, auto_release = auto_release)
elif v == ccs_object_type.EVALUATION:
return Evaluation.from_handle(h)
elif v == ccs_object_type.Tuner:
return Tuner.from_handle(h)
return Evaluation.from_handle(h, retain = retain, auto_release = auto_release)
elif v == ccs_object_type.TUNER:
return Tuner.from_handle(h, retain = retain, auto_release = auto_release)
else:
raise Error(ccs_error(ccs_error.INVALID_OBJECT))
def set_destroy_callback(self, callback, user_data = None):
_set_destroy_callback(self.handle, callback, user_data = user_data)
_callbacks = {}
def _set_destroy_callback(handle, callback, user_data = None):
if callback is None:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
ptr = ct.c_int(32)
def cb_wrapper(obj, data):
try:
callback(Object.from_handle(obj), data)
del _callbacks[ct.addressof(ptr)]
except Error as e:
None
cb_wrapper_func = ccs_object_destroy_callback_type(cb_wrapper)
res = ccs_object_set_destroy_callback(handle, cb_wrapper_func, user_data)
Error.check(res)
_callbacks[ct.addressof(ptr)] = (cb_wrapper_func, user_data, ptr)
_ccs_id = 0
def _ccs_get_id():
global _ccs_id
......
......@@ -17,7 +17,8 @@ ccs_configuration_hash = _ccs_get_function("ccs_configuration_hash", [ccs_config
ccs_configuration_cmp = _ccs_get_function("ccs_configuration_cmp", [ccs_configuration, ccs_configuration, ct.POINTER(ccs_int)])
class Configuration(Object):
def __init__(self, handle = None, retain = False, configuration_space = None, values = None, user_data = None):
def __init__(self, handle = None, retain = False, auto_release = True,
configuration_space = None, values = None, user_data = None):
if handle is None:
count = 0
if values:
......@@ -32,11 +33,11 @@ class Configuration(Object):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def from_handle(cls, handle):
return cls(handle = handle, retain = True)
def from_handle(cls, handle, retain = True, auto_release = True):
return cls(handle = handle, retain = retain, auto_release = auto_release)
@property
def user_data(self):
......
......@@ -29,18 +29,19 @@ ccs_configuration_space_sample = _ccs_get_function("ccs_configuration_space_samp
ccs_configuration_space_samples = _ccs_get_function("ccs_configuration_space_samples", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_configuration)])
class ConfigurationSpace(Context):
def __init__(self, handle = None, retain = False, name = "", user_data = None):
def __init__(self, handle = None, retain = False, auto_release = True,
name = "", user_data = None):
if handle is None:
handle = ccs_configuration_space()
res = ccs_create_configuration_space(str.encode(name), user_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def from_handle(cls, handle):
return cls(handle = handle, retain = True)
def from_handle(cls, handle, retain = True, auto_release = True):
return cls(handle = handle, retain = retain, auto_release = auto_release)
@property
def rng(self):
......
......@@ -27,21 +27,21 @@ ccs_distribution_samples = _ccs_get_function("ccs_distribution_samples", [ccs_di
class Distribution(Object):
@classmethod
def from_handle(cls, handle):
def from_handle(cls, handle, retain = True, auto_release = True):
v = ccs_distribution_type(0)
res = ccs_distribution_get_type(handle, ct.byref(v))
Error.check(res)
v = v.value
if v == ccs_distribution_type.UNIFORM:
return UniformDistribution(handle = handle, retain = True)
return UniformDistribution(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_distribution_type.NORMAL:
return NormalDistribution(handle = handle, retain = True)
return NormalDistribution(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_distribution_type.ROULETTE:
return RouletteDistribution(handle = handle, retain = True)
return RouletteDistribution(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_distribution_type.MIXTURE:
return MixtureDistribution(handle = handle, retain = True)
return MixtureDistribution(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_distribution_type.MULTIVARIATE:
return MultivariateDistribution(handle = handle, retain = True)
return MultivariateDistribution(handle = handle, retain = retain, auto_release = auto_release)
else:
raise Error(ccs_error(ccs_error.INVALID_DISTRIBUTION))
......@@ -134,7 +134,8 @@ ccs_create_uniform_float_distribution = _ccs_get_function("ccs_create_uniform_fl
ccs_uniform_distribution_get_parameters = _ccs_get_function("ccs_uniform_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_numeric), ct.POINTER(ccs_numeric), ct.POINTER(ccs_scale_type), ct.POINTER(ccs_numeric)])
class UniformDistribution(Distribution):
def __init__(self, handle = None, retain = False, data_type = NUM_FLOAT, lower = 0.0, upper = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0):
def __init__(self, handle = None, retain = False, auto_release = True,
data_type = NUM_FLOAT, lower = 0.0, upper = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0):
if handle is None:
handle = ccs_distribution(0)
if data_type == NUM_FLOAT:
......@@ -146,7 +147,7 @@ class UniformDistribution(Distribution):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def int(cls, lower, upper, scale = ccs_scale_type.LINEAR, quantization = 0):
......@@ -229,7 +230,8 @@ ccs_create_normal_float_distribution = _ccs_get_function("ccs_create_normal_floa
ccs_normal_distribution_get_parameters = _ccs_get_function("ccs_normal_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_float), ct.POINTER(ccs_float), ct.POINTER(ccs_scale_type), ct.POINTER(ccs_numeric)])
class NormalDistribution(Distribution):
def __init__(self, handle = None, retain = False, data_type = NUM_FLOAT, mu = 0.0, sigma = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0):
def __init__(self, handle = None, retain = False, auto_release = True,
data_type = NUM_FLOAT, mu = 0.0, sigma = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0):
if handle is None:
handle = ccs_distribution(0)
if data_type == NUM_FLOAT:
......@@ -241,7 +243,7 @@ class NormalDistribution(Distribution):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def int(cls, mu, sigma, scale = ccs_scale_type.LINEAR, quantization = 0):
......@@ -311,7 +313,8 @@ ccs_roulette_distribution_get_num_areas = _ccs_get_function("ccs_roulette_distri
ccs_roulette_distribution_get_areas = _ccs_get_function("ccs_roulette_distribution_get_areas", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ct.c_size_t)])
class RouletteDistribution(Distribution):
def __init__(self, handle = None, retain = False, areas = []):
def __init__(self, handle = None, retain = False, auto_release = True,
areas = []):
if handle is None:
handle = ccs_distribution(0)
v = (ccs_float * len(areas))(*areas)
......@@ -319,7 +322,7 @@ class RouletteDistribution(Distribution):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@property
def data_type(self):
......@@ -354,7 +357,8 @@ ccs_mixture_distribution_get_distributions = _ccs_get_function("ccs_mixture_dist
ccs_mixture_distribution_get_weights = _ccs_get_function("ccs_mixture_distribution_get_weights", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ct.c_size_t)])
class MixtureDistribution(Distribution):
def __init__(self, handle = None, retain = False, distributions = [], weights = None):
def __init__(self, handle = None, retain = False, auto_release = True,
distributions = [], weights = None):
if handle is None:
handle = ccs_distribution(0)
if weights is None:
......@@ -365,7 +369,7 @@ class MixtureDistribution(Distribution):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@property
def num_distributions(self):
......@@ -402,7 +406,8 @@ ccs_multivariate_distribution_get_num_distributions = _ccs_get_function("ccs_mul
ccs_multivariate_distribution_get_distributions = _ccs_get_function("ccs_multivariate_distribution_get_distributions", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_distribution), ct.POINTER(ct.c_size_t)])
class MultivariateDistribution(Distribution):
def __init__(self, handle = None, retain = False, distributions = [], weights = None):
def __init__(self, handle = None, retain = False, auto_release = True,
distributions = [], weights = None):
if handle is None:
handle = ccs_distribution(0)
ds = (ccs_distribution * len(distributions))(*[x.handle.value for x in distributions])
......@@ -410,7 +415,7 @@ class MultivariateDistribution(Distribution):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@property
def num_distributions(self):
......
......@@ -28,7 +28,8 @@ ccs_evaluation_get_objective_values = _ccs_get_function("ccs_evaluation_get_obje
ccs_evaluation_cmp = _ccs_get_function("ccs_evaluation_cmp", [ccs_evaluation, ccs_evaluation, ct.POINTER(ccs_comparison)])
class Evaluation(Object):
def __init__(self, handle = None, retain = False, objective_space = None, configuration = None, error = ccs_error.SUCCESS, values = None, user_data = None):
def __init__(self, handle = None, retain = False, auto_release = True,
objective_space = None, configuration = None, error = ccs_error.SUCCESS, values = None, user_data = None):
if handle is None:
count = 0
if values:
......@@ -43,11 +44,11 @@ class Evaluation(Object):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def from_handle(cls, handle):
return cls(handle = handle, retain = True)
def from_handle(cls, handle, retain = True, auto_release = True):
return cls(handle = handle, retain = retain, auto_release = auto_release)
@property
def user_data(self):
......
......@@ -69,7 +69,8 @@ ccs_expression_get_hyperparameters = _ccs_get_function("ccs_expression_get_hyper
ccs_expression_check_context = _ccs_get_function("ccs_expression_check_context", [ccs_expression, ccs_context])
class Expression(Object):
def __init__(self, handle = None, retain = False, t = None, nodes = []):
def __init__(self, handle = None, retain = False, auto_release = True,
t = None, nodes = []):
if handle is None:
sz = len(nodes)
handle = ccs_expression()
......@@ -80,22 +81,22 @@ class Expression(Object):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def from_handle(cls, handle):
def from_handle(cls, handle, retain = True, auto_release = True):
v = ccs_expression_type(0)
res = ccs_expression_get_type(handle, ct.byref(v))
Error.check(res)
v = v.value
if v == ccs_expression_type.LIST:
return List(handle = handle, retain = True)
return List(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_expression_type.LITERAL:
return Literal(handle = handle, retain = True)
return Literal(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_expression_type.VARIABLE:
return Variable(handle = handle, retain = True)
return Variable(handle = handle, retain = retain, auto_release = auto_release)
else:
return cls(handle = handle, retain = True)
return cls(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def binary(cls, t, left, right):
......@@ -208,7 +209,8 @@ class Literal(Expression):
true_aymbol = ccs_terminal_symbols[ccs_terminal_type.TERM_TRUE]
false_symbol = ccs_terminal_symbols[ccs_terminal_type.TERM_FALSE]
def __init__(self, handle = None, retain = False, value = None):
def __init__(self, handle = None, retain = False, auto_release = True,
value = None):
if handle is None:
handle = ccs_expression()
pv = ccs_datum(value)
......@@ -219,7 +221,7 @@ class Literal(Expression):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@property
def value(self):
......@@ -246,14 +248,15 @@ class Literal(Expression):
class Variable(Expression):
def __init__(self, handle = None, retain = False, hyperparameter = None):
def __init__(self, handle = None, retain = False, auto_release = True,
hyperparameter = None):
if handle is None:
handle = ccs_expression()
res = ccs_create_variable(hyperparameter.handle, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@property
def hyperparameter(self):
......@@ -270,11 +273,12 @@ class Variable(Expression):
class List(Expression):
def __init__(self, handle = None, retain = False, values = []):
def __init__(self, handle = None, retain = False, auto_release = True,
values = []):
if handle is None:
super().__init__(t = ccs_expression_type.LIST, nodes = values)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
def eval(self, index, context = None, values = None):
if context and values:
......
......@@ -25,19 +25,19 @@ ccs_hyperparameter_samples = _ccs_get_function("ccs_hyperparameter_samples", [cc
class Hyperparameter(Object):
@classmethod
def from_handle(cls, handle):
def from_handle(cls, handle, retain = True, auto_release = True):
v = ccs_hyperparameter_type(0)
res = ccs_hyperparameter_get_type(handle, ct.byref(v))
Error.check(res)
v = v.value
if v == ccs_hyperparameter_type.NUMERICAL:
return NumericalHyperparameter(handle = handle, retain = True)
return NumericalHyperparameter(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_hyperparameter_type.CATEGORICAL:
return CategoricalHyperparameter(handle = handle, retain = True)
return CategoricalHyperparameter(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_hyperparameter_type.ORDINAL:
return OrdinalHyperparameter(handle = handle, retain = True)
return OrdinalHyperparameter(handle = handle, retain = retain, auto_release = auto_release)
elif v == ccs_hyperparameter_type.DISCRETE:
return DiscreteHyperparameter(handle = handle, retain = True)
return DiscreteHyperparameter(handle = handle, retain = retain, auto_release = auto_release)
else:
raise Error(ccs_error(ccs_error.INVALID_HYPERPARAMETER))
......@@ -144,7 +144,8 @@ ccs_create_numerical_hyperparameter = _ccs_get_function("ccs_create_numerical_hy
ccs_numerical_hyperparameter_get_parameters = _ccs_get_function("ccs_numerical_hyperparameter_get_parameters", [ccs_hyperparameter, ct.POINTER(ccs_numeric_type), ct.POINTER(ccs_numeric), ct.POINTER(ccs_numeric), ct.POINTER(ccs_numeric)])
class NumericalHyperparameter(Hyperparameter):
def __init__(self, handle = None, retain = False, name = None, data_type = ccs_numeric_type.NUM_FLOAT, lower = 0.0, upper = 1.0, quantization = 0.0, default = None, user_data = None):
def __init__(self, handle = None, retain = False, auto_release = True,
name = None, data_type = ccs_numeric_type.NUM_FLOAT, lower = 0.0, upper = 1.0, quantization = 0.0, default = None, user_data = None):
if handle is None:
if name is None:
name = NumericalHyperparameter.default_name()
......@@ -171,7 +172,7 @@ class NumericalHyperparameter(Hyperparameter):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def int(cls, lower, upper, name = None, quantization = 0, default = None, user_data = None):
......@@ -243,7 +244,8 @@ ccs_create_categorical_hyperparameter = _ccs_get_function("ccs_create_categorica
ccs_categorical_hyperparameter_get_values = _ccs_get_function("ccs_categorical_hyperparameter_get_values", [ccs_hyperparameter, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
class CategoricalHyperparameter(Hyperparameter):
def __init__(self, handle = None, retain = False, name = None, values = [], default_index = 0, user_data = None):
def __init__(self, handle = None, retain = False, auto_release = True,
name = None, values = [], default_index = 0, user_data = None):
if handle is None:
if name is None:
name = NumericalHyperparameter.default_name()
......@@ -256,7 +258,7 @@ class CategoricalHyperparameter(Hyperparameter):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@property
def values(self):
......@@ -273,7 +275,8 @@ ccs_ordinal_hyperparameter_compare_values = _ccs_get_function("ccs_ordinal_hyper
ccs_ordinal_hyperparameter_get_values = _ccs_get_function("ccs_ordinal_hyperparameter_get_values", [ccs_hyperparameter, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
class OrdinalHyperparameter(Hyperparameter):
def __init__(self, handle = None, retain = False, name = None, values = [], default_index = 0, user_data = None):
def __init__(self, handle = None, retain = False, auto_release = True,
name = None, values = [], default_index = 0, user_data = None):
if handle is None:
if name is None:
name = NumericalHyperparameter.default_name()
......@@ -286,7 +289,7 @@ class OrdinalHyperparameter(Hyperparameter):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@property
def values(self):
......@@ -316,7 +319,8 @@ ccs_create_discrete_hyperparameter = _ccs_get_function("ccs_create_discrete_hype
ccs_discrete_hyperparameter_get_values = _ccs_get_function("ccs_discrete_hyperparameter_get_values", [ccs_hyperparameter, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
class DiscreteHyperparameter(Hyperparameter):
def __init__(self, handle = None, retain = False, name = None, values = [], default_index = 0, user_data = None):
def __init__(self, handle = None, retain = False, auto_release = True,
name = None, values = [], default_index = 0, user_data = None):
if handle is None:
if name is None:
name = NumericalHyperparameter.default_name()
......@@ -329,7 +333,7 @@ class DiscreteHyperparameter(Hyperparameter):
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@property
def values(self):
......
......@@ -22,18 +22,19 @@ ccs_objective_space_get_objective = _ccs_get_function("ccs_objective_space_get_o
ccs_objective_space_get_objectives = _ccs_get_function("ccs_objective_space_get_objectives", [ccs_objective_space, ct.c_size_t, ct.POINTER(ccs_expression), ct.POINTER(ccs_objective_type), ct.POINTER(ct.c_size_t)])
class ObjectiveSpace(Context):
def __init__(self, handle = None, retain = False, name = "", user_data = None):
def __init__(self, handle = None, retain = False, auto_release = True,
name = "", user_data = None):
if handle is None:
handle = ccs_objective_space()
res = ccs_create_objective_space(str.encode(name), user_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def from_handle(cls, handle):
return cls(handle = handle, retain = True)
def from_handle(cls, handle, retain = True, auto_release = True):
return cls(handle = handle, retain = retain, auto_release = auto_release)
def add_hyperparameter(self, hyperparameter):
res = ccs_objective_space_add_hyperparameter(self.handle, hyperparameter.handle)
......
......@@ -10,18 +10,18 @@ ccs_rng_min = _ccs_get_function("ccs_rng_min", [ccs_rng, ct.POINTER(ct.c_ulong)]
ccs_rng_max = _ccs_get_function("ccs_rng_max", [ccs_rng, ct.POINTER(ct.c_ulong)])
class Rng(Object):
def __init__(self, handle = None, retain = False):
def __init__(self, handle = None, retain = False, auto_release = True):
if handle is None:
handle = ccs_rng(0)
res = ccs_rng_create(ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
super().__init__(handle = handle, retain = retain, auto_release = auto_release)
@classmethod
def from_handle(cls, handle):
return cls(handle, retain = True)
def from_handle(cls, handle, retain = True, auto_release = True):
return cls(handle, retain = retain, auto_release = auto_release)
def __setattr__(self, name, value):
if name == 'seed':
......
This diff is collapsed.
......@@ -24,6 +24,7 @@ class TestTuner(unittest.TestCase):
def test_create_random(self):
(cs, os) = self.create_tuning_problem()
t = ccs.RandomTuner(name = "tuner", configuration_space = cs, objective_space = os)
t2 = ccs.Object.from_handle(t.handle)
self.assertEqual("tuner", t.name)
self.assertEqual(ccs.TUNER_RANDOM, t.type)
func = lambda x, y, z: [(x-2)*(x-2), sin(z+y)]
......@@ -85,6 +86,7 @@ class TestTuner(unittest.TestCase):
(cs, os) = self.create_tuning_problem()
t = ccs.UserDefinedTuner(name = "tuner", configuration_space = cs, objective_space = os, delete = delete, ask = ask, tell = tell, get_optimums = get_optimums, get_history = get_history)
t2 = ccs.Object.from_handle(t.handle)
self.assertEqual("tuner", t.name)
self.assertEqual(ccs.TUNER_USER_DEFINED, t.type)
self.assertEqual(cs.handle.value, t.configuration_space.handle.value)
......
......@@ -339,6 +339,8 @@ module CCS
attach_function :ccs_release_object, [:ccs_object_t], :ccs_result_t
attach_function :ccs_object_get_type, [:ccs_object_t, :pointer], :ccs_result_t
attach_function :ccs_object_get_refcount, [:ccs_object_t, :pointer], :ccs_result_t
callback :ccs_object_release_callback, [:ccs_object_t, :pointer], :void
attach_function :ccs_object_set_destroy_callback, [:ccs_object_t, :ccs_object_release_callback, :pointer], :ccs_result_t
class << self
alias version ccs_get_version
......@@ -423,34 +425,58 @@ module CCS
ptr = MemoryPointer::new(:ccs_object_type_t)
res = CCS.ccs_object_get_type(handle, ptr)
CCS.error_check(res)
ptr2 = MemoryPointer::new(:int32)
res = CCS.ccs_object_get_refcount(handle, ptr2)
CCS.error_check(res)
opts = ptr2.read_int32 == 0 ? {retain: false, auto_release: false} : {}
case ptr.read_ccs_object_type_t
when :CCS_RNG
CCS::Rng::from_handle(handle)
CCS::Rng
when :CCS_DISTRIBUTION
CCS::Distribution::from_handle(handle)
CCS::Distribution
when :CCS_HYPERPARAMETER
CCS::Hyperparameter::from_handle(handle)
CCS::Hyperparameter
when :CCS_EXPRESSION
CCS::Expression::from_handle(handle)
CCS::Expression
when :CCS_CONFIGURATION_SPACE
CCS::ConfigurationSpace::from_handle(handle)
CCS::ConfigurationSpace
when :CCS_CONFIGURATION
CCS::Configuration::from_handle(handle)
CCS::Configuration
when :CCS_OBJECTIVE_SPACE
CCS::ObjectiveSpace::from_handle(handle)
CCS::ObjectiveSpace
when :CCS_EVALUATION
CCS::Evaluation::from_handle(handle)
CCS::Evaluation
when :CCS_TUNER
CCS::Tuner::from_handle(handle)
CCS::Tuner
else
raise CCSError, :CCS_INVALID_OBJECT
end
end.from_handle(handle, **opts)
end
def to_ptr
@handle
end
def set_destroy_callback(user_data: nil, &block)
CCS.set_destroy_callback(@handle, user_data: user_data, &block)
self
end
end
@@callbacks = {}
def self.set_destroy_callback(handle, user_data: nil, &block)
if block
cb_wrapper = lambda { |object, data|
block.call(Object.from_handle(object), data)
@@callbacks.delete(cb_wrapper)
}
@@callbacks[cb_wrapper] = user_data
else
cb_wrapper = nil
end
res = CCS.ccs_object_set_destroy_callback(handle, cb_wrapper, user_data)
CCS.error_check(res)
end
end
......@@ -17,9 +17,10 @@ module CCS
add_property :hash, :ccs_hash_t, :ccs_configuration_hash, memoize: false
add_handle_property :configuration_space, :ccs_configuration_space_t, :ccs_configuration_get_configuration_space, memoize: true
def initialize(handle = nil, retain: false, configuration_space: nil, values: nil, user_data: nil)
def initialize(handle = nil, retain: false, auto_release: true,
configuration_space: nil, values: nil, user_data: nil)
if (handle)
super(handle, retain: retain)
super(handle, retain: retain, auto_release: auto_release)
else
if values
count = values.size
......@@ -37,8 +38,8 @@ module CCS
end
end
def self.from_handle(handle)
self::new(handle, retain: true)
def self.from_handle(handle, retain: true, auto_release: true)
self::new(handle, retain: retain, auto_release: auto_release)
end
def set_value(hyperparameter, value)
......