Commit 9d25e57d authored by Brice Videau's avatar Brice Videau

Minor improvements to ruby bindings.

parent 839e23cf
......@@ -23,8 +23,8 @@ module CCS
attach_function :ccs_configuration_space_check_configuration, [:ccs_configuration_space_t, :ccs_configuration_t], :ccs_result_t
attach_function :ccs_configuration_space_check_configuration_values, [:ccs_configuration_space_t, :size_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_get_default_configuration, [:ccs_configuration_space_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_sample, [:ccs_configuration_space_t, :ccs_configuration_t], :ccs_result_t
attach_function :ccs_configuration_space_samples, [:ccs_configuration_space_t, :size_t, :ccs_configuration_t], :ccs_result_t
attach_function :ccs_configuration_space_sample, [:ccs_configuration_space_t, :pointer], :ccs_result_t
attach_function :ccs_configuration_space_samples, [:ccs_configuration_space_t, :size_t, :pointer], :ccs_result_t
class ConfigurationSpace < Context
add_property :user_data, :pointer, :ccs_configuration_space_get_user_data, memoize: true
......@@ -80,11 +80,12 @@ module CCS
raise CCSError, :CCS_INVALID_VALUE if count != distributions.size
p_dists = MemoryPointer::new(:ccs_distribution_t, count)
p_dists.write_array_of_pointer(distributions.collect(&:handle))
distributions = p_dist
else
p_dists = nil
end
p_hypers = MemoryPointer::new(:ccs_hyperparameter_t, count)
p_hypers.write_array_of_pointer(hyperparameters.collect(&:handle))
res = CCS.ccs_configuration_space_add_hyperparameters(@handle, count, p_hypers, distributions)
res = CCS.ccs_configuration_space_add_hyperparameters(@handle, count, p_hypers, p_dists)
CCS.error_check(res)
self
end
......@@ -199,11 +200,15 @@ module CCS
Expression.from_handle(ptr.read_ccs_expression_t)
end
def forbidden_clauses
def num_forbidden_clauses
ptr = MemoryPointer::new(:size_t)
res = CCS.ccs_configuration_space_get_forbidden_clauses(@handle, 0, nil, ptr)
CCS.error_check(res)
count = ptr.read_size_t
ptr.read_size_t
end
def forbidden_clauses
count = num_forbidden_clauses
ptr = MemoryPointer::new(:ccs_expression_t, count)
res = CCS.ccs_configuration_space_get_forbidden_clauses(@handle, count, ptr, nil)
CCS.error_check(res)
......
......@@ -94,6 +94,7 @@ module CCS
end
def samples(rng, count)
return [] if count == 0
ptr = MemoryPointer::new(:ccs_numeric_t, count)
res = CCS.ccs_distribution_samples(@handle, rng, count, ptr)
CCS.error_check(res)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment