diff --git a/Python/shapeworks/shapeworks/shape_scalars.py b/Python/shapeworks/shapeworks/shape_scalars.py
index f6f1246959..495413b637 100644
--- a/Python/shapeworks/shapeworks/shape_scalars.py
+++ b/Python/shapeworks/shapeworks/shape_scalars.py
@@ -90,7 +90,7 @@ def run_find_num_components(x, y, max_components, cv=5):
return figdata_png
-def pred_from_mbpls(new_x, n_components=3):
+def pred_from_mbpls(new_x):
""" Predict new_y from new_x using existing mbpls fit """
if not does_mbpls_model_exist():
@@ -112,4 +112,16 @@ def does_mbpls_model_exist():
except NameError:
return False
- return True
\ No newline at end of file
+ return True
+
+def clear_mbpls_model():
+ """ Clear mbpls model """
+
+ global mbpls_model
+ try:
+ mbpls_model
+ except NameError:
+ return
+
+ del mbpls_model
+ return
\ No newline at end of file
diff --git a/Studio/Analysis/AnalysisTool.cpp b/Studio/Analysis/AnalysisTool.cpp
index 281415af1e..4bab5d8c6b 100644
--- a/Studio/Analysis/AnalysisTool.cpp
+++ b/Studio/Analysis/AnalysisTool.cpp
@@ -463,6 +463,23 @@ void AnalysisTool::network_analysis_clicked() {
app_->get_py_worker()->run_job(network_analysis_job_);
}
+//-----------------------------------------------------------------------------
+Eigen::VectorXd AnalysisTool::extract_positions(Eigen::VectorXd& data) {
+ /*
+ auto positions = data;
+
+ if (pca_shape_plus_scalar_mode()) {
+ positions = extract_shape_only(data);
+ } else if (pca_scalar_only_mode()) {
+ computed_scalars_ = temp_shape_;
+ if (ui_->pca_predict_shape->isChecked()) {
+ positions = ShapeScalarJob::predict_shape(session_, QString::fromStdString(feature_map_), computed_scalars_);
+ } else {
+ positions = construct_mean_shape();
+ }
+ }*/
+}
+
//-----------------------------------------------------------------------------
bool AnalysisTool::compute_stats() {
if (stats_ready_) {
@@ -637,6 +654,10 @@ Particles AnalysisTool::get_mean_shape_points() {
return Particles();
}
+ if (ui_->pca_scalar_only->isChecked()) {
+ return convert_from_combined(construct_mean_shape());
+ }
+
if (ui_->group1_button->isChecked() || ui_->difference_button->isChecked()) {
return convert_from_combined(stats_.get_group1_mean());
} else if (ui_->group2_button->isChecked()) {
@@ -700,7 +721,11 @@ Particles AnalysisTool::get_shape_points(int mode, double value) {
computed_scalars_ = extract_scalar_only(temp_shape_);
} else if (pca_scalar_only_mode()) {
computed_scalars_ = temp_shape_;
- positions = construct_mean_shape();
+ if (ui_->pca_predict_shape->isChecked()) {
+ positions = ShapeScalarJob::predict_shape(session_, QString::fromStdString(feature_map_), computed_scalars_);
+ } else {
+ positions = construct_mean_shape();
+ }
}
return convert_from_combined(positions);
@@ -1245,7 +1270,7 @@ ShapeHandle AnalysisTool::create_shape_from_points(Particles points) {
shape->set_reconstruction_transforms(reconstruction_transforms_);
if (feature_map_ != "") {
- if (ui_->pca_predict_scalar->isChecked()) {
+ if (ui_->pca_scalar_shape_only->isChecked() && ui_->pca_predict_scalar->isChecked()) {
auto scalars = ShapeScalarJob::predict_scalars(session_, QString::fromStdString(feature_map_),
points.get_combined_global_particles());
shape->set_point_features(feature_map_, scalars);
@@ -1730,6 +1755,11 @@ void AnalysisTool::change_pca_analysis_type() {
stats_ready_ = false;
evals_ready_ = false;
stats_ = ParticleShapeStatistics();
+ ShapeScalarJob::clear_model();
+
+ ui_->pca_predict_scalar->setEnabled(ui_->pca_scalar_shape_only->isChecked());
+ ui_->pca_predict_shape->setEnabled(ui_->pca_scalar_only->isChecked());
+
compute_stats();
Q_EMIT pca_update();
}
diff --git a/Studio/Analysis/AnalysisTool.h b/Studio/Analysis/AnalysisTool.h
index a41a1c0d69..67c5f7868f 100644
--- a/Studio/Analysis/AnalysisTool.h
+++ b/Studio/Analysis/AnalysisTool.h
@@ -194,9 +194,14 @@ class AnalysisTool : public QWidget {
void change_pca_analysis_type();
+ //Eigen::VectorXd get_mean_shape();
+
//! Compute the mean shape outside of the PCA in case we are using scalars only
Eigen::VectorXd construct_mean_shape();
+
+ Eigen::VectorXd extract_positions(Eigen::VectorXd& data);
+
Q_SIGNALS:
void update_view();
diff --git a/Studio/Analysis/AnalysisTool.ui b/Studio/Analysis/AnalysisTool.ui
index 4ca9f8d1dc..3d1b62f58e 100644
--- a/Studio/Analysis/AnalysisTool.ui
+++ b/Studio/Analysis/AnalysisTool.ui
@@ -1451,14 +1451,14 @@ QWidget#particles_panel {
-
- Predict scalar
+ Predict Scalar
-
- Predict shape
+ Predict Shape
diff --git a/Studio/Job/ShapeScalarJob.cpp b/Studio/Job/ShapeScalarJob.cpp
index 1ea16fc3a8..64eb08eba7 100644
--- a/Studio/Job/ShapeScalarJob.cpp
+++ b/Studio/Job/ShapeScalarJob.cpp
@@ -16,10 +16,12 @@ using namespace pybind11::literals; // to bring in the `_a` literal
namespace shapeworks {
+std::atomic ShapeScalarJob::needs_clear_ = false;
+
//---------------------------------------------------------------------------
ShapeScalarJob::ShapeScalarJob(QSharedPointer session, QString target_feature,
Eigen::MatrixXd target_particles, JobType job_type)
- : session_(session), target_feature_(target_feature), target_particles_(target_particles), job_type_(job_type) {}
+ : session_(session), target_feature_(target_feature), target_values_(target_particles), job_type_(job_type) {}
//---------------------------------------------------------------------------
void ShapeScalarJob::run() {
@@ -68,10 +70,23 @@ QPixmap ShapeScalarJob::get_plot() { return plot_; }
//---------------------------------------------------------------------------
Eigen::VectorXd ShapeScalarJob::predict_scalars(QSharedPointer session, QString target_feature,
Eigen::MatrixXd target_particles) {
+ return predict(session, target_feature, target_particles, Direction::To_Scalar);
+}
+
+//---------------------------------------------------------------------------
+Eigen::VectorXd ShapeScalarJob::predict_shape(QSharedPointer session, QString target_feature,
+ Eigen::MatrixXd target_scalars) {
+ return predict(session, target_feature, target_scalars, Direction::To_Shape);
+}
+
+//---------------------------------------------------------------------------
+Eigen::VectorXd ShapeScalarJob::predict(QSharedPointer session, QString target_feature,
+ Eigen::MatrixXd target_values, Direction direction) {
// blocking call to predict scalars for given target particles
- auto job = QSharedPointer::create(session, target_feature, target_particles, JobType::Predict);
+ auto job = QSharedPointer::create(session, target_feature, target_values, JobType::Predict);
job->set_quiet_mode(true);
+ job->set_direction(direction);
Eigen::VectorXd prediction;
@@ -104,8 +119,15 @@ void ShapeScalarJob::run_fit() {
py::module np = py::module::import("numpy");
py::module sw = py::module::import("shapeworks");
- py::object A = np.attr("array")(all_particles_);
- py::object B = np.attr("array")(all_scalars_);
+ py::object A;
+ py::object B;
+ if (direction_ == Direction::To_Scalar) {
+ A = np.attr("array")(all_particles_);
+ B = np.attr("array")(all_scalars_);
+ } else {
+ A = np.attr("array")(all_scalars_);
+ B = np.attr("array")(all_particles_);
+ }
// returns a tuple of (png_raw_bytes, y_pred, mse)
using ResultType = std::tuple;
@@ -131,12 +153,13 @@ void ShapeScalarJob::run_prediction() {
py::module sw = py::module::import("shapeworks");
py::object does_mbpls_model_exist = sw.attr("shape_scalars").attr("does_mbpls_model_exist");
- if (!does_mbpls_model_exist().cast()) {
+ if (needs_clear_ == true || !does_mbpls_model_exist().cast()) {
SW_LOG("No MBPLS model exists, running fit");
run_fit();
+ needs_clear_ = false;
}
- py::object new_x = np.attr("array")(target_particles_.transpose());
+ py::object new_x = np.attr("array")(target_values_.transpose());
py::object run_prediction = sw.attr("shape_scalars").attr("pred_from_mbpls");
using ResultType = Eigen::VectorXd;
diff --git a/Studio/Job/ShapeScalarJob.h b/Studio/Job/ShapeScalarJob.h
index 9bd8eb3b88..d1607f1397 100644
--- a/Studio/Job/ShapeScalarJob.h
+++ b/Studio/Job/ShapeScalarJob.h
@@ -13,6 +13,7 @@ class ShapeScalarJob : public Job {
Q_OBJECT
public:
enum class JobType { Find_Components, MSE_Plot, Predict };
+ enum class Direction { To_Shape, To_Scalar };
ShapeScalarJob(QSharedPointer session, QString target_feature, Eigen::MatrixXd target_particles,
JobType job_type);
@@ -30,12 +31,22 @@ class ShapeScalarJob : public Job {
static Eigen::VectorXd predict_scalars(QSharedPointer session, QString target_feature,
Eigen::MatrixXd target_particles);
+ static Eigen::VectorXd predict_shape(QSharedPointer session, QString target_feature,
+ Eigen::MatrixXd target_particles);
+
+ static void clear_model() { needs_clear_ = true; };
+
+ void set_direction(Direction direction) { direction_ = direction; }
+
private:
void prep_data();
void run_fit();
void run_prediction();
+ static Eigen::VectorXd predict(QSharedPointer session, QString target_feature,
+ Eigen::MatrixXd target_particles, Direction direction);
+
QSharedPointer session_;
ParticleShapeStatistics stats_;
@@ -47,13 +58,16 @@ class ShapeScalarJob : public Job {
Eigen::MatrixXd all_particles_;
Eigen::MatrixXd all_scalars_;
- Eigen::MatrixXd target_particles_;
+ Eigen::MatrixXd target_values_;
Eigen::VectorXd prediction_;
bool num_components_ = 3;
int num_folds_ = 5;
int max_components_ = 20;
+ Direction direction_{Direction::To_Scalar};
JobType job_type_;
+
+ static std::atomic needs_clear_;
};
} // namespace shapeworks