Commit 01498aaf authored by Brice Videau's avatar Brice Videau
Browse files

Added user defined tuner.

parent b04251e1
import ctypes as ct import ctypes as ct
from .base import Object, Error, CEnumeration, ccs_error, ccs_result, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_datum, ccs_objective_space, ccs_evaluation, ccs_tuner from .base import Object, Error, CEnumeration, ccs_error, ccs_result, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_datum, ccs_objective_space, ccs_evaluation, ccs_tuner, ccs_retain_object
from .context import Context from .context import Context
from .hyperparameter import Hyperparameter from .hyperparameter import Hyperparameter
from .configuration_space import ConfigurationSpace from .configuration_space import ConfigurationSpace
...@@ -142,3 +142,145 @@ class RandomTuner(Tuner): ...@@ -142,3 +142,145 @@ class RandomTuner(Tuner):
else: else:
super().__init__(handle = handle, retain = retain) super().__init__(handle = handle, retain = retain)
class ccs_tuner_common_data(ct.Structure):
_fields_ = [
('type', ccs_tuner_type),
('name', ct.c_char_p),
('user_data', ct.c_void_p),
('configuration_space', ccs_configuration_space),
('objective_space', ccs_objective_space) ]
ccs_user_defined_tuner_del_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p)
ccs_user_defined_tuner_ask_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p, ct.c_size_t, ct.POINTER(ccs_configuration), ct.POINTER(ct.c_size_t))
ccs_user_defined_tuner_tell_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p, ct.c_size_t, ct.POINTER(ccs_evaluation))
ccs_user_defined_tuner_get_optimums_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.POINTER(ct.c_size_t))
ccs_user_defined_tuner_get_history_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.POINTER(ct.c_size_t))
class ccs_user_defined_tuner_vector(ct.Structure):
_fields_ = [
('delete', ccs_user_defined_tuner_del_type),
('ask', ccs_user_defined_tuner_ask_type),
('tell', ccs_user_defined_tuner_tell_type),
('get_optimums', ccs_user_defined_tuner_get_optimums_type),
('get_history', ccs_user_defined_tuner_get_history_type) ]
class ccs_user_defined_tuner_data(ct.Structure):
_fields_ = [
('common_data', ccs_tuner_common_data),
('vector', ccs_user_defined_tuner_vector),
('tuner_data', ct.c_void_p) ]
ccs_create_user_defined_tuner = _ccs_get_function("ccs_create_user_defined_tuner", [ct.c_char_p, ccs_configuration_space, ccs_objective_space, ct.c_void_p, ct.POINTER(ccs_user_defined_tuner_vector), ct.c_void_p, ct.POINTER(ccs_tuner)])
class UserDefinedTuner(Tuner):
callbacks = {}
def __init__(self, handle = None, retain = False, name = None, configuration_space = None, objective_space = None, user_data = None, delete = None, ask = None, tell = None, get_optimums = None, get_history = None, tuner_data = None ):
if handle is None:
if delete is None or ask is None or tell is None or get_optimums is None or get_history is None:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
def delete_wrapper(data):
try:
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
delete(data.contents)
del UserDefinedTuner.callbacks[self]
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
def ask_wrapper(data, count, p_configurations, p_count):
try:
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
p_confs = ct.cast(p_configurations, ct.c_void_p)
p_c = ct.cast(p_count, ct.c_void_p)
(configurations, count_ret) = ask(data.contents, count if p_confs.value else None)
if p_confs.value is not None and count < count_ret:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
if p_confs.value is not None:
for i in range(len(configurations)):
res = ccs_retain_object(configurations[i].handle)
Error.check(res)
p_configurations[i] = configurations[i].handle.value
for i in range(len(configurations), count):
p_configurations[i] = None
if p_c.value is not None:
p_count[0] = count_ret
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
def tell_wrapper(data, count, p_evaluations):
try:
if count == 0:
return ccs_error.SUCCESS
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
p_evals = ct.cast(p_evaluations, ct.c_void_p)
if p_evals.value is None:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
evals = [Evaluation.from_handle(ccs_evaluation(p_evaluations[i])) for i in range(count)]
tell(data.contents, evals)
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
def get_optimums_wrapper(data, count, p_evaluations, p_count):
try:
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
p_evals = ct.cast(p_evaluations, ct.c_void_p)
p_c = ct.cast(p_count, ct.c_void_p)
optimums = get_optimums(data.contents)
count_ret = len(optimums)
if p_evals.value is not None and count < count_ret:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
if p_evals.value is not None:
for i in range(count_ret):
p_evaluations[i] = optimums[i].handle.value
for i in range(count_ret, count):
p_evaluations[i] = None
if p_c.value is not None:
p_count[0] = count_ret
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
def get_history_wrapper(data, count, p_evaluations, p_count):
try:
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
p_evals = ct.cast(p_evaluations, ct.c_void_p)
p_c = ct.cast(p_count, ct.c_void_p)
history = get_history(data.contents)
count_ret = len(history)
if p_evals.value is not None and count < count_ret:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
if p_evals.value is not None:
for i in range(count_ret):
p_evaluations[i] = history[i].handle.value
for i in range(count_ret, count):
p_evaluations[i] = None
if p_c.value is not None:
p_count[0] = count_ret
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
handle = ccs_tuner()
vec = ccs_user_defined_tuner_vector()
delete_wrapper_func = ccs_user_defined_tuner_del_type(delete_wrapper)
vec.delete = delete_wrapper_func
ask_wrapper_func = ccs_user_defined_tuner_ask_type(ask_wrapper)
vec.ask = ask_wrapper_func
tell_wrapper_func = ccs_user_defined_tuner_tell_type(tell_wrapper)
vec.tell = tell_wrapper_func
get_optimums_wrapper_func = ccs_user_defined_tuner_get_optimums_type(get_optimums_wrapper)
vec.get_optimums = get_optimums_wrapper_func
get_history_wrapper_func = ccs_user_defined_tuner_get_history_type(get_history_wrapper)
vec.get_history = get_history_wrapper_func
res = ccs_create_user_defined_tuner(str.encode(name), configuration_space.handle, objective_space.handle, user_data, ct.byref(vec), tuner_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
UserDefinedTuner.callbacks[self] = [delete_wrapper, ask_wrapper, tell_wrapper, get_optimums_wrapper, get_history_wrapper, delete_wrapper_func, ask_wrapper_func, tell_wrapper_func, get_optimums_wrapper_func, get_history_wrapper_func]
else:
super().__init__(handle = handle, retain = retain)
...@@ -38,7 +38,76 @@ class TestTuner(unittest.TestCase): ...@@ -38,7 +38,76 @@ class TestTuner(unittest.TestCase):
optims = t.optimums optims = t.optimums
objs = [x.objective_values for x in optims] objs = [x.objective_values for x in optims]
objs.sort(key = lambda x: x[0]) objs.sort(key = lambda x: x[0])
self.assertTrue(all(objs[i] <= objs[i+1] for i in range(len(objs)-1))) # assert pareto front
self.assertTrue(all(objs[i][1] >= objs[i+1][1] for i in range(len(objs)-1)))
def test_user_defined(self):
global history
history = []
global optimums
optimums = []
def delete(data):
return None
def ask(data, count):
if count is None:
return (None, 1)
else:
cs = ccs.ConfigurationSpace.from_handle(ccs.ccs_configuration_space(data.common_data.configuration_space))
return (cs.samples(count), count)
def tell(data, evaluations):
global history
global optimums
history += evaluations
for e in evaluations:
discard = False
new_optimums = []
for o in optimums:
if discard:
new_optimums.append(o)
else:
c = e.cmp(o).value
if c == ccs.EQUIVALENT or c == ccs.WORSE:
discard = True
new_optimums.append(o)
elif c == ccs.NOT_COMPARABLE:
new_optimums.append(o)
if not discard:
new_optimums.append(e)
optimums = new_optimums
return None
def get_history(data):
global history
return history
def get_optimums(data):
global optimums
return optimums
(cs, os) = self.create_tuning_problem()
t = ccs.UserDefinedTuner(name = "tuner", configuration_space = cs, objective_space = os, delete = delete, ask = ask, tell = tell, get_optimums = get_optimums, get_history = get_history)
self.assertEqual("tuner", t.name)
self.assertEqual(ccs.TUNER_USER_DEFINED, t.type.value)
self.assertEqual(cs.handle.value, t.configuration_space.handle.value)
self.assertEqual(os.handle.value, t.objective_space.handle.value)
func = lambda x, y, z: [(x-2)*(x-2), sin(z+y)]
evals = [ccs.Evaluation(objective_space = os, configuration = c, values = func(*(c.values))) for c in t.ask(100)]
t.tell(evals)
hist = t.history
self.assertEqual(100, len(hist))
evals = [ccs.Evaluation(objective_space = os, configuration = c, values = func(*(c.values))) for c in t.ask(100)]
t.tell(evals)
hist = t.history
self.assertEqual(200, len(hist))
optims = t.optimums
objs = [x.objective_values for x in optims]
objs.sort(key = lambda x: x[0])
# assert pareto front
self.assertTrue(all(objs[i][1] >= objs[i+1][1] for i in range(len(objs)-1)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment