Commit 7518220c authored by Brice Videau's avatar Brice Videau
Browse files

Added expression support to python bindings.

parent 7b44ffe0
......@@ -11,3 +11,4 @@ from .rng import *
from .interval import *
from .distribution import *
from .hyperparameter import *
from .expression import *
......@@ -111,7 +111,7 @@ class CEnumeration(ct.c_int, metaclass=CEnumerationType):
return cls(param)
def __repr__(self):
return "<member %s=%d of %r>" % (self.name, self.value, self.__class__)
return "<member %s(%d) of %r>" % (self.name, self.value, self.__class__)
def __str__(self):
return "%s.%s" % (self.__class__.__name__, self.name)
......@@ -154,7 +154,7 @@ class CEnumeration64(ct.c_longlong, metaclass=CEnumerationType64):
ct.c_longlong.__init__(self, value)
def __repr__(self):
return "<member %s=%d of %r>" % (self.name, self.value, self.__class__)
return "<member %s(%d) of %r>" % (self.name, self.value, self.__class__)
def __str__(self):
return "%s.%s" % (self.__class__.__name__, self.name)
......@@ -256,6 +256,10 @@ class ccs_value(ct.Union):
('s', ct.c_char_p),
('o', ccs_object)]
class ccs_datum_fix(ct.Structure):
_fields_ = [('value', ccs_int),
('type', ccs_data_type)]
class ccs_datum(ct.Structure):
_fields_ = [('_value', ccs_value),
('type', ccs_data_type)]
......@@ -330,22 +334,31 @@ class Object:
def __init__(self, handle, retain = False, auto_release = True):
if handle is None:
raise Error(ccs_error.INVALID_OBJECT)
self.handle = handle
self._handle = handle
self.auto_release = auto_release
if retain:
res = ccs_retain_object(handle)
Error.check(res)
def __del__(self):
res = ccs_release_object(self.handle)
res = ccs_release_object(self._handle)
Error.check(res)
@property
def handle(self):
return self._handle
@property
def object_type(self):
if hasattr(self, "_object_type"):
return self._object_type
t = ccs_object_type(0)
res = ccs_object_get_type(self.handle, ct.byref(t))
Error.check(res)
self._object_type = t
return t
@property
def refcount(self):
c = ct.c_int(0)
res = ccs_object_get_refcount(self.handle, ct.byref(c))
......@@ -364,6 +377,8 @@ class Object:
return Distribution.from_handle(h)
elif v == ccs_object_type.HYPERPARAMETER:
return Hyperparameter.from_handle(h)
elif v == ccs_object_type.EXPRESSION:
return Expression.from_handle(h)
else:
raise Error(ccs_error.INVALID_OBJECT)
......@@ -376,3 +391,4 @@ def _ccs_get_id():
from .rng import Rng
from .distribution import Distribution
from .hyperparameter import Hyperparameter
from .expression import Expression
import ctypes as ct
from . import libcconfigspace
from .base import Object, Error, ccs_error, CEnumeration, _ccs_get_function, ccs_expression, ccs_datum, ccs_datum_fix, ccs_hyperparameter, ccs_context
from .hyperparameter import Hyperparameter
class ccs_expression_type(CEnumeration):
_members_ = [
('OR', 0),
'AND',
'EQUAL',
'NOT_EQUAL',
'LESS',
'GREATER',
'LESS_OR_EQUAL',
'GREATER_OR_EQUAL',
'ADD',
'SUBSTRACT',
'MULTIPLY',
'DIVIDE',
'MODULO',
'POSITIVE',
'NEGATIVE',
'NOT',
'IN',
'LIST',
'LITERAL',
'VARIABLE' ]
class ccs_associativity_type(CEnumeration):
_members_ = [
('ASSOCIATIVITY_NONE', 0),
'LEFT_TO_RIGHT',
'RIGHT_TO_LEFT' ]
_sz_expr = len(ccs_expression_type._members_)
ccs_expression_precedence = (ct.c_int * _sz_expr).in_dll(libcconfigspace, "ccs_expression_precedence")
ccs_expression_associativity = (ccs_associativity_type * _sz_expr).in_dll(libcconfigspace, "ccs_expression_associativity")
ccs_expression_symbols = [x.decode() if x else x for x in (ct.c_char_p * _sz_expr).in_dll(libcconfigspace, "ccs_expression_symbols")]
ccs_expression_arity = (ct.c_int * _sz_expr).in_dll(libcconfigspace, "ccs_expression_arity")
ccs_create_binary_expression = _ccs_get_function("ccs_create_binary_expression", [ccs_expression_type, ccs_datum_fix, ccs_datum_fix, ct.POINTER(ccs_expression)])
ccs_create_unary_expression = _ccs_get_function("ccs_create_unary_expression", [ccs_expression_type, ccs_datum_fix, ct.POINTER(ccs_expression)])
ccs_create_expression = _ccs_get_function("ccs_create_expression", [ccs_expression_type, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ccs_expression)])
ccs_create_literal = _ccs_get_function("ccs_create_literal", [ccs_datum_fix, ct.POINTER(ccs_expression)])
ccs_create_variable = _ccs_get_function("ccs_create_variable", [ccs_hyperparameter, ct.POINTER(ccs_expression)])
ccs_expression_get_type = _ccs_get_function("ccs_expression_get_type", [ccs_expression, ct.POINTER(ccs_expression_type)])
ccs_expression_get_num_nodes = _ccs_get_function("ccs_expression_get_num_nodes", [ccs_expression, ct.POINTER(ct.c_size_t)])
ccs_expression_get_nodes = _ccs_get_function("ccs_expression_get_nodes", [ccs_expression, ct.c_size_t, ct.POINTER(ccs_expression), ct.POINTER(ct.c_size_t)])
ccs_literal_get_value = _ccs_get_function("ccs_literal_get_value", [ccs_expression, ct.POINTER(ccs_datum)])
ccs_variable_get_hyperparameter = _ccs_get_function("ccs_variable_get_hyperparameter", [ccs_expression, ct.POINTER(ccs_hyperparameter)])
ccs_expression_eval = _ccs_get_function("ccs_expression_eval", [ccs_expression, ccs_context, ct.POINTER(ccs_datum), ct.POINTER(ccs_datum)])
ccs_expression_list_eval_node = _ccs_get_function("ccs_expression_list_eval_node", [ccs_expression, ccs_context, ct.POINTER(ccs_datum), ct.c_size_t, ct.POINTER(ccs_datum)])
ccs_expression_get_hyperparameters = _ccs_get_function("ccs_expression_get_hyperparameters", [ccs_expression, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ct.c_size_t)])
ccs_expression_check_context = _ccs_get_function("ccs_expression_check_context", [ccs_expression, ccs_context])
class Expression(Object):
def __init__(self, handle = None, retain = False, t = None, nodes = []):
if handle is None:
sz = len(nodes)
handle = ccs_expression()
v = (ccs_datum*sz)()
for i in range(sz):
v[i].value = nodes[i]
res = ccs_create_expression(t, sz, v, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@classmethod
def from_handle(cls, handle):
v = ccs_expression_type(0)
res = ccs_expression_get_type(handle, ct.byref(v))
Error.check(res)
v = v.value
if v == ccs_expression_type.LIST:
return List(handle = handle, retain = True)
elif v == ccs_expression_type.LITERAL:
return Literal(handle = handle, retain = True)
elif v == ccs_expression_type.VARIABLE:
return Variable(handle = handle, retain = True)
else:
return cls(handle = handle, retain = True)
@classmethod
def binary(cls, t, left, right):
pvleft = ccs_datum(left)
pvright = ccs_datum(right)
vleft = ccs_datum_fix()
vright = ccs_datum_fix()
vleft.value = pvleft._value.i
vleft.type = pvleft.type
vright.value = pvright._value.i
vright.type = pvright.type
handle = ccs_expression()
res = ccs_create_binary_expression(t, vleft, vright, ct.byref(handle))
Error.check(res)
return cls(handle = handle, retain = False)
@classmethod
def unary(cls, t, node):
pvnode = ccs_datum(node)
vnode = ccs_datum_fix()
vnode.value = pvnode._value.i
vnode.type = pvnode.type
handle = ccs_expression()
res = ccs_create_unary_expression(t, vnode, ct.byref(handle))
Error.check(res)
return cls(handle = handle, retain = False)
@property
def type(self):
if hasattr(self, "_type"):
return self._type
v = ccs_expression_type(0)
res = ccs_expression_get_type(self.handle, ct.byref(v))
Error.check(res)
self._type = v
return v
@property
def num_nodes(self):
if hasattr(self, "_num_nodes"):
return self._num_nodes
v = ct.c_size_t(0)
res = ccs_expression_get_num_nodes(self.handle, ct.byref(v))
Error.check(res)
self._num_nodes = v.value
return self._num_nodes
@property
def nodes(self):
if hasattr(self, "_nodes"):
return self._nodes
sz = self.num_nodes
v = (ccs_expression * sz)()
res = ccs_expression_get_nodes(self.handle, sz, v, None)
Error.check(res)
self._nodes = [Expression.from_handle(handle = ccs_expression(x)) for x in v]
return self._nodes
@property
def hyperparameters(self):
if hasattr(self, "_hyperparameters"):
return self._hyperparameters
sz = ct.c_size_t()
res = ccs_expression_get_hyperparameters(self.handle, 0, None, ct.byref(sz))
Error.check(res)
sz = sz.value
if sz == 0:
self._hyperparameters = []
return []
v = (ccs_hyperparameter * sz.value)()
res = ccs_expression_get_hyperparameters(self.handle, sz, v, None)
Error.check(res)
self._hyperparameters = [Hyperparameter.from_handle(ccs_hyperparameter(x)) for x in v]
return self._hyperparameters
def eval(self, context = None, values = None):
if context and values:
count = context.num_hyperparameters
if count != len(values):
raise Error(ccs_error.INVALID_VALUE)
v = (ccs_datum * count)()
for i in range(count):
v[i].value = values[i]
values = v
context = context.handle
elif context or values:
raise Error(ccs_error.INVALID_VALUE)
v = ccs_datum()
res = ccs_expression_eval(self.handle, context, values, ct.byref(v))
Error.check(res)
return v.value
def check_context(self, context):
res = ccs_expression_check_context(self.handle, context.handle)
Error.check(res)
def __str__(self):
t = self.type.value
symbol = ccs_expression_symbols[t]
prec = ccs_expression_precedence[t]
nds = ["({})".format(n) if ccs_expression_precedence[n.type.value] < prec else n.__str__() for n in self.nodes]
if len(nds) == 1:
return "{}{}".format(symbol, nds[0])
else:
return "{} {} {}".format(nds[0], symbol, nds[1])
class Literal(Expression):
def __init__(self, handle = None, retain = False, value = None):
if handle is None:
handle = ccs_expression()
pv = ccs_datum(value)
v = ccs_datum_fix()
v.value = pv._value.i
v.type = pv.type
res = ccs_create_literal(v, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def value(self):
if hasattr(self, "_value"):
return self._value
v = ccs_datum()
res = ccs_literal_get_value(self.handle, ct.byref(v))
Error.check(res)
self._value = v.value
return self._value
def __str__(self):
v = self.value
if isinstance(v, str):
return repr(v)
else:
return "{}".format(v)
class Variable(Expression):
def __init__(self, handle = None, retain = False, hyperparameter = None):
if handle is None:
handle = ccs_expression()
res = ccs_create_variable(hyperparameter.handle, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def hyperparameter(self):
if hasattr(self, "_hyperparameter"):
return self._hyperparameter
v = ccs_hyperparameter()
res = ccs_variable_get_hyperparameter(self.handle, ct.byref(v))
Error.check(res)
self._hyperparameter = Hyperparameter.from_handle(v)
return self._hyperparameter
def __str__(self):
return self.hyperparameter.name
class List(Expression):
def __init__(self, handle = None, retain = False, values = []):
if handle is None:
super().__init__(t = ccs_expression_type.LIST, nodes = values)
else:
super().__init__(handle = handle, retain = retain)
def eval(self, index, context = None, values = None):
if context and values:
count = context.num_hyperparameters
if count != len(values):
raise Error(ccs_error.INVALID_VALUE)
v = (ccs_datum * count)()
for i in range(count):
v[i].value = values[i]
values = v
context = context.handle
elif context or values:
raise Error(ccs_error.INVALID_VALUE)
v = ccs_datum()
res = ccs_expression_list_eval_node(self.handle, context, values, index, ct.byref(v))
Error.check(res)
return v.value
def __str__(self):
return "[ {} ]".format(", ".join(map(str, self.nodes)))
import ctypes as ct
from . import libcconfigspace
from .base import Object, Error, ccs_error, CEnumeration, _ccs_get_function, ccs_hyperparameter, ccs_datum, ccs_distribution, ccs_rng, ccs_int, ccs_data_type, ccs_bool, ccs_numeric_type, ccs_numeric, _ccs_get_id
from .base import Object, Error, ccs_error, CEnumeration, _ccs_get_function, ccs_hyperparameter, ccs_datum, ccs_datum_fix, ccs_distribution, ccs_rng, ccs_int, ccs_data_type, ccs_bool, ccs_numeric_type, ccs_numeric, _ccs_get_id
from .rng import ccs_default_rng
from .distribution import Distribution
......@@ -12,10 +12,6 @@ class ccs_hyperparameter_type(CEnumeration):
'DISCRETE'
]
class ccs_datum_fix(ct.Structure):
_fields_ = [('value', ccs_int),
('type', ccs_data_type)]
ccs_hyperparameter_get_type = _ccs_get_function("ccs_hyperparameter_get_type", [ccs_hyperparameter, ct.POINTER(ccs_hyperparameter_type)])
ccs_hyperparameter_get_default_value = _ccs_get_function("ccs_hyperparameter_get_default_value", [ccs_hyperparameter, ct.POINTER(ccs_datum)])
ccs_hyperparameter_get_name = _ccs_get_function("ccs_hyperparameter_get_name", [ccs_hyperparameter, ct.POINTER(ct.c_char_p)])
......@@ -113,7 +109,7 @@ class Hyperparameter(Object):
def check_values(self, values):
sz = len(values)
v = (ccs_datum * sz)()
for i in range(len(values)):
for i in range(sz):
v[i].value = values[i]
b = (ccs_bool * sz)()
res = ccs_hyperparameter_check_values(self.handle, sz, v, b)
......@@ -243,4 +239,106 @@ class NumericalHyperparameter(Hyperparameter):
raise Error(ccs_error.INVALID_VALUE)
return self._quantization
ccs_create_categorical_hyperparameter = _ccs_get_function("ccs_create_categorical_hyperparameter", [ct.c_char_p, ct.c_size_t, ct.POINTER(ccs_datum), ct.c_size_t, ct.c_void_p, ct.POINTER(ccs_hyperparameter)])
ccs_categorical_hyperparameter_get_values = _ccs_get_function("ccs_categorical_hyperparameter_get_values", [ccs_hyperparameter, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
class CategoricalHyperparameter(Hyperparameter):
def __init__(self, handle = None, retain = False, name = None, values = [], default_index = 0, user_data = None):
if handle is None:
if name is None:
name = NumericalHyperparameter.default_name()
sz = len(values)
handle = ccs_hyperparameter()
v = (ccs_datum*sz)()
for i in range(sz):
v[i].value = values[i]
res = ccs_create_categorical_hyperparameter(str.encode(name), sz, v, default_index, user_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def values(self):
sz = ct.c_size_t()
res = ccs_categorical_hyperparameter_get_values(self.handle, 0, None, ct.byref(sz))
Error.check(res)
v = (ccs_datum*sz.value)()
res = ccs_categorical_hyperparameter_get_values(self.handle, sz, v, None)
Error.check(res)
return [x.value for x in v]
ccs_create_ordinal_hyperparameter = _ccs_get_function("ccs_create_ordinal_hyperparameter", [ct.c_char_p, ct.c_size_t, ct.POINTER(ccs_datum), ct.c_size_t, ct.c_void_p, ct.POINTER(ccs_hyperparameter)])
ccs_ordinal_hyperparameter_compare_values = _ccs_get_function("ccs_ordinal_hyperparameter_compare_values", [ccs_hyperparameter, ccs_datum_fix, ccs_datum_fix, ct.POINTER(ccs_int)])
ccs_ordinal_hyperparameter_get_values = _ccs_get_function("ccs_ordinal_hyperparameter_get_values", [ccs_hyperparameter, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
class OrdinalHyperparameter(Hyperparameter):
def __init__(self, handle = None, retain = False, name = None, values = [], default_index = 0, user_data = None):
if handle is None:
if name is None:
name = NumericalHyperparameter.default_name()
sz = len(values)
handle = ccs_hyperparameter()
v = (ccs_datum*sz)()
for i in range(sz):
v[i].value = values[i]
res = ccs_create_ordinal_hyperparameter(str.encode(name), sz, v, default_index, user_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def values(self):
sz = ct.c_size_t()
res = ccs_ordinal_hyperparameter_get_values(self.handle, 0, None, ct.byref(sz))
Error.check(res)
v = (ccs_datum*sz.value)()
res = ccs_ordinal_hyperparameter_get_values(self.handle, sz, v, None)
Error.check(res)
return [x.value for x in v]
def compare(self, value1, value2):
pv1 = ccs_datum(value1)
pv2 = ccs_datum(value2)
v1 = ccs_datum_fix()
v2 = ccs_datum_fix()
v1.value = pv1._value.i
v1.type = pv1.type
v2.value = pv2._value.i
v2.type = pv2.type
c = ccs_int()
res = ccs_ordinal_hyperparameter_compare_values(self.handle, v1, v2, ct.byref(c))
Error.check(res)
return c.value
ccs_create_discrete_hyperparameter = _ccs_get_function("ccs_create_discrete_hyperparameter", [ct.c_char_p, ct.c_size_t, ct.POINTER(ccs_datum), ct.c_size_t, ct.c_void_p, ct.POINTER(ccs_hyperparameter)])
ccs_discrete_hyperparameter_get_values = _ccs_get_function("ccs_discrete_hyperparameter_get_values", [ccs_hyperparameter, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
class DiscreteHyperparameter(Hyperparameter):
def __init__(self, handle = None, retain = False, name = None, values = [], default_index = 0, user_data = None):
if handle is None:
if name is None:
name = NumericalHyperparameter.default_name()
sz = len(values)
handle = ccs_hyperparameter()
v = (ccs_datum*sz)()
for i in range(sz):
v[i].value = values[i]
res = ccs_create_discrete_hyperparameter(str.encode(name), sz, v, default_index, user_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def values(self):
sz = ct.c_size_t()
res = ccs_discrete_hyperparameter_get_values(self.handle, 0, None, ct.byref(sz))
Error.check(res)
v = (ccs_datum*sz.value)()
res = ccs_discrete_hyperparameter_get_values(self.handle, sz, v, None)
Error.check(res)
return [x.value for x in v]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment