Commit 839e23cf authored by Brice Videau's avatar Brice Videau
Browse files

Added support for configurations and configuration space in python.

parent 80a68abd
......@@ -12,3 +12,6 @@ from .interval import *
from .distribution import *
from .hyperparameter import *
from .expression import *
from .context import *
from .configuration_space import *
from .configuration import *
......@@ -379,6 +379,10 @@ class Object:
return Hyperparameter.from_handle(h)
elif v == ccs_object_type.EXPRESSION:
return Expression.from_handle(h)
elif v == ccs_object_type.CONFIGURATION_SPACE:
return ConfigurationSpace.from_handle(h)
elif v == ccs_object_type.CONFIGURATION:
return Configuration.from_handle(h)
else:
raise Error(ccs_error.INVALID_OBJECT)
......@@ -392,3 +396,5 @@ from .rng import Rng
from .distribution import Distribution
from .hyperparameter import Hyperparameter
from .expression import Expression
from .configuration_space import ConfigurationSpace
from .configuration import Configuration
import ctypes as ct
from .base import Object, Error, ccs_error, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_rng, ccs_distribution, ccs_expression, ccs_datum, ccs_hash, ccs_int
from .context import Context
from .rng import Rng
from .configuration_space import ConfigurationSpace
ccs_create_configuration = _ccs_get_function("ccs_create_configuration", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_datum), ct.c_void_p, ct.POINTER(ccs_configuration)])
ccs_configuration_get_configuration_space = _ccs_get_function("ccs_configuration_get_configuration_space", [ccs_configuration, ct.POINTER(ccs_configuration_space)])
ccs_configuration_get_user_data = _ccs_get_function("ccs_configuration_get_user_data", [ccs_configuration, ct.POINTER(ct.c_void_p)])
ccs_configuration_get_value = _ccs_get_function("ccs_configuration_get_value", [ccs_configuration, ct.c_size_t, ct.POINTER(ccs_datum)])
ccs_configuration_set_value = _ccs_get_function("ccs_configuration_set_value", [ccs_configuration, ct.c_size_t, ccs_datum])
ccs_configuration_get_values = _ccs_get_function("ccs_configuration_get_values", [ccs_configuration, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
ccs_configuration_get_value_by_name = _ccs_get_function("ccs_configuration_get_value_by_name", [ccs_configuration, ct.c_char_p, ct.POINTER(ccs_datum)])
ccs_configuration_check = _ccs_get_function("ccs_configuration_check", [ccs_configuration])
ccs_configuration_hash = _ccs_get_function("ccs_configuration_hash", [ccs_configuration, ct.POINTER(ccs_hash)])
ccs_configuration_cmp = _ccs_get_function("ccs_configuration_cmp", [ccs_configuration, ccs_configuration, ct.POINTER(ccs_int)])
class Configuration(Object):
def __init__(self, handle = None, retain = False, configuration_space = None, values = None, user_data = None):
if handle is None:
count = 0
if values:
count = len(values)
vals = (ccs_datum * count)()
for i in range(count):
vals[i].value = values[i]
else:
vals = None
handle = ccs_configuration()
res = ccs_create_configuration(configuration_space.handle, count, vals, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@classmethod
def from_handle(cls, handle):
return cls(handle = handle, retain = True)
@property
def user_data(self):
if hasattr(self, "_user_data"):
return self._user_data
v = ct.c_void_p()
res = ccs_configuration_get_user_data(self.handle, ct.byref(v))
Error.check(res)
self._user_data = v
return v
@property
def configuration_space(self):
if hasattr(self, "_configuration_space"):
return self._configuration_space
v = ccs_configuration_space()
res = ccs_configuration_get_configuration_space(self.handle, ct.byref(v))
Error.check(res)
self._configuration_space = ConfigurationSpace.from_handle(v)
return self._configuration_space
@property
def num_values(self):
if hasattr(self, "_num_values"):
return self._num_values
v = ct.c_size_t()
res = ccs_configuration_get_values(self.handle, 0, None, ct.byref(v))
Error.check(res)
self._num_values = v.value
return self._num_values
@property
def hash(self):
v = ccs_hash()
res = ccs_configuration_hash(self.handle, ct.byref(v))
Error.check(res)
return self.value
def set_value(self, hyperparameter, value):
if isinstance(hyperparameter, Hyperparameter):
hyperparameter = self.configuration_space.hyperparameter_index(hyperparameter)
elif isinstance(hyperparameter, str):
hyperparameter = self.configuration_space.hyperparameter_index_by_name(hyperparameter)
pv = ccs_datum(value)
v = ccs_datum_fix()
v.value = pv._value.i
v.type = pv.type
res = ccs_configuration_set_value(self.handle, hyperparameter, v)
Error.check(res)
def value(self, hyperparameter):
v = ccs_datum()
if isinstance(hyperparameter, Hyperparameter):
res = ccs_configuration_get_value(self.handle, self.configuration_space.hyperparameter_index(hyperparameter), ct.byref(v))
elif isinstance(hyperparameter, str):
res = ccs_configuration_get_value_by_name(self.handle, str.encode(hyperparameter), ct.byref(v))
else:
res = ccs_configuration_get_value(self.handle, hyperparameter, ct.byref(v))
Error.check(res)
return v.value
@property
def values(self):
sz = self.num_values
if sz == 0:
return []
v = (ccs_datum * sz)()
res = ccs_configuration_get_values(self.handle, sz, v, None)
Error.check(res)
return [x.value for x in v]
def check(self):
res = ccs_configuration_check(self.handle)
Error.check(res)
def cmp(self, other):
v = ccs_int()
res = ccs_configuration_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value
def __lt__(self, other):
v = ccs_int()
res = ccs_configuration_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value < 0
def __le__(self, other):
v = ccs_int()
res = ccs_configuration_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value <= 0
def __gt__(self, other):
v = ccs_int()
res = ccs_configuration_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value > 0
def __ge__(self, other):
v = ccs_int()
res = ccs_configuration_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value >= 0
def __eq__(self, other):
v = ccs_int()
res = ccs_configuration_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value == 0
def __ne__(self, other):
v = ccs_int()
res = ccs_configuration_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value != 0
def __hash__(self):
return self.hash
def asdict(self):
res = {}
hyperparameters = self.configuration_space.hyperparameters
values = self.values
for i in range(len(hyperparameters)):
res[hyperparameters[i].name] = values[i]
return res
import ctypes as ct
from .base import Object, Error, ccs_error, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_rng, ccs_distribution, ccs_expression, ccs_datum
from .context import Context
from .hyperparameter import Hyperparameter
from .expression import Expression
from .rng import Rng
ccs_create_configuration_space = _ccs_get_function("ccs_create_configuration_space", [ct.c_char_p, ct.c_void_p, ct.POINTER(ccs_configuration_space)])
ccs_configuration_space_get_name = _ccs_get_function("ccs_configuration_space_get_name", [ccs_configuration_space, ct.POINTER(ct.c_char_p)])
ccs_configuration_space_get_user_data = _ccs_get_function("ccs_configuration_space_get_user_data", [ccs_configuration_space, ct.POINTER(ct.c_void_p)])
ccs_configuration_space_set_rng = _ccs_get_function("ccs_configuration_space_set_rng", [ccs_configuration_space, ccs_rng])
ccs_configuration_space_get_rng = _ccs_get_function("ccs_configuration_space_get_rng", [ccs_configuration_space, ct.POINTER(ccs_rng)])
ccs_configuration_space_add_hyperparameter = _ccs_get_function("ccs_configuration_space_add_hyperparameter", [ccs_configuration_space, ccs_hyperparameter, ccs_distribution])
ccs_configuration_space_add_hyperparameters = _ccs_get_function("ccs_configuration_space_add_hyperparameters", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ccs_distribution)])
ccs_configuration_space_get_num_hyperparameters = _ccs_get_function("ccs_configuration_space_get_num_hyperparameters", [ccs_configuration_space, ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameter = _ccs_get_function("ccs_configuration_space_get_hyperparameter", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter)])
ccs_configuration_space_get_hyperparameter_by_name = _ccs_get_function("ccs_configuration_space_get_hyperparameter_by_name", [ccs_configuration_space, ct.c_char_p, ct.POINTER(ccs_hyperparameter)])
ccs_configuration_space_get_hyperparameter_index_by_name = _ccs_get_function("ccs_configuration_space_get_hyperparameter_index_by_name", [ccs_configuration_space, ct.c_char_p, ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameter_index = _ccs_get_function("ccs_configuration_space_get_hyperparameter_index", [ccs_configuration_space, ccs_hyperparameter, ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameter_indexes = _ccs_get_function("ccs_configuration_space_get_hyperparameter_indexes", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameters = _ccs_get_function("ccs_configuration_space_get_hyperparameters", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ct.c_size_t)])
ccs_configuration_space_set_condition = _ccs_get_function("ccs_configuration_space_set_condition", [ccs_configuration_space, ct.c_size_t, ccs_expression])
ccs_configuration_space_get_condition = _ccs_get_function("ccs_configuration_space_get_condition", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_expression)])
ccs_configuration_space_get_conditions = _ccs_get_function("ccs_configuration_space_get_conditions", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_expression), ct.POINTER(ct.c_size_t)])
ccs_configuration_space_add_forbidden_clause = _ccs_get_function("ccs_configuration_space_add_forbidden_clause", [ccs_configuration_space, ccs_expression])
ccs_configuration_space_add_forbidden_clauses = _ccs_get_function("ccs_configuration_space_add_forbidden_clauses", [ccs_configuration_space, ct.c_size_t, ccs_expression])
ccs_configuration_space_get_forbidden_clause = _ccs_get_function("ccs_configuration_space_get_forbidden_clause", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_expression)])
ccs_configuration_space_get_forbidden_clauses = _ccs_get_function("ccs_configuration_space_get_forbidden_clauses", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_expression), ct.POINTER(ct.c_size_t)])
ccs_configuration_space_check_configuration = _ccs_get_function("ccs_configuration_space_check_configuration", [ccs_configuration_space, ccs_configuration])
ccs_configuration_space_check_configuration_values = _ccs_get_function("ccs_configuration_space_check_configuration_values", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_datum)])
ccs_configuration_space_get_default_configuration = _ccs_get_function("ccs_configuration_space_get_default_configuration", [ccs_configuration_space, ct.POINTER(ccs_configuration)])
ccs_configuration_space_sample = _ccs_get_function("ccs_configuration_space_sample", [ccs_configuration_space, ct.POINTER(ccs_configuration)])
ccs_configuration_space_samples = _ccs_get_function("ccs_configuration_space_samples", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_configuration)])
class ConfigurationSpace(Context):
def __init__(self, handle = None, retain = False, name = "", user_data = None):
if handle is None:
handle = ccs_configuration_space()
res = ccs_create_configuration_space(str.encode(name), user_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@classmethod
def from_handle(cls, handle):
return cls(handle = handle, retain = True)
@property
def user_data(self):
if hasattr(self, "_user_data"):
return self._user_data
v = ct.c_void_p()
res = ccs_configuration_space_get_user_data(self.handle, ct.byref(v))
Error.check(res)
self._user_data = v
return v
@property
def name(self):
if hasattr(self, "_name"):
return self._name
v = ct.c_char_p()
res = ccs_configuration_space_get_name(self.handle, ct.byref(v))
Error.check(res)
self._name = v.value.decode()
return self._name
@property
def rng(self):
v = ccs_rng()
res = ccs_configuration_space_get_rng(self.handle, ct.byref(v))
Error.check(res)
return Rng.from_handle(v)
@rng.setter
def rng(self, r):
res = ccs_configuration_space_set_rng(self.handle, r.handle)
Error.check(res)
def add_hyperparameter(self, hyperparameter, distribution = None):
if distribution:
distribution = distribution.handle
res = ccs_configuration_space_add_hyperparameter(self.handle, hyperparameter.handle, distribution)
Error.check(res)
def add_hyperparameters(self, hyperparameters, distributions = None):
count = len(hyperparameters)
if count == 0:
return None
if distributions:
if count != len(distributions):
raise Error(ccs_error.INVALID_VALUE)
distribs = (ccs_distribution * count)(*[x.handle.value if x else x for x in distributions])
else:
distribs = None
hypers = (ccs_hyperparameter * count)(*[x.handle.value for x in hyperparameters])
res = ccs_configuration_space_add_hyperparameters(self.handle, count, hypers, distribs)
Error.check(res)
def hyperparameter(self, index):
v = ccs_hyperparameter()
res = ccs_configuration_space_get_hyperparameter(self.handle, index, ct.byref(v))
Error.check(res)
return Hyperparameter.from_handle(v)
def hyperparameter_by_name(self, name):
v = ccs_hyperparameter()
res = ccs_configuration_space_get_hyperparameter_by_name(self.handle, str.encode(name), ct.byref(v))
Error.check(res)
return Hyperparameter.from_handle(v)
def hyperparameter_index(self, hyperparameter):
v = ct.c_size_t()
res = ccs_configuration_space_get_hyperparameter_index(self.handle, hyperparameter.handle, ct.byref(v))
Error.check(res)
return v.value
def hyperparameter_index_by_name(self, name):
v = ct.c_size_t()
res = ccs_configuration_space_get_hyperparameter_index_by_name(self.handle, str.encode(name), ct.byref(v))
Error.check(res)
return v.value
@property
def num_hyperparameters(self):
v = ct.c_size_t(0)
res = ccs_configuration_space_get_num_hyperparameters(self.handle, ct.byref(v))
Error.check(res)
return v.value
@property
def hyperparameters(self):
count = self.num_hyperparameters
if count == 0:
return []
v = (ccs_hyperparameter * count)()
res = ccs_configuration_space_get_hyperparameters(self.handle, count, v, None)
Error.check(res)
return [Hyperparameter.from_handle(ccs_hyperparameter(x)) for x in v]
def set_condition(self, hyperparameter, expression):
if isinstance(hyperparameter, Hyperparameter):
hyperparameter = self.hyperparameter_index(hyperparameter)
elif isinstance(hyperparameter, str):
hyperparameter = self.hyperparameter_index_by_name(hyperparameter)
res = ccs_configuration_space_set_condition(self.handle, hyperparameter, expression.handle)
Error.check(res)
def condition(self, hyperparameter):
if isinstance(hyperparameter, Hyperparameter):
hyperparameter = self.hyperparameter_index(hyperparameter)
elif isinstance(hyperparameter, str):
hyperparameter = self.hyperparameter_index_by_name(hyperparameter)
v = ccs_expression()
res = ccs_configuration_space_get_condition(self.handle, hyperparameter, ct.byref(v))
Error.check(res)
if v.value is None:
return None
else:
return Expression.from_handle(v)
@property
def num_conditions(self):
return self.num_hyperparameters
@property
def conditions(self):
sz = self.num_hyperparameters
if sz == 0:
return []
v = (ccs_expression * sz)()
res = ccs_configuration_space_get_conditions(self.handle, sz, v, None)
Error.check(res)
return [Expression.from_handle(ccs_expression(x)) if x else None for x in v]
def add_forbidden_clause(self, expression):
res = ccs_configuration_space_add_forbidden_clause(self.handle, expression.handle)
Error.check(res)
def add_forbidden_clauses(self, expressions):
sz = len(expressions)
if sz == 0:
return None
v = (ccs_expression * sz)(*[x.handle.value if x else x for x in expressions])
res = ccs_configuration_space_add_forbidden_clauses(self.handle, sz, v)
Error.check(res)
def forbidden_clause(self, index):
v = ccs_expression()
res = ccs_configuration_space_get_forbidden_clause(self.handle, index, ct.byref(v))
Error.check(res)
return Expression.from_handle(v)
@property
def num_forbidden_clauses(self):
v = ct.c_size_t()
res = ccs_configuration_space_get_forbidden_clauses(self.handle, 0, None, ct.byref(v))
Error.check(res)
return v.value
@property
def forbidden_clauses(self):
sz = self.num_forbidden_clauses
if sz == 0:
return []
v = (ccs_expression * sz)()
res = ccs_configuration_space_get_forbidden_clauses(self.handle, sz, v, None)
Error.check(res)
return [Expression.from_handle(ccs_expression(x)) for x in v]
def check(self, configuration):
res = ccs_configuration_space_check_configuration(self.handle, configuration.handle)
Error.check(res)
def check_values(self, values):
count = len(values)
if count != self.num_hyperparameters:
raise Error(ccs_error.INVALID_VALUE)
v = (ccs_datum * count)()
for i in range(count):
v[i].value = values[i]
res = ccs_configuration_space_check_configuration_values(self.handle, count, v)
Error.check(res)
@property
def default_configuration(self):
v = ccs_configuration()
res = ccs_configuration_space_get_default_configuration(self.handle, ct.byref(v))
Error.check(res)
return Configuration.from_handle(v)
def sample(self):
v = ccs_configuration()
res = ccs_configuration_space_sample(self.handle, ct.byref(v))
Error.check(res)
return Configuration.from_handle(v)
def samples(self, count):
if count == 0:
return []
v = (ccs_configuration * count)()
res = ccs_configuration_space_samples(self.handle, count, v)
Error.check(res)
return [Configuration.from_handle(x) for x in v]
from .configuration import Configuration
import ctypes as ct
from .base import Object, Error, ccs_error, _ccs_get_function, ccs_context, ccs_hyperparameter
ccs_context_get_hyperparameter_index = _ccs_get_function("ccs_context_get_hyperparameter_index", [ccs_context, ccs_hyperparameter, ct.POINTER(ct.c_size_t)])
class Context(Object):
def hyperparameter_index(self, hyperparameter):
v = ct.c_sizeof_t()
res = ccs_context_get_hyperparameter_index(self.handle, hyperparameter.handle, ct.byref(v))
Error.check(res)
return v.value
......@@ -126,6 +126,8 @@ class Distribution(Object):
raise Error(ccs_error.INVALID_VALUE)
def samples(self, rng, count):
if count == 0:
return []
t = self.data_type.value
if t == ccs_numeric_type.NUM_INTEGER:
v = (ccs_int * count)()
......
......@@ -136,7 +136,7 @@ class Hyperparameter(Object):
Error.check(res)
return [x.value for x in v]
def __eql__(self, other):
def __eq__(self, other):
return self.__class__ == other.__class__ and self.handle == other.handle
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment