Commit c9b22afc authored by Brice Videau's avatar Brice Videau

Better support for symbols as hyperparameter names.

parent 98fff1b4
......@@ -44,7 +44,7 @@ module CCS
def set_value(hyperparameter, value)
d = Datum.from_value(value)
case hyperparameter
when String
when String, Symbol
hyperparameter = configuration_space.hyperparameter_index_by_name(hyperparameter)
when Hyperparameter
hyperparameter = configuration_space.hyperparameter_index(hyperparameter)
......@@ -59,6 +59,8 @@ module CCS
case hyperparameter
when String
res = CCS.ccs_configuration_get_value_by_name(@handle, hyperparameter, ptr)
when Symbol
res = CCS.ccs_configuration_get_value_by_name(@handle, hyperparameter.inspect, ptr)
when Hyperparameter
res = CCS.ccs_configuration_get_value(@handle, configuration_space.hyperparameter_index(hyperparameter), ptr)
when Integer
......
......@@ -62,7 +62,7 @@ module CCS
case h
when Hyperparameter
hyperparameter_index(h)
when String
when String, Symbol
hyperparameter_index_by_name(hyperparameter)
else
h
......@@ -79,7 +79,7 @@ module CCS
case hyperparameter
when Hyperparameter
hyperparameter = hyperparameter_index(hyperparameter);
when String
when String, Symbol
hyperparameter = hyperparameter_index_by_name(hyperparameter);
end
p_distribution = MemoryPointer::new(:ccs_distribution_t)
......@@ -113,7 +113,7 @@ module CCS
case hyperparameter
when Hyperparameter
hyperparameter = hyperparameter_index(hyperparameter);
when String
when String, Symbol
hyperparameter = hyperparameter_index_by_name(hyperparameter);
end
res = CCS.ccs_configuration_space_set_condition(@handle, hyperparameter, expression)
......@@ -125,7 +125,7 @@ module CCS
case hyperparameter
when Hyperparameter
hyperparameter = hyperparameter_index(hyperparameter);
when String
when String, Symbol
hyperparameter = hyperparameter_index_by_name(hyperparameter);
end
ptr = MemoryPointer::new(:ccs_expression_t)
......
......@@ -29,6 +29,7 @@ module CCS
end
def hyperparameter_by_name(name)
name = name.inspect if name.kind_of?(Symbol)
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_context_get_hyperparameter_by_name(@handle, name, ptr)
CCS.error_check(res)
......@@ -36,6 +37,7 @@ module CCS
end
def hyperparameter_index_by_name(name)
name = name.inspect if name.kind_of?(Symbol)
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_context_get_hyperparameter_index_by_name(@handle, name, ptr)
CCS.error_check(res)
......
......@@ -65,7 +65,7 @@ module CCS
def set_value(hyperparameter, value)
d = Datum.from_value(value)
case hyperparameter
when String
when String, Symbol
hyperparameter = objective_space.hyperparameter_index_by_name(hyperparameter)
when Hyperparameter
hyperparameter = objective_space.hyperparameter_index(hyperparameter)
......@@ -80,6 +80,8 @@ module CCS
case hyperparameter
when String
res = CCS.ccs_evaluation_get_value_by_name(@handle, hyperparameter, ptr)
when Symbol
res = CCS.ccs_evaluation_get_value_by_name(@handle, hyperparameter.inspect, ptr)
when Hyperparameter
res = CCS.ccs_evaluation_get_value(@handle, objective_space.hyperparameter_index(hyperparameter), ptr)
when Integer
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment