Commit 01498aaf authored by Brice Videau's avatar Brice Videau

Added user defined tuner.

parent b04251e1
import ctypes as ct
from .base import Object, Error, CEnumeration, ccs_error, ccs_result, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_datum, ccs_objective_space, ccs_evaluation, ccs_tuner
from .base import Object, Error, CEnumeration, ccs_error, ccs_result, _ccs_get_function, ccs_context, ccs_hyperparameter, ccs_configuration_space, ccs_configuration, ccs_datum, ccs_objective_space, ccs_evaluation, ccs_tuner, ccs_retain_object
from .context import Context
from .hyperparameter import Hyperparameter
from .configuration_space import ConfigurationSpace
......@@ -142,3 +142,145 @@ class RandomTuner(Tuner):
else:
super().__init__(handle = handle, retain = retain)
class ccs_tuner_common_data(ct.Structure):
_fields_ = [
('type', ccs_tuner_type),
('name', ct.c_char_p),
('user_data', ct.c_void_p),
('configuration_space', ccs_configuration_space),
('objective_space', ccs_objective_space) ]
ccs_user_defined_tuner_del_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p)
ccs_user_defined_tuner_ask_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p, ct.c_size_t, ct.POINTER(ccs_configuration), ct.POINTER(ct.c_size_t))
ccs_user_defined_tuner_tell_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p, ct.c_size_t, ct.POINTER(ccs_evaluation))
ccs_user_defined_tuner_get_optimums_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.POINTER(ct.c_size_t))
ccs_user_defined_tuner_get_history_type = ct.CFUNCTYPE(ccs_result, ct.c_void_p, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.POINTER(ct.c_size_t))
class ccs_user_defined_tuner_vector(ct.Structure):
_fields_ = [
('delete', ccs_user_defined_tuner_del_type),
('ask', ccs_user_defined_tuner_ask_type),
('tell', ccs_user_defined_tuner_tell_type),
('get_optimums', ccs_user_defined_tuner_get_optimums_type),
('get_history', ccs_user_defined_tuner_get_history_type) ]
class ccs_user_defined_tuner_data(ct.Structure):
_fields_ = [
('common_data', ccs_tuner_common_data),
('vector', ccs_user_defined_tuner_vector),
('tuner_data', ct.c_void_p) ]
ccs_create_user_defined_tuner = _ccs_get_function("ccs_create_user_defined_tuner", [ct.c_char_p, ccs_configuration_space, ccs_objective_space, ct.c_void_p, ct.POINTER(ccs_user_defined_tuner_vector), ct.c_void_p, ct.POINTER(ccs_tuner)])
class UserDefinedTuner(Tuner):
callbacks = {}
def __init__(self, handle = None, retain = False, name = None, configuration_space = None, objective_space = None, user_data = None, delete = None, ask = None, tell = None, get_optimums = None, get_history = None, tuner_data = None ):
if handle is None:
if delete is None or ask is None or tell is None or get_optimums is None or get_history is None:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
def delete_wrapper(data):
try:
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
delete(data.contents)
del UserDefinedTuner.callbacks[self]
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
def ask_wrapper(data, count, p_configurations, p_count):
try:
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
p_confs = ct.cast(p_configurations, ct.c_void_p)
p_c = ct.cast(p_count, ct.c_void_p)
(configurations, count_ret) = ask(data.contents, count if p_confs.value else None)
if p_confs.value is not None and count < count_ret:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
if p_confs.value is not None:
for i in range(len(configurations)):
res = ccs_retain_object(configurations[i].handle)
Error.check(res)
p_configurations[i] = configurations[i].handle.value
for i in range(len(configurations), count):
p_configurations[i] = None
if p_c.value is not None:
p_count[0] = count_ret
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
def tell_wrapper(data, count, p_evaluations):
try:
if count == 0:
return ccs_error.SUCCESS
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
p_evals = ct.cast(p_evaluations, ct.c_void_p)
if p_evals.value is None:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
evals = [Evaluation.from_handle(ccs_evaluation(p_evaluations[i])) for i in range(count)]
tell(data.contents, evals)
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
def get_optimums_wrapper(data, count, p_evaluations, p_count):
try:
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
p_evals = ct.cast(p_evaluations, ct.c_void_p)
p_c = ct.cast(p_count, ct.c_void_p)
optimums = get_optimums(data.contents)
count_ret = len(optimums)
if p_evals.value is not None and count < count_ret:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
if p_evals.value is not None:
for i in range(count_ret):
p_evaluations[i] = optimums[i].handle.value
for i in range(count_ret, count):
p_evaluations[i] = None
if p_c.value is not None:
p_count[0] = count_ret
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
def get_history_wrapper(data, count, p_evaluations, p_count):
try:
data = ct.cast(data, ct.POINTER(ccs_user_defined_tuner_data))
p_evals = ct.cast(p_evaluations, ct.c_void_p)
p_c = ct.cast(p_count, ct.c_void_p)
history = get_history(data.contents)
count_ret = len(history)
if p_evals.value is not None and count < count_ret:
raise Error(ccs_error(ccs_error.INVALID_VALUE))
if p_evals.value is not None:
for i in range(count_ret):
p_evaluations[i] = history[i].handle.value
for i in range(count_ret, count):
p_evaluations[i] = None
if p_c.value is not None:
p_count[0] = count_ret
return ccs_error.SUCCESS
except Error as e:
return -e.message.value
handle = ccs_tuner()
vec = ccs_user_defined_tuner_vector()
delete_wrapper_func = ccs_user_defined_tuner_del_type(delete_wrapper)
vec.delete = delete_wrapper_func
ask_wrapper_func = ccs_user_defined_tuner_ask_type(ask_wrapper)
vec.ask = ask_wrapper_func
tell_wrapper_func = ccs_user_defined_tuner_tell_type(tell_wrapper)
vec.tell = tell_wrapper_func
get_optimums_wrapper_func = ccs_user_defined_tuner_get_optimums_type(get_optimums_wrapper)
vec.get_optimums = get_optimums_wrapper_func
get_history_wrapper_func = ccs_user_defined_tuner_get_history_type(get_history_wrapper)
vec.get_history = get_history_wrapper_func
res = ccs_create_user_defined_tuner(str.encode(name), configuration_space.handle, objective_space.handle, user_data, ct.byref(vec), tuner_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
UserDefinedTuner.callbacks[self] = [delete_wrapper, ask_wrapper, tell_wrapper, get_optimums_wrapper, get_history_wrapper, delete_wrapper_func, ask_wrapper_func, tell_wrapper_func, get_optimums_wrapper_func, get_history_wrapper_func]
else:
super().__init__(handle = handle, retain = retain)
......@@ -38,7 +38,76 @@ class TestTuner(unittest.TestCase):
optims = t.optimums
objs = [x.objective_values for x in optims]
objs.sort(key = lambda x: x[0])
self.assertTrue(all(objs[i] <= objs[i+1] for i in range(len(objs)-1)))
# assert pareto front
self.assertTrue(all(objs[i][1] >= objs[i+1][1] for i in range(len(objs)-1)))
def test_user_defined(self):
global history
history = []
global optimums
optimums = []
def delete(data):
return None
def ask(data, count):
if count is None:
return (None, 1)
else:
cs = ccs.ConfigurationSpace.from_handle(ccs.ccs_configuration_space(data.common_data.configuration_space))
return (cs.samples(count), count)
def tell(data, evaluations):
global history
global optimums
history += evaluations
for e in evaluations:
discard = False
new_optimums = []
for o in optimums:
if discard:
new_optimums.append(o)
else:
c = e.cmp(o).value
if c == ccs.EQUIVALENT or c == ccs.WORSE:
discard = True
new_optimums.append(o)
elif c == ccs.NOT_COMPARABLE:
new_optimums.append(o)
if not discard:
new_optimums.append(e)
optimums = new_optimums
return None
def get_history(data):
global history
return history
def get_optimums(data):
global optimums
return optimums
(cs, os) = self.create_tuning_problem()
t = ccs.UserDefinedTuner(name = "tuner", configuration_space = cs, objective_space = os, delete = delete, ask = ask, tell = tell, get_optimums = get_optimums, get_history = get_history)
self.assertEqual("tuner", t.name)
self.assertEqual(ccs.TUNER_USER_DEFINED, t.type.value)
self.assertEqual(cs.handle.value, t.configuration_space.handle.value)
self.assertEqual(os.handle.value, t.objective_space.handle.value)
func = lambda x, y, z: [(x-2)*(x-2), sin(z+y)]
evals = [ccs.Evaluation(objective_space = os, configuration = c, values = func(*(c.values))) for c in t.ask(100)]
t.tell(evals)
hist = t.history
self.assertEqual(100, len(hist))
evals = [ccs.Evaluation(objective_space = os, configuration = c, values = func(*(c.values))) for c in t.ask(100)]
t.tell(evals)
hist = t.history
self.assertEqual(200, len(hist))
optims = t.optimums
objs = [x.objective_values for x in optims]
objs.sort(key = lambda x: x[0])
# assert pareto front
self.assertTrue(all(objs[i][1] >= objs[i+1][1] for i in range(len(objs)-1)))
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment