Skip to content

Commit

Permalink
Merge pull request #310 from MeasureTransport/dannys4/atmBinding
Browse files Browse the repository at this point in the history
Bindings for ATM
  • Loading branch information
dannys4 authored Apr 6, 2023
2 parents ec2ccfd + 05b5199 commit fdd4afc
Show file tree
Hide file tree
Showing 31 changed files with 560 additions and 125 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
build
*.egg-info*
.vscode
docs/_build
bin
Expand Down
20 changes: 10 additions & 10 deletions MParT/MapObjective.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ class MapObjective {
unsigned int Dim(){return train_.extent(0);}
unsigned int NumSamples(){return train_.extent(1);}

/**
* @brief Shortcut to calculate the error of the map on the training dataset
*
* @param map Map to calculate the error on
* @return double training error
*/
double TrainError(std::shared_ptr<ConditionalMapBase<MemorySpace>> map) const;

/**
* @brief Shortcut to calculate the error of the map on the testing dataset
*
Expand All @@ -82,14 +90,6 @@ class MapObjective {
*/
StridedVector<double, MemorySpace> TrainCoeffGrad(std::shared_ptr<ConditionalMapBase<MemorySpace>> map) const;

/**
* @brief Shortcut to calculate the error of the map on the training dataset
*
* @param map Map to calculate the error on
* @return double training error
*/
double TrainError(std::shared_ptr<ConditionalMapBase<MemorySpace>> map) const;

/**
* @brief Shortcut to calculate the gradient of the objective on the training dataset w.r.t. the map coefficients
*
Expand Down Expand Up @@ -187,10 +187,10 @@ class KLObjective: public MapObjective<MemorySpace> {

namespace ObjectiveFactory {
template<typename MemorySpace>
std::shared_ptr<MapObjective<MemorySpace>> CreateGaussianKLObjective(StridedMatrix<const double, MemorySpace> train);
std::shared_ptr<MapObjective<MemorySpace>> CreateGaussianKLObjective(StridedMatrix<const double, MemorySpace> train, unsigned int dim=0);

template<typename MemorySpace>
std::shared_ptr<MapObjective<MemorySpace>> CreateGaussianKLObjective(StridedMatrix<const double, MemorySpace> train, StridedMatrix<const double, MemorySpace> test);
std::shared_ptr<MapObjective<MemorySpace>> CreateGaussianKLObjective(StridedMatrix<const double, MemorySpace> train, StridedMatrix<const double, MemorySpace> test, unsigned int dim=0);
} // namespace ObjectiveFactory

} // namespace mpart
Expand Down
2 changes: 1 addition & 1 deletion MParT/MapOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ namespace mpart{
return ret;
}

std::string String() {
virtual std::string String() {
std::string btypes[3] = {"ProbabilistHermite", "PhysicistHermite", "HermiteFunctions"};
std::string pftypes[2] = {"Exp", "SoftPlus"};
std::string qtypes[3] = {"ClenshawCurtis", "AdaptiveSimpson", "AdaptiveClenshawCurtis"};
Expand Down
1 change: 1 addition & 0 deletions MParT/TrainMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct TrainOptions {
return ss.str();
}
};

/**
* @brief Function to train a map inplace given an objective and optimization options
*
Expand Down
16 changes: 12 additions & 4 deletions MParT/TrainMapAdaptive.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@

namespace mpart {

// Options specifically for ATM algorithm
// 0. indicates that MParT uses the optimization's default value
// Options specifically for ATM algorithm, with map eval opts -> training opts-> ATM specific opts
struct ATMOptions: public MapOptions, public TrainOptions {
int maxPatience = 10;
int maxSize = 10;
unsigned int maxPatience = 10;
unsigned int maxSize = 10;
MultiIndex maxDegrees;
std::string String() override {
std::string md_str = maxDegrees.String();
std::stringstream ss;
ss << MapOptions::String() << "\n" << TrainOptions::String() << "\n";
ss << "maxPatience = " << maxPatience << "\n";
ss << "maxSize = " << maxSize << "\n";
ss << "maxDegrees = " << maxDegrees.String();
return ss.str();
}
};

template<typename MemorySpace>
Expand Down
1 change: 1 addition & 0 deletions bindings/julia/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ if(MPART_OPT)
set(JULIA_BINDING_SOURCES ${JULIA_BINDING_SOURCES}
src/MapObjective.cpp
src/TrainMap.cpp
src/TrainMapAdaptive.cpp
)
endif()

Expand Down
10 changes: 10 additions & 0 deletions bindings/julia/include/CommonJuliaUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ void MapFactoryWrapper(jlcxx::Module&);
*/
void ComposedMapWrapper(jlcxx::Module &);

#if defined(MPART_HAS_NLOPT)

/**
* @brief Adds MapObjective bindings to the existing module m.
* @param mod CxxWrap.jl module
Expand All @@ -83,6 +85,14 @@ void MapObjectiveWrapper(jlcxx::Module &);
*/
void TrainMapWrapper(jlcxx::Module&);

/**
* @brief Adds TrainMapAdaptive and ATMOptions to the existing module m.
* @param mod CxxWrap.jl module
*/
void TrainMapAdaptiveWrapper(jlcxx::Module&);

#endif // defined(MPART_HAS_NLOPT)

#if defined(MPART_ENABLE_GPU)
void ConditionalMapBaseDeviceWrapper(jlcxx::Module&);
#endif
Expand Down
12 changes: 6 additions & 6 deletions bindings/julia/src/ComposedMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ using namespace mpart;

void mpart::binding::ComposedMapWrapper(jlcxx::Module &mod)
{
mod.add_type<ComposedMap<Kokkos::HostSpace>>("ComposedMap", jlcxx::julia_base_type<ConditionalMapBase<Kokkos::HostSpace>>())
.constructor<std::vector<std::shared_ptr<ConditionalMapBase<Kokkos::HostSpace>>> const&>()
// .method("EvaluateUntilK", [](ComposedMap<Kokkos::HostSpace> &map, int k, jlcxx::ArrayRef<double,2>& intPts, jlcxx::ArrayRef<double,2>& output){
// return KokkosToJulia(map.EvaluateUntilK(k, JuliaToKokkos(intPts), JuliaToKokkos(output)));
// })
;
mod.add_type<ComposedMap<Kokkos::HostSpace>>("ComposedMap", jlcxx::julia_base_type<ConditionalMapBase<Kokkos::HostSpace>>());
mod.method("ComposedMap", [](std::vector<std::shared_ptr<ConditionalMapBase<Kokkos::HostSpace>>> const& maps){
std::shared_ptr<ConditionalMapBase<Kokkos::HostSpace>> ret = std::make_shared<ComposedMap<Kokkos::HostSpace>>(maps);
return ret;
})
;
}
8 changes: 4 additions & 4 deletions bindings/julia/src/MapObjective.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ void mpart::binding::MapObjectiveWrapper(jlcxx::Module &mod) {
;

mod.add_type<KLObjective<MemorySpace>>(tName,jlcxx::julia_base_type<MapObjective<MemorySpace>>());
mod.method(mName, [](jlcxx::ArrayRef<double,2> train) {
mod.method(mName, [](jlcxx::ArrayRef<double,2> train, unsigned int dim) {
StridedMatrix<const double, MemorySpace> trainView = JuliaToKokkos(train);
Kokkos::View<double**,MemorySpace> storeTrain ("Training data", trainView.extent(0), trainView.extent(1));
Kokkos::deep_copy(storeTrain, trainView);
trainView = storeTrain;
return ObjectiveFactory::CreateGaussianKLObjective(trainView);
return ObjectiveFactory::CreateGaussianKLObjective(trainView, dim);
});
mod.method(mName, [](jlcxx::ArrayRef<double,2> train, jlcxx::ArrayRef<double,2> test) {
mod.method(mName, [](jlcxx::ArrayRef<double,2> train, jlcxx::ArrayRef<double,2> test, unsigned int dim) {
StridedMatrix<const double, MemorySpace> trainView = JuliaToKokkos(train);
StridedMatrix<const double, MemorySpace> testView = JuliaToKokkos(test);
Kokkos::View<double**,MemorySpace> storeTrain ("Training data", trainView.extent(0), trainView.extent(1));
Expand All @@ -41,7 +41,7 @@ void mpart::binding::MapObjectiveWrapper(jlcxx::Module &mod) {
Kokkos::deep_copy(storeTest, testView);
trainView = storeTrain;
testView = storeTest;
return ObjectiveFactory::CreateGaussianKLObjective(trainView, testView);
return ObjectiveFactory::CreateGaussianKLObjective(trainView, testView, dim);
});
}

40 changes: 40 additions & 0 deletions bindings/julia/src/TrainMapAdaptive.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "CommonJuliaUtilities.h"
#include "MParT/MapOptions.h"
#include "MParT/MapObjective.h"
#include "MParT/TrainMap.h"
#include "MParT/TrainMapAdaptive.h"

#include <Kokkos_Core.hpp>

using namespace mpart;

namespace jlcxx {
// Tell CxxWrap.jl the supertype structure for ConditionalMapBase
template<> struct SuperType<mpart::ATMOptions> {typedef mpart::MapOptions type;};
}

void mpart::binding::TrainMapAdaptiveWrapper(jlcxx::Module &mod) {
// Can only do single inheritence, so I arbitrarily picked inheritence from MapOptions
// If you need to convert to TrainOptions, I allow a conversion
mod.add_type<ATMOptions>("__ATMOptions", jlcxx::julia_base_type<MapOptions>())
.method("__opt_alg!", [](ATMOptions &opts, std::string alg){opts.opt_alg = alg;})
.method("__opt_ftol_rel!", [](ATMOptions &opts, double tol){opts.opt_ftol_rel = tol;})
.method("__opt_ftol_abs!", [](ATMOptions &opts, double tol){opts.opt_ftol_abs = tol;})
.method("__opt_xtol_rel!", [](ATMOptions &opts, double tol){opts.opt_xtol_rel = tol;})
.method("__opt_xtol_abs!", [](ATMOptions &opts, double tol){opts.opt_xtol_abs = tol;})
.method("__opt_maxeval!", [](ATMOptions &opts, int eval){opts.opt_maxeval = eval;})
.method("__verbose!", [](ATMOptions &opts, int verbose){opts.verbose = verbose;})
.method("__maxPatience!", [](ATMOptions &opts, int maxPatience){opts.maxPatience = maxPatience;})
.method("__maxSize!", [](ATMOptions &opts, int maxSize){opts.maxSize = maxSize;})
.method("__maxDegrees!", [](ATMOptions &opts, MultiIndex &maxDegrees){opts.maxDegrees = maxDegrees;})
.method("TrainOptions", [](ATMOptions &opts){ return static_cast<TrainOptions>(opts);})
;


mod.method("TrainMapAdaptive", [](jlcxx::ArrayRef<MultiIndexSet> arr, std::shared_ptr<MapObjective<Kokkos::HostSpace>> objective, ATMOptions options) {
std::vector<MultiIndexSet> vec (arr.begin(), arr.end());
auto map = TrainMapAdaptive(vec, objective, options);
for(int i = 0; i < vec.size(); i++) arr[i] = vec[i];
return map;
});
}
3 changes: 3 additions & 0 deletions bindings/julia/src/Wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ JLCXX_MODULE MParT_julia_module(jlcxx::Module& mod)
binding::AffineMapWrapper(mod);
binding::AffineFunctionWrapper(mod);
binding::MapFactoryWrapper(mod);
#if defined(MPART_HAS_NLOPT)
binding::MapObjectiveWrapper(mod);
binding::TrainMapWrapper(mod);
binding::TrainMapAdaptiveWrapper(mod);
#endif // MPART_HAS_NLOPT
}
32 changes: 23 additions & 9 deletions bindings/matlab/include/MexOptionsConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,36 @@
#include <iostream>
#include <mexplus.h>
#include "MParT/MapOptions.h"
#if defined(MPART_HAS_NLOPT)
#include "MParT/TrainMapAdaptive.h"
#endif // defined(MPART_HAS_NLOPT)

namespace mpart{
namespace binding{

/** Converts a real-valued matlab vector to a Kokkos::View. The memory in matlab vector is not copied for performance
reasons. However, this means that the user is responsible for ensuring the vector is not freed before the view.
*/
MapOptions MapOptionsFromMatlab(std::string basisType, std::string posFuncType,
std::string quadType, double quadAbsTol,
double quadRelTol, unsigned int quadMaxSub,
unsigned int quadMinSub,unsigned int quadPts,
bool contDeriv, double basisLB, double basisUB, bool basisNorm);
MapOptions MapOptionsFromMatlab(std::string basisType, std::string posFuncType,
std::string quadType, double quadAbsTol,
double quadRelTol, unsigned int quadMaxSub,
unsigned int quadMinSub,unsigned int quadPts,
bool contDeriv, double basisLB, double basisUB, bool basisNorm);

void MapOptionsToMatlab(MapOptions opts, mexplus::OutputArguments &output, int start = 0);
#if defined(MPART_HAS_NLOPT)
TrainOptions TrainOptionsFromMatlab(mexplus::InputArguments &input, unsigned int start);
ATMOptions ATMOptionsFromMatlab(mexplus::InputArguments &input, unsigned int start);
ATMOptions ATMOptionsFromMatlab(std::string basisType, std::string posFuncType,
std::string quadType, double quadAbsTol,
double quadRelTol, unsigned int quadMaxSub,
unsigned int quadMinSub,unsigned int quadPts,
bool contDeriv, double basisLB, double basisUB, bool basisNorm,
std::string opt_alg, double opt_stopval,
double opt_ftol_rel, double opt_ftol_abs,
double opt_xtol_rel, double opt_xtol_abs,
int opt_maxeval, double opt_maxtime, int verbose,
unsigned int maxPatience, unsigned int maxSize, MultiIndex& maxDegrees);
#endif // defined(MPART_HAS_NLOPT)
}
}


#endif
#endif
44 changes: 44 additions & 0 deletions bindings/matlab/mat/ATMOptions.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
classdef ATMOptions < TrainOptions & MapOptions
properties (Access = public)
maxPatience = 10;
maxSize = 10;
maxDegrees = MultiIndex(0);
end
methods
function obj = set.maxPatience(obj,value)
obj.maxPatience = value;
end
function obj = set.maxSize(obj,value)
obj.maxSize = value;
end
function obj = set.maxDegrees(obj,value)
obj.maxDegrees = value;
end
function optionsArray = getMexOptions(obj)
optionsArray{1} = char(obj.basisType);
optionsArray{2} = char(obj.posFuncType);
optionsArray{3} = char(obj.quadType);
optionsArray{4} = obj.quadAbsTol;
optionsArray{5} = obj.quadRelTol;
optionsArray{6} = obj.quadMaxSub;
optionsArray{7} = obj.quadMinSub;
optionsArray{8} = obj.quadPts;
optionsArray{9} = obj.contDeriv;
optionsArray{10} = obj.basisLB;
optionsArray{11} = obj.basisUB;
optionsArray{12} = obj.basisNorm;
optionsArray{12+1} = char(obj.opt_alg);
optionsArray{12+2} = obj.opt_stopval;
optionsArray{12+3} = obj.opt_ftol_rel;
optionsArray{12+4} = obj.opt_ftol_abs;
optionsArray{12+5} = obj.opt_xtol_rel;
optionsArray{12+6} = obj.opt_xtol_abs;
optionsArray{12+7} = obj.opt_maxeval;
optionsArray{12+8} = obj.opt_maxtime;
optionsArray{12+9} = obj.verbose;
optionsArray{12+9+1} = obj.maxPatience;
optionsArray{12+9+2} = obj.maxSize;
optionsArray{12+9+3} = obj.maxDegrees.get_id();
end
end
end
12 changes: 12 additions & 0 deletions bindings/matlab/mat/AdaptiveTransportMap.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function map = TrainMapAdaptive(mset0, objective, atm_options)
mexOptions = atm_options.getMexOptions;
mset_ids = arrayfun(@(mset) mset.get_id(), mset0);
input_str=['map_ptr = MParT_(',char(39),'ConditionalMap_TrainMapAdaptive',char(39),',mset_ids,objective.get_id()'];
for o=1:length(mexOptions)
input_o=[',mexOptions{',num2str(o),'}'];
input_str=[input_str,input_o];
end
input_str=[input_str,');'];
eval(input_str);
map = ConditionalMap(map_ptr,"id");
end
10 changes: 7 additions & 3 deletions bindings/matlab/mat/GaussianKLObjective.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
end

methods
function this = GaussianKLObjective(varargin)
function this = GaussianKLObjective(train, test, dim)
if(nargin==1)
this.id_ = MParT_('GaussianKLObjective_newTrain',varargin{1});
this.id_ = MParT_('GaussianKLObjective_newTrain',train,0);
elseif(isinteger(test))
this.id_ = MParT_('GaussianKLObjective_newTrain',train,test);
elseif(nargin==2)
this.id_ = MParT_('GaussianKLObjective_newTrainTest',train,test,0);
else
this.id_ = MParT_('GaussianKLObjective_newTrainTest',varargin{1},varargin{2});
this.id_ = MParT_('GaussianKLObjective_newTrainTest',train,test,dim);
end
end

Expand Down
6 changes: 3 additions & 3 deletions bindings/matlab/mat/MapOptions/MapOptions.m
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
classdef MapOptions
properties (Access = public)
basisType = BasisTypes.ProbabilistHermite;
basisLB = log(0);
basisUB = 1.0/0.0;
basisNorm = true;
posFuncType = PosFuncTypes.SoftPlus;
quadType = QuadTypes.AdaptiveSimpson;
quadAbsTol = 1e-6;
Expand All @@ -12,6 +9,9 @@
quadMinSub = 0;
quadPts = 5;
contDeriv = true;
basisLB = log(0);
basisUB = 1.0/0.0;
basisNorm = true;
end

methods
Expand Down
4 changes: 2 additions & 2 deletions bindings/matlab/mat/TrainOptions.m
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
classdef TrainOptions
properties (Access = public)
opt_alg = "LD_LBFGS"
opt_alg = "LD_SLSQP"
opt_stopval = log(0)
opt_ftol_rel = 1e-3
opt_ftol_abs = 1e-3
opt_xtol_rel = 1e-4
opt_xtol_abs = 1e-4
opt_maxeval = 30
opt_maxtime = 100
opt_maxtime = 1e2
verbose = 0
end

Expand Down
Loading

0 comments on commit fdd4afc

Please sign in to comment.