Commit c7c4e0b3 authored by Brice Videau's avatar Brice Videau
Browse files

Improved Python bindings.

parent 0d7c2c9c
import ctypes as ct
from . import libcconfigspace
from enum import IntEnum, auto
ccs_init = libcconfigspace.ccs_init
ccs_init.restype = ct.c_int
......@@ -11,16 +10,65 @@ class Version(ct.Structure):
("minor", ct.c_ushort),
("major", ct.c_ushort)]
# http://code.activestate.com/recipes/576415/
# Base types
ccs_float = ct.c_double
ccs_int = ct.c_longlong
ccs_bool = ct.c_int
ccs_result = ct.c_int
ccs_hash = ct.c_uint
ccs_object = ct.c_void_p
# Objects
ccs_rng = ccs_object
ccs_distribution = ccs_object
ccs_hyperparameter = ccs_object
ccs_expression = ccs_object
ccs_context = ccs_object
ccs_configuration_space = ccs_object
ccs_configuration = ccs_object
ccs_objective_space = ccs_object
ccs_evaluation = ccs_object
ccs_tuner = ccs_object
ccs_false = 0
ccs_true = 1
# https://www.python-course.eu/python3_metaclasses.php
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Inactive(metaclass=Singleton):
pass
ccs_inactive = Inactive()
# derived and adapted from http://code.activestate.com/recipes/576415/
class CEnumerationType(type(ct.c_int)):
def __new__(metacls, name, bases, dict):
if not "_members_" in dict:
raise ValueError("CEnumeration must define a _members_ attribute")
last = -1
if isinstance(dict["_members_"], list):
_members_ = {}
for key,value in dict.items():
if not key.startswith("_"):
_members_[key] = value
for item in dict["_members_"]:
if isinstance(item, tuple):
(i, v) = item
_members_[i] = v
last = v
else:
last += 1
_members_[item] = last
dict["_members_"] = _members_
cls = type(c_uint).__new__(metacls, name, bases, dict)
_reverse_members_ = {}
for key,value in dict["_members_"].items():
dict[key] = value
_reverse_members_[value] = key
dict["_reverse_members_"] = _reverse_members_
cls = type(ct.c_int).__new__(metacls, name, bases, dict)
for key,value in cls._members_.items():
globals()[key] = value
return cls
......@@ -31,14 +79,11 @@ class CEnumerationType(type(ct.c_int)):
def __repr__(self):
return "<Enumeration %s>" % self.__name__
class CEnumeration(ct.c_int):
__metaclass__ = CEnumerationType
class CEnumeration(ct.c_int, metaclass=CEnumerationType):
_members_ = {}
def __init__(self, value):
for k,v in self._members_.items():
if v == value:
self.name = k
break
if value in self._reverse_members_:
self.name = self._reverse_members_[value]
else:
raise ValueError("No enumeration member with value %r" % value)
ct.c_int.__init__(self, value)
......@@ -59,112 +104,266 @@ class CEnumeration(ct.c_int):
def __str__(self):
return "%s.%s" % (self.__class__.__name__, self.name)
class ObjectType(CEnumeration):
_members_ = {
'RNG': 0,
'DISTRIBUTION': 1,
'HYPERPARAMETER': 2,
'EXPRESSION': 3,
'CONFIGURATION_SPACE': 4,
'CONFIGURATION': 5,
'OBJECTIVE_SPACE': 6,
'EVALUATION': 7,
'TUNER': 8 }
class CTypesIntEnum(IntEnum):
class CEnumerationType64(type(ct.c_longlong)):
def __new__(metacls, name, bases, dict):
if not "_members_" in dict:
raise ValueError("CEnumeration must define a _members_ attribute")
last = -1
if isinstance(dict["_members_"], list):
_members_ = {}
for item in dict["_members_"]:
if isinstance(item, tuple):
(i, v) = item
_members_[i] = v
last = v
else:
last += 1
_members_[item] = last
dict["_members_"] = _members_
_reverse_members_ = {}
for key,value in dict["_members_"].items():
dict[key] = value
_reverse_members_[value] = key
dict["_reverse_members_"] = _reverse_members_
cls = type(ct.c_longlong).__new__(metacls, name, bases, dict)
for key,value in cls._members_.items():
globals()[key] = value
return cls
def __contains__(self, value):
return value in self._members_.values()
def __repr__(self):
return "<Enumeration %s>" % self.__name__
class CEnumeration64(ct.c_longlong, metaclass=CEnumerationType64):
_members_ = {}
def __init__(self, value):
if value in self._reverse_members_:
self._name = self._reverse_members_[value]
else:
raise ValueError("No enumeration member with value %r" % value)
ct.c_longlong.__init__(self, value)
def __repr__(self):
return "<member %s=%d of %r>" % (self.name, self.value, self.__class__)
def __str__(self):
return "%s.%s" % (self.__class__.__name__, self.name)
@property
def name(self):
if self.value in self._reverse_members_:
return self._reverse_members_[self.value]
else:
raise ValueError("No enumeration member with value %r" % value)
@classmethod
def from_param(cls, obj):
return ct.c_int(int(obj))
def _generate_next_value_(name, start, count, last_values):
if len(last_values) == 0:
return 0
return last_values[-1] + 1
#class ObjectType(CTypesIntEnum):
# RNG = auto()
# DISTRIBUTION = auto()
# HYPERPARAMETER = auto()
# EXPRESSION = auto()
# CONFIGURATION_SPACE = auto()
# CONFIGURATION = auto()
# OBJECTIVE_SPACE = auto()
# EVALUATION = auto()
# TUNER = auto()
class Error(CTypesIntEnum):
SUCCESS = auto()
INVALID_OBJECT = auto()
INVALID_VALUE = auto()
INVALID_TYPE = auto()
INVALID_SCALE = auto()
INVALID_DISTRIBUTION = auto()
INVALID_EXPRESSION = auto()
INVALID_HYPERPARAMETER = auto()
INVALID_CONFIGURATION = auto()
INVALID_NAME = auto()
INVALID_CONDITION = auto()
INVALID_TUNER = auto()
INVALID_GRAPH = auto()
TYPE_NOT_COMPARABLE = auto()
INVALID_BOUNDS = auto()
OUT_OF_BOUNDS = auto()
SAMPLING_UNSUCCESSFUL = auto()
INACTIVE_HYPERPARAMETER = auto()
OUT_OF_MEMORY = auto()
UNSUPPORTED_OPERATION = auto()
class CCSError(Exception):
def from_param(cls, param):
if isinstance(param, CEnumeration):
if param.__class__ != cls:
raise ValueError("Cannot mix enumeration members")
else:
return param
else:
return cls(param)
class ccs_object_type(CEnumeration):
_members_ = [
('RNG', 0),
'DISTRIBUTION',
'HYPERPARAMETER',
'EXPRESSION',
'CONFIGURATION_SPACE',
'CONFIGURATION',
'OBJECTIVE_SPACE',
'EVALUATION',
'TUNER' ]
class ccs_error(CEnumeration):
_members_ = [
('SUCCESS', 0),
'INVALID_OBJECT',
'INVALID_VALUE',
'INVALID_TYPE',
'INVALID_SCALE',
'INVALID_DISTRIBUTION',
'INVALID_EXPRESSION',
'INVALID_HYPERPARAMETER',
'INVALID_CONFIGURATION',
'INVALID_NAME',
'INVALID_CONDITION',
'INVALID_TUNER',
'INVALID_GRAPH',
'TYPE_NOT_COMPARABLE',
'INVALID_BOUNDS',
'OUT_OF_BOUNDS',
'SAMPLING_UNSUCCESSFUL',
'INACTIVE_HYPERPARAMETER',
'OUT_OF_MEMORY',
'UNSUPPORTED_OPERATION' ]
class ccs_data_type(CEnumeration64):
_members_ = [
('NONE', 0),
'INTEGER',
'FLOAT',
'BOOLEAN',
'STRING',
'INACTIVE',
'OBJECT' ]
class ccs_numeric_type(CEnumeration64):
_members_ = [
('NUM_INTEGER', ccs_data_type.INTEGER),
('NUM_FLOAT', ccs_data_type.FLOAT) ]
class Numeric(ct.Union):
_fields_ = [('f', ccs_float),
('i', ccs_int)]
def get_value(self, t):
if t == ccs_numeric_type.NUM_INTEGER:
return self.f
elif t == ccs_numeric_type.NUM_FLOAT:
return self.v
else:
raise Error(ccs_error.INVALID_VALUE)
def set_value(self, v):
if isinstance(v, int):
self.i = v
elif isinstance(v, float):
self.f = v
else:
raise Error(ccs_error.INVALID_VALUE)
class Value(ct.Union):
_fields_ = [('f', ccs_float),
('i', ccs_int),
('s', ct.c_char_p),
('o', ccs_object)]
class Datum(ct.Structure):
_fields_ = [('_value', Value),
('type', ccs_data_type)]
def __init__(self, v = None):
super().__init__()
self._string = None
self._object = None
self.value = v
@property
def value(self):
if self.type.value == ccs_data_type.NONE:
return None
elif self.type.value == ccs_data_type.INTEGER:
return self._value.i
elif self.type.value == ccs_data_type.FLOAT:
return self._value.f
elif self.type.value == ccs_data_type.BOOLEAN:
return False if self._value.i == ccs_false else True
elif self.type.value == ccs_data_type.STRING:
return self._value.s.decode()
elif self.type.value == ccs_data_type.INACTIVE:
return ccs_inactive
elif self.type.value == ccs_data_type.OBJECT:
return Object.from_handle(ct.c_void_p(self._value.o))
else:
raise Error(ccs_error.INVALID_VALUE)
@value.setter
def value(self, v):
self._string = None
self._object = None
if v is None:
self.type.value = ccs_data_type.NONE
self._value.i = 0
elif isinstance(v, bool):
self.type.value = ccs_data_type.BOOLEAN
self._value.i = 1 if v else 0
elif isinstance(v, int):
self.type.value = ccs_data_type.INTEGER
self._value.i = v
elif isinstance(v, float):
self.type.value = ccs_data_type.FLOAT
self._value.f = v
elif isinstance(v, str):
self.type.value = ccs_data_type.STRING
self._string = str.encode(v)
self._value.s = ct.c_char_p(self._string)
elif v is ccs_inactive:
self.type.value = ccs_data_type.INACTIVE
self._value.i = 0
elif isinstance(v, Object):
self.type.value = ccs_data_type.OBJECT
self_object = v
self._value.o = v.handle
else:
raise Error(ccs_error.INVALID_VALUE)
class Error(Exception):
def __init__(self, message):
self.message = message
@classmethod
def check(cls, err):
if err < 0:
raise cls(Error(-err))
raise cls(ccs_error(-err))
ccs_get_version = libcconfigspace.ccs_get_version
ccs_get_version.restype = Version
ccs_retain_object = libcconfigspace.ccs_retain_object
ccs_retain_object.restype = ct.c_int
ccs_retain_object.argtypes = [ct.c_void_p]
ccs_retain_object.restype = ccs_result
ccs_retain_object.argtypes = [ccs_object]
ccs_release_object = libcconfigspace.ccs_release_object
ccs_release_object.restype = ct.c_int
ccs_release_object.argtypes = [ct.c_void_p]
ccs_release_object.restype = ccs_result
ccs_release_object.argtypes = [ccs_object]
ccs_object_get_type = libcconfigspace.ccs_object_get_type
ccs_object_get_type.restype = ct.c_int
ccs_object_get_type.argtypes = [ct.c_void_p, ct.POINTER(ObjectType)]
ccs_object_get_type.restype = ccs_result
ccs_object_get_type.argtypes = [ccs_object, ct.POINTER(ccs_object_type)]
ccs_object_get_refcount = libcconfigspace.ccs_object_get_refcount
ccs_object_get_refcount.restype = ct.c_int
ccs_object_get_refcount.argtypes = [ct.c_void_p, ct.POINTER(ct.c_int)]
ccs_object_get_refcount.restype = ccs_result
ccs_object_get_refcount.argtypes = [ccs_object, ct.POINTER(ct.c_int)]
class Object:
def __init__(self, handle, retain = False, auto_release = True):
if handle is None:
raise CCSError(Error.INVALID_OBJECT)
raise Error(ccs_error.INVALID_OBJECT)
self.handle = handle
self.auto_release = auto_release
if retain:
res = ccs_retain_object(handle)
CCSError.check(res)
Error.check(res)
def __del__(self):
res = ccs_release_object(self.handle)
CCSError.check(res)
Error.check(res)
def object_type(self):
t = ObjectType(0)
t = ccs_object_type(0)
res = ccs_object_get_type(self.handle, ct.byref(t))
CCSError.check(res)
Error.check(res)
return t
def refcount(self):
c = ct.c_int(0)
res = ccs_object_get_refcount(self.handle, ct.byref(c))
CCSError.check(res)
Error.check(res)
return c.value
@classmethod
def from_handle(cls, h):
t = ccs_object_type(0)
res = ccs_object_get_type(h, ct.byref(t))
Error.check(res)
if t.value == ccs_object_type.RNG:
from .rng import Rng
return Rng.from_handle(h)
else:
raise Error(ccs_error.INVALID_OBJECT)
import ctypes as ct
from . import libcconfigspace
from .base import Object, CCSError
from .base import Object, Error, ccs_float, ccs_result, ccs_rng
ccs_rng_create = libcconfigspace.ccs_rng_create
ccs_rng_create.restype = ct.c_int
ccs_rng_create.argtypes = [ct.c_void_p]
ccs_rng_create.restype = ccs_result
ccs_rng_create.argtypes = [ct.POINTER(ccs_rng)]
ccs_rng_set_seed = libcconfigspace.ccs_rng_set_seed
ccs_rng_set_seed.restype = ct.c_int
ccs_rng_set_seed.argtypes = [ct.c_void_p, ct.c_ulong]
ccs_rng_set_seed.restype = ccs_result
ccs_rng_set_seed.argtypes = [ccs_rng, ct.c_ulong]
ccs_rng_get = libcconfigspace.ccs_rng_get
ccs_rng_get.restype = ct.c_int
ccs_rng_get.argtypes = [ct.c_void_p, ct.c_void_p]
#attach_function :ccs_rng_get, [:ccs_rng_t, :pointer], :ccs_result_t
#attach_function :ccs_rng_min, [:ccs_rng_t, :pointer], :ccs_result_t
#attach_function :ccs_rng_max, [:ccs_rng_t, :pointer], :ccs_result_t
ccs_rng_get.restype = ccs_result
ccs_rng_get.argtypes = [ccs_rng, ct.POINTER(ct.c_ulong)]
ccs_rng_uniform = libcconfigspace.ccs_rng_uniform
ccs_rng_uniform.restype = ccs_result
ccs_rng_uniform.argtypes = [ccs_rng, ct.POINTER(ccs_float)]
ccs_rng_min = libcconfigspace.ccs_rng_min
ccs_rng_min.restype = ccs_result
ccs_rng_min.argtypes = [ccs_rng, ct.POINTER(ct.c_ulong)]
ccs_rng_max = libcconfigspace.ccs_rng_max
ccs_rng_max.restype = ccs_result
ccs_rng_max.argtypes = [ccs_rng, ct.POINTER(ct.c_ulong)]
class Rng(Object):
def __init__(self, handle = None, retain = False):
if handle is None:
handle = ct.c_void_p(0)
handle = ccs_rng(0)
res = ccs_rng_create(ct.byref(handle))
CCSError.check(res)
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@classmethod
def from_handle(cls, handle):
cls(handle, retain = True)
return cls(handle, retain = True)
def set_seed(self, value):
res = ccs_rng_set_seed(self.handle, value)
CCSError.check(res)
Error.check(res)
return self
def get(self):
v = ct.c_ulong(0)
res = ccs_rng_get(self.handle, ct.byref(v))
Error.check(res)
return v.value
def uniform(self):
v = ccs_float(0.0)
res = ccs_rng_uniform(self.handle, ct.byref(v))
Error.check(res)
return v.value
def min(self):
v = ct.c_ulong(0)
res = ccs_rng_min(self.handle, ct.byref(v))
Error.check(res)
return v.value
def max(self):
v = ct.c_ulong(0)
res = ccs_rng_max(self.handle, ct.byref(v))
Error.check(res)
return v.value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment