Commit 2ce430ab authored by Brice Videau's avatar Brice Videau

Uniform usage of enum int values.

parent 5b3c7150
......@@ -265,27 +265,35 @@ class ccs_datum_fix(ct.Structure):
class ccs_datum(ct.Structure):
_fields_ = [('_value', ccs_value),
('type', ccs_data_type)]
('_type', ccs_data_type)]
def __init__(self, v = None):
super().__init__()
self.value = v
@property
def type(self):
return self._type.value
@type.setter
def type(self, v):
self._type.value = v
@property
def value(self):
if self.type.value == ccs_data_type.NONE:
if self.type == ccs_data_type.NONE:
return None
elif self.type.value == ccs_data_type.INTEGER:
elif self.type == ccs_data_type.INTEGER:
return self._value.i
elif self.type.value == ccs_data_type.FLOAT:
elif self.type == ccs_data_type.FLOAT:
return self._value.f
elif self.type.value == ccs_data_type.BOOLEAN:
elif self.type == ccs_data_type.BOOLEAN:
return False if self._value.i == ccs_false else True
elif self.type.value == ccs_data_type.STRING:
elif self.type == ccs_data_type.STRING:
return self._value.s.decode()
elif self.type.value == ccs_data_type.INACTIVE:
elif self.type == ccs_data_type.INACTIVE:
return ccs_inactive
elif self.type.value == ccs_data_type.OBJECT:
elif self.type == ccs_data_type.OBJECT:
return Object.from_handle(ct.c_void_p(self._value.o))
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
......@@ -293,26 +301,26 @@ class ccs_datum(ct.Structure):
@value.setter
def value(self, v):
if v is None:
self.type.value = ccs_data_type.NONE
self.type = ccs_data_type.NONE
self._value.i = 0
elif isinstance(v, bool):
self.type.value = ccs_data_type.BOOLEAN
self.type = ccs_data_type.BOOLEAN
self._value.i = 1 if v else 0
elif isinstance(v, int):
self.type.value = ccs_data_type.INTEGER
self.type = ccs_data_type.INTEGER
self._value.i = v
elif isinstance(v, float):
self.type.value = ccs_data_type.FLOAT
self.type = ccs_data_type.FLOAT
self._value.f = v
elif isinstance(v, str):
self.type.value = ccs_data_type.STRING
self.type = ccs_data_type.STRING
self._string = str.encode(v)
self._value.s = ct.c_char_p(self._string)
elif v is ccs_inactive:
self.type.value = ccs_data_type.INACTIVE
self.type = ccs_data_type.INACTIVE
self._value.i = 0
elif isinstance(v, Object):
self.type.value = ccs_data_type.OBJECT
self.type = ccs_data_type.OBJECT
self._object = v
self._value.o = v.handle
else:
......@@ -362,8 +370,8 @@ class Object:
t = ccs_object_type(0)
res = ccs_object_get_type(self.handle, ct.byref(t))
Error.check(res)
self._object_type = t
return t
self._object_type = t.value
return self._object_type
@property
def refcount(self):
......
......@@ -48,8 +48,8 @@ class Distribution(Object):
v = ccs_distribution_type(0)
res = ccs_distribution_get_type(self.handle, ct.byref(v))
Error.check(res)
self._type = v
return v
self._type = v.value
return self._type
@property
def data_type(self):
......@@ -58,8 +58,8 @@ class Distribution(Object):
v = ccs_numeric_type(0)
res = ccs_distribution_get_data_type(self.handle, ct.byref(v))
Error.check(res)
self._data_type = v
return v
self._data_type = v.value
return self._data_type
@property
def dimension(self):
......@@ -78,8 +78,8 @@ class Distribution(Object):
v = ccs_scale_type(0)
res = ccs_distribution_get_scale_type(self.handle, ct.byref(v))
Error.check(res)
self._scale_type = v
return v
self._scale_type = v.value
return self._scale_type
@property
def quantization(self):
......@@ -88,7 +88,7 @@ class Distribution(Object):
v = ccs_numeric(0)
res = ccs_distribution_get_quantization(self.handle, ct.byref(v))
Error.check(res)
t = self.data_type.value
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._quantization = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
......@@ -117,7 +117,7 @@ class Distribution(Object):
v = ccs_numeric()
res = ccs_distribution_sample(self.handle, rng.handle, ct.byref(v))
Error.check(res)
t = self.data_type.value
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
return v.i
elif t == ccs_numeric_type.NUM_FLOAT:
......@@ -128,7 +128,7 @@ class Distribution(Object):
def samples(self, rng, count):
if count == 0:
return []
t = self.data_type.value
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
v = (ccs_int * count)()
elif t == ccs_numeric_type.NUM_FLOAT:
......@@ -174,7 +174,7 @@ class UniformDistribution(Distribution):
v = ccs_numeric()
res = ccs_uniform_distribution_get_parameters(self.handle, ct.byref(v), None)
Error.check(res)
t = self.data_type.value
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._lower = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
......@@ -190,7 +190,7 @@ class UniformDistribution(Distribution):
v = ccs_numeric()
res = ccs_uniform_distribution_get_parameters(self.handle, None, ct.byref(v))
Error.check(res)
t = self.data_type.value
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._upper = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
......@@ -226,7 +226,6 @@ class NormalDistribution(Distribution):
v = ccs_float()
res = ccs_normal_distribution_get_parameters(self.handle, ct.byref(v), None)
Error.check(res)
t = self.data_type.value
self._mu = v.value
return self._mu
......@@ -237,7 +236,6 @@ class NormalDistribution(Distribution):
v = ccs_float()
res = ccs_normal_distribution_get_parameters(self.handle, None, ct.byref(v))
Error.check(res)
t = self.data_type.value
self._sigma = v.value
return self._sigma
......
......@@ -158,4 +158,4 @@ class Evaluation(Object):
v = ccs_comparison(0)
res = ccs_evaluation_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v
return v.value
......@@ -115,8 +115,8 @@ class Expression(Object):
v = ccs_expression_type(0)
res = ccs_expression_get_type(self.handle, ct.byref(v))
Error.check(res)
self._type = v
return v
self._type = v.value
return self._type
@property
def num_nodes(self):
......
......@@ -52,8 +52,8 @@ class Hyperparameter(Object):
v = ccs_hyperparameter_type(0)
res = ccs_hyperparameter_get_type(self.handle, ct.byref(v))
Error.check(res)
self._type = v
return v
self._type = v.value
return self._type
@property
def user_data(self):
......@@ -188,8 +188,8 @@ class NumericalHyperparameter(Hyperparameter):
v = ccs_numeric_type(0)
res = ccs_numerical_hyperparameter_get_parameters(self.handle, ct.byref(v), None, None, None)
Error.check(res)
self._data_type = v
return v
self._data_type = v.value
return self._data_type
@property
def lower(self):
......@@ -198,7 +198,7 @@ class NumericalHyperparameter(Hyperparameter):
v = ccs_numeric()
res = ccs_numerical_hyperparameter_get_parameters(self.handle, None, ct.byref(v), None, None)
Error.check(res)
t = self.data_type.value
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._lower = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
......@@ -214,7 +214,7 @@ class NumericalHyperparameter(Hyperparameter):
v = ccs_numeric()
res = ccs_numerical_hyperparameter_get_parameters(self.handle, None, None, ct.byref(v), None)
Error.check(res)
t = self.data_type.value
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._upper = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
......@@ -230,7 +230,7 @@ class NumericalHyperparameter(Hyperparameter):
v = ccs_numeric(0)
res = ccs_numerical_hyperparameter_get_parameters(self.handle, None, None, None, ct.byref(v))
Error.check(res)
t = self.data_type.value
t = self.data_type
if t == ccs_numeric_type.NUM_INTEGER:
self._quantization = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
......
......@@ -3,44 +3,56 @@ from . import libcconfigspace
from .base import Error, ccs_error, ccs_numeric_type, ccs_numeric, ccs_float, ccs_int, ccs_result, ccs_bool, ccs_false, ccs_true, _ccs_get_function
class ccs_interval(ct.Structure):
_fields_ = [('type', ccs_numeric_type),
_fields_ = [('_type', ccs_numeric_type),
('_lower', ccs_numeric),
('_upper', ccs_numeric),
('_lower_included', ccs_bool),
('_upper_included', ccs_bool)]
@property
def type(self):
return self._type.value
@property
def type(self, v):
self._type.value = v
@property
def lower(self):
if self.type.value == ccs_numeric_type.NUM_INTEGER:
t = self.type
if t == ccs_numeric_type.NUM_INTEGER:
return self._lower.i
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
elif t == ccs_numeric_type.NUM_FLOAT:
return self._lower.f
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
@lower.setter
def lower(self, value):
if self.type.value == ccs_numeric_type.NUM_INTEGER:
t = self.type
if t == ccs_numeric_type.NUM_INTEGER:
self._lower.i = value
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
elif t == ccs_numeric_type.NUM_FLOAT:
self._lower.f = value
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
@property
def upper(self):
if self.type.value == ccs_numeric_type.NUM_INTEGER:
t = self.type
if t == ccs_numeric_type.NUM_INTEGER:
return self._upper.i
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
elif t == ccs_numeric_type.NUM_FLOAT:
return self._upper.f
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
@upper.setter
def upper(self, value):
if self.type.value == ccs_numeric_type.NUM_INTEGER:
t = self.type
if t == ccs_numeric_type.NUM_INTEGER:
self._upper.i = value
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
elif t == ccs_numeric_type.NUM_FLOAT:
self._upper.f = value
else:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
......@@ -88,9 +100,10 @@ class ccs_interval(ct.Structure):
# this works around a subtle bug in union support...
def include(self, value):
v = ccs_numeric()
if self.type.value == ccs_numeric_type.NUM_INTEGER:
t = self.type
if t == ccs_numeric_type.NUM_INTEGER:
v.i = value
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
elif t == ccs_numeric_type.NUM_FLOAT:
v.f = value
res = ccs_interval_include(ct.byref(self), v.i)
return False if res == ccs_false else True
......
......@@ -137,7 +137,7 @@ class ObjectiveSpace(Context):
t = ccs_objective_type()
res = ccs_objective_space_get_objective(self.handle, index, ct.byref(v), ct.byref(t))
Error.check(res)
return (Expression.from_handle(v), t)
return (Expression.from_handle(v), t.value)
@property
def num_objective(self):
......@@ -153,4 +153,4 @@ class ObjectiveSpace(Context):
t = (ccs_objective_type * sz)()
res = ccs_objective_space_get_objectives(self.handle, sz, v, t, None)
Error.check(res)
return [(Expression.from_handle(ccs_expression(v[x])), ccs_objective_type(t[x])) for x in range(sz)]
return [(Expression.from_handle(ccs_expression(v[x])), t[x]) for x in range(sz)]
......@@ -43,8 +43,8 @@ class Tuner(Object):
v = ccs_tuner_type(0)
res = ccs_tuner_get_type(self.handle, ct.byref(v))
Error.check(res)
self._type = v
return v
self._type = v.value
return self._type
@property
def user_data(self):
......
......@@ -25,7 +25,7 @@ class TestTuner(unittest.TestCase):
(cs, os) = self.create_tuning_problem()
t = ccs.RandomTuner(name = "tuner", configuration_space = cs, objective_space = os)
self.assertEqual("tuner", t.name)
self.assertEqual(ccs.TUNER_RANDOM, t.type.value)
self.assertEqual(ccs.TUNER_RANDOM, t.type)
func = lambda x, y, z: [(x-2)*(x-2), sin(z+y)]
evals = [ccs.Evaluation(objective_space = os, configuration = c, values = func(*(c.values))) for c in t.ask(100)]
t.tell(evals)
......@@ -66,7 +66,7 @@ class TestTuner(unittest.TestCase):
if discard:
new_optimums.append(o)
else:
c = e.cmp(o).value
c = e.cmp(o)
if c == ccs.EQUIVALENT or c == ccs.WORSE:
discard = True
new_optimums.append(o)
......@@ -86,7 +86,7 @@ class TestTuner(unittest.TestCase):
(cs, os) = self.create_tuning_problem()
t = ccs.UserDefinedTuner(name = "tuner", configuration_space = cs, objective_space = os, delete = delete, ask = ask, tell = tell, get_optimums = get_optimums, get_history = get_history)
self.assertEqual("tuner", t.name)
self.assertEqual(ccs.TUNER_USER_DEFINED, t.type.value)
self.assertEqual(ccs.TUNER_USER_DEFINED, t.type)
self.assertEqual(cs.handle.value, t.configuration_space.handle.value)
self.assertEqual(os.handle.value, t.objective_space.handle.value)
func = lambda x, y, z: [(x-2)*(x-2), sin(z+y)]
......
......@@ -30,7 +30,7 @@ class TestTuner(ccs.UserDefinedTuner):
if discard:
new_optimums.append(o)
else:
c = e.cmp(o).value
c = e.cmp(o)
if c == ccs.EQUIVALENT or c == ccs.WORSE:
discard = True
new_optimums.append(o)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment