Commit f9388fcb authored by Brice Videau's avatar Brice Videau

Refactoring context/configuration space/objective space.

parent 4f84993f
......@@ -6,19 +6,10 @@ from .expression import Expression
from .rng import Rng
ccs_create_configuration_space = _ccs_get_function("ccs_create_configuration_space", [ct.c_char_p, ct.c_void_p, ct.POINTER(ccs_configuration_space)])
ccs_configuration_space_get_name = _ccs_get_function("ccs_configuration_space_get_name", [ccs_configuration_space, ct.POINTER(ct.c_char_p)])
ccs_configuration_space_get_user_data = _ccs_get_function("ccs_configuration_space_get_user_data", [ccs_configuration_space, ct.POINTER(ct.c_void_p)])
ccs_configuration_space_set_rng = _ccs_get_function("ccs_configuration_space_set_rng", [ccs_configuration_space, ccs_rng])
ccs_configuration_space_get_rng = _ccs_get_function("ccs_configuration_space_get_rng", [ccs_configuration_space, ct.POINTER(ccs_rng)])
ccs_configuration_space_add_hyperparameter = _ccs_get_function("ccs_configuration_space_add_hyperparameter", [ccs_configuration_space, ccs_hyperparameter, ccs_distribution])
ccs_configuration_space_add_hyperparameters = _ccs_get_function("ccs_configuration_space_add_hyperparameters", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ccs_distribution)])
ccs_configuration_space_get_num_hyperparameters = _ccs_get_function("ccs_configuration_space_get_num_hyperparameters", [ccs_configuration_space, ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameter = _ccs_get_function("ccs_configuration_space_get_hyperparameter", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter)])
ccs_configuration_space_get_hyperparameter_by_name = _ccs_get_function("ccs_configuration_space_get_hyperparameter_by_name", [ccs_configuration_space, ct.c_char_p, ct.POINTER(ccs_hyperparameter)])
ccs_configuration_space_get_hyperparameter_index_by_name = _ccs_get_function("ccs_configuration_space_get_hyperparameter_index_by_name", [ccs_configuration_space, ct.c_char_p, ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameter_index = _ccs_get_function("ccs_configuration_space_get_hyperparameter_index", [ccs_configuration_space, ccs_hyperparameter, ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameter_indexes = _ccs_get_function("ccs_configuration_space_get_hyperparameter_indexes", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ct.c_size_t)])
ccs_configuration_space_get_hyperparameters = _ccs_get_function("ccs_configuration_space_get_hyperparameters", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ct.c_size_t)])
ccs_configuration_space_set_condition = _ccs_get_function("ccs_configuration_space_set_condition", [ccs_configuration_space, ct.c_size_t, ccs_expression])
ccs_configuration_space_get_condition = _ccs_get_function("ccs_configuration_space_get_condition", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_expression)])
ccs_configuration_space_get_conditions = _ccs_get_function("ccs_configuration_space_get_conditions", [ccs_configuration_space, ct.c_size_t, ct.POINTER(ccs_expression), ct.POINTER(ct.c_size_t)])
......@@ -46,26 +37,6 @@ class ConfigurationSpace(Context):
def from_handle(cls, handle):
return cls(handle = handle, retain = True)
@property
def user_data(self):
if hasattr(self, "_user_data"):
return self._user_data
v = ct.c_void_p()
res = ccs_configuration_space_get_user_data(self.handle, ct.byref(v))
Error.check(res)
self._user_data = v
return v
@property
def name(self):
if hasattr(self, "_name"):
return self._name
v = ct.c_char_p()
res = ccs_configuration_space_get_name(self.handle, ct.byref(v))
Error.check(res)
self._name = v.value.decode()
return self._name
@property
def rng(self):
v = ccs_rng()
......@@ -98,47 +69,6 @@ class ConfigurationSpace(Context):
res = ccs_configuration_space_add_hyperparameters(self.handle, count, hypers, distribs)
Error.check(res)
def hyperparameter(self, index):
v = ccs_hyperparameter()
res = ccs_configuration_space_get_hyperparameter(self.handle, index, ct.byref(v))
Error.check(res)
return Hyperparameter.from_handle(v)
def hyperparameter_by_name(self, name):
v = ccs_hyperparameter()
res = ccs_configuration_space_get_hyperparameter_by_name(self.handle, str.encode(name), ct.byref(v))
Error.check(res)
return Hyperparameter.from_handle(v)
def hyperparameter_index(self, hyperparameter):
v = ct.c_size_t()
res = ccs_configuration_space_get_hyperparameter_index(self.handle, hyperparameter.handle, ct.byref(v))
Error.check(res)
return v.value
def hyperparameter_index_by_name(self, name):
v = ct.c_size_t()
res = ccs_configuration_space_get_hyperparameter_index_by_name(self.handle, str.encode(name), ct.byref(v))
Error.check(res)
return v.value
@property
def num_hyperparameters(self):
v = ct.c_size_t(0)
res = ccs_configuration_space_get_num_hyperparameters(self.handle, ct.byref(v))
Error.check(res)
return v.value
@property
def hyperparameters(self):
count = self.num_hyperparameters
if count == 0:
return []
v = (ccs_hyperparameter * count)()
res = ccs_configuration_space_get_hyperparameters(self.handle, count, v, None)
Error.check(res)
return [Hyperparameter.from_handle(ccs_hyperparameter(x)) for x in v]
def set_condition(self, hyperparameter, expression):
if isinstance(hyperparameter, Hyperparameter):
hyperparameter = self.hyperparameter_index(hyperparameter)
......
import ctypes as ct
from .base import Object, Error, ccs_error, _ccs_get_function, ccs_context, ccs_hyperparameter
from .hyperparameter import Hyperparameter
ccs_context_get_name = _ccs_get_function("ccs_context_get_name", [ccs_context, ct.POINTER(ct.c_char_p)])
ccs_context_get_user_data = _ccs_get_function("ccs_context_get_user_data", [ccs_context, ct.POINTER(ct.c_void_p)])
ccs_context_get_num_hyperparameters = _ccs_get_function("ccs_context_get_num_hyperparameters", [ccs_context, ct.POINTER(ct.c_size_t)])
ccs_context_get_hyperparameter = _ccs_get_function("ccs_context_get_hyperparameter", [ccs_context, ct.c_size_t, ct.POINTER(ccs_hyperparameter)])
ccs_context_get_hyperparameter_by_name = _ccs_get_function("ccs_context_get_hyperparameter_by_name", [ccs_context, ct.c_char_p, ct.POINTER(ccs_hyperparameter)])
ccs_context_get_hyperparameter_index_by_name = _ccs_get_function("ccs_context_get_hyperparameter_index_by_name", [ccs_context, ct.c_char_p, ct.POINTER(ct.c_size_t)])
ccs_context_get_hyperparameter_index = _ccs_get_function("ccs_context_get_hyperparameter_index", [ccs_context, ccs_hyperparameter, ct.POINTER(ct.c_size_t)])
ccs_context_get_hyperparameter_indexes = _ccs_get_function("ccs_context_get_hyperparameter_indexes", [ccs_context, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ct.c_size_t)])
ccs_context_get_hyperparameters = _ccs_get_function("ccs_context_get_hyperparameters", [ccs_context, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ct.c_size_t)])
class Context(Object):
@property
def user_data(self):
if hasattr(self, "_user_data"):
return self._user_data
v = ct.c_void_p()
res = ccs_context_get_user_data(self.handle, ct.byref(v))
Error.check(res)
self._user_data = v
return v
@property
def name(self):
if hasattr(self, "_name"):
return self._name
v = ct.c_char_p()
res = ccs_context_get_name(self.handle, ct.byref(v))
Error.check(res)
self._name = v.value.decode()
return self._name
def hyperparameter(self, index):
v = ccs_hyperparameter()
res = ccs_context_get_hyperparameter(self.handle, index, ct.byref(v))
Error.check(res)
return Hyperparameter.from_handle(v)
def hyperparameter_by_name(self, name):
v = ccs_hyperparameter()
res = ccs_context_get_hyperparameter_by_name(self.handle, str.encode(name), ct.byref(v))
Error.check(res)
return Hyperparameter.from_handle(v)
def hyperparameter_index(self, hyperparameter):
v = ct.c_sizeof_t()
v = ct.c_size_t()
res = ccs_context_get_hyperparameter_index(self.handle, hyperparameter.handle, ct.byref(v))
Error.check(res)
return v.value
def hyperparameter_index_by_name(self, name):
v = ct.c_size_t()
res = ccs_context_get_hyperparameter_index_by_name(self.handle, str.encode(name), ct.byref(v))
Error.check(res)
return v.value
@property
def num_hyperparameters(self):
v = ct.c_size_t(0)
res = ccs_context_get_num_hyperparameters(self.handle, ct.byref(v))
Error.check(res)
return v.value
@property
def hyperparameters(self):
count = self.num_hyperparameters
if count == 0:
return []
v = (ccs_hyperparameter * count)()
res = ccs_context_get_hyperparameters(self.handle, count, v, None)
Error.check(res)
return [Hyperparameter.from_handle(ccs_hyperparameter(x)) for x in v]
......@@ -12,16 +12,8 @@ class ccs_objective_type(CEnumeration):
'MAXIMIZE' ]
ccs_create_objective_space = _ccs_get_function("ccs_create_objective_space", [ct.c_char_p, ct.c_void_p, ct.POINTER(ccs_objective_space)])
ccs_objective_space_get_name = _ccs_get_function("ccs_objective_space_get_name", [ccs_objective_space, ct.POINTER(ct.c_char_p)])
ccs_objective_space_get_user_data = _ccs_get_function("ccs_objective_space_get_user_data", [ccs_objective_space, ct.POINTER(ct.c_void_p)])
ccs_objective_space_add_hyperparameter = _ccs_get_function("ccs_objective_space_add_hyperparameter", [ccs_objective_space, ccs_hyperparameter])
ccs_objective_space_add_hyperparameters = _ccs_get_function("ccs_objective_space_add_hyperparameters", [ccs_objective_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter)])
ccs_objective_space_get_num_hyperparameters = _ccs_get_function("ccs_objective_space_get_num_hyperparameters", [ccs_objective_space, ct.POINTER(ct.c_size_t,)])
ccs_objective_space_get_hyperparameter = _ccs_get_function("ccs_objective_space_get_hyperparameter", [ccs_objective_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter)])
ccs_objective_space_get_hyperparameter_by_name = _ccs_get_function("ccs_objective_space_get_hyperparameter_by_name", [ccs_objective_space, ct.c_char_p, ct.POINTER(ccs_hyperparameter)])
ccs_objective_space_get_hyperparameter_index_by_name = _ccs_get_function("ccs_objective_space_get_hyperparameter_index_by_name", [ccs_objective_space, ct.c_char_p, ct.POINTER(ct.c_size_t)])
ccs_objective_space_get_hyperparameter_index = _ccs_get_function("ccs_objective_space_get_hyperparameter_index", [ccs_objective_space, ccs_hyperparameter, ct.POINTER(ct.c_size_t)])
ccs_objective_space_get_hyperparameters = _ccs_get_function("ccs_objective_space_get_hyperparameters", [ccs_objective_space, ct.c_size_t, ct.POINTER(ccs_hyperparameter), ct.POINTER(ct.c_size_t)])
ccs_objective_space_add_objective = _ccs_get_function("ccs_objective_space_add_objective", [ccs_objective_space, ccs_expression, ccs_objective_type])
ccs_objective_space_add_objectives = _ccs_get_function("ccs_objective_space_add_objectives", [ccs_objective_space, ct.c_size_t, ct.POINTER(ccs_expression), ct.POINTER(ccs_objective_type)])
ccs_objective_space_get_objective = _ccs_get_function("ccs_objective_space_get_objective", [ccs_objective_space, ct.c_size_t, ct.POINTER(ccs_expression), ct.POINTER(ccs_objective_type)])
......@@ -41,26 +33,6 @@ class ObjectiveSpace(Context):
def from_handle(cls, handle):
return cls(handle = handle, retain = True)
@property
def user_data(self):
if hasattr(self, "_user_data"):
return self._user_data
v = ct.c_void_p()
res = ccs_objective_space_get_user_data(self.handle, ct.byref(v))
Error.check(res)
self._user_data = v
return v
@property
def name(self):
if hasattr(self, "_name"):
return self._name
v = ct.c_char_p()
res = ccs_objective_space_get_name(self.handle, ct.byref(v))
Error.check(res)
self._name = v.value.decode()
return self._name
def add_hyperparameter(self, hyperparameter):
res = ccs_objective_space_add_hyperparameter(self.handle, hyperparameter.handle)
Error.check(res)
......@@ -73,47 +45,6 @@ class ObjectiveSpace(Context):
res = ccs_objective_space_add_hyperparameters(self.handle, count, hypers)
Error.check(res)
def hyperparameter(self, index):
v = ccs_hyperparameter()
res = ccs_objective_space_get_hyperparameter(self.handle, index, ct.byref(v))
Error.check(res)
return Hyperparameter.from_handle(v)
def hyperparameter_by_name(self, name):
v = ccs_hyperparameter()
res = ccs_objective_space_get_hyperparameter_by_name(self.handle, str.encode(name), ct.byref(v))
Error.check(res)
return Hyperparameter.from_handle(v)
def hyperparameter_index(self, hyperparameter):
v = ct.c_size_t()
res = ccs_objective_space_get_hyperparameter_index(self.handle, hyperparameter.handle, ct.byref(v))
Error.check(res)
return v.value
def hyperparameter_index_by_name(self, name):
v = ct.c_size_t()
res = ccs_objective_space_get_hyperparameter_index_by_name(self.handle, str.encode(name), ct.byref(v))
Error.check(res)
return v.value
@property
def num_hyperparameters(self):
v = ct.c_size_t(0)
res = ccs_objective_space_get_num_hyperparameters(self.handle, ct.byref(v))
Error.check(res)
return v.value
@property
def hyperparameters(self):
count = self.num_hyperparameters
if count == 0:
return []
v = (ccs_hyperparameter * count)()
res = ccs_objective_space_get_hyperparameters(self.handle, count, v, None)
Error.check(res)
return [Hyperparameter.from_handle(ccs_hyperparameter(x)) for x in v]
def add_objective(self, expression, t = ccs_objective_type.MINIMIZE):
res = ccs_objective_space_add_objective(self.handle, expression.handle, t)
Error.check(res)
......
module CCS
attach_function :ccs_create_configuration_space, [:string, :pointer, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_name, [:ccs_configuration_space_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_user_data, [:ccs_configuration_space_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_set_rng, [:ccs_configuration_space_t, :ccs_rng_t], :ccs_result_t
attach_function :ccs_configuration_space_get_rng, [:ccs_configuration_space_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_add_hyperparameter, [:ccs_configuration_space_t, :ccs_hyperparameter_t, :ccs_distribution_t], :ccs_result_t
attach_function :ccs_configuration_space_add_hyperparameters, [:ccs_configuration_space_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_num_hyperparameters, [:ccs_configuration_space_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_hyperparameter, [:ccs_configuration_space_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_hyperparameter_by_name, [:ccs_configuration_space_t, :string, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_hyperparameter_index_by_name, [:ccs_configuration_space_t, :string, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_hyperparameter_index, [:ccs_configuration_space_t, :ccs_hyperparameter_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_hyperparameter_indexes, [:ccs_configuration_space_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_hyperparameters, [:ccs_configuration_space_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_set_condition, [:ccs_configuration_space_t, :size_t, :ccs_expression_t], :ccs_result_t
attach_function :ccs_configuration_space_get_condition, [:ccs_configuration_space_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_conditions, [:ccs_configuration_space_t, :size_t, :pointer, :pointer], :ccs_result_t
......@@ -27,8 +18,6 @@ module CCS
attach_function :ccs_configuration_space_samples, [:ccs_configuration_space_t, :size_t, :pointer], :ccs_result_t
class ConfigurationSpace < Context
add_property :user_data, :pointer, :ccs_configuration_space_get_user_data, memoize: true
add_property :num_hyperparameters, :size_t, :ccs_configuration_space_get_num_hyperparameters, memoize: false
def initialize(handle = nil, retain: false, name: "", user_data: nil)
if handle
......@@ -45,15 +34,6 @@ module CCS
self::new(handle, retain: true)
end
def name
@name ||= begin
ptr = MemoryPointer::new(:pointer)
res = CCS.ccs_configuration_space_get_name(@handle, ptr)
CCS.error_check(res)
ptr.read_pointer.read_string
end
end
def rng
ptr = MemoryPointer::new(:ccs_rng_t)
res = CCS.ccs_configuration_space_get_rng(@handle, ptr)
......@@ -90,43 +70,6 @@ module CCS
self
end
def hyperparameter(index)
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_configuration_space_get_hyperparameter(@handle, index, ptr)
CCS.error_check(res)
Hyperparameter.from_handle(ptr.read_ccs_hyperparameter_t)
end
def hyperparameter_by_name(name)
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_configuration_space_get_hyperparameter_by_name(@handle, name, ptr)
CCS.error_check(res)
Hyperparameter.from_handle(ptr.read_ccs_hyperparameter_t)
end
def hyperparameter_index(hyperparameter)
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_configuration_space_get_hyperparameter_index(@handle, hyperparameter, ptr)
CCS.error_check(res)
ptr.read_size_t
end
def hyperparameter_index_by_name(name)
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_configuration_space_get_hyperparameter_index_by_name(@handle, name, ptr)
CCS.error_check(res)
ptr.read_size_t
end
def hyperparameters
count = num_hyperparameters
return [] if count == 0
ptr = MemoryPointer::new(:ccs_hyperparameter_t, count)
res = CCS.ccs_configuration_space_get_hyperparameters(@handle, count, ptr, nil)
CCS.error_check(res)
count.times.collect { |i| Hyperparameter.from_handle(ptr[i].read_pointer) }
end
def set_condition(hyperparameter, expression)
if expression.kind_of? String
expression = ExpressionParser::new(self).parse(expression)
......
module CCS
attach_function :ccs_context_get_name, [:ccs_context_t, :pointer], :ccs_result_t
attach_function :ccs_context_get_user_data, [:ccs_context_t, :pointer], :ccs_result_t
attach_function :ccs_context_get_num_hyperparameters, [:ccs_context_t, :pointer], :ccs_result_t
attach_function :ccs_context_get_hyperparameter, [:ccs_context_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_context_get_hyperparameter_by_name, [:ccs_context_t, :string, :pointer], :ccs_result_t
attach_function :ccs_context_get_hyperparameter_index_by_name, [:ccs_context_t, :string, :pointer], :ccs_result_t
attach_function :ccs_context_get_hyperparameter_index, [:ccs_context_t, :ccs_hyperparameter_t, :pointer], :ccs_result_t
attach_function :ccs_context_get_hyperparameter_indexes, [:ccs_context_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_context_get_hyperparameters, [:ccs_context_t, :size_t, :pointer, :pointer], :ccs_result_t
class Context < Object
add_property :user_data, :pointer, :ccs_context_get_user_data, memoize: true
add_property :num_hyperparameters, :size_t, :ccs_context_get_num_hyperparameters, memoize: false
def name
@name ||= begin
ptr = MemoryPointer::new(:pointer)
res = CCS.ccs_context_get_name(@handle, ptr)
CCS.error_check(res)
ptr.read_pointer.read_string
end
end
def hyperparameter(index)
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_context_get_hyperparameter(@handle, index, ptr)
CCS.error_check(res)
Hyperparameter.from_handle(ptr.read_ccs_hyperparameter_t)
end
def hyperparameter_by_name(name)
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_context_get_hyperparameter_by_name(@handle, name, ptr)
CCS.error_check(res)
Hyperparameter.from_handle(ptr.read_ccs_hyperparameter_t)
end
def hyperparameter_index_by_name(name)
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_context_get_hyperparameter_index_by_name(@handle, name, ptr)
CCS.error_check(res)
ptr.read_size_t
end
def hyperparameter_index(hyperparameter)
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_context_get_hyperparameter_index(@handle, hyperparameter, ptr)
CCS.error_check(res)
ptr.read_size_t
end
def hyperparameters
count = num_hyperparameters
return [] if count == 0
ptr = MemoryPointer::new(:ccs_hyperparameter_t, count)
res = CCS.ccs_context_get_hyperparameters(@handle, count, ptr, nil)
CCS.error_check(res)
count.times.collect { |i| Hyperparameter.from_handle(ptr[i].read_pointer) }
end
end
end
......@@ -19,24 +19,14 @@ module CCS
end
attach_function :ccs_create_objective_space, [:string, :pointer, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_name, [:ccs_objective_space_t, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_user_data, [:ccs_objective_space_t, :pointer], :ccs_result_t
attach_function :ccs_objective_space_add_hyperparameter, [:ccs_objective_space_t, :ccs_hyperparameter_t], :ccs_result_t
attach_function :ccs_objective_space_add_hyperparameters, [:ccs_objective_space_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_num_hyperparameters, [:ccs_objective_space_t, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_hyperparameter, [:ccs_objective_space_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_hyperparameter_by_name, [:ccs_objective_space_t, :string, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_hyperparameter_index_by_name, [:ccs_objective_space_t, :string, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_hyperparameter_index, [:ccs_objective_space_t, :ccs_hyperparameter_t, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_hyperparameters, [:ccs_objective_space_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_objective_space_add_objective, [:ccs_objective_space_t, :ccs_expression_t, :ccs_objective_type_t], :ccs_result_t
attach_function :ccs_objective_space_add_objectives, [:ccs_objective_space_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_objective, [:ccs_objective_space_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_objective_space_get_objectives, [:ccs_objective_space_t, :size_t, :pointer, :pointer, :pointer], :ccs_result_t
class ObjectiveSpace < Object
add_property :user_data, :pointer, :ccs_objective_space_get_user_data, memoize: true
add_property :num_hyperparameters, :size_t, :ccs_objective_space_get_num_hyperparameters, memoize: false
class ObjectiveSpace < Context
def initialize(handle = nil, retain: false, name: "", user_data: nil)
if handle
......@@ -53,15 +43,6 @@ module CCS
self::new(handle, retain: true)
end
def name
@name ||= begin
ptr = MemoryPointer::new(:pointer)
res = CCS.ccs_objective_space_get_name(@handle, ptr)
CCS.error_check(res)
ptr.read_pointer.read_string
end
end
def add_hyperparameter(hyperparameter)
res = CCS.ccs_objective_space_add_hyperparameter(@handle, hyperparameter)
CCS.error_check(res)
......@@ -78,43 +59,6 @@ module CCS
self
end
def hyperparameter(index)
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_objective_space_get_hyperparameter(@handle, index, ptr)
CCS.error_check(res)
Hyperparameter.from_handle(ptr.read_ccs_hyperparameter_t)
end
def hyperparameter_by_name(name)
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_objective_space_get_hyperparameter_by_name(@handle, name, ptr)
CCS.error_check(res)
Hyperparameter.from_handle(ptr.read_ccs_hyperparameter_t)
end
def hyperparameter_index(hyperparameter)
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_objective_space_get_hyperparameter_index(@handle, hyperparameter, ptr)
CCS.error_check(res)
ptr.read_size_t
end
def hyperparameter_index_by_name(name)
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_objective_space_get_hyperparameter_index_by_name(@handle, name, ptr)
CCS.error_check(res)
ptr.read_size_t
end
def hyperparameters
count = num_hyperparameters
return [] if count == 0
ptr = MemoryPointer::new(:ccs_hyperparameter_t, count)
res = CCS.ccs_objective_space_get_hyperparameters(@handle, count, ptr, nil)
CCS.error_check(res)
count.times.collect { |i| Hyperparameter.from_handle(ptr[i].read_pointer) }
end
def add_objective(expression, type: :CCS_MINIMIZE)
if expression.kind_of? String
expression = ExpressionParser::new(self).parse(expression)
......
......@@ -5,11 +5,51 @@
extern "C" {
#endif
extern ccs_result_t
ccs_context_get_name(ccs_context_t context,
const char **name_ret);
extern ccs_result_t
ccs_context_get_user_data(ccs_context_t context,
void **user_data_ret);
extern ccs_result_t
ccs_context_get_hyperparameter_index(ccs_context_t context,
ccs_hyperparameter_t hyperparameter,
size_t *index_ret);
extern ccs_result_t
ccs_context_get_num_hyperparameters(ccs_context_t context,
size_t *num_hyperparameters_ret);
extern ccs_result_t
ccs_context_get_hyperparameter(ccs_context_t context,
size_t index,
ccs_hyperparameter_t *hyperparameter_ret);
extern ccs_result_t
ccs_context_get_hyperparameter_by_name(ccs_context_t context,
const char * name,
ccs_hyperparameter_t *hyperparameter_ret);
extern ccs_result_t
ccs_context_get_hyperparameter_index_by_name(ccs_context_t context,
const char *name,
size_t *index_ret);
extern ccs_result_t
ccs_context_get_hyperparameters(ccs_context_t context,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,
size_t *num_hyperparameters_ret);
extern ccs_result_t
ccs_context_get_hyperparameter_indexes(
ccs_context_t context,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,
size_t *indexes);
#ifdef __cplusplus
}
#endif
......
This diff is collapsed.
#ifndef _CONFIGURATION_SPACE_INTERNAL_H
#define _CONFIGURATION_SPACE_INTERNAL_H
#include "utarray.h"
#include "context_internal.h"
#define HASH_NONFATAL_OOM 1
#include "uthash.h"
struct _ccs_distribution_wrapper_s;
typedef struct _ccs_distribution_wrapper_s _ccs_distribution_wrapper_t;
struct _ccs_hyperparameter_wrapper_s {
struct _ccs_hyperparameter_wrapper_cs_s {
ccs_hyperparameter_t hyperparameter;
size_t index;
const char *name;
......@@ -20,7 +17,7 @@ struct _ccs_hyperparameter_wrapper_s {
UT_array *parents;
UT_array *children;
};
typedef struct _ccs_hyperparameter_wrapper_s _ccs_hyperparameter_wrapper_t;
typedef struct _ccs_hyperparameter_wrapper_cs_s _ccs_hyperparameter_wrapper_cs_t;
struct _ccs_distribution_wrapper_s {
ccs_distribution_t distribution;
......@@ -44,16 +41,16 @@ struct _ccs_configuration_space_s {
};
struct _ccs_configuration_space_data_s {
const char *name;
void *user_data;
ccs_rng_t rng;
UT_array *hyperparameters;
_ccs_hyperparameter_wrapper_t *name_hash;
_ccs_hyperparameter_wrapper_t *handle_hash;
_ccs_distribution_wrapper_t *distribution_list;
UT_array *forbidden_clauses;
ccs_bool_t graph_ok;
UT_array *sorted_indexes;
const char *name;
void *user_data;
UT_array *hyperparameters;
_ccs_hyperparameter_wrapper_cs_t *name_hash;
_ccs_hyperparameter_wrapper_cs_t *handle_hash;
ccs_rng_t rng;
_ccs_distribution_wrapper_t *distribution_list;
UT_array *forbidden_clauses;
ccs_bool_t graph_ok;
UT_array *sorted_indexes;
};
#endif //_CONFIGURATION_SPACE_INTERNAL_H
......@@ -13,12 +13,85 @@ ccs_context_get_hyperparameter_index(
size_t *index_ret) {
if (!context || !context->data)
return -CCS_INVALID_OBJECT;
if (!hyperparameter)
return -CCS_INVALID_HYPERPARAMETER;
if (!index_ret)
return -CCS_INVALID_VALUE;
_ccs_context_ops_t *ops = ccs_context_get_ops(context);
return ops->get_hyperparameter_index(context->data, hyperparameter, index_ret);
return _ccs_context_get_hyperparameter_index(
context, hyperparameter, index_ret);
}
ccs_result_t
ccs_context_get_num_hyperparameters(
ccs_context_t context,
size_t *num_hyperparameters_ret) {
if (!context || !context->data)
return -CCS_INVALID_OBJECT;
return _ccs_context_get_num_hyperparameters(context, num_hyperparameters_ret);
}
ccs_result_t
ccs_context_get_hyperparameter(
ccs_context_t context,
size_t index,
ccs_hyperparameter_t *hyperparameter_ret) {
if (!context || !context->data)
return -CCS_INVALID_OBJECT;
return _ccs_context_get_hyperparameter(context, index, hyperparameter_ret);
}
ccs_result_t
ccs_context_get_hyperparameter_by_name(
ccs_context_t context,
const char * name,
ccs_hyperparameter_t *hyperparameter_ret) {
if (!context || !context->data)
return -CCS_INVALID_OBJECT;
return _ccs_context_get_hyperparameter_by_name(context, name, hyperparameter_ret);
}
ccs_result_t
ccs_context_get_hyperparameter_index_by_name(
ccs_context_t context,
const char *name,
size_t *index_ret) {
if (!context || !context->data)
return -CCS_INVALID_OBJECT;
return _ccs_context_get_hyperparameter_index_by_name(context, name, index_ret);
}
ccs_result_t
ccs_context_get_hyperparameters(
ccs_context_t context,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,
size_t *num_hyperparameters_ret) {
if (!context || !context->data)
return -CCS_INVALID_OBJECT;
return _ccs_context_get_hyperparameters(context, num_hyperparameters,
hyperparameters, num_hyperparameters_ret);
}
ccs_result_t
ccs_context_get_hyperparameter_indexes(
ccs_context_t context,
size_t num_hyperparameters,
ccs_hyperparameter_t *hyperparameters,