Skip to content

Commit

Permalink
Remaining pieces for pca shape scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
akenmorris committed Nov 21, 2023
1 parent bed1237 commit 7e2cf17
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 13 deletions.
16 changes: 14 additions & 2 deletions Python/shapeworks/shapeworks/shape_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run_find_num_components(x, y, max_components, cv=5):
return figdata_png


def pred_from_mbpls(new_x, n_components=3):
def pred_from_mbpls(new_x):
""" Predict new_y from new_x using existing mbpls fit """

if not does_mbpls_model_exist():
Expand All @@ -112,4 +112,16 @@ def does_mbpls_model_exist():
except NameError:
return False

return True
return True

def clear_mbpls_model():
""" Clear mbpls model """

global mbpls_model
try:
mbpls_model
except NameError:
return

del mbpls_model
return
34 changes: 32 additions & 2 deletions Studio/Analysis/AnalysisTool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,23 @@ void AnalysisTool::network_analysis_clicked() {
app_->get_py_worker()->run_job(network_analysis_job_);
}

//-----------------------------------------------------------------------------
Eigen::VectorXd AnalysisTool::extract_positions(Eigen::VectorXd& data) {
/*
auto positions = data;
if (pca_shape_plus_scalar_mode()) {
positions = extract_shape_only(data);
} else if (pca_scalar_only_mode()) {
computed_scalars_ = temp_shape_;
if (ui_->pca_predict_shape->isChecked()) {
positions = ShapeScalarJob::predict_shape(session_, QString::fromStdString(feature_map_), computed_scalars_);
} else {
positions = construct_mean_shape();
}
}*/
}

//-----------------------------------------------------------------------------
bool AnalysisTool::compute_stats() {
if (stats_ready_) {
Expand Down Expand Up @@ -637,6 +654,10 @@ Particles AnalysisTool::get_mean_shape_points() {
return Particles();
}

if (ui_->pca_scalar_only->isChecked()) {
return convert_from_combined(construct_mean_shape());
}

if (ui_->group1_button->isChecked() || ui_->difference_button->isChecked()) {
return convert_from_combined(stats_.get_group1_mean());
} else if (ui_->group2_button->isChecked()) {
Expand Down Expand Up @@ -700,7 +721,11 @@ Particles AnalysisTool::get_shape_points(int mode, double value) {
computed_scalars_ = extract_scalar_only(temp_shape_);
} else if (pca_scalar_only_mode()) {
computed_scalars_ = temp_shape_;
positions = construct_mean_shape();
if (ui_->pca_predict_shape->isChecked()) {
positions = ShapeScalarJob::predict_shape(session_, QString::fromStdString(feature_map_), computed_scalars_);
} else {
positions = construct_mean_shape();
}
}

return convert_from_combined(positions);
Expand Down Expand Up @@ -1245,7 +1270,7 @@ ShapeHandle AnalysisTool::create_shape_from_points(Particles points) {
shape->set_reconstruction_transforms(reconstruction_transforms_);

if (feature_map_ != "") {
if (ui_->pca_predict_scalar->isChecked()) {
if (ui_->pca_scalar_shape_only->isChecked() && ui_->pca_predict_scalar->isChecked()) {
auto scalars = ShapeScalarJob::predict_scalars(session_, QString::fromStdString(feature_map_),
points.get_combined_global_particles());
shape->set_point_features(feature_map_, scalars);
Expand Down Expand Up @@ -1730,6 +1755,11 @@ void AnalysisTool::change_pca_analysis_type() {
stats_ready_ = false;
evals_ready_ = false;
stats_ = ParticleShapeStatistics();
ShapeScalarJob::clear_model();

ui_->pca_predict_scalar->setEnabled(ui_->pca_scalar_shape_only->isChecked());
ui_->pca_predict_shape->setEnabled(ui_->pca_scalar_only->isChecked());

compute_stats();
Q_EMIT pca_update();
}
Expand Down
5 changes: 5 additions & 0 deletions Studio/Analysis/AnalysisTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,14 @@ class AnalysisTool : public QWidget {

void change_pca_analysis_type();

//Eigen::VectorXd get_mean_shape();

//! Compute the mean shape outside of the PCA in case we are using scalars only
Eigen::VectorXd construct_mean_shape();


Eigen::VectorXd extract_positions(Eigen::VectorXd& data);

Q_SIGNALS:

void update_view();
Expand Down
4 changes: 2 additions & 2 deletions Studio/Analysis/AnalysisTool.ui
Original file line number Diff line number Diff line change
Expand Up @@ -1451,14 +1451,14 @@ QWidget#particles_panel {
<item row="1" column="1">
<widget class="QCheckBox" name="pca_predict_scalar">
<property name="text">
<string>Predict scalar</string>
<string>Predict Scalar</string>
</property>
</widget>
</item>
<item row="2" column="1">
<widget class="QCheckBox" name="pca_predict_shape">
<property name="text">
<string>Predict shape</string>
<string>Predict Shape</string>
</property>
</widget>
</item>
Expand Down
35 changes: 29 additions & 6 deletions Studio/Job/ShapeScalarJob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ using namespace pybind11::literals; // to bring in the `_a` literal

namespace shapeworks {

std::atomic<bool> ShapeScalarJob::needs_clear_ = false;

//---------------------------------------------------------------------------
ShapeScalarJob::ShapeScalarJob(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_particles, JobType job_type)
: session_(session), target_feature_(target_feature), target_particles_(target_particles), job_type_(job_type) {}
: session_(session), target_feature_(target_feature), target_values_(target_particles), job_type_(job_type) {}

//---------------------------------------------------------------------------
void ShapeScalarJob::run() {
Expand Down Expand Up @@ -68,10 +70,23 @@ QPixmap ShapeScalarJob::get_plot() { return plot_; }
//---------------------------------------------------------------------------
Eigen::VectorXd ShapeScalarJob::predict_scalars(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_particles) {
return predict(session, target_feature, target_particles, Direction::To_Scalar);
}

//---------------------------------------------------------------------------
Eigen::VectorXd ShapeScalarJob::predict_shape(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_scalars) {
return predict(session, target_feature, target_scalars, Direction::To_Shape);
}

//---------------------------------------------------------------------------
Eigen::VectorXd ShapeScalarJob::predict(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_values, Direction direction) {
// blocking call to predict scalars for given target particles

auto job = QSharedPointer<ShapeScalarJob>::create(session, target_feature, target_particles, JobType::Predict);
auto job = QSharedPointer<ShapeScalarJob>::create(session, target_feature, target_values, JobType::Predict);
job->set_quiet_mode(true);
job->set_direction(direction);

Eigen::VectorXd prediction;

Expand Down Expand Up @@ -104,8 +119,15 @@ void ShapeScalarJob::run_fit() {
py::module np = py::module::import("numpy");
py::module sw = py::module::import("shapeworks");

py::object A = np.attr("array")(all_particles_);
py::object B = np.attr("array")(all_scalars_);
py::object A;
py::object B;
if (direction_ == Direction::To_Scalar) {
A = np.attr("array")(all_particles_);
B = np.attr("array")(all_scalars_);
} else {
A = np.attr("array")(all_scalars_);
B = np.attr("array")(all_particles_);
}

// returns a tuple of (png_raw_bytes, y_pred, mse)
using ResultType = std::tuple<py::array, Eigen::MatrixXd, double>;
Expand All @@ -131,12 +153,13 @@ void ShapeScalarJob::run_prediction() {
py::module sw = py::module::import("shapeworks");

py::object does_mbpls_model_exist = sw.attr("shape_scalars").attr("does_mbpls_model_exist");
if (!does_mbpls_model_exist().cast<bool>()) {
if (needs_clear_ == true || !does_mbpls_model_exist().cast<bool>()) {
SW_LOG("No MBPLS model exists, running fit");
run_fit();
needs_clear_ = false;
}

py::object new_x = np.attr("array")(target_particles_.transpose());
py::object new_x = np.attr("array")(target_values_.transpose());
py::object run_prediction = sw.attr("shape_scalars").attr("pred_from_mbpls");

using ResultType = Eigen::VectorXd;
Expand Down
16 changes: 15 additions & 1 deletion Studio/Job/ShapeScalarJob.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class ShapeScalarJob : public Job {
Q_OBJECT
public:
enum class JobType { Find_Components, MSE_Plot, Predict };
enum class Direction { To_Shape, To_Scalar };

ShapeScalarJob(QSharedPointer<Session> session, QString target_feature, Eigen::MatrixXd target_particles,
JobType job_type);
Expand All @@ -30,12 +31,22 @@ class ShapeScalarJob : public Job {
static Eigen::VectorXd predict_scalars(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_particles);

static Eigen::VectorXd predict_shape(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_particles);

static void clear_model() { needs_clear_ = true; };

void set_direction(Direction direction) { direction_ = direction; }

private:
void prep_data();

void run_fit();
void run_prediction();

static Eigen::VectorXd predict(QSharedPointer<Session> session, QString target_feature,
Eigen::MatrixXd target_particles, Direction direction);

QSharedPointer<Session> session_;

ParticleShapeStatistics stats_;
Expand All @@ -47,13 +58,16 @@ class ShapeScalarJob : public Job {
Eigen::MatrixXd all_particles_;
Eigen::MatrixXd all_scalars_;

Eigen::MatrixXd target_particles_;
Eigen::MatrixXd target_values_;
Eigen::VectorXd prediction_;

bool num_components_ = 3;
int num_folds_ = 5;
int max_components_ = 20;

Direction direction_{Direction::To_Scalar};
JobType job_type_;

static std::atomic<bool> needs_clear_;
};
} // namespace shapeworks

0 comments on commit 7e2cf17

Please sign in to comment.