Commit 6bbb1214 authored by Brice Videau's avatar Brice Videau

Added hyperparameter binding.

parent fe06805f
......@@ -72,6 +72,19 @@ module CCS
typedef :pointer, :ccs_evaluation_t
typedef :pointer, :ccs_tuner_t
typedef :pointer, :ccs_object_t
class MemoryPointer
alias read_ccs_rng_t read_pointer
alias read_ccs_distribution_t read_pointer
alias read_ccs_hyperparameter_t read_pointer
alias read_ccs_expression_t read_pointer
alias read_ccs_context_t read_pointer
alias read_ccs_configuration_space_t read_pointer
alias read_ccs_configuration_t read_pointer
alias read_ccs_objective_space_t read_pointer
alias read_ccs_evaluation_t read_pointer
alias read_ccs_tuner_t read_pointer
alias read_ccs_object_t read_pointer
end
Error = enum FFI::Type::INT32, :ccs_error_t, [
:CCS_SUCCESS,
......@@ -130,6 +143,42 @@ module CCS
class Numeric < FFI::Union
layout :f, :ccs_float_t,
:i, :ccs_int_t
def value(type)
case type
when :CCS_NUM_FLOAT
self[:f]
when :CCS_NUM_INTEGER
self[:i]
else
raise StandardError, :CCS_INVALID_TYPE
end
end
def self.from_value(v)
case v
when Float
n = self::new
n[:f] = v
n
when Integer
n = self::new
n[:i] = v
n
else
raise StandardError, :CCS_INVALID_TYPE
end
end
def value=(v)
case v
when Float
self[:f] = v
when Integer
self[:i] = v
else
raise StandardError, :CCS_INVALID_TYPE
end
end
end
typedef Numeric.by_value, :ccs_numeric_t
......@@ -165,7 +214,7 @@ module CCS
INACTIVE[:value][:i] = 0
def value
case self[:type]
when :NONE
when :CCS_NONE
nil
when :CCS_INTEGER
self[:value][:i]
......@@ -179,7 +228,54 @@ module CCS
Inactive
when :CCS_OBJECT
Object::from_handle(self[:value][:o])
else
raise StandardError, :CCS_INVALID_TYPE
end
end
def value=(v, string_store: nil, object_store: nil)
@string = nil if defined?(@string) && @string
@object = nil if defined?(@object) && @object
case v
when nil
self[:type] = :CCS_NONE
self[:value][:i] = 0
when true
self[:type] = :CCS_BOOLEAN
self[:value][:i] = 1
when false
self[:type] = :CCS_BOOLEAN
self[:value][:i] = 0
when Inactive
self[:type] = :CCS_INACTIVE
self[:value][:i] = 0
when Float
self[:type] = :CCS_FLOAT
self[:value][:f] = v
when Integer
self[:type] = :CCS_INTEGER
self[:value][:i] = v
when String
ptr = MemoryPointer::from_string(v)
if string_store
string_store.push ptr
else
@string = ptr
end
self[:type] = :CCS_STRING
self[:value][:s] = ptr
when Object
if object_store
object_store.push v
else
@object = v
end
self[:type] = :CCS_OBJECT
self[:value][:o] = v.handle
else
raise StandardError, :CCS_INVALID_TYPE
end
v
end
def self.from_value(v)
......@@ -206,8 +302,8 @@ module CCS
d = self::new
ptr = MemoryPointer::from_string(v)
d.instance_variable_set(:@string, ptr)
d[:type] = :STRING
d[:valus][:s] = ptr
d[:type] = :CCS_STRING
d[:value][:s] = ptr
d
when Object
d = self::new
......@@ -216,7 +312,7 @@ module CCS
d.instance_variable_set(:@object, v)
d
else
raise StandardError, :CCS_INVALID_VALUE
raise StandardError, :CCS_INVALID_TYPE
end
end
end
......
module CCS
@hyperparameter_counter = 0
def self.get_id
id = @hyperparameter_counter
@hyperparameter_counter += 1
id
end
HyperparameterType = enum FFI::Type::INT32, :ccs_hyperparameter_type_t, [
:CCS_NUMERICAL,
:CCS_CATEGORICAL,
......@@ -18,9 +25,11 @@ module CCS
attach_function :ccs_hyperparameter_get_default_distribution, [:ccs_hyperparameter_t, :pointer], :ccs_result_t
attach_function :ccs_hyperparameter_check_value, [:ccs_hyperparameter_t, :ccs_datum_t, :pointer], :ccs_result_t
attach_function :ccs_hyperparameter_check_values, [:ccs_hyperparameter_t, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_hyperparameter_sample, [:ccs_hyperparameter_t, :ccs_distribution_t, :ccs_rng_t, :pointer], :ccs_result_t
attach_function :ccs_hyperparameter_samples, [:ccs_hyperparameter_t, :ccs_distribution_t, :ccs_rng_t, :size_t, :pointer], :ccs_result_t
class Hyperparameter < Object
add_property :type, :ccs_distribution_type_t, :ccs_hyperparameter_get_type, memoize:true
add_property :type, :ccs_hyperparameter_type_t, :ccs_hyperparameter_get_type, memoize:true
add_property :user_data, :pointer, :ccs_hyperparameter_get_user_data, memoize: true
def initialize(handle, retain: false)
if !handle
......@@ -29,6 +38,10 @@ module CCS
super
end
def self.default_name
"param#{"%03d"%CCS.get_id}"
end
def self.from_handle(handle)
ptr = MemoryPointer::new(:ccs_hyperparameter_type_t)
res = CCS.ccs_hyperparameter_get_type(handle, ptr)
......@@ -57,7 +70,7 @@ module CCS
def default_value
@default_value ||= begin
ptr = MemoryPointer::new(:ccs_datum_t)
res = CCS.ccs_hyperparameter_get_default_value(handle, ptr)
res = CCS.ccs_hyperparameter_get_default_value(@handle, ptr)
CCS.error_check(res)
d = Datum::new(ptr)
d.value
......@@ -67,10 +80,203 @@ module CCS
def default_distribution
@default_distribution ||= begin
ptr = MemoryPointer::new(:ccs_distribution_t)
res = CCS.ccs_hyperparameter_get_default_distribution(handle, ptr)
res = CCS.ccs_hyperparameter_get_default_distribution(@handle, ptr)
CCS.error_check(res)
Object::from_handle(ptr.read_pointer)
end
end
def check_value(v)
ptr = MemoryPointer::new(:ccs_bool_t)
res = CCS.ccs_hyperparameter_check_value(@handle, Datum::from_value(v), ptr)
CCS.error_check(res)
ptr.read_ccs_bool_t == CCS::FALSE ? false : true
end
def check_values(vals)
count = vals.size
return [] if count == 0
values = MemoryPointer::new(:ccs_datum_t, count)
vals.each_with_index{ |v, i| Datum::new(values[i]).value = v }
ptr = MemoryPointer::new(:ccs_bool_t, count)
res = CCS.ccs_hyperparameter_check_values(@handle, count, values, ptr)
CCS.error_check(res)
count.times.collect { |i| ptr[i].read_ccs_bool_t == CCS::FALSE ? false : true }
end
def sample(distribution: default_distribution, rng: CCS::DefaultRng)
value = MemoryPointer::new(:ccs_datum_t)
res = CCS.ccs_hyperparameter_sample(@handle, distribution, rng, value)
CCS.error_check(res)
Datum::new(value).value
end
def samples(count, distribution: default_distribution, rng: CCS::DefaultRng)
return [] if count <= 0
values = MemoryPointer::new(:ccs_datum_t, count)
res = CCS.ccs_hyperparameter_samples(@handle, distribution, rng, count, values)
CCS.error_check(res)
count.times.collect { |i| Datum::new(values[i]).value }
end
end
attach_function :ccs_create_numerical_hyperparameter, [:string, :ccs_numeric_type_t, :ccs_numeric_t, :ccs_numeric_t, :ccs_numeric_t, :ccs_numeric_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_numerical_hyperparameter_get_parameters, [:ccs_hyperparameter_t, :pointer, :pointer, :pointer, :pointer], :ccs_result_t
class NumericalHyperparameter < Hyperparameter
def initialize(handle = nil, retain: false, name: Hyperparameter.default_name, data_type: :CCS_NUM_FLOAT, lower: 0.0, upper: 1.0, quantization: 0.0, default: lower, user_data: nil)
if (handle)
super(handle, retain: retain)
else
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
case data_type
when :CCS_NUM_FLOAT
lower = Numeric::from_value(lower.to_f)
upper = Numeric::from_value(upper.to_f)
quantization = Numeric::from_value(quantization.to_f)
default = Numeric::from_value(default.to_f)
when :CCS_NUM_INTEGER
lower = Numeric::from_value(lower.to_i)
upper = Numeric::from_value(upper.to_i)
quantization = Numeric::from_value(quantization.to_i)
default = Numeric::from_value(default.to_i)
else
raise StandardError, :CCS_INVALID_TYPE
end
res = CCS.ccs_create_numerical_hyperparameter(name, data_type, lower, upper, quantization, default, user_data, ptr)
CCS.error_check(res)
super(ptr.read_pointer, retain: false)
end
end
def self.int(name: default_name, lower:, upper:, quantization: 0, default: lower, user_data: nil)
self.new(nil, name: name, data_type: :CCS_NUM_INTEGER, lower: lower, upper: upper, quantization: quantization, default: default, user_data: user_data)
end
def self.float(name: default_name, lower:, upper:, quantization: 0.0, default: lower, user_data: nil)
self.new(nil, name: name, data_type: :CCS_NUM_FLOAT, lower: lower, upper: upper, quantization: quantization, default: default, user_data: user_data)
end
def data_type
@data_type ||= begin
ptr = MemoryPointer::new(:ccs_numeric_type_t)
res = CCS.ccs_numerical_hyperparameter_get_parameters(@handle, ptr, nil, nil, nil)
CCS.error_check(res)
ptr.read_ccs_numeric_type_t
end
end
def lower
@lower ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_numerical_hyperparameter_get_parameters(@handle, nil, ptr, nil, nil)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
else
ptr.read_ccs_int_t
end
end
end
def upper
@upper ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_numerical_hyperparameter_get_parameters(@handle, nil, nil, ptr, nil)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
else
ptr.read_ccs_int_t
end
end
end
def quantization
@quantization ||= begin
ptr = MemoryPointer::new(:ccs_numeric_t)
res = CCS.ccs_numerical_hyperparameter_get_parameters(@handle, nil, nil, nil, ptr)
CCS.error_check(res)
if data_type == :CCS_NUM_FLOAT
ptr.read_ccs_float_t
else
ptr.read_ccs_int_t
end
end
end
end
attach_function :ccs_create_categorical_hyperparameter, [:string, :size_t, :pointer, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_categorical_hyperparameter_get_values, [:ccs_hyperparameter_t, :size_t, :pointer, :pointer], :ccs_result_t
class CategoricalHyperparameter < Hyperparameter
def initialize(handle = nil, retain: false, name: Hyperparameter.default_name, values: [], default_index: 0, user_data: nil)
if handle
super(handle, retain: retain)
else
count = values.size
return [] if count == 0
vals = MemoryPointer::new(:ccs_datum_t, count)
values.each_with_index{ |v, i| Datum::new(vals[i]).value = v }
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_create_categorical_hyperparameter(name, count, vals, default_index, user_data, ptr)
CCS.error_check(res)
super(ptr.read_ccs_hyperparameter_t, retain: false)
end
end
def values
@values ||= begin
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_categorical_hyperparameter_get_values(@handle, 0, nil, ptr)
CCS.error_check(res)
count = ptr.read_size_t
ptr = MemoryPointer::new(:ccs_datum_t, count)
res = CCS.ccs_categorical_hyperparameter_get_values(@handle, count, ptr, nil)
CCS.error_check(res)
count.times.collect { |i| Datum::new(ptr[i]).value }
end
end
end
attach_function :ccs_create_ordinal_hyperparameter, [:string, :size_t, :pointer, :size_t, :pointer, :pointer], :ccs_result_t
attach_function :ccs_ordinal_hyperparameter_compare_values, [:ccs_hyperparameter_t, :ccs_datum_t, :ccs_datum_t, :pointer], :ccs_result_t
attach_function :ccs_ordinal_hyperparameter_get_values, [:ccs_hyperparameter_t, :size_t, :pointer, :pointer], :ccs_result_t
class OrdinalHyperparameter < Hyperparameter
def initialize(handle = nil, retain: false, name: Hyperparameter.default_name, values: [], default_index: 0, user_data: nil)
if handle
super(handle, retain: retain)
else
count = values.size
return [] if count == 0
vals = MemoryPointer::new(:ccs_datum_t, count)
values.each_with_index{ |v, i| Datum::new(vals[i]).value = v }
ptr = MemoryPointer::new(:ccs_hyperparameter_t)
res = CCS.ccs_create_ordinal_hyperparameter(name, count, vals, default_index, user_data, ptr)
CCS.error_check(res)
super(ptr.read_ccs_hyperparameter_t, retain: false)
end
end
def compare(value1, value2)
v1 = Datum::from_value(value1)
v2 = Datum::from_value(value2)
ptr = MemoryPointer::new(:ccs_int_t)
res = CCS.ccs_ordinal_hyperparameter_compare_values(@handle, v1, v2, ptr)
CCS.error_check(res)
ptr.read_ccs_int_t
end
def values
@values ||= begin
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_ordinal_hyperparameter_get_values(@handle, 0, nil, ptr)
CCS.error_check(res)
count = ptr.read_size_t
ptr = MemoryPointer::new(:ccs_datum_t, count)
res = CCS.ccs_ordinal_hyperparameter_get_values(@handle, count, ptr, nil)
CCS.error_check(res)
count.times.collect { |i| Datum::new(ptr[i]).value }
end
end
end
end
......@@ -45,4 +45,6 @@ module CCS
ptr.read_ccs_float_t
end
end
DefaultRng = Rng::new
end
......@@ -60,4 +60,43 @@ class CConfigSpaceTest < Minitest::Test
assert_equal( :CCS_OBJECT, d[:type] )
assert_equal( rng.handle, d[:value][:o] )
end
def test_value_affect
d = CCS::Datum::from_value(nil)
assert_equal( :CCS_NONE, d[:type] )
assert_equal( 0, d[:value][:i] )
d.value = CCS::Inactive
assert_equal( :CCS_INACTIVE, d[:type] )
assert_equal( 0, d[:value][:i] )
d.value = false
assert_equal( :CCS_BOOLEAN, d[:type] )
assert_equal( CCS::FALSE, d[:value][:i] )
d.value = true
assert_equal( :CCS_BOOLEAN, d[:type] )
assert_equal( CCS::TRUE, d[:value][:i] )
d.value = 15
assert_equal( :CCS_INTEGER, d[:type] )
assert_equal( 15, d[:value][:i] )
d.value = 15.0
assert_equal( :CCS_FLOAT, d[:type] )
assert_equal( 15.0, d[:value][:f] )
rng = CCS::Rng::new
d.value = rng
assert_equal( :CCS_OBJECT, d[:type] )
assert_equal( rng.handle, d[:value][:o] )
d.value = nil
assert_equal( :CCS_NONE, d[:type] )
assert_equal( 0, d[:value][:i] )
end
def test_numeric
n = CCS::Numeric::from_value(1)
assert_equal( 1, n.value(:CCS_NUM_INTEGER) )
n.value = 2
assert_equal( 2, n.value(:CCS_NUM_INTEGER) )
n.value = 1.0
assert_equal( 1.0, n.value(:CCS_NUM_FLOAT) )
n = CCS::Numeric::from_value(2.0)
assert_equal( 2.0, n.value(:CCS_NUM_FLOAT) )
end
end
[ '../lib', 'lib' ].each { |d| $:.unshift(d) if File::directory?(d) }
require 'minitest/autorun'
require 'cconfigspace'
class CConfigSpaceTestHyperparameter < Minitest::Test
def setup
CCS.init
end
def test_ordinal_compare
values = ["foo", 2, 3.0]
h = CCS::OrdinalHyperparameter::new(values: values)
assert_equal( 0, h.compare("foo", "foo") )
assert_equal( -1, h.compare("foo", 2) )
assert_equal( -1, h.compare("foo", 3.0) )
assert_equal( 1, h.compare(2, "foo") )
assert_equal( 0, h.compare(2, 2) )
assert_equal( -1, h.compare(2, 3.0) )
assert_equal( 1, h.compare(3.0, "foo") )
assert_equal( 1, h.compare(3.0, 2) )
assert_equal( 0, h.compare(3.0, 3.0) )
assert_raises(StandardError, :CCS_INVALID_VALUE) { h.compare(4.0, "foo") }
end
def test_from_handle_ordinal
values = ["foo", 2, 3.0]
h = CCS::OrdinalHyperparameter::new(values: values)
h2 = CCS::Object::from_handle(h)
assert_equal( h.class, h2.class )
end
def test_ordinal
values = ["foo", 2, 3.0]
h = CCS::OrdinalHyperparameter::new(values: values)
assert_equal( :CCS_HYPERPARAMETER, h.object_type )
assert_equal( :CCS_ORDINAL, h.type )
assert_match( /param/, h.name )
assert( h.user_data.null? )
assert_equal( "foo", h.default_value )
assert_equal( :CCS_UNIFORM, h.default_distribution.type )
assert_equal( values, h.values )
assert( h.check_value("foo") )
assert( h.check_value(2) )
assert( h.check_value(3.0) )
refute( h.check_value(1.5) )
v = h.sample
assert( values.include? v )
vals = h.samples(100)
vals.each { |v|
assert( values.include? v )
}
end
def test_from_handle_categorical
values = ["foo", 2, 3.0]
h = CCS::CategoricalHyperparameter::new(values: values)
h2 = CCS::Object::from_handle(h)
assert_equal( h.class, h2.class )
end
def test_categorical
values = ["foo", 2, 3.0]
h = CCS::CategoricalHyperparameter::new(values: values)
assert_equal( :CCS_HYPERPARAMETER, h.object_type )
assert_equal( :CCS_CATEGORICAL, h.type )
assert_match( /param/, h.name )
assert( h.user_data.null? )
assert_equal( "foo", h.default_value )
assert_equal( :CCS_UNIFORM, h.default_distribution.type )
assert_equal( values, h.values )
assert( h.check_value("foo") )
assert( h.check_value(2) )
assert( h.check_value(3.0) )
refute( h.check_value(1.5) )
v = h.sample
assert( values.include? v )
vals = h.samples(100)
vals.each { |v|
assert( values.include? v )
}
end
def test_from_handle_numerical
h = CCS::NumericalHyperparameter::new
h2 = CCS::Object::from_handle(h)
assert_equal( h.class, h2.class )
end
def test_create_numerical
h = CCS::NumericalHyperparameter::new
assert_equal( :CCS_HYPERPARAMETER, h.object_type )
assert_equal( :CCS_NUMERICAL, h.type )
assert_match( /param/, h.name )
assert( h.user_data.null? )
assert_equal( 0.0, h.default_value )
assert_equal( :CCS_UNIFORM, h.default_distribution.type )
assert( h.check_value(0.5) )
refute( h.check_value(1.5) )
v = h.sample
assert( v.kind_of?(Float) )
assert( v >= 0.0 && v < 1.0 )
vals = h.samples(100)
vals.each { |v|
assert( v.kind_of?(Float) )
assert( v >= 0.0 && v < 1.0 )
}
end
def test_create_numerical_float
h = CCS::NumericalHyperparameter::float(lower: 0.0, upper: 1.0)
assert_equal( :CCS_HYPERPARAMETER, h.object_type )
assert_equal( :CCS_NUMERICAL, h.type )
assert_match( /param/, h.name )
assert( h.user_data.null? )
assert_equal( 0.0, h.default_value )
assert_equal( :CCS_UNIFORM, h.default_distribution.type )
assert_equal( :CCS_NUM_FLOAT, h.data_type )
assert_equal( 0.0, h.lower )
assert_equal( 1.0, h.upper )
assert_equal( 0.0, h.quantization )
assert( h.check_value(0.5) )
refute( h.check_value(1.5) )
v = h.sample
assert( v.kind_of?(Float) )
assert( v >= 0.0 && v < 1.0 )
vals = h.samples(100)
vals.each { |v|
assert( v.kind_of?(Float) )
assert( v >= 0.0 && v < 1.0 )
}
end
def test_create_numerical_int
h = CCS::NumericalHyperparameter::int(lower: 0, upper: 100)
assert_equal( :CCS_NUMERICAL, h.type )
assert_match( /param/, h.name )
assert( h.user_data.null? )
assert_equal( 0, h.default_value )
assert_equal( :CCS_UNIFORM, h.default_distribution.type )
assert_equal( :CCS_NUM_INTEGER, h.data_type )
assert_equal( 0, h.lower )
assert_equal( 100, h.upper )
assert_equal( 0, h.quantization )
assert( h.check_value(50) )
refute( h.check_value(150) )
v = h.sample
assert( v.kind_of?(Integer) )
assert( v >= 0 && v < 100 )
vals = h.samples(100)
vals.each { |v|
assert( v.kind_of?(Integer) )
assert( v >= 0 && v < 100 )
}
end
end
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment