Skip to content

Commit

Permalink
src: Fix options handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tmmsartor committed Jul 13, 2024
1 parent f7b52b1 commit 6af570f
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 48 deletions.
26 changes: 12 additions & 14 deletions compiler/generate_precompile.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using MadNLP_C
using Base
using Base: unsafe_convert
using Logging

logger = ConsoleLogger(stderr, Logging.Warn)
Expand Down Expand Up @@ -55,7 +55,7 @@ function eval_grad_f(Cw::Ptr{Cdouble},Cgrad::Ptr{Cdouble}, d::Ptr{Cvoid})::Cint
grad::Vector{Float64} = unsafe_wrap(Array, Cgrad, nnzo)
_eval_grad_f!(w,grad)
@debug "grad-callback" grad
# Cgrad::Ptr{Cdouble} = Base.unsafe_convert(Ptr{Cdouble}, grad)
# Cgrad::Ptr{Cdouble} = unsafe_convert(Ptr{Cdouble}, grad)
return 0
end

Expand Down Expand Up @@ -140,10 +140,8 @@ lin_solver_names = Dict(
4=>"LapackGPUSolver",
5=>"CuCholeskySolver",
)
cases::Vector{Pair{UInt64,Csize_t}} = [0=>3]
# cases::Vector{Pair{UInt64,Csize_t}} = [0=>3,1=>3,5=>3,3=>0]
for (lin_solver_id,print_level) in cases

cases::Vector{Tuple{Int,Int,Int}} = [(0,3,3),(2,2,1000),(1,1,1000),(0,0,1000)]
for (lin_solver_id,print_level, max_iters) in cases
nlp_interface = MadnlpCInterface(
@cfunction(eval_f,Cint,(Ptr{Cdouble},Ptr{Cdouble},Ptr{Cvoid})),
@cfunction(eval_g,Cint,(Ptr{Cdouble},Ptr{Cdouble},Ptr{Cvoid})),
Expand All @@ -162,7 +160,7 @@ for (lin_solver_id,print_level) in cases
user_data
)

s = madnlp_c_create(Base.unsafe_convert(Ptr{MadnlpCInterface}, pointer_from_objref(nlp_interface)))
s = madnlp_c_create(unsafe_convert(Ptr{MadnlpCInterface}, pointer_from_objref(nlp_interface)))

inp = MadnlpCNumericIn(Cx0,Cy0,Clbx,Cubx,Clbg,Cubg)
# inp = MadnlpCNumericIn{Ptr{typeof(Cx0)}}()
Expand All @@ -179,16 +177,16 @@ for (lin_solver_id,print_level) in cases
s_jl::MadnlpCSolver = unsafe_load(s)
s_jl.in_c = inp
s_jl.out_c = out
s = Base.unsafe_convert(Ptr{MadnlpCSolver}, pointer_from_objref(s_jl))
# s = unsafe_convert(Ptr{MadnlpCSolver}, pointer_from_objref(s_jl))
unsafe_store!(s, s_jl)

test_x0 = unsafe_wrap(Array, inp.x0, (nvar,))
@info test_x0

#madnlp_c_set_option_int(s, "lin_solver_id", lin_solver_id)
#madnlp_c_set_option_int(s, "max_iters", max_iters)
#madnlp_c_set_option_int(s, "print_level", print_level)
#madnlp_c_set_option_bool(s, "minimize", minimize)
#

madnlp_c_set_option_int(s, unsafe_convert(Ptr{Int8},"lin_solver_id"), lin_solver_id)
madnlp_c_set_option_int(s, unsafe_convert(Ptr{Int8},"max_iters"), max_iters)
madnlp_c_set_option_int(s, unsafe_convert(Ptr{Int8},"print_level"), print_level)
madnlp_c_set_option_bool(s, unsafe_convert(Ptr{Int8},"minimize"), minimize)

Cret = madnlp_c_solve(s)
# Base.unsafe_convert(Ptr{MadnlpCNumericIn}, pointer_from_objref(inp)),
Expand Down
51 changes: 23 additions & 28 deletions src/MadNLP_C.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,17 +260,24 @@ end


function set_option(s::Ptr{MadnlpCSolver}, name::String, value::Any)
s_jl::MadnlpCSolver = unsafe_load(s)
if name == "print_level"
s.print_level = Int(value)
if s.print_level > 5 s.print_level = 5 end
if s.print_level < 0 s.print_level = 0 end
if value > 5 value = 5 end
if value < 0 value = 0 end
s_jl.print_level = Int(value)
elseif name == "lin_solver_id"
s.lin_solver_id = Int(value)
if s.lin_solver_id > 5 s.lin_solver_id = 5 end
if s.lin_solver_id < 0 s.lin_solver_id = 0 end
if value > 5 value = 5 end
if value < 0 value = 0 end
s_jl.lin_solver_id = Int(value)
elseif name == "max_iters"
if value < 0 value = 0 end
s_jl.max_iters = Int(value)
elseif name == "minimize"
s_jl.minimize = Bool(value)
else
@warn "Unknown option $name"
end
unsafe_store!(s, s_jl)
end

Base.@ccallable function madnlp_c_startup(argc::Cint, argv::Ptr{Ptr{Cchar}})::Cvoid
Expand Down Expand Up @@ -338,24 +345,11 @@ end

Base.@ccallable function madnlp_c_option_type(name::Ptr{Cchar})::Cint
n = unsafe_string(name)
if n == "acceptable_tol" return 0 end
if n == "bound_frac" return 0 end
if n == "bound_push" return 0 end
if n == "bound_relax_factor" return 0 end
if n == "constr_viol_tol" return 0 end
if n == "lammax" return 0 end
if n == "mu_init" return 0 end
if n == "recalc_y_feas_tol" return 0 end
if n == "tol" return 0 end
if n == "warm_start_mult_bound_push" return 0 end
if n == "acceptable_iter" return 1 end
if n == "max_iter" return 1 end
if n == "print_level" return 1 end
if n == "lin_solver_id" return 1 end
if n == "iterative_refinement" return 2 end
if n == "ls_scaling" return 2 end
if n == "recalc_y" return 2 end
if n == "warm_start_init_point" return 2 end
if n == "minimize" return 2 end
return -1
end

Expand All @@ -368,7 +362,7 @@ Base.@ccallable function madnlp_c_set_option_double(s::Ptr{MadnlpCSolver}, name:
return 0
end

Base.@ccallable function madnlp_c_set_option_bool(s::Ptr{MadnlpCSolver}, name::Ptr{Cchar}, val::Cint)::Cint
Base.@ccallable function madnlp_c_set_option_bool(s::Ptr{MadnlpCSolver}, name::Ptr{Cchar}, val::Bool)::Cint
try
set_option(s, unsafe_string(name), Bool(val))
catch e
Expand All @@ -377,7 +371,7 @@ Base.@ccallable function madnlp_c_set_option_bool(s::Ptr{MadnlpCSolver}, name::P
return 0
end

Base.@ccallable function madnlp_c_set_option_int(s::Ptr{MadnlpCSolver}, name::Ptr{Cchar}, val::Cint)::Cint
Base.@ccallable function madnlp_c_set_option_int(s::Ptr{MadnlpCSolver}, name::Ptr{Cchar}, val::Clong)::Cint
try
set_option(s, unsafe_string(name), val)
catch e
Expand Down Expand Up @@ -420,22 +414,22 @@ Base.@ccallable function madnlp_c_solve(s::Ptr{MadnlpCSolver})::Cint
main_log_level = Logging.Warn
madnlp_log = MadNLP.NOTICE

if solver.print_level == 1
if solver.print_level == 0
main_log_level = Logging.Error
madnlp_log = MadNLP.ERROR
elseif solver.print_level == 2
elseif solver.print_level == 1
main_log_level = Logging.Warn
madnlp_log = MadNLP.WARN
elseif solver.print_level == 3
elseif solver.print_level == 2
main_log_level = Logging.Warn
madnlp_log = MadNLP.NOTICE
elseif solver.print_level == 4
elseif solver.print_level == 3
main_log_level = Logging.Info
madnlp_log = MadNLP.INFO
elseif solver.print_level == 5
elseif solver.print_level == 4
main_log_level = Logging.Debug
madnlp_log = MadNLP.DEBUG
elseif solver.print_level == 6
elseif solver.print_level == 5
main_log_level = Logging.Debug
madnlp_log = MadNLP.TRACE
end
Expand Down Expand Up @@ -578,6 +572,7 @@ Base.@ccallable function madnlp_c_solve(s::Ptr{MadnlpCSolver})::Cint
solver.out_c.mul_U = Base.unsafe_convert(Ptr{Cdouble},solver.res.multipliers_U)

return 0

end


Expand Down
4 changes: 2 additions & 2 deletions src/madnlp_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ MADNLP_SYMBOL_EXPORT madnlp_int madnlp_c_solve(struct MadnlpCSolver*);
/* -1 for not found, 0 for double, 1 for int, 2 for bool, 3 for string */
MADNLP_SYMBOL_EXPORT int madnlp_c_option_type(const char* name);
MADNLP_SYMBOL_EXPORT int madnlp_c_set_option_double(struct MadnlpCSolver* s, const char* name, double val);
MADNLP_SYMBOL_EXPORT int madnlp_c_set_option_bool(struct MadnlpCSolver* s, const char* name, int val);
MADNLP_SYMBOL_EXPORT int madnlp_c_set_option_int(struct MadnlpCSolver* s, const char* name, int val);
MADNLP_SYMBOL_EXPORT int madnlp_c_set_option_bool(struct MadnlpCSolver* s, const char* name, bool val);
MADNLP_SYMBOL_EXPORT int madnlp_c_set_option_int(struct MadnlpCSolver* s, const char* name, madnlp_int val);
MADNLP_SYMBOL_EXPORT int madnlp_c_set_option_string(struct MadnlpCSolver* s, const char* name, const char* val);

MADNLP_SYMBOL_EXPORT const struct MadnlpCStats* madnlp_c_get_stats(struct MadnlpCSolver* s);
Expand Down
13 changes: 9 additions & 4 deletions tests/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ int main(int argc, char** argv) {

struct MadnlpCSolver* solver = madnlp_c_create(&interf);

//madnlp_c_set_option_int(solver, "max_iter", 5);
madnlp_c_set_option_int(solver, "max_iter", 5);
madnlp_c_set_option_int(solver, "print_level", 2);
madnlp_c_set_option_int(solver, "lin_solver_id", 1);

const MadnlpCNumericIn* in = madnlp_c_input(solver);
std::copy(x0,x0+2,in->x0);
Expand All @@ -99,11 +101,14 @@ int main(int argc, char** argv) {

double sol[2];
double cons[1];
double obj[1];
std::copy(out->sol,out->sol+2,sol);
std::copy(out->sol,out->sol+1,cons);
std::copy(out->con,out->con+1,cons);
std::copy(out->obj,out->obj+1,obj);

std::cout << sol[0] << std::endl;
std::cout << sol[1] << std::endl;
std::cout << "sol: " << sol[0] << ", " << sol[1] << std::endl;
std::cout << "obj: " << obj[1] << std::endl;
std::cout << "con: " << con[1] << std::endl;

shutdown_julia(0);

Expand Down

0 comments on commit 6af570f

Please sign in to comment.