Commit 14a5d02c authored by Brice Videau's avatar Brice Videau
Browse files

Added interval and distribution to Python bindings.

parent d09a505c
......@@ -8,3 +8,5 @@ else:
from .base import *
from .rng import *
from .interval import *
from .distribution import *
......@@ -4,12 +4,15 @@ from . import libcconfigspace
ccs_init = libcconfigspace.ccs_init
ccs_init.restype = ct.c_int
class Version(ct.Structure):
class ccs_version(ct.Structure):
_fields_ = [("revision", ct.c_ushort),
("patch", ct.c_ushort),
("minor", ct.c_ushort),
("major", ct.c_ushort)]
def __str__(self):
return "{}.{}.{}.{}".format(self.major, self.minor, self.patch, self.revision)
# Base types
ccs_float = ct.c_double
ccs_int = ct.c_longlong
......@@ -33,6 +36,12 @@ ccs_tuner = ccs_object
ccs_false = 0
ccs_true = 1
def _ccs_get_function(method, argtypes = [], restype = ccs_result):
res = getattr(libcconfigspace, method)
res.restype = restype
res.argtypes = argtypes
return res
# https://www.python-course.eu/python3_metaclasses.php
class Singleton(type):
_instances = {}
......@@ -82,11 +91,14 @@ class CEnumerationType(type(ct.c_int)):
class CEnumeration(ct.c_int, metaclass=CEnumerationType):
_members_ = {}
def __init__(self, value):
if value in self._reverse_members_:
self.name = self._reverse_members_[value]
ct.c_int.__init__(self, value)
@property
def name(self):
if self.value in self._reverse_members_:
return self._reverse_members_[self.value]
else:
raise ValueError("No enumeration member with value %r" % value)
ct.c_int.__init__(self, value)
@classmethod
def from_param(cls, param):
......@@ -139,10 +151,6 @@ class CEnumerationType64(type(ct.c_longlong)):
class CEnumeration64(ct.c_longlong, metaclass=CEnumerationType64):
_members_ = {}
def __init__(self, value):
if value in self._reverse_members_:
self._name = self._reverse_members_[value]
else:
raise ValueError("No enumeration member with value %r" % value)
ct.c_longlong.__init__(self, value)
def __repr__(self):
......@@ -218,10 +226,14 @@ class ccs_numeric_type(CEnumeration64):
('NUM_INTEGER', ccs_data_type.INTEGER),
('NUM_FLOAT', ccs_data_type.FLOAT) ]
class Numeric(ct.Union):
class ccs_numeric(ct.Union):
_fields_ = [('f', ccs_float),
('i', ccs_int)]
def __init__(self, v = 0):
super().__init__()
self.set_value(v)
def get_value(self, t):
if t == ccs_numeric_type.NUM_INTEGER:
return self.f
......@@ -238,14 +250,14 @@ class Numeric(ct.Union):
else:
raise Error(ccs_error.INVALID_VALUE)
class Value(ct.Union):
class ccs_value(ct.Union):
_fields_ = [('f', ccs_float),
('i', ccs_int),
('s', ct.c_char_p),
('o', ccs_object)]
class Datum(ct.Structure):
_fields_ = [('_value', Value),
class ccs_datum(ct.Structure):
_fields_ = [('_value', ccs_value),
('type', ccs_data_type)]
def __init__(self, v = None):
......@@ -312,24 +324,11 @@ class Error(Exception):
if err < 0:
raise cls(ccs_error(-err))
ccs_get_version = libcconfigspace.ccs_get_version
ccs_get_version.restype = Version
ccs_retain_object = libcconfigspace.ccs_retain_object
ccs_retain_object.restype = ccs_result
ccs_retain_object.argtypes = [ccs_object]
ccs_release_object = libcconfigspace.ccs_release_object
ccs_release_object.restype = ccs_result
ccs_release_object.argtypes = [ccs_object]
ccs_object_get_type = libcconfigspace.ccs_object_get_type
ccs_object_get_type.restype = ccs_result
ccs_object_get_type.argtypes = [ccs_object, ct.POINTER(ccs_object_type)]
ccs_object_get_refcount = libcconfigspace.ccs_object_get_refcount
ccs_object_get_refcount.restype = ccs_result
ccs_object_get_refcount.argtypes = [ccs_object, ct.POINTER(ct.c_int)]
ccs_get_version = _ccs_get_function("ccs_get_version", restype = ccs_version)
ccs_retain_object = _ccs_get_function("ccs_retain_object", [ccs_object])
ccs_release_object = _ccs_get_function("ccs_release_object", [ccs_object])
ccs_object_get_type = _ccs_get_function("ccs_object_get_type", [ccs_object, ct.POINTER(ccs_object_type)])
ccs_object_get_refcount = _ccs_get_function("ccs_object_get_refcount", [ccs_object, ct.POINTER(ct.c_int)])
class Object:
def __init__(self, handle, retain = False, auto_release = True):
......
import ctypes as ct
from . import libcconfigspace
from .base import Object, Error, ccs_error, ccs_int, ccs_float, ccs_bool, ccs_result, ccs_rng, ccs_distribution, ccs_numeric_type, ccs_numeric, CEnumeration, NUM_FLOAT, NUM_INTEGER, _ccs_get_function
from .interval import ccs_interval
class ccs_distribution_type(CEnumeration):
_members_ = [
('UNIFORM', 0),
'NORMAL',
'ROULETTE' ]
class ccs_scale_type(CEnumeration):
_members_ = [
('LINEAR', 0),
'LOGARITHMIC' ]
ccs_distribution_get_type = _ccs_get_function("ccs_distribution_get_type", [ccs_distribution, ct.POINTER(ccs_distribution_type)])
ccs_distribution_get_data_type = _ccs_get_function("ccs_distribution_get_data_type", [ccs_distribution, ct.POINTER(ccs_numeric_type)])
ccs_distribution_get_dimension = _ccs_get_function("ccs_distribution_get_dimension", [ccs_distribution, ct.POINTER(ct.c_size_t)])
ccs_distribution_get_scale_type = _ccs_get_function("ccs_distribution_get_scale_type", [ccs_distribution, ct.POINTER(ccs_scale_type)])
ccs_distribution_get_quantization = _ccs_get_function("ccs_distribution_get_quantization", [ccs_distribution, ct.POINTER(ccs_numeric)])
ccs_distribution_get_bounds = _ccs_get_function("ccs_distribution_get_bounds", [ccs_distribution, ct.POINTER(ccs_interval)])
ccs_distribution_check_oversampling = _ccs_get_function("ccs_distribution_check_oversampling", [ccs_distribution, ct.POINTER(ccs_interval), ct.POINTER(ccs_bool)])
ccs_distribution_sample = _ccs_get_function("ccs_distribution_sample", [ccs_distribution, ccs_rng, ct.POINTER(ccs_numeric)])
ccs_distribution_samples = _ccs_get_function("ccs_distribution_samples", [ccs_distribution, ccs_rng, ct.c_size_t, ct.POINTER(ccs_numeric)])
class Distribution(Object):
@classmethod
def from_handle(cls, handle):
v = ccs_distribution_type(0)
res = ccs_distribution_get_type(handle, ct.byref(v))
Error.check(res)
v = v.value
if v == ccs_distribution_type.UNIFORM:
return UniformDistribution(handle = handle, retain = True)
elif v == ccs_distribution_type.NORMAL:
return NormalDistribution(handle = handle, retain = True)
elif v == ccs_distribution_type.ROULETTE:
return RouletteDistribution(handle = handle, retain = True)
else:
raise Error(ccs_error.INVALID_DISTRIBUTION)
@property
def type(self):
if hasattr(self, "_type"):
return self._type
v = ccs_distribution_type(0)
res = ccs_distribution_get_type(self.handle, ct.byref(v))
Error.check(res)
self._type = v
return v
@property
def data_type(self):
if hasattr(self, "_data_type"):
return self._data_type
v = ccs_numeric_type(0)
res = ccs_distribution_get_data_type(self.handle, ct.byref(v))
Error.check(res)
self._data_type = v
return v
@property
def dimension(self):
if hasattr(self, "_dimension"):
return self._dimension
v = ct.c_size_t()
res = ccs_distribution_get_dimension(self.handle, ct.byref(v))
Error.check(res)
self._dimension = v.value
return v.value
@property
def scale_type(self):
if hasattr(self, "_scale_type"):
return self._scale_type
v = ccs_scale_type(0)
res = ccs_distribution_get_scale_type(self.handle, ct.byref(v))
Error.check(res)
self._scale_type = v
return v
@property
def quantization(self):
if hasattr(self, "_quantization"):
return self._quantization
v = ccs_numeric(0)
res = ccs_distribution_get_quantization(self.handle, ct.byref(v))
Error.check(res)
t = self.data_type.value
if t == ccs_numeric_type.NUM_INTEGER:
self._quantization = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
self._quantization = v.f
else:
raise Error(ccs_error.INVALID_VALUE)
return self._quantization
@property
def bounds(self):
if hasattr(self, "_bounds"):
return self._bounds
v = ccs_interval()
res = ccs_distribution_get_bounds(self.handle, ct.byref(v))
Error.check(res)
self._bounds = v
return v
def oversampling(self, interval):
v = ccs_bool()
res = ccs_distribution_check_oversampling(self.handle, ct.byref(interval), ct.byref(v))
Error.check(res)
return False if v.value == ccs_false else True
def sample(self, rng):
v = ccs_numeric()
res = ccs_distribution_sample(self.handle, rng.handle, ct.byref(v))
Error.check(res)
t = self.data_type.value
if t == ccs_numeric_type.NUM_INTEGER:
return v.i
elif t == ccs_numeric_type.NUM_FLOAT:
return v.f
else:
raise Error(ccs_error.INVALID_VALUE)
def samples(self, rng, count):
t = self.data_type.value
if t == ccs_numeric_type.NUM_INTEGER:
v = (ccs_int * count)()
elif t == ccs_numeric_type.NUM_FLOAT:
v = (ccs_float * count)()
else:
raise Error(ccs_error.INVALID_VALUE)
res = ccs_distribution_samples(self.handle, rng.handle, count, ct.cast(v, ct.POINTER(ccs_numeric)))
Error.check(res)
return list(v)
ccs_create_uniform_distribution = _ccs_get_function("ccs_create_uniform_distribution", [ccs_numeric_type, ccs_int, ccs_int, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_uniform_int_distribution = _ccs_get_function("ccs_create_uniform_int_distribution", [ccs_int, ccs_int, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_uniform_float_distribution = _ccs_get_function("ccs_create_uniform_float_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_float, ct.POINTER(ccs_distribution)])
ccs_uniform_distribution_get_parameters = _ccs_get_function("ccs_uniform_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_numeric), ct.POINTER(ccs_numeric)])
class UniformDistribution(Distribution):
def __init__(self, handle = None, retain = False, data_type = NUM_FLOAT, lower = 0.0, upper = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0):
if handle is None:
handle = ccs_distribution(0)
if data_type == NUM_FLOAT:
res = ccs_create_uniform_float_distribution(lower, upper, scale, quantization, ct.byref(handle))
elif data_type == NUM_INTEGER:
res = ccs_create_uniform_int_distribution(lower, upper, scale, quantization, ct.byref(handle))
else:
raise Error(ccs_error.INVALID_VALUE)
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@classmethod
def int(cls, lower, upper, scale = ccs_scale_type.LINEAR, quantization = 0):
return cls(data_type = NUM_INTEGER, lower = lower, upper = upper, scale = scale, quantization = quantization)
@classmethod
def float(cls, lower, upper, scale = ccs_scale_type.LINEAR, quantization = 0.0):
return cls(data_type = NUM_FLOAT, lower = lower, upper = upper, scale = scale, quantization = quantization)
@property
def lower(self):
if hasattr(self, "_lower"):
return self._lower
v = ccs_numeric()
res = ccs_uniform_distribution_get_parameters(self.handle, ct.byref(v), None)
Error.check(res)
t = self.data_type.value
if t == ccs_numeric_type.NUM_INTEGER:
self._lower = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
self._lower = v.f
else:
raise Error(ccs_error.INVALID_VALUE)
return self._lower
@property
def upper(self):
if hasattr(self, "_upper"):
return self._upper
v = ccs_numeric()
res = ccs_uniform_distribution_get_parameters(self.handle, None, ct.byref(v))
Error.check(res)
t = self.data_type.value
if t == ccs_numeric_type.NUM_INTEGER:
self._upper = v.i
elif t == ccs_numeric_type.NUM_FLOAT:
self._upper = v.f
else:
raise Error(ccs_error.INVALID_VALUE)
return self._upper
ccs_create_normal_distribution = _ccs_get_function("ccs_create_normal_distribution", [ccs_numeric_type, ccs_float, ccs_float, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_normal_int_distribution = _ccs_get_function("ccs_create_normal_int_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_int, ct.POINTER(ccs_distribution)])
ccs_create_normal_float_distribution = _ccs_get_function("ccs_create_normal_float_distribution", [ccs_float, ccs_float, ccs_scale_type, ccs_float, ct.POINTER(ccs_distribution)])
ccs_normal_distribution_get_parameters = _ccs_get_function("ccs_normal_distribution_get_parameters", [ccs_distribution, ct.POINTER(ccs_float), ct.POINTER(ccs_float)])
class NormalDistribution(Distribution):
def __init__(self, handle = None, retain = False, data_type = NUM_FLOAT, mu = 0.0, sigma = 1.0, scale = ccs_scale_type.LINEAR, quantization = 0.0):
if handle is None:
handle = ccs_distribution(0)
if data_type == NUM_FLOAT:
res = ccs_create_normal_float_distribution(mu, sigma, scale, quantization, ct.byref(handle))
elif data_type == NUM_INTEGER:
res = ccs_create_normal_int_distribution(mu, sigma, scale, quantization, ct.byref(handle))
else:
raise Error(ccs_error.INVALID_VALUE)
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def mu(self):
if hasattr(self, "_mu"):
return self._mu
v = ccs_float()
res = ccs_normal_distribution_get_parameters(self.handle, ct.byref(v), None)
Error.check(res)
t = self.data_type.value
self._mu = v.value
return self._mu
@property
def sigma(self):
if hasattr(self, "_sigma"):
return self._sigma
v = ccs_float()
res = ccs_normal_distribution_get_parameters(self.handle, None, ct.byref(v))
Error.check(res)
t = self.data_type.value
self._sigma = v.value
return self._sigma
ccs_create_roulette_distribution = _ccs_get_function("ccs_create_roulette_distribution", [ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ccs_distribution)])
ccs_roulette_distribution_get_num_areas = _ccs_get_function("ccs_roulette_distribution_get_num_areas", [ccs_distribution, ct.POINTER(ct.c_size_t)])
ccs_roulette_distribution_get_areas = _ccs_get_function("ccs_roulette_distribution_get_areas", [ccs_distribution, ct.c_size_t, ct.POINTER(ccs_float), ct.POINTER(ct.c_size_t)])
class RouletteDistribution(Distribution):
def __init__(self, handle = None, retain = False, areas = []):
if handle is None:
handle = ccs_distribution(0)
v = (ccs_float * len(areas))(*areas)
res = ccs_create_roulette_distribution(len(areas), v, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
else:
super().__init__(handle = handle, retain = retain)
@property
def num_areas(self):
if hasattr(self, "_num_areas"):
return self._num_areas
v = ct.c_size_t()
res = ccs_roulette_distribution_get_num_areas(self.handle, ct.byref(v))
Error.check(res)
self._num_areas = v.value
return self._num_areas
@property
def areas(self):
if hasattr(self, "_areas"):
return self._areas
v = (ccs_float * self.num_areas)()
res = ccs_roulette_distribution_get_areas(self.handle, self.num_areas, v, None)
Error.check(res)
self._areas = list(v)
return self._areas
import ctypes as ct
from . import libcconfigspace
from .base import Error, ccs_error, ccs_numeric_type, ccs_numeric, ccs_float, ccs_int, ccs_result, ccs_bool, ccs_false, ccs_true, _ccs_get_function
class ccs_interval(ct.Structure):
_fields_ = [('type', ccs_numeric_type),
('_lower', ccs_numeric),
('_upper', ccs_numeric),
('_lower_included', ccs_bool),
('_upper_included', ccs_bool)]
@property
def lower(self):
if self.type.value == ccs_numeric_type.NUM_INTEGER:
return self._lower.i
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
return self._lower.f
else:
raise Error(ccs_error.INVALID_VALUE)
@lower.setter
def lower(self, value):
if self.type.value == ccs_numeric_type.NUM_INTEGER:
self._lower.i = value
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
self._lower.f = value
else:
raise Error(ccs_error.INVALID_VALUE)
@property
def upper(self):
if self.type.value == ccs_numeric_type.NUM_INTEGER:
return self._upper.i
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
return self._upper.f
else:
raise Error(ccs_error.INVALID_VALUE)
@upper.setter
def upper(self, value):
if self.type.value == ccs_numeric_type.NUM_INTEGER:
self._upper.i = value
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
self._upper.f = value
else:
raise Error(ccs_error.INVALID_VALUE)
@property
def lower_included(self):
return False if self._lower_included == ccs_false else True
@lower_included.setter
def lower_included(self, value):
if value:
self._lower_included = ccs_true
else:
self._lower_included = ccs_false
@property
def upper_included(self):
return False if self._upper_included == ccs_false else True
@upper_included.setter
def upper_included(self, value):
if value:
self._upper_included = ccs_true
else:
self._upper_included = ccs_false
def empty(self):
v = ccs_bool(0)
res = ccs_interval_empty(ct.byref(self), ct.byref(v))
Error.check(res)
return False if v.value == ccs_false else True
def intersect(self, other):
v = ccs_interval()
res = ccs_interval_intersect(ct.byref(self), ct.byref(other), ct.byref(v))
Error.check(res)
return v
def __eq__(self, other):
v = ccs_bool(0)
res = ccs_interval_equal(ct.byref(self), ct.byref(other), ct.byref(v))
Error.check(res)
return False if v.value == ccs_false else True
# this works around a subtle bug in union support...
def include(self, value):
v = ccs_numeric()
if self.type.value == ccs_numeric_type.NUM_INTEGER:
v.i = value
elif self.type.value == ccs_numeric_type.NUM_FLOAT:
v.f = value
res = ccs_interval_include(ct.byref(self), v.i)
return False if res == ccs_false else True
def __str__(self):
s = ""
s += "[" if self.lower_included else "("
s += "{}, {}".format(self.lower, self.upper)
s += "]" if self.upper_included else ")"
return s
ccs_interval_empty = _ccs_get_function("ccs_interval_empty", [ct.POINTER(ccs_interval), ct.POINTER(ccs_bool)])
ccs_interval_intersect = _ccs_get_function("ccs_interval_intersect", [ct.POINTER(ccs_interval), ct.POINTER(ccs_interval), ct.POINTER(ccs_interval)])
ccs_interval_equal = _ccs_get_function("ccs_interval_equal", [ct.POINTER(ccs_interval), ct.POINTER(ccs_interval), ct.POINTER(ccs_bool)])
ccs_interval_include = _ccs_get_function("ccs_interval_include", [ct.POINTER(ccs_interval), ccs_int], ccs_bool)
import ctypes as ct
from . import libcconfigspace
from .base import Object, Error, ccs_float, ccs_result, ccs_rng
from .base import Object, Error, ccs_float, ccs_result, ccs_rng, _ccs_get_function
ccs_rng_create = libcconfigspace.ccs_rng_create
ccs_rng_create.restype = ccs_result
ccs_rng_create.argtypes = [ct.POINTER(ccs_rng)]
ccs_rng_set_seed = libcconfigspace.ccs_rng_set_seed
ccs_rng_set_seed.restype = ccs_result
ccs_rng_set_seed.argtypes = [ccs_rng, ct.c_ulong]
ccs_rng_get = libcconfigspace.ccs_rng_get
ccs_rng_get.restype = ccs_result
ccs_rng_get.argtypes = [ccs_rng, ct.POINTER(ct.c_ulong)]
ccs_rng_uniform = libcconfigspace.ccs_rng_uniform
ccs_rng_uniform.restype = ccs_result
ccs_rng_uniform.argtypes = [ccs_rng, ct.POINTER(ccs_float)]
ccs_rng_min = libcconfigspace.ccs_rng_min
ccs_rng_min.restype = ccs_result
ccs_rng_min.argtypes = [ccs_rng, ct.POINTER(ct.c_ulong)]
ccs_rng_max = libcconfigspace.ccs_rng_max
ccs_rng_max.restype = ccs_result
ccs_rng_max.argtypes = [ccs_rng, ct.POINTER(ct.c_ulong)]
ccs_rng_create = _ccs_get_function("ccs_rng_create", [ct.POINTER(ccs_rng)])
ccs_rng_set_seed = _ccs_get_function("ccs_rng_set_seed", [ccs_rng, ct.c_ulong])
ccs_rng_get = _ccs_get_function("ccs_rng_get", [ccs_rng, ct.POINTER(ct.c_ulong)])
ccs_rng_uniform = _ccs_get_function("ccs_rng_uniform", [ccs_rng, ct.POINTER(ccs_float)])
ccs_rng_min = _ccs_get_function("ccs_rng_min", [ccs_rng, ct.POINTER(ct.c_ulong)])
ccs_rng_max = _ccs_get_function("ccs_rng_max", [ccs_rng, ct.POINTER(ct.c_ulong)])
class Rng(Object):
def __init__(self, handle = None, retain = False):
......@@ -69,3 +52,4 @@ class Rng(Object):
Error.check(res)
return v.value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment