Commit ada52eb1 authored by Brice Videau's avatar Brice Videau

Added support for discrete hyperparameters to bindings.

parent 4a9490e3
......@@ -10,7 +10,8 @@ module CCS
HyperparameterType = enum FFI::Type::INT32, :ccs_hyperparameter_type_t, [
:CCS_NUMERICAL,
:CCS_CATEGORICAL,
:CCS_ORDINAL
:CCS_ORDINAL,
:CCS_DISCRETE
]
class MemoryPointer
def read_ccs_hyperparameter_type_t
......@@ -47,6 +48,8 @@ module CCS
CategoricalHyperparameter::new(handle, retain: true)
when :CCS_ORDINAL
OrdinalHyperparameter::new(handle, retain: true)
when :CCS_DISCRETE
DiscreteHyperparameter::new(handle, retain: true)
else
raise CCSError, :CCS_INVALID_HYPERPARAMETER
end
......@@ -277,4 +280,36 @@ module CCS
end
end
attach_function :ccs_create_discrete_hyperparameter, [:string, :size_t, :pointer, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_discrete_hyperparameter_get_values, [:ccs_hyperparameter_t, :size_t, :pointer, :pointer], :ccs_result_t
class DiscreteHyperparameter < Hyperparameter
def initialize(handle = nil, retain: false, name: Hyperparameter.default_name, values: [], default_index: 0, user_data: nil)
if handle
super(handle, retain: retain)
else
count = values.size
return [] if count == 0
vals = MemoryPointer::new(:ccs_datum_t, count)
values.each_with_index{ |v, i| Datum::new(vals[i]).value = v }
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_create_discrete_hyperparameter(name, count, vals, default_index, user_data, ptr)
CCS.error_check(res)
super(ptr.read_ccs_hyperparameter_t, retain: false)
end
end
def values
@values ||= begin
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_discrete_hyperparameter_get_values(@handle, 0, nil, ptr)
CCS.error_check(res)
count = ptr.read_size_t
ptr = MemoryPointer::new(:ccs_datum_t, count)
res = CCS.ccs_discrete_hyperparameter_get_values(@handle, count, ptr, nil)
CCS.error_check(res)
count.times.collect { |i| Datum::new(ptr[i]).value }
end
end
end
end
......@@ -7,6 +7,31 @@ class CConfigSpaceTestHyperparameter < Minitest::Test
CCS.init
end
def test_from_handle_discrete
values = [0, 1.5, 2, 7.2]
h = CCS::DiscreteHyperparameter::new(values: values)
h2 = CCS::Object::from_handle(h)
assert_equal( h.class, h2.class )
end
def test_discrete
values = [0, 1.5, 2, 7.2]
h = CCS::DiscreteHyperparameter::new(values: values)
assert_equal( :CCS_HYPERPARAMETER, h.object_type )
assert_equal( :CCS_DISCRETE, h.type )
assert_match( /param/, h.name )
assert( h.user_data.null? )
assert_equal( 0, h.default_value )
values.each { |v| assert( h.check_value(v) ) }
refute( h.check_value("foo") )
v = h.sample
assert( values.include? v )
vals = h.samples(100)
vals.each { |v|
assert( values.include? v )
}
end
def test_ordinal_compare
values = ["foo", 2, 3.0]
h = CCS::OrdinalHyperparameter::new(values: values)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment