Commit 98fff1b4 authored by Brice Videau's avatar Brice Videau

Allow symbols to be used for hyperparameter names in Ruby.

parent 0beb7360
...@@ -53,7 +53,7 @@ module CCS ...@@ -53,7 +53,7 @@ module CCS
Literal::new(value: Float(num)) } Literal::new(value: Float(num)) }
rule(:integer => Regexp.new(TerminalRegexp[:CCS_TERM_INTEGER])).as { |num| rule(:integer => Regexp.new(TerminalRegexp[:CCS_TERM_INTEGER])).as { |num|
Literal::new(value: Integer(num)) } Literal::new(value: Integer(num)) }
rule(:identifier => Regexp.new(TerminalRegexp[:CCS_TERM_IDENTIFIER])).as { |identifier| rule(:identifier => /[:a-zA-Z_][a-zA-Z_0-9]*/).as { |identifier|
Variable::new(hyperparameter: context.hyperparameter_by_name(identifier)) } Variable::new(hyperparameter: context.hyperparameter_by_name(identifier)) }
rule(:string => Regexp.new(TerminalRegexp[:CCS_TERM_STRING])).as { |str| rule(:string => Regexp.new(TerminalRegexp[:CCS_TERM_STRING])).as { |str|
Literal::new(value: eval(str)) } Literal::new(value: eval(str)) }
......
...@@ -60,7 +60,9 @@ module CCS ...@@ -60,7 +60,9 @@ module CCS
ptr = MemoryPointer::new(:pointer) ptr = MemoryPointer::new(:pointer)
res = CCS.ccs_hyperparameter_get_name(@handle, ptr) res = CCS.ccs_hyperparameter_get_name(@handle, ptr)
CCS.error_check(res) CCS.error_check(res)
ptr.read_pointer.read_string r = ptr.read_pointer.read_string
r = r.sub(/^:/, "").to_sym if r.match(/^:/)
r
end end
end end
...@@ -143,6 +145,7 @@ module CCS ...@@ -143,6 +145,7 @@ module CCS
else else
raise CCSError, :CCS_INVALID_TYPE raise CCSError, :CCS_INVALID_TYPE
end end
name = name.inspect if name.kind_of?(Symbol)
res = CCS.ccs_create_numerical_hyperparameter(name, data_type, lower, upper, quantization, default, user_data, ptr) res = CCS.ccs_create_numerical_hyperparameter(name, data_type, lower, upper, quantization, default, user_data, ptr)
CCS.error_check(res) CCS.error_check(res)
super(ptr.read_pointer, retain: false) super(ptr.read_pointer, retain: false)
...@@ -218,6 +221,7 @@ module CCS ...@@ -218,6 +221,7 @@ module CCS
vals = MemoryPointer::new(:ccs_datum_t, count) vals = MemoryPointer::new(:ccs_datum_t, count)
values.each_with_index{ |v, i| Datum::new(vals[i]).value = v } values.each_with_index{ |v, i| Datum::new(vals[i]).value = v }
ptr = MemoryPointer::new(:ccs_hyperparameter_t) ptr = MemoryPointer::new(:ccs_hyperparameter_t)
name = name.inspect if name.kind_of?(Symbol)
res = CCS.ccs_create_categorical_hyperparameter(name, count, vals, default_index, user_data, ptr) res = CCS.ccs_create_categorical_hyperparameter(name, count, vals, default_index, user_data, ptr)
CCS.error_check(res) CCS.error_check(res)
super(ptr.read_ccs_hyperparameter_t, retain: false) super(ptr.read_ccs_hyperparameter_t, retain: false)
...@@ -251,6 +255,7 @@ module CCS ...@@ -251,6 +255,7 @@ module CCS
vals = MemoryPointer::new(:ccs_datum_t, count) vals = MemoryPointer::new(:ccs_datum_t, count)
values.each_with_index{ |v, i| Datum::new(vals[i]).value = v } values.each_with_index{ |v, i| Datum::new(vals[i]).value = v }
ptr = MemoryPointer::new(:ccs_hyperparameter_t) ptr = MemoryPointer::new(:ccs_hyperparameter_t)
name = name.inspect if name.kind_of?(Symbol)
res = CCS.ccs_create_ordinal_hyperparameter(name, count, vals, default_index, user_data, ptr) res = CCS.ccs_create_ordinal_hyperparameter(name, count, vals, default_index, user_data, ptr)
CCS.error_check(res) CCS.error_check(res)
super(ptr.read_ccs_hyperparameter_t, retain: false) super(ptr.read_ccs_hyperparameter_t, retain: false)
...@@ -292,6 +297,7 @@ module CCS ...@@ -292,6 +297,7 @@ module CCS
vals = MemoryPointer::new(:ccs_datum_t, count) vals = MemoryPointer::new(:ccs_datum_t, count)
values.each_with_index{ |v, i| Datum::new(vals[i]).value = v } values.each_with_index{ |v, i| Datum::new(vals[i]).value = v }
ptr = MemoryPointer::new(:ccs_hyperparameter_t) ptr = MemoryPointer::new(:ccs_hyperparameter_t)
name = name.inspect if name.kind_of?(Symbol)
res = CCS.ccs_create_discrete_hyperparameter(name, count, vals, default_index, user_data, ptr) res = CCS.ccs_create_discrete_hyperparameter(name, count, vals, default_index, user_data, ptr)
CCS.error_check(res) CCS.error_check(res)
super(ptr.read_ccs_hyperparameter_t, retain: false) super(ptr.read_ccs_hyperparameter_t, retain: false)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment