diff --git a/Python/shapeworks/shapeworks/shape_scalars.py b/Python/shapeworks/shapeworks/shape_scalars.py index f6f1246959..495413b637 100644 --- a/Python/shapeworks/shapeworks/shape_scalars.py +++ b/Python/shapeworks/shapeworks/shape_scalars.py @@ -90,7 +90,7 @@ def run_find_num_components(x, y, max_components, cv=5): return figdata_png -def pred_from_mbpls(new_x, n_components=3): +def pred_from_mbpls(new_x): """ Predict new_y from new_x using existing mbpls fit """ if not does_mbpls_model_exist(): @@ -112,4 +112,16 @@ def does_mbpls_model_exist(): except NameError: return False - return True \ No newline at end of file + return True + +def clear_mbpls_model(): + """ Clear mbpls model """ + + global mbpls_model + try: + mbpls_model + except NameError: + return + + del mbpls_model + return \ No newline at end of file diff --git a/Studio/Analysis/AnalysisTool.cpp b/Studio/Analysis/AnalysisTool.cpp index 281415af1e..4bab5d8c6b 100644 --- a/Studio/Analysis/AnalysisTool.cpp +++ b/Studio/Analysis/AnalysisTool.cpp @@ -463,6 +463,23 @@ void AnalysisTool::network_analysis_clicked() { app_->get_py_worker()->run_job(network_analysis_job_); } +//----------------------------------------------------------------------------- +Eigen::VectorXd AnalysisTool::extract_positions(Eigen::VectorXd& data) { + /* + auto positions = data; + + if (pca_shape_plus_scalar_mode()) { + positions = extract_shape_only(data); + } else if (pca_scalar_only_mode()) { + computed_scalars_ = temp_shape_; + if (ui_->pca_predict_shape->isChecked()) { + positions = ShapeScalarJob::predict_shape(session_, QString::fromStdString(feature_map_), computed_scalars_); + } else { + positions = construct_mean_shape(); + } + }*/ +} + //----------------------------------------------------------------------------- bool AnalysisTool::compute_stats() { if (stats_ready_) { @@ -637,6 +654,10 @@ Particles AnalysisTool::get_mean_shape_points() { return Particles(); } + if (ui_->pca_scalar_only->isChecked()) { + return convert_from_combined(construct_mean_shape()); + } + if (ui_->group1_button->isChecked() || ui_->difference_button->isChecked()) { return convert_from_combined(stats_.get_group1_mean()); } else if (ui_->group2_button->isChecked()) { @@ -700,7 +721,11 @@ Particles AnalysisTool::get_shape_points(int mode, double value) { computed_scalars_ = extract_scalar_only(temp_shape_); } else if (pca_scalar_only_mode()) { computed_scalars_ = temp_shape_; - positions = construct_mean_shape(); + if (ui_->pca_predict_shape->isChecked()) { + positions = ShapeScalarJob::predict_shape(session_, QString::fromStdString(feature_map_), computed_scalars_); + } else { + positions = construct_mean_shape(); + } } return convert_from_combined(positions); @@ -1245,7 +1270,7 @@ ShapeHandle AnalysisTool::create_shape_from_points(Particles points) { shape->set_reconstruction_transforms(reconstruction_transforms_); if (feature_map_ != "") { - if (ui_->pca_predict_scalar->isChecked()) { + if (ui_->pca_scalar_shape_only->isChecked() && ui_->pca_predict_scalar->isChecked()) { auto scalars = ShapeScalarJob::predict_scalars(session_, QString::fromStdString(feature_map_), points.get_combined_global_particles()); shape->set_point_features(feature_map_, scalars); @@ -1730,6 +1755,11 @@ void AnalysisTool::change_pca_analysis_type() { stats_ready_ = false; evals_ready_ = false; stats_ = ParticleShapeStatistics(); + ShapeScalarJob::clear_model(); + + ui_->pca_predict_scalar->setEnabled(ui_->pca_scalar_shape_only->isChecked()); + ui_->pca_predict_shape->setEnabled(ui_->pca_scalar_only->isChecked()); + compute_stats(); Q_EMIT pca_update(); } diff --git a/Studio/Analysis/AnalysisTool.h b/Studio/Analysis/AnalysisTool.h index a41a1c0d69..67c5f7868f 100644 --- a/Studio/Analysis/AnalysisTool.h +++ b/Studio/Analysis/AnalysisTool.h @@ -194,9 +194,14 @@ class AnalysisTool : public QWidget { void change_pca_analysis_type(); + //Eigen::VectorXd get_mean_shape(); + //! Compute the mean shape outside of the PCA in case we are using scalars only Eigen::VectorXd construct_mean_shape(); + + Eigen::VectorXd extract_positions(Eigen::VectorXd& data); + Q_SIGNALS: void update_view(); diff --git a/Studio/Analysis/AnalysisTool.ui b/Studio/Analysis/AnalysisTool.ui index 4ca9f8d1dc..3d1b62f58e 100644 --- a/Studio/Analysis/AnalysisTool.ui +++ b/Studio/Analysis/AnalysisTool.ui @@ -1451,14 +1451,14 @@ QWidget#particles_panel { - Predict scalar + Predict Scalar - Predict shape + Predict Shape diff --git a/Studio/Job/ShapeScalarJob.cpp b/Studio/Job/ShapeScalarJob.cpp index 1ea16fc3a8..64eb08eba7 100644 --- a/Studio/Job/ShapeScalarJob.cpp +++ b/Studio/Job/ShapeScalarJob.cpp @@ -16,10 +16,12 @@ using namespace pybind11::literals; // to bring in the `_a` literal namespace shapeworks { +std::atomic ShapeScalarJob::needs_clear_ = false; + //--------------------------------------------------------------------------- ShapeScalarJob::ShapeScalarJob(QSharedPointer session, QString target_feature, Eigen::MatrixXd target_particles, JobType job_type) - : session_(session), target_feature_(target_feature), target_particles_(target_particles), job_type_(job_type) {} + : session_(session), target_feature_(target_feature), target_values_(target_particles), job_type_(job_type) {} //--------------------------------------------------------------------------- void ShapeScalarJob::run() { @@ -68,10 +70,23 @@ QPixmap ShapeScalarJob::get_plot() { return plot_; } //--------------------------------------------------------------------------- Eigen::VectorXd ShapeScalarJob::predict_scalars(QSharedPointer session, QString target_feature, Eigen::MatrixXd target_particles) { + return predict(session, target_feature, target_particles, Direction::To_Scalar); +} + +//--------------------------------------------------------------------------- +Eigen::VectorXd ShapeScalarJob::predict_shape(QSharedPointer session, QString target_feature, + Eigen::MatrixXd target_scalars) { + return predict(session, target_feature, target_scalars, Direction::To_Shape); +} + +//--------------------------------------------------------------------------- +Eigen::VectorXd ShapeScalarJob::predict(QSharedPointer session, QString target_feature, + Eigen::MatrixXd target_values, Direction direction) { // blocking call to predict scalars for given target particles - auto job = QSharedPointer::create(session, target_feature, target_particles, JobType::Predict); + auto job = QSharedPointer::create(session, target_feature, target_values, JobType::Predict); job->set_quiet_mode(true); + job->set_direction(direction); Eigen::VectorXd prediction; @@ -104,8 +119,15 @@ void ShapeScalarJob::run_fit() { py::module np = py::module::import("numpy"); py::module sw = py::module::import("shapeworks"); - py::object A = np.attr("array")(all_particles_); - py::object B = np.attr("array")(all_scalars_); + py::object A; + py::object B; + if (direction_ == Direction::To_Scalar) { + A = np.attr("array")(all_particles_); + B = np.attr("array")(all_scalars_); + } else { + A = np.attr("array")(all_scalars_); + B = np.attr("array")(all_particles_); + } // returns a tuple of (png_raw_bytes, y_pred, mse) using ResultType = std::tuple; @@ -131,12 +153,13 @@ void ShapeScalarJob::run_prediction() { py::module sw = py::module::import("shapeworks"); py::object does_mbpls_model_exist = sw.attr("shape_scalars").attr("does_mbpls_model_exist"); - if (!does_mbpls_model_exist().cast()) { + if (needs_clear_ == true || !does_mbpls_model_exist().cast()) { SW_LOG("No MBPLS model exists, running fit"); run_fit(); + needs_clear_ = false; } - py::object new_x = np.attr("array")(target_particles_.transpose()); + py::object new_x = np.attr("array")(target_values_.transpose()); py::object run_prediction = sw.attr("shape_scalars").attr("pred_from_mbpls"); using ResultType = Eigen::VectorXd; diff --git a/Studio/Job/ShapeScalarJob.h b/Studio/Job/ShapeScalarJob.h index 9bd8eb3b88..d1607f1397 100644 --- a/Studio/Job/ShapeScalarJob.h +++ b/Studio/Job/ShapeScalarJob.h @@ -13,6 +13,7 @@ class ShapeScalarJob : public Job { Q_OBJECT public: enum class JobType { Find_Components, MSE_Plot, Predict }; + enum class Direction { To_Shape, To_Scalar }; ShapeScalarJob(QSharedPointer session, QString target_feature, Eigen::MatrixXd target_particles, JobType job_type); @@ -30,12 +31,22 @@ class ShapeScalarJob : public Job { static Eigen::VectorXd predict_scalars(QSharedPointer session, QString target_feature, Eigen::MatrixXd target_particles); + static Eigen::VectorXd predict_shape(QSharedPointer session, QString target_feature, + Eigen::MatrixXd target_particles); + + static void clear_model() { needs_clear_ = true; }; + + void set_direction(Direction direction) { direction_ = direction; } + private: void prep_data(); void run_fit(); void run_prediction(); + static Eigen::VectorXd predict(QSharedPointer session, QString target_feature, + Eigen::MatrixXd target_particles, Direction direction); + QSharedPointer session_; ParticleShapeStatistics stats_; @@ -47,13 +58,16 @@ class ShapeScalarJob : public Job { Eigen::MatrixXd all_particles_; Eigen::MatrixXd all_scalars_; - Eigen::MatrixXd target_particles_; + Eigen::MatrixXd target_values_; Eigen::VectorXd prediction_; bool num_components_ = 3; int num_folds_ = 5; int max_components_ = 20; + Direction direction_{Direction::To_Scalar}; JobType job_type_; + + static std::atomic needs_clear_; }; } // namespace shapeworks