Skip to content

Commit

Permalink
Speed up 2bpls prediction.
Browse files Browse the repository at this point in the history
  • Loading branch information
akenmorris committed Nov 20, 2023
1 parent f87ebe3 commit 181baf4
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 69 deletions.
45 changes: 30 additions & 15 deletions Python/shapeworks/shapeworks/shape_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,22 @@ def get_fig_png():
def run_mbpls(x, y, n_components=3, cv=5):
""" Run MBPLS on shape and scalar data """

model = MBPLS(n_components=n_components)
# don't set cv higher than the number of samples
cv = min(cv, len(x))

global mbpls_model
mbpls_model = MBPLS(n_components=n_components)
if cv != 1:
y_pred = cross_val_predict(mbpls_model, x, y, cv=cv)

mbpls_model.fit(x, y)

if cv == 1:
model.fit(x, y)
y_pred = model.predict(x)
else:
y_pred = cross_val_predict(model, x, y, cv=cv)
y_pred = mbpls_model.predict(x)

mse = mean_squared_error(y, y_pred)

sw_message(f'MSE: {mse}')
sw_message(f'Python MSE: {mse}')

prediction = pd.DataFrame(np.array(y_pred))

Expand Down Expand Up @@ -84,17 +90,26 @@ def run_find_num_components(x, y, max_components, cv=5):
return figdata_png


def pred_from_mbpls(x, y, new_x, n_components=3):
""" Run MBPLS on shape and scalar data, then predict new_y from new_x """
def pred_from_mbpls(new_x, n_components=3):
""" Predict new_y from new_x using existing mbpls fit """

if not does_mbpls_model_exist():
sw_message('MBPLS model does not exist, returning none')
return None

global mbpls_model
y_pred = mbpls_model.predict(new_x)
# return as vector
return y_pred.flatten()

def does_mbpls_model_exist():
""" Check if mbpls model exists """

# check if global variable model exists, otherwise create it
global model
global mbpls_model
try:
model
mbpls_model
except NameError:
model = MBPLS(n_components=n_components)
model.fit(x, y)
return False

y_pred = model.predict(new_x)
# return as vector
return y_pred.flatten()
return True
53 changes: 36 additions & 17 deletions Studio/Analysis/AnalysisTool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ bool AnalysisTool::compute_stats() {
return false;
}

SW_LOG("Compute Stats!");
SW_DEBUG("Compute Stats!");
compute_reconstructed_domain_transforms();

ui_->pcaModeSpinBox->setMaximum(std::max<double>(1, session_->get_shapes().size() - 1));
Expand Down Expand Up @@ -691,11 +691,14 @@ Particles AnalysisTool::get_shape_points(int mode, double value) {

auto positions = temp_shape_;

computed_scalars_ = Eigen::VectorXd();
if (pca_shape_plus_scalar_mode()) {
positions = extract_shape_only(temp_shape_);
temp_scalars_ = extract_scalar_only(temp_shape_);
computed_scalars_ = extract_scalar_only(temp_shape_);
} else if (pca_scalar_only_mode()) {
SW_LOG("Scalar only mode not implemented yet");
computed_scalars_ = temp_shape_;
positions = construct_mean_shape();
}

return convert_from_combined(positions);
Expand Down Expand Up @@ -908,19 +911,13 @@ AnalysisTool::GroupAnalysisType AnalysisTool::get_group_analysis_type() {
}

//---------------------------------------------------------------------------
bool AnalysisTool::pca_scalar_only_mode() {
return ui_->pca_scalar_only->isChecked();
}
bool AnalysisTool::pca_scalar_only_mode() { return ui_->pca_scalar_only->isChecked(); }

//---------------------------------------------------------------------------
bool AnalysisTool::pca_shape_plus_scalar_mode() {
return ui_->pca_shape_and_scalar->isChecked();
}
bool AnalysisTool::pca_shape_plus_scalar_mode() { return ui_->pca_shape_and_scalar->isChecked(); }

//---------------------------------------------------------------------------
bool AnalysisTool::pca_shape_only_mode() {
return ui_->pca_scalar_shape_only->isChecked();
}
bool AnalysisTool::pca_shape_only_mode() { return ui_->pca_scalar_shape_only->isChecked(); }

//---------------------------------------------------------------------------
void AnalysisTool::on_tabWidget_currentChanged() { update_analysis_mode(); }
Expand Down Expand Up @@ -976,6 +973,8 @@ void AnalysisTool::handle_pca_timer() {
}

ui_->pcaSlider->setValue(value);

QApplication::processEvents();
}

//---------------------------------------------------------------------------
Expand Down Expand Up @@ -1244,12 +1243,13 @@ ShapeHandle AnalysisTool::create_shape_from_points(Particles points) {
shape->set_reconstruction_transforms(reconstruction_transforms_);

if (feature_map_ != "") {
// auto scalars = ShapeScalarJob::predict_scalars(session_, QString::fromStdString(feature_map_),
// points.get_combined_global_particles());

// shape->set_point_features(feature_map_, scalars);

shape->set_point_features(feature_map_, temp_scalars_);
if (ui_->pca_predict_scalar->isChecked()) {
auto scalars = ShapeScalarJob::predict_scalars(session_, QString::fromStdString(feature_map_),
points.get_combined_global_particles());
shape->set_point_features(feature_map_, scalars);
} else {
shape->set_point_features(feature_map_, computed_scalars_);
}
}
return shape;
}
Expand Down Expand Up @@ -1729,6 +1729,25 @@ void AnalysisTool::change_pca_analysis_type() {
evals_ready_ = false;
stats_ = ParticleShapeStatistics();
compute_stats();
Q_EMIT pca_update();
}

//---------------------------------------------------------------------------
Eigen::VectorXd AnalysisTool::construct_mean_shape() {
if (session_->get_shapes().empty()) {
return Eigen::VectorXd();
}

Eigen::VectorXd sum_shape =
Eigen::VectorXd::Zero(session_->get_shapes()[0]->get_global_correspondence_points().size());

for (auto& shape : session_->get_shapes()) {
Eigen::VectorXd particles = shape->get_global_correspondence_points();
sum_shape += particles;
}

Eigen::VectorXd mean_shape = sum_shape / session_->get_shapes().size();
return mean_shape;
}

//---------------------------------------------------------------------------
Expand Down
5 changes: 4 additions & 1 deletion Studio/Analysis/AnalysisTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ class AnalysisTool : public QWidget {

void change_pca_analysis_type();

//! Compute the mean shape outside of the PCA in case we are using scalars only
Eigen::VectorXd construct_mean_shape();

Q_SIGNALS:

void update_view();
Expand Down Expand Up @@ -253,7 +256,7 @@ class AnalysisTool : public QWidget {
Eigen::VectorXd temp_shape_mca;
std::vector<int> number_of_particles_array_;

Eigen::VectorXd temp_scalars_;
Eigen::VectorXd computed_scalars_;

bool pca_animate_direction_ = true;
QTimer pca_animate_timer_;
Expand Down
8 changes: 7 additions & 1 deletion Studio/Analysis/AnalysisTool.ui
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,13 @@ QWidget#particles_panel {
</widget>
</item>
<item row="0" column="1">
<widget class="QComboBox" name="pca_scalar_combo"/>
<widget class="QComboBox" name="pca_scalar_combo">
<property name="font">
<font>
<family>.AppleSystemUIFont</family>
</font>
</property>
</widget>
</item>
<item row="1" column="0">
<widget class="QRadioButton" name="pca_scalar_shape_only">
Expand Down
7 changes: 7 additions & 0 deletions Studio/Job/Job.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class Job : public QObject {
//! was the job aborted?
bool is_aborted() const { return abort_; }

//! set to quiet mode (no progress messages)
void set_quiet_mode(bool quiet) { quiet_mode_ = quiet; }

//! get quiet mode
bool get_quiet_mode() { return quiet_mode_; }

public Q_SLOTS:

Q_SIGNALS:
Expand All @@ -51,6 +57,7 @@ class Job : public QObject {
private:
std::atomic<bool> complete_ = false;
std::atomic<bool> abort_ = false;
std::atomic<bool> quiet_mode_ = false;

QElapsedTimer timer_;
};
Expand Down
95 changes: 62 additions & 33 deletions Studio/Job/ShapeScalarJob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <pybind11/stl.h>

#include <Eigen/Dense>
#include <QApplication>
#include <QImage>

namespace py = pybind11;
Expand All @@ -22,46 +23,22 @@ ShapeScalarJob::ShapeScalarJob(QSharedPointer<Session> session, QString target_f

//---------------------------------------------------------------------------
void ShapeScalarJob::run() {
SW_DEBUG("Running shape scalar job");
// SW_DEBUG("Running shape scalar job");

try {
prep_data();

py::module np = py::module::import("numpy");
py::object A = np.attr("array")(all_particles_);
py::object B = np.attr("array")(all_scalars_);

py::module sw = py::module::import("shapeworks");

if (job_type_ == JobType::MSE_Plot) {
// returns a tuple of (png_raw_bytes, y_pred, mse)
using ResultType = std::tuple<py::array, Eigen::MatrixXd, double>;

py::object run_mbpls = sw.attr("shape_scalars").attr("run_mbpls");
ResultType result = run_mbpls(A, B, num_components_, num_folds_).cast<ResultType>();

py::array png_raw_bytes = std::get<0>(result);
Eigen::MatrixXd y_pred = std::get<1>(result);
double mse = std::get<2>(result);

// interpret png_raw_bytes as a QImage
QImage image;
image.loadFromData((const uchar*)png_raw_bytes.data(), png_raw_bytes.size(), "PNG");
plot_ = QPixmap::fromImage(image);

SW_LOG("mse = {}", mse);

run_fit();
} else if (job_type_ == JobType::Predict) {
py::object new_x = np.attr("array")(target_particles_.transpose());
py::object run_prediction = sw.attr("shape_scalars").attr("pred_from_mbpls");

using ResultType = Eigen::VectorXd;
ResultType result = run_prediction(A, B, new_x).cast<ResultType>();

auto y_pred = result;

prediction_ = y_pred;
run_prediction();
} else if (job_type_ == JobType::Find_Components) {
prep_data();
py::object A = np.attr("array")(all_particles_);
py::object B = np.attr("array")(all_scalars_);

// returns a tuple of (png_raw_bytes, y_pred, mse)
using ResultType = py::array;

Expand All @@ -75,7 +52,7 @@ void ShapeScalarJob::run() {
image.loadFromData((const uchar*)png_raw_bytes.data(), png_raw_bytes.size(), "PNG");
plot_ = QPixmap::fromImage(image);
}
SW_DEBUG("End shape scalar job");
// SW_DEBUG("End shape scalar job");

} catch (const std::exception& e) {
SW_ERROR("Exception in shape scalar job: {}", e.what());
Expand All @@ -94,6 +71,7 @@ Eigen::VectorXd ShapeScalarJob::predict_scalars(QSharedPointer<Session> session,
// blocking call to predict scalars for given target particles

auto job = QSharedPointer<ShapeScalarJob>::create(session, target_feature, target_particles, JobType::Predict);
job->set_quiet_mode(true);

Eigen::VectorXd prediction;

Expand All @@ -106,8 +84,9 @@ Eigen::VectorXd ShapeScalarJob::predict_scalars(QSharedPointer<Session> session,

session->get_py_worker()->run_job(job);

// wait for job to finish without using sleep
while (!finished) {
QThread::msleep(100);
QApplication::processEvents();
}

return prediction;
Expand All @@ -119,5 +98,55 @@ void ShapeScalarJob::prep_data() {
all_scalars_ = session_->get_all_scalars(target_feature_.toStdString());
}

//---------------------------------------------------------------------------
void ShapeScalarJob::run_fit() {
prep_data();
py::module np = py::module::import("numpy");
py::module sw = py::module::import("shapeworks");

py::object A = np.attr("array")(all_particles_);
py::object B = np.attr("array")(all_scalars_);

// returns a tuple of (png_raw_bytes, y_pred, mse)
using ResultType = std::tuple<py::array, Eigen::MatrixXd, double>;

py::object run_mbpls = sw.attr("shape_scalars").attr("run_mbpls");
ResultType result = run_mbpls(A, B, num_components_, num_folds_).cast<ResultType>();

py::array png_raw_bytes = std::get<0>(result);
Eigen::MatrixXd y_pred = std::get<1>(result);
double mse = std::get<2>(result);

// interpret png_raw_bytes as a QImage
QImage image;
image.loadFromData((const uchar*)png_raw_bytes.data(), png_raw_bytes.size(), "PNG");
plot_ = QPixmap::fromImage(image);

SW_LOG("mse = {}", mse);
}

//---------------------------------------------------------------------------
void ShapeScalarJob::run_prediction() {
py::module np = py::module::import("numpy");
py::module sw = py::module::import("shapeworks");

py::object does_mbpls_model_exist = sw.attr("shape_scalars").attr("does_mbpls_model_exist");
if (!does_mbpls_model_exist().cast<bool>()) {
SW_LOG("No MBPLS model exists, running fit");
run_fit();
}

py::object new_x = np.attr("array")(target_particles_.transpose());
py::object run_prediction = sw.attr("shape_scalars").attr("pred_from_mbpls");

using ResultType = Eigen::VectorXd;

ResultType result = run_prediction(new_x).cast<ResultType>();

auto y_pred = result;

prediction_ = y_pred;
}

//---------------------------------------------------------------------------
} // namespace shapeworks
2 changes: 2 additions & 0 deletions Studio/Job/ShapeScalarJob.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ShapeScalarJob : public Job {
private:
void prep_data();

void run_fit();
void run_prediction();

QSharedPointer<Session> session_;

Expand Down
8 changes: 6 additions & 2 deletions Studio/Python/PythonWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,16 @@ void PythonWorker::start_job(QSharedPointer<Job> job) {
if (init()) {
try {
job->start_timer();
SW_LOG("Running Task: " + job->name().toStdString());
if (!job->get_quiet_mode()) {
SW_LOG("Running Task: " + job->name().toStdString());
}
Q_EMIT job->progress(0);
current_job_ = job;
current_job_->run();
current_job_->set_complete(true);
SW_LOG(current_job_->get_completion_message().toStdString());
if (!job->get_quiet_mode()) {
SW_LOG(current_job_->get_completion_message().toStdString());
}
} catch (py::error_already_set& e) {
SW_ERROR(e.what());
}
Expand Down

0 comments on commit 181baf4

Please sign in to comment.