Commit 02dfa188 authored by Brice Videau's avatar Brice Videau
Browse files

Updated features to support binding abstractions.

parent c8d1e2c7
......@@ -9,7 +9,7 @@ ccs_binding_set_value = _ccs_get_function("ccs_binding_set_value", [ccs_binding,
ccs_binding_get_values = _ccs_get_function("ccs_binding_get_values", [ccs_binding, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
ccs_binding_get_value_by_name = _ccs_get_function("ccs_binding_get_value_by_name", [ccs_binding, ct.c_char_p, ct.POINTER(ccs_datum)])
ccs_binding_hash = _ccs_get_function("ccs_binding_hash", [ccs_binding, ct.POINTER(ccs_hash)])
ccs_binding_cmp = _ccs_get_function("ccs_binding_cmp", [ccs_binding, ct.POINTER(ccs_int)])
ccs_binding_cmp = _ccs_get_function("ccs_binding_cmp", [ccs_binding, ccs_binding, ct.POINTER(ccs_int)])
class Binding(Object):
......
import ctypes as ct
from .base import Object, Error, ccs_error, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_rng, ccs_distribution, ccs_expression, ccs_datum, ccs_hash, ccs_int
from .base import Object, Error, ccs_error, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_distribution, ccs_expression, ccs_datum, ccs_hash, ccs_int
from .context import Context
from .rng import Rng
from .hyperparameter import Hyperparameter
from .configuration_space import ConfigurationSpace
from .binding import Binding
......
......@@ -3,19 +3,13 @@ from .base import Object, Error, ccs_error, _ccs_get_function, ccs_context, ccs_
from .context import Context
from .hyperparameter import Hyperparameter
from .features_space import FeaturesSpace
from .binding import Binding
ccs_create_features = _ccs_get_function("ccs_create_features", [ccs_features_space, ct.c_size_t, ct.POINTER(ccs_datum), ct.c_void_p, ct.POINTER(ccs_features)])
ccs_features_get_features_space = _ccs_get_function("ccs_features_get_features_space", [ccs_features, ct.POINTER(ccs_features_space)])
ccs_features_get_user_data = _ccs_get_function("ccs_features_get_user_data", [ccs_features, ct.POINTER(ct.c_void_p)])
ccs_features_get_value = _ccs_get_function("ccs_features_get_value", [ccs_features, ct.c_size_t, ct.POINTER(ccs_datum)])
ccs_features_set_value = _ccs_get_function("ccs_features_set_value", [ccs_features, ct.c_size_t, ccs_datum])
ccs_features_get_values = _ccs_get_function("ccs_features_get_values", [ccs_features, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
ccs_features_get_value_by_name = _ccs_get_function("ccs_features_get_value_by_name", [ccs_features, ct.c_char_p, ct.POINTER(ccs_datum)])
ccs_features_check = _ccs_get_function("ccs_features_check", [ccs_features])
ccs_features_hash = _ccs_get_function("ccs_features_hash", [ccs_features, ct.POINTER(ccs_hash)])
ccs_features_cmp = _ccs_get_function("ccs_features_cmp", [ccs_features, ccs_features, ct.POINTER(ccs_int)])
class Features(Object):
class Features(Binding):
def __init__(self, handle = None, retain = False, auto_release = True,
features_space = None, values = None, user_data = None):
if handle is None:
......@@ -38,16 +32,6 @@ class Features(Object):
def from_handle(cls, handle, retain = True, auto_release = True):
return cls(handle = handle, retain = retain, auto_release = auto_release)
@property
def user_data(self):
if hasattr(self, "_user_data"):
return self._user_data
v = ct.c_void_p()
res = ccs_features_get_user_data(self.handle, ct.byref(v))
Error.check(res)
self._user_data = v
return v
@property
def features_space(self):
if hasattr(self, "_features_space"):
......@@ -58,109 +42,6 @@ class Features(Object):
self._features_space = ConfigurationSpace.from_handle(v)
return self._features_space
@property
def num_values(self):
if hasattr(self, "_num_values"):
return self._num_values
v = ct.c_size_t()
res = ccs_features_get_values(self.handle, 0, None, ct.byref(v))
Error.check(res)
self._num_values = v.value
return self._num_values
@property
def hash(self):
v = ccs_hash()
res = ccs_features_hash(self.handle, ct.byref(v))
Error.check(res)
return self.value
def set_value(self, hyperparameter, value):
if isinstance(hyperparameter, Hyperparameter):
hyperparameter = self.features_space.hyperparameter_index(hyperparameter)
elif isinstance(hyperparameter, str):
hyperparameter = self.features_space.hyperparameter_index_by_name(hyperparameter)
pv = ccs_datum(value)
v = ccs_datum_fix()
v.value = pv._value.i
v.type = pv.type
res = ccs_features_set_value(self.handle, hyperparameter, v)
Error.check(res)
def value(self, hyperparameter):
v = ccs_datum()
if isinstance(hyperparameter, Hyperparameter):
res = ccs_features_get_value(self.handle, self.features_space.hyperparameter_index(hyperparameter), ct.byref(v))
elif isinstance(hyperparameter, str):
res = ccs_features_get_value_by_name(self.handle, str.encode(hyperparameter), ct.byref(v))
else:
res = ccs_features_get_value(self.handle, hyperparameter, ct.byref(v))
Error.check(res)
return v.value
@property
def values(self):
sz = self.num_values
if sz == 0:
return []
v = (ccs_datum * sz)()
res = ccs_features_get_values(self.handle, sz, v, None)
Error.check(res)
return [x.value for x in v]
def check(self):
res = ccs_features_check(self.handle)
Error.check(res)
def cmp(self, other):
v = ccs_int()
res = ccs_features_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value
def __lt__(self, other):
v = ccs_int()
res = ccs_features_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value < 0
def __le__(self, other):
v = ccs_int()
res = ccs_features_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value <= 0
def __gt__(self, other):
v = ccs_int()
res = ccs_features_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value > 0
def __ge__(self, other):
v = ccs_int()
res = ccs_features_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value >= 0
def __eq__(self, other):
v = ccs_int()
res = ccs_features_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value == 0
def __ne__(self, other):
v = ccs_int()
res = ccs_features_cmp(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value != 0
def __hash__(self):
return self.hash
def asdict(self):
res = {}
hyperparameters = self.features_space.hyperparameters
values = self.values
for i in range(len(hyperparameters)):
res[hyperparameters[i].name] = values[i]
return res
......@@ -8,23 +8,20 @@ from .features_space import FeaturesSpace
from .features import Features
from .objective_space import ObjectiveSpace
from .evaluation import ccs_comparison
from .binding import Binding
ccs_create_features_evaluation = _ccs_get_function("ccs_create_features_evaluation", [ccs_objective_space, ccs_configuration, ccs_features, ccs_result, ct.c_size_t, ct.POINTER(ccs_datum), ct.c_void_p, ct.POINTER(ccs_features_evaluation)])
ccs_features_evaluation_get_objective_space = _ccs_get_function("ccs_features_evaluation_get_objective_space", [ccs_features_evaluation, ct.POINTER(ccs_objective_space)])
ccs_features_evaluation_get_configuration = _ccs_get_function("ccs_features_evaluation_get_configuration", [ccs_features_evaluation, ct.POINTER(ccs_configuration)])
ccs_features_evaluation_get_features = _ccs_get_function("ccs_features_evaluation_get_features", [ccs_features_evaluation, ct.POINTER(ccs_features)])
ccs_features_evaluation_get_user_data = _ccs_get_function("ccs_features_evaluation_get_user_data", [ccs_features_evaluation, ct.POINTER(ct.c_void_p)])
ccs_features_evaluation_get_error = _ccs_get_function("ccs_features_evaluation_get_error", [ccs_features_evaluation, ct.POINTER(ccs_result)])
ccs_features_evaluation_set_error = _ccs_get_function("ccs_features_evaluation_set_error", [ccs_features_evaluation, ccs_result])
ccs_features_evaluation_get_value = _ccs_get_function("ccs_features_evaluation_get_value", [ccs_features_evaluation, ct.c_size_t, ct.POINTER(ccs_datum)])
ccs_features_evaluation_set_value = _ccs_get_function("ccs_features_evaluation_set_value", [ccs_features_evaluation, ct.c_size_t, ccs_datum_fix])
ccs_features_evaluation_get_values = _ccs_get_function("ccs_features_evaluation_get_values", [ccs_features_evaluation, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
ccs_features_evaluation_get_value_by_name = _ccs_get_function("ccs_features_evaluation_get_value_by_name", [ccs_features_evaluation, ct.c_char_p, ccs_datum])
ccs_features_evaluation_get_objective_value = _ccs_get_function("ccs_features_evaluation_get_objective_value", [ccs_features_evaluation, ct.c_size_t, ct.POINTER(ccs_datum)])
ccs_features_evaluation_get_objective_values = _ccs_get_function("ccs_features_evaluation_get_objective_values", [ccs_features_evaluation, ct.c_size_t, ct.POINTER(ccs_datum), ct.POINTER(ct.c_size_t)])
ccs_features_evaluation_cmp = _ccs_get_function("ccs_features_evaluation_cmp", [ccs_features_evaluation, ccs_features_evaluation, ct.POINTER(ccs_comparison)])
ccs_features_evaluation_compare = _ccs_get_function("ccs_features_evaluation_compare", [ccs_features_evaluation, ccs_features_evaluation, ct.POINTER(ccs_comparison)])
ccs_features_evaluation_check = _ccs_get_function("ccs_features_evaluation_check", [ccs_features_evaluation])
class FeaturesEvaluation(Object):
class FeaturesEvaluation(Binding):
def __init__(self, handle = None, retain = False, auto_release = True,
objective_space = None, configuration = None, features = None, error = ccs_error.SUCCESS, values = None, user_data = None):
if handle is None:
......@@ -47,16 +44,6 @@ class FeaturesEvaluation(Object):
def from_handle(cls, handle, retain = True, auto_release = True):
return cls(handle = handle, retain = retain, auto_release = auto_release)
@property
def user_data(self):
if hasattr(self, "_user_data"):
return self._user_data
v = ct.c_void_p()
res = ccs_features_evaluation_get_user_data(self.handle, ct.byref(v))
Error.check(res)
self._user_data = v
return v
@property
def objective_space(self):
if hasattr(self, "_objective_space"):
......@@ -99,49 +86,6 @@ class FeaturesEvaluation(Object):
res = ccs_features_evaluation_set_error(self.handle, v)
Error.check(res)
def set_value(self, hyperparameter, value):
if isinstance(hyperparameter, Hyperparameter):
hyperparameter = self.objective_space.hyperparameter_index(hyperparameter)
elif isinstance(hyperparameter, str):
hyperparameter = self.objective_space.hyperparameter_index_by_name(hyperparameter)
pv = ccs_datum(value)
v = ccs_datum_fix()
v.value = pv._value.i
v.type = pv.type
res = ccs_features_evaluation_set_value(self.handle, hyperparameter, v)
Error.check(res)
def value(self, hyperparameter):
v = ccs_datum()
if isinstance(hyperparameter, Hyperparameter):
res = ccs_features_evaluation_get_value(self.handle, self.objective_space.hyperparameter_index(hyperparameter), ct.byref(v))
elif isinstance(hyperparameter, str):
res = ccs_features_evaluation_get_value_by_name(self.handle, str.encode(hyperparameter), ct.byref(v))
else:
res = ccs_features_evaluation_get_value(self.handle, hyperparameter, ct.byref(v))
Error.check(res)
return v.value
@property
def num_values(self):
if hasattr(self, "_num_values"):
return self._num_values
v = ct.c_size_t()
res = ccs_features_evaluation_get_values(self.handle, 0, None, ct.byref(v))
Error.check(res)
self._num_values = v.value
return self._num_values
@property
def values(self):
sz = self.num_values
if sz == 0:
return []
v = (ccs_datum * sz)()
res = ccs_features_evaluation_get_values(self.handle, sz, v, None)
Error.check(res)
return [x.value for x in v]
@property
def num_objective_values(self):
if hasattr(self, "_num_objective_values"):
......@@ -162,8 +106,12 @@ class FeaturesEvaluation(Object):
Error.check(res)
return [x.value for x in v]
def cmp(self, other):
def compare(self, other):
v = ccs_comparison(0)
res = ccs_features_evaluation_cmp(self.handle, other.handle, ct.byref(v))
res = ccs_features_evaluation_compare(self.handle, other.handle, ct.byref(v))
Error.check(res)
return v.value
def check(self):
res = res = ccs_features_evaluation(self.handle)
Error.check(res)
......@@ -84,7 +84,7 @@ class TestFeaturesTuner(unittest.TestCase):
if discard:
new_optimums.append(o)
else:
c = e.cmp(o)
c = e.compare(o)
if c == ccs.EQUIVALENT or c == ccs.WORSE:
discard = True
new_optimums.append(o)
......
......@@ -83,14 +83,14 @@ module CCS
end
def compare(other)
ptr = MemoryPointer::new(:ccs_objective_type_t)
ptr = MemoryPointer::new(:ccs_comparison_t)
res = CCS.ccs_evaluation_compare(@handle, other, ptr)
CCS.error_check(res)
ptr.read_ccs_comparison_t
end
def <=>(other)
ptr = MemoryPointer::new(:ccs_objective_type_t)
ptr = MemoryPointer::new(:ccs_comparison_t)
res = CCS.ccs_evaluation_compare(@handle, other, ptr)
CCS.error_check(res)
r = ptr.read_int32
......
module CCS
attach_function :ccs_create_features, [:ccs_features_space_t, :size_t, :pointer, :pointer, :pointer], :ccs_result_t
attach_function :ccs_features_get_features_space, [:ccs_features_t, :pointer], :ccs_result_t
attach_function :ccs_features_get_user_data, [:ccs_features_t, :pointer], :ccs_result_t
attach_function :ccs_features_get_value, [:ccs_features_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_features_set_value, [:ccs_features_t, :size_t, :ccs_datum_t], :ccs_result_t
attach_function :ccs_features_get_values, [:ccs_features_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_features_get_value_by_name, [:ccs_features_t, :string, :pointer], :ccs_result_t
attach_function :ccs_features_check, [:ccs_features_t], :ccs_result_t
attach_function :ccs_features_hash, [:ccs_features_t, :pointer], :ccs_result_t
attach_function :ccs_features_cmp, [:ccs_features_t, :ccs_features_t, :pointer], :ccs_result_t
class Features < Object
class Features < Binding
alias features_space context
include Comparable
add_property :user_data, :pointer, :ccs_features_get_user_data, memoize: true
add_property :hash, :ccs_hash_t, :ccs_features_hash, memoize: false
add_handle_property :features_space, :ccs_features_space_t, :ccs_features_get_features_space, memoize: true
def initialize(handle = nil, retain: false, auto_release: true,
features_space: nil, values: nil, user_data: nil)
......@@ -42,71 +32,12 @@ module CCS
self::new(handle, retain: retain, auto_release: auto_release)
end
def set_value(hyperparameter, value)
d = Datum.from_value(value)
case hyperparameter
when String, Symbol
hyperparameter = features_space.hyperparameter_index_by_name(hyperparameter)
when Hyperparameter
hyperparameter = features_space.hyperparameter_index(hyperparameter)
end
res = CCS.ccs_features_set_value(@handle, hyperparameter, d)
CCS.error_check(res)
self
end
def value(hyperparameter)
ptr = MemoryPointer::new(:ccs_datum_t)
case hyperparameter
when String
res = CCS.ccs_features_get_value_by_name(@handle, hyperparameter, ptr)
when Symbol
res = CCS.ccs_features_get_value_by_name(@handle, hyperparameter.inspect, ptr)
when Hyperparameter
res = CCS.ccs_features_get_value(@handle, features_space.hyperparameter_index(hyperparameter), ptr)
when Integer
res = CCS.ccs_features_get_value(@handle, hyperparameter, ptr)
else
raise CCSError, :CCS_INVALID_VALUE
end
CCS.error_check(res)
Datum::new(ptr).value
end
def num_values
@num_values ||= begin
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_features_get_values(@handle, 0, nil, ptr)
CCS.error_check(res)
ptr.read_size_t
end
end
def values
count = num_values
return [] if count == 0
values = MemoryPointer::new(:ccs_datum_t, count)
res = CCS.ccs_features_get_values(@handle, count, values, nil)
CCS.error_check(res)
count.times.collect { |i| Datum::new(values[i]).value }
end
def check
res = CCS.ccs_features_check(@handle)
CCS.error_check(res)
self
end
def <=>(other)
ptr = MemoryPointer::new(:int)
res = CCS.ccs_features_cmp(@handle, other, ptr)
CCS.error_check(res)
return ptr.read_int
end
def to_h
features_space.hyperparameters.collect(&:name).zip(values).to_h
end
end
end
module CCS
attach_function :ccs_create_features_evaluation, [:ccs_objective_space_t, :ccs_configuration_t, :ccs_features_t, :ccs_result_t, :size_t, :pointer, :pointer, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_get_objective_space, [:ccs_features_evaluation_t, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_get_configuration, [:ccs_features_evaluation_t, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_get_features, [:ccs_features_evaluation_t, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_get_user_data, [:ccs_features_evaluation_t, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_get_error, [:ccs_features_evaluation_t, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_set_error, [:ccs_features_evaluation_t, :ccs_result_t], :ccs_result_t
attach_function :ccs_features_evaluation_get_value, [:ccs_features_evaluation_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_set_value, [:ccs_features_evaluation_t, :size_t, :ccs_datum_t], :ccs_result_t
attach_function :ccs_features_evaluation_get_values, [:ccs_features_evaluation_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_get_value_by_name, [:ccs_features_evaluation_t, :string, :ccs_datum_t], :ccs_result_t
attach_function :ccs_features_evaluation_get_objective_value, [:ccs_features_evaluation_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_get_objective_values, [:ccs_features_evaluation_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_cmp, [:ccs_features_evaluation_t, :ccs_features_evaluation_t, :pointer], :ccs_result_t
class FeaturesEvaluation < Object
add_handle_property :objective_space, :ccs_objective_space_t, :ccs_features_evaluation_get_objective_space, memoize: true
attach_function :ccs_features_evaluation_compare, [:ccs_features_evaluation_t, :ccs_features_evaluation_t, :pointer], :ccs_result_t
attach_function :ccs_features_evaluation_check, [:ccs_features_evaluation_t], :ccs_result_t
class FeaturesEvaluation < Binding
alias objective_space context
add_handle_property :configuration, :ccs_configuration_t, :ccs_features_evaluation_get_configuration, memoize: true
add_handle_property :features, :ccs_features_t, :ccs_features_evaluation_get_features, memoize: true
add_property :user_data, :pointer, :ccs_features_evaluation_get_user_data, memoize: true
add_property :error, :ccs_result_t, :ccs_features_evaluation_get_error, memoize: false
def initialize(handle = nil, retain: false, auto_release: true,
......@@ -54,55 +48,6 @@ module CCS
err
end
def set_value(hyperparameter, value)
d = Datum.from_value(value)
case hyperparameter
when String, Symbol
hyperparameter = objective_space.hyperparameter_index_by_name(hyperparameter)
when Hyperparameter
hyperparameter = objective_space.hyperparameter_index(hyperparameter)
end
res = CCS.ccs_features_evaluation_set_value(@handle, hyperparameter, d)
CCS.error_check(res)
self
end
def value(hyperparameter)
ptr = MemoryPointer::new(:ccs_datum_t)
case hyperparameter
when String
res = CCS.ccs_features_evaluation_get_value_by_name(@handle, hyperparameter, ptr)
when Symbol
res = CCS.ccs_features_evaluation_get_value_by_name(@handle, hyperparameter.inspect, ptr)
when Hyperparameter
res = CCS.ccs_features_evaluation_get_value(@handle, objective_space.hyperparameter_index(hyperparameter), ptr)
when Integer
res = CCS.ccs_features_evaluation_get_value(@handle, hyperparameter, ptr)
else
raise CCSError, :CCS_INVALID_VALUE
end
CCS.error_check(res)
Datum::new(ptr).value
end
def num_values
@num_values ||= begin
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_features_evaluation_get_values(@handle, 0, nil, ptr)
CCS.error_check(res)
ptr.read_size_t
end
end
def values
count = num_values
return [] if count == 0
values = MemoryPointer::new(:ccs_datum_t, count)
res = CCS.ccs_features_evaluation_get_values(@handle, count, values, nil)
CCS.error_check(res)
count.times.collect { |i| Datum::new(values[i]).value }
end
def num_objective_values
@num_values ||= begin
ptr = MemoryPointer::new(:size_t)
......@@ -121,16 +66,22 @@ module CCS
count.times.collect { |i| Datum::new(values[i]).value }
end
def cmp(other)
ptr = MemoryPointer::new(:ccs_objective_type_t)
res = CCS.ccs_features_evaluation_cmp(@handle, other, ptr)
def check
res = CCS.ccs_features_evaluation_check(@handle)
CCS.error_check(res)
self
end
def compare(other)
ptr = MemoryPointer::new(:ccs_comparison_t)
res = CCS.ccs_features_evaluation_compare(@handle, other, ptr)
CCS.error_check(res)
ptr.read_ccs_comparison_t
end
def <=>(other)
ptr = MemoryPointer::new(:ccs_objective_type_t)
res = CCS.ccs_features_evaluation_cmp(@handle, other, ptr)
ptr = MemoryPointer::new(:ccs_comparison_t)
res = CCS.ccs_features_evaluation_compare(@handle, other, ptr)
CCS.error_check(res)
r = ptr.read_int32
r == 2 ? nil : r
......
......@@ -81,7 +81,7 @@ class CConfigSpaceTestFeaturesTuner < Minitest::Test
discard = false
optimums = optimums.collect { |o|
unless discard
case e.cmp(o)
case e.compare(o)
when :CCS_EQUIVALENT, :CCS_WORSE
discard = true
o
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment