diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index 5dc5abb31a..d0b6b11f8c 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -107,7 +107,7 @@ def Run_Pipeline(args): The required grooming steps are: 1. Load mesh 2. Apply clipping with planes for finding alignment transform - 3. Find reflection tansfrom + 3. Find reflection transfrom 4. Select reference mesh 5. Find rigid alignment transform For more information on grooming see docs/workflow/groom.md @@ -479,7 +479,8 @@ def Run_Pipeline(args): transform = [ train_transforms[i].flatten() ] subject.set_groomed_transforms(transform) subject.set_constraints_filenames(rel_plane_files) - subject.set_landmarks_filenames([train_local_particles[i]]) + subject.set_local_particle_filenames([train_local_particles[i]]) + subject.set_world_particle_filenames([train_local_particles[i]]) subject.set_extra_values({"fixed": "yes"}) subjects.append(subject) # Add new validation shapes @@ -500,7 +501,8 @@ def Run_Pipeline(args): transform = [ val_transforms[i].flatten() ] subject.set_groomed_transforms(transform) subject.set_constraints_filenames(rel_plane_files) - subject.set_landmarks_filenames(rel_particle_files) + subject.set_local_particle_filenames(rel_particle_files) + subject.set_world_particle_filenames(rel_particle_files) subject.set_extra_values({"fixed": "no"}) subjects.append(subject) project = sw.Project() @@ -512,11 +514,7 @@ def Run_Pipeline(args): parameter_dictionary["procrustes"] = 0 parameter_dictionary["procrustes_interval"] = 0 parameter_dictionary["procrustes_scaling"] = 0 - parameter_dictionary["use_landmarks"] = 1 - parameter_dictionary["use_fixed_subjects"] = 1 parameter_dictionary["narrow_band"] = 1e10 - parameter_dictionary["fixed_subjects_column"] = "fixed" - parameter_dictionary["fixed_subjects_choice"] = "yes" for key in parameter_dictionary: parameters.set(key, sw.Variant(parameter_dictionary[key])) project.set_parameters("optimize", parameters) diff --git a/Examples/Python/ellipsoid_fd.py b/Examples/Python/ellipsoid_fd.py index 4f3e9e1676..8d3b512a4d 100644 --- a/Examples/Python/ellipsoid_fd.py +++ b/Examples/Python/ellipsoid_fd.py @@ -123,7 +123,8 @@ def Run_Pipeline(args): rel_particle_files = sw.utils.get_relative_paths([os.getcwd() + "/" + fixed_local_particles[i]], project_location) subject.set_original_filenames(original_groom_files) subject.set_groomed_filenames(rel_groom_files) - subject.set_landmarks_filenames(rel_particle_files) + subject.set_local_particle_filenames(rel_particle_files) + subject.set_world_particle_filenames(rel_particle_files) subject.set_extra_values({"fixed": "yes"}) subjects.append(subject) @@ -135,7 +136,8 @@ def Run_Pipeline(args): rel_particle_files = sw.utils.get_relative_paths([os.getcwd() + "/" + mean_shape_path], project_location) subject.set_original_filenames(original_groom_files) subject.set_groomed_filenames(rel_groom_files) - subject.set_landmarks_filenames(rel_particle_files) + subject.set_local_particle_filenames(rel_particle_files) + subject.set_world_particle_filenames(rel_particle_files) subject.set_extra_values({"fixed": "no"}) subjects.append(subject) @@ -160,11 +162,7 @@ def Run_Pipeline(args): "procrustes_scaling": 0, "save_init_splits": 0, "verbosity": 0, - "use_landmarks": 1, - "use_fixed_subjects": 1, "narrow_band": 1e10, - "fixed_subjects_column": "fixed", - "fixed_subjects_choice": "yes" } for key in parameter_dictionary: @@ -186,6 +184,7 @@ def Run_Pipeline(args): # Run optimization optimize_cmd = ('shapeworks optimize --progress --name ' + spreadsheet_file).split() + print(optimize_cmd) subprocess.check_call(optimize_cmd) # If tiny test or verify, check results and exit diff --git a/Libs/Analyze/Shape.cpp b/Libs/Analyze/Shape.cpp index be00dec6c9..55c6276fd1 100644 --- a/Libs/Analyze/Shape.cpp +++ b/Libs/Analyze/Shape.cpp @@ -146,8 +146,10 @@ void Shape::clear_reconstructed_mesh() { reconstructed_meshes_ = MeshGroup(subje bool Shape::import_global_point_files(std::vector filenames) { for (int i = 0; i < filenames.size(); i++) { Eigen::VectorXd points; - if (!Shape::import_point_file(filenames[i], points)) { - throw std::invalid_argument("Unable to import point file: " + filenames[i]); + if (filenames[i] != "") { + if (!Shape::import_point_file(filenames[i], points)) { + throw std::invalid_argument("Unable to import point file: " + filenames[i]); + } } global_point_filenames_.push_back(filenames[i]); particles_.set_world_particles(i, points); @@ -160,8 +162,10 @@ bool Shape::import_global_point_files(std::vector filenames) { bool Shape::import_local_point_files(std::vector filenames) { for (int i = 0; i < filenames.size(); i++) { Eigen::VectorXd points; - if (!Shape::import_point_file(filenames[i], points)) { - throw std::invalid_argument("Unable to import point file: " + filenames[i]); + if (filenames[i] != "") { + if (!Shape::import_point_file(filenames[i], points)) { + throw std::invalid_argument("Unable to import point file: " + filenames[i]); + } } local_point_filenames_.push_back(filenames[i]); particles_.set_local_particles(i, points); @@ -307,9 +311,7 @@ bool Shape::store_constraints() { Eigen::VectorXd Shape::get_global_correspondence_points() { return particles_.get_combined_global_particles(); } //--------------------------------------------------------------------------- -Eigen::VectorXd Shape::get_local_correspondence_points() { - return particles_.get_combined_local_particles(); -} +Eigen::VectorXd Shape::get_local_correspondence_points() { return particles_.get_combined_local_particles(); } //--------------------------------------------------------------------------- int Shape::get_id() { return id_; } @@ -699,9 +701,7 @@ void Shape::set_particle_transform(vtkSmartPointer transform) { } //--------------------------------------------------------------------------- -void Shape::set_alignment_type(int alignment) { - particles_.set_alignment_type(alignment); -} +void Shape::set_alignment_type(int alignment) { particles_.set_alignment_type(alignment); } //--------------------------------------------------------------------------- vtkSmartPointer Shape::get_reconstruction_transform(int domain) { diff --git a/Libs/Groom/Groom.cpp b/Libs/Groom/Groom.cpp index 42f95cf6f1..151f435d45 100644 --- a/Libs/Groom/Groom.cpp +++ b/Libs/Groom/Groom.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -10,7 +11,6 @@ #include #include #include -#include #include #include @@ -24,16 +24,16 @@ typedef float PixelType; typedef itk::Image ImageType; //--------------------------------------------------------------------------- -Groom::Groom(ProjectHandle project) { this->project_ = project; } +Groom::Groom(ProjectHandle project) { project_ = project; } //--------------------------------------------------------------------------- bool Groom::run() { used_names_.clear(); - this->progress_ = 0; - this->progress_counter_ = 0; - this->total_ops_ = this->get_total_ops(); + progress_ = 0; + progress_counter_ = 0; + total_ops_ = get_total_ops(); - auto subjects = this->project_->get_subjects(); + auto subjects = project_->get_subjects(); if (subjects.empty()) { throw std::invalid_argument("No subjects to groom"); @@ -43,29 +43,33 @@ bool Groom::run() { tbb::parallel_for(tbb::blocked_range{0, subjects.size()}, [&](const tbb::blocked_range& r) { for (size_t i = r.begin(); i < r.end(); ++i) { for (int domain = 0; domain < project_->get_number_of_domains_per_subject(); domain++) { - if (this->abort_) { + if (abort_) { success = false; continue; } + if (subjects[i]->is_fixed()) { + continue; + } + bool is_image = project_->get_original_domain_types()[domain] == DomainType::Image; bool is_mesh = project_->get_original_domain_types()[domain] == DomainType::Mesh; bool is_contour = project_->get_original_domain_types()[domain] == DomainType::Contour; if (is_image) { - if (!this->image_pipeline(subjects[i], domain)) { + if (!image_pipeline(subjects[i], domain)) { success = false; } } if (is_mesh) { - if (!this->mesh_pipeline(subjects[i], domain)) { + if (!mesh_pipeline(subjects[i], domain)) { success = false; } } if (is_contour) { - if (!this->contour_pipeline(subjects[i], domain)) { + if (!contour_pipeline(subjects[i], domain)) { success = false; } } @@ -73,7 +77,7 @@ bool Groom::run() { } }); - if (!this->run_alignment()) { + if (!run_alignment()) { success = false; } increment_progress(10); // alignment complete @@ -85,7 +89,7 @@ bool Groom::run() { //--------------------------------------------------------------------------- bool Groom::image_pipeline(std::shared_ptr subject, size_t domain) { // grab parameters - auto params = GroomParameters(this->project_, this->project_->get_domain_names()[domain]); + auto params = GroomParameters(project_, project_->get_domain_names()[domain]); auto original = subject->get_original_filenames()[domain]; @@ -119,34 +123,34 @@ bool Groom::image_pipeline(std::shared_ptr subject, size_t domain) { return true; } - this->run_image_pipeline(image, params); + run_image_pipeline(image, params); // reflection if (params.get_reflect()) { auto table = subject->get_table_values(); if (table.find(params.get_reflect_column()) != table.end()) { if (table[params.get_reflect_column()] == params.get_reflect_choice()) { - this->add_reflect_transform(transform, params.get_reflect_axis()); + add_reflect_transform(transform, params.get_reflect_axis()); } } } // centering if (params.get_use_center()) { - this->add_center_transform(transform, image); + add_center_transform(transform, image); } - if (this->abort_) { + if (abort_) { return false; } // groomed filename - std::string groomed_name = this->get_output_filename(original, DomainType::Image); + std::string groomed_name = get_output_filename(original, DomainType::Image); if (params.get_convert_to_mesh()) { Mesh mesh = image.toMesh(0.0); - this->run_mesh_pipeline(mesh, params); - groomed_name = this->get_output_filename(original, DomainType::Mesh); + run_mesh_pipeline(mesh, params); + groomed_name = get_output_filename(original, DomainType::Mesh); // save the groomed mesh MeshUtils::threadSafeWriteMesh(groomed_name, mesh); } else { @@ -179,18 +183,18 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) { // isolate if (params.get_isolate_tool()) { image.isolate(); - this->increment_progress(); + increment_progress(); } - if (this->abort_) { + if (abort_) { return false; } // fill holes if (params.get_fill_holes_tool()) { image.closeHoles(); - this->increment_progress(); + increment_progress(); } - if (this->abort_) { + if (abort_) { return false; } @@ -198,28 +202,28 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) { if (params.get_crop()) { PhysicalRegion region = image.physicalBoundingBox(0.5); image.crop(region); - this->increment_progress(); + increment_progress(); } - if (this->abort_) { + if (abort_) { return false; } // autopad if (params.get_auto_pad_tool()) { image.pad(params.get_padding_amount()); - this->fix_origin(image); - this->increment_progress(); + fix_origin(image); + increment_progress(); } - if (this->abort_) { + if (abort_) { return false; } // antialias if (params.get_antialias_tool()) { image.antialias(params.get_antialias_iterations()); - this->increment_progress(); + increment_progress(); } - if (this->abort_) { + if (abort_) { return false; } @@ -239,26 +243,26 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) { } else { image.resample(v, Image::InterpolationType::Linear); } - this->increment_progress(); + increment_progress(); } - if (this->abort_) { + if (abort_) { return false; } // create distance transform if (params.get_fast_marching()) { image.computeDT(); - this->increment_progress(10); + increment_progress(10); } - if (this->abort_) { + if (abort_) { return false; } // blur if (params.get_blur_tool()) { image.gaussianBlur(params.get_blur_amount()); - this->increment_progress(); + increment_progress(); } return true; @@ -267,12 +271,12 @@ bool Groom::run_image_pipeline(Image& image, GroomParameters params) { //--------------------------------------------------------------------------- bool Groom::mesh_pipeline(std::shared_ptr subject, size_t domain) { // grab parameters - auto params = GroomParameters(this->project_, this->project_->get_domain_names()[domain]); + auto params = GroomParameters(project_, project_->get_domain_names()[domain]); auto original = subject->get_original_filenames()[domain]; // groomed mesh name - std::string groom_name = this->get_output_filename(original, DomainType::Mesh); + std::string groom_name = get_output_filename(original, DomainType::Mesh); Mesh mesh = MeshUtils::threadSafeReadMesh(original); @@ -280,21 +284,21 @@ bool Groom::mesh_pipeline(std::shared_ptr subject, size_t domain) { auto transform = vtkSmartPointer::New(); if (!params.get_skip_grooming()) { - this->run_mesh_pipeline(mesh, params); + run_mesh_pipeline(mesh, params); // reflection if (params.get_reflect()) { auto table = subject->get_table_values(); if (table.find(params.get_reflect_column()) != table.end()) { if (table[params.get_reflect_column()] == params.get_reflect_choice()) { - this->add_reflect_transform(transform, params.get_reflect_axis()); + add_reflect_transform(transform, params.get_reflect_axis()); } } } // centering if (params.get_use_center()) { - this->add_center_transform(transform, mesh); + add_center_transform(transform, mesh); } // save the groomed mesh MeshUtils::threadSafeWriteMesh(groom_name, mesh); @@ -302,7 +306,6 @@ bool Groom::mesh_pipeline(std::shared_ptr subject, size_t domain) { groom_name = original; } - { // lock for project data structure std::scoped_lock lock(mutex_); @@ -328,7 +331,7 @@ bool Groom::mesh_pipeline(std::shared_ptr subject, size_t domain) { bool Groom::run_mesh_pipeline(Mesh& mesh, GroomParameters params) { if (params.get_fill_mesh_holes_tool()) { mesh.fillHoles(); - this->increment_progress(); + increment_progress(); } if (params.get_remesh()) { @@ -344,7 +347,7 @@ bool Groom::run_mesh_pipeline(Mesh& mesh, GroomParameters params) { num_vertices = std::max(num_vertices, 25); double gradation = clamp(params.get_remesh_gradation(), 0.0, 2.0); mesh.remesh(num_vertices, gradation); - this->increment_progress(); + increment_progress(); } if (params.get_mesh_smooth()) { @@ -353,7 +356,7 @@ bool Groom::run_mesh_pipeline(Mesh& mesh, GroomParameters params) { } else if (params.get_mesh_smoothing_method() == GroomParameters::GROOM_SMOOTH_VTK_WINDOWED_SINC_C) { mesh.smoothSinc(params.get_mesh_vtk_windowed_sinc_iterations(), params.get_mesh_vtk_windowed_sinc_passband()); } - this->increment_progress(); + increment_progress(); } return true; } @@ -361,12 +364,12 @@ bool Groom::run_mesh_pipeline(Mesh& mesh, GroomParameters params) { //--------------------------------------------------------------------------- bool Groom::contour_pipeline(std::shared_ptr subject, size_t domain) { // grab parameters - auto params = GroomParameters(this->project_, this->project_->get_domain_names()[domain]); + auto params = GroomParameters(project_, project_->get_domain_names()[domain]); auto original = subject->get_original_filenames()[domain]; // groomed mesh name - std::string groom_name = this->get_output_filename(original, DomainType::Mesh); + std::string groom_name = get_output_filename(original, DomainType::Mesh); Mesh mesh = MeshUtils::threadSafeReadMesh(original); @@ -379,14 +382,14 @@ bool Groom::contour_pipeline(std::shared_ptr subject, size_t domain) { auto table = subject->get_table_values(); if (table.find(params.get_reflect_column()) != table.end()) { if (table[params.get_reflect_column()] == params.get_reflect_choice()) { - this->add_reflect_transform(transform, params.get_reflect_axis()); + add_reflect_transform(transform, params.get_reflect_axis()); } } } // centering if (params.get_use_center()) { - this->add_center_transform(transform, mesh); + add_center_transform(transform, mesh); } // save the groomed contour @@ -430,16 +433,16 @@ void Groom::fix_origin(Image& image) { //--------------------------------------------------------------------------- int Groom::get_total_ops() { - int num_subjects = this->project_->get_subjects().size(); + int num_subjects = project_->get_subjects().size(); int num_tools = 0; project_->update_subjects(); - auto domains = this->project_->get_domain_names(); - auto subjects = this->project_->get_subjects(); + auto domains = project_->get_domain_names(); + auto subjects = project_->get_subjects(); for (int i = 0; i < domains.size(); i++) { - auto params = GroomParameters(this->project_, domains[i]); + auto params = GroomParameters(project_, domains[i]); if (project_->get_original_domain_types()[i] == DomainType::Image) { num_tools += params.get_isolate_tool() ? 1 : 0; @@ -453,7 +456,7 @@ int Groom::get_total_ops() { } bool run_mesh = project_->get_original_domain_types()[i] == DomainType::Mesh || - (project_->get_original_domain_types()[i] == DomainType::Image && params.get_convert_to_mesh()); + (project_->get_original_domain_types()[i] == DomainType::Image && params.get_convert_to_mesh()); if (run_mesh) { num_tools += params.get_fill_holes_tool() ? 1 : 0; @@ -469,47 +472,53 @@ int Groom::get_total_ops() { //--------------------------------------------------------------------------- void Groom::increment_progress(int amount) { std::scoped_lock lock(mutex); - this->progress_counter_ += amount; - this->progress_ = static_cast(this->progress_counter_) / static_cast(this->total_ops_) * 100.0; + progress_counter_ += amount; + progress_ = static_cast(progress_counter_) / static_cast(total_ops_) * 100.0; SW_PROGRESS(progress_, fmt::format("Grooming ({}/{} ops)", progress_counter_, total_ops_)); } //--------------------------------------------------------------------------- -void Groom::abort() { this->abort_ = true; } +void Groom::abort() { abort_ = true; } //--------------------------------------------------------------------------- -bool Groom::get_aborted() { return this->abort_; } +bool Groom::get_aborted() { return abort_; } //--------------------------------------------------------------------------- bool Groom::run_alignment() { - size_t num_domains = this->project_->get_number_of_domains_per_subject(); - auto subjects = this->project_->get_subjects(); + size_t num_domains = project_->get_number_of_domains_per_subject(); + auto subjects = project_->get_subjects(); - auto base_params = GroomParameters(this->project_); + auto base_params = GroomParameters(project_); bool global_icp = false; bool global_landmarks = false; // per-domain alignment for (size_t domain = 0; domain < num_domains; domain++) { - if (this->abort_) { + if (abort_) { return false; } - auto params = GroomParameters(this->project_, this->project_->get_domain_names()[domain]); + auto params = GroomParameters(project_, project_->get_domain_names()[domain]); if (params.get_use_icp()) { global_icp = true; + std::vector reference_meshes; std::vector meshes; for (size_t i = 0; i < subjects.size(); i++) { - auto mesh = this->get_mesh(i, domain); + auto mesh = get_mesh(i, domain); auto list = subjects[i]->get_groomed_transforms()[domain]; vtkSmartPointer transform = ProjectUtils::convert_transform(list); mesh.applyTransform(transform); + + if (subjects[i]->is_fixed() || !project_->get_fixed_subjects_present()) { + // if fixed subjects are present, only add the fixed subjects + reference_meshes.push_back(mesh); + } meshes.push_back(mesh); } - size_t reference_mesh = MeshUtils::findReferenceMesh(meshes); + size_t reference_mesh = MeshUtils::findReferenceMesh(reference_meshes); auto transforms = Groom::get_icp_transforms(meshes, reference_mesh); @@ -532,17 +541,18 @@ bool Groom::run_alignment() { std::vector meshes; for (size_t i = 0; i < subjects.size(); i++) { - Mesh mesh = this->get_mesh(i, 0); + Mesh mesh = get_mesh(i, 0); for (size_t domain = 1; domain < num_domains; domain++) { - mesh += this->get_mesh(i, domain); // combine + mesh += get_mesh(i, domain); // combine } // grab the first domain's initial transform (e.g. potentially reflect) and use before ICP auto list = subjects[i]->get_groomed_transforms()[0]; vtkSmartPointer transform = ProjectUtils::convert_transform(list); mesh.applyTransform(transform); - - meshes.push_back(mesh); + if (subjects[i]->is_fixed() || !project_->get_fixed_subjects_present()) { + meshes.push_back(mesh); + } } if (global_icp) { @@ -577,7 +587,7 @@ bool Groom::run_alignment() { //--------------------------------------------------------------------------- void Groom::assign_transforms(std::vector> transforms, int domain, bool global) { - auto subjects = this->project_->get_subjects(); + auto subjects = project_->get_subjects(); for (size_t i = 0; i < subjects.size(); i++) { auto subject = subjects[i]; @@ -593,7 +603,9 @@ void Groom::assign_transforms(std::vector> transforms, int d transform->Concatenate(ProjectUtils::convert_transform(transforms[i])); // store transform - subject->set_groomed_transform(domain, ProjectUtils::convert_transform(transform)); + if (!subject->is_fixed()) { + subject->set_groomed_transform(domain, ProjectUtils::convert_transform(transform)); + } } } @@ -615,10 +627,10 @@ std::string Groom::get_output_filename(std::string input, DomainType domain_type std::scoped_lock lock(mutex_); // grab parameters - auto params = GroomParameters(this->project_); + auto params = GroomParameters(project_); // if the project is not saved, use the path of the input filename - auto filename = this->project_->get_filename(); + auto filename = project_->get_filename(); if (filename == "") { filename = input; } @@ -663,7 +675,7 @@ std::string Groom::get_output_filename(std::string input, DomainType domain_type //--------------------------------------------------------------------------- Mesh Groom::get_mesh(int subject, int domain) { - auto subjects = this->project_->get_subjects(); + auto subjects = project_->get_subjects(); auto path = subjects[subject]->get_original_filenames()[domain]; if (project_->get_original_domain_types()[domain] == DomainType::Image) { @@ -682,7 +694,7 @@ Mesh Groom::get_mesh(int subject, int domain) { //--------------------------------------------------------------------------- vtkSmartPointer Groom::get_landmarks(int subject, int domain) { vtkSmartPointer vtk_points = vtkSmartPointer::New(); - auto subjects = this->project_->get_subjects(); + auto subjects = project_->get_subjects(); auto path = subjects[subject]->get_landmarks_filenames()[domain]; std::ifstream in(path.c_str()); @@ -852,7 +864,7 @@ void Groom::add_center_transform(vtkSmartPointer transform, vtkSma //--------------------------------------------------------------------------- std::vector> Groom::get_combined_points() { auto subjects = project_->get_subjects(); - size_t num_domains = this->project_->get_number_of_domains_per_subject(); + size_t num_domains = project_->get_number_of_domains_per_subject(); std::vector> landmarks; for (size_t i = 0; i < subjects.size(); i++) { diff --git a/Libs/Optimize/Optimize.cpp b/Libs/Optimize/Optimize.cpp index f5aa5a77bd..615db4c9da 100644 --- a/Libs/Optimize/Optimize.cpp +++ b/Libs/Optimize/Optimize.cpp @@ -22,11 +22,8 @@ #include // shapeworks -#include - -//#include "Libs/Optimize/Domain/ImageDomain.h" -//#include "Libs/Optimize/Domain/ImplicitSurfaceDomain.h" #include +#include #include "Libs/Optimize/Domain/VtkMeshWrapper.h" #include "Libs/Optimize/Utils/ObjectReader.h" @@ -1890,8 +1887,8 @@ void Optimize::AddMesh(vtkSmartPointer poly_data) { //--------------------------------------------------------------------------- void Optimize::AddContour(vtkSmartPointer poly_data) { m_sampler->AddContour(poly_data); - this->m_num_shapes++; - this->m_spacing = 0.5; + m_num_shapes++; + m_spacing = 0.5; } //--------------------------------------------------------------------------- @@ -1904,11 +1901,18 @@ void Optimize::SetPointFiles(const std::vector& point_files) { } } +//--------------------------------------------------------------------------- +void Optimize::SetInitialPoints(std::vector>> initial_points) { + m_sampler->SetInitialPoints(initial_points); +} + //--------------------------------------------------------------------------- int Optimize::GetNumShapes() { return this->m_num_shapes; } +//--------------------------------------------------------------------------- shapeworks::OptimizationVisualizer& Optimize::GetVisualizer() { return visualizer_; } +//--------------------------------------------------------------------------- void Optimize::SetShowVisualizer(bool show) { if (show && this->m_verbosity_level > 0) { std::cout << "WARNING Using the visualizer will increase run time!\n"; @@ -1916,10 +1920,8 @@ void Optimize::SetShowVisualizer(bool show) { this->show_visualizer_ = show; } -bool Optimize::GetShowVisualizer() { return this->show_visualizer_; } - //--------------------------------------------------------------------------- -void Optimize::SetMeshFiles(const std::vector& mesh_files) { m_sampler->SetMeshFiles(mesh_files); } +bool Optimize::GetShowVisualizer() { return this->show_visualizer_; } //--------------------------------------------------------------------------- void Optimize::SetAttributeScales(const std::vector& scales) { m_sampler->SetAttributeScales(scales); } @@ -1933,7 +1935,7 @@ void Optimize::SetFieldAttributes(const std::vector& field_attribut void Optimize::SetParticleFlags(std::vector flags) { this->m_particle_flags = flags; } //--------------------------------------------------------------------------- -void Optimize::SetDomainFlags(std::vector flags) { +void Optimize::SetFixedDomains(std::vector flags) { if (flags.size() > 0) { // Fixed domains are in use. this->m_fixed_domains_present = true; @@ -1961,15 +1963,15 @@ void Optimize::SetNarrowBand(double v) { //--------------------------------------------------------------------------- double Optimize::GetNarrowBand() { + if (this->m_fixed_domains_present) { + return 1e10; + } + if (this->m_narrow_band_set) { return this->m_narrow_band; } - if (this->m_fixed_domains_present) { - return 1e10; - } else { - return 4.0; - } + return 4.0; } //--------------------------------------------------------------------------- @@ -2019,8 +2021,8 @@ bool Optimize::LoadParameterFile(std::string filename) { } //--------------------------------------------------------------------------- -bool Optimize::SetUpOptimize(ProjectHandle projectFile) { - OptimizeParameters param(projectFile); +bool Optimize::SetUpOptimize(ProjectHandle project) { + OptimizeParameters param(project); param.set_up_optimize(this); return true; } diff --git a/Libs/Optimize/Optimize.h b/Libs/Optimize/Optimize.h index f25e373b5c..b970519605 100644 --- a/Libs/Optimize/Optimize.h +++ b/Libs/Optimize/Optimize.h @@ -19,10 +19,8 @@ #include #include "Libs/Optimize/Domain/DomainType.h" -#include "Libs/Optimize/Domain/MeshWrapper.h" #include "Libs/Optimize/Function/VectorFunction.h" #include "Libs/Optimize/Utils/OptimizationVisualizer.h" -#include "ParticleSystem.h" #include "ProcrustesRegistration.h" #include "Sampler.h" @@ -66,11 +64,13 @@ class Optimize { //! Load a parameter file bool LoadParameterFile(std::string filename); - bool SetUpOptimize(ProjectHandle projectFile); + //! Set up this Optimize object using a ShapeWorks project + bool SetUpOptimize(ProjectHandle project); - //! Set the Projects + //! Set the Project object void SetProject(std::shared_ptr project); + //! Set an iteration callback function to be called after each iteration void SetIterationCallbackFunction(const std::function& f) { this->iteration_callback_ = f; } //! Abort optimization @@ -226,10 +226,11 @@ class Optimize { //! Set starting point files (TODO: details) void SetPointFiles(const std::vector& point_files); + //! Set initial particle positions (e.g. for fixed subjects) + void SetInitialPoints(std::vector>> initial_points); + //! Get number of shapes int GetNumShapes(); - //! Set the mesh files (TODO: details) - void SetMeshFiles(const std::vector& mesh_files); //! Set attribute scales (TODO: details) void SetAttributeScales(const std::vector& scales); @@ -239,7 +240,7 @@ class Optimize { //! Set Particle Flags (TODO: details) void SetParticleFlags(std::vector flags); //! Set Domain Flags (TODO: details) - void SetDomainFlags(std::vector flags); + void SetFixedDomains(std::vector flags); //! Shared boundary settings void SetSharedBoundaryEnabled(bool enabled); diff --git a/Libs/Optimize/OptimizeParameterFile.cpp b/Libs/Optimize/OptimizeParameterFile.cpp index 5d9f823830..102785c72e 100644 --- a/Libs/Optimize/OptimizeParameterFile.cpp +++ b/Libs/Optimize/OptimizeParameterFile.cpp @@ -698,26 +698,6 @@ bool OptimizeParameterFile::read_mesh_attributes(TiXmlHandle* docHandle, Optimiz std::string filename; int numShapes = optimize->GetNumShapes(); - // load mesh files - elem = docHandle->FirstChild("mesh_files").Element(); - if (elem) { - std::vector meshFiles; - inputsBuffer.str(elem->GetText()); - while (inputsBuffer >> filename) { - meshFiles.push_back(filename); - } - inputsBuffer.clear(); - inputsBuffer.str(""); - - // read mesh files only if they are all present - if (meshFiles.size() != numShapes) { - std::cerr << "Error: incorrect number of mesh files!" << std::endl; - return false; - } else { - optimize->SetMeshFiles(meshFiles); - } - } - std::vector attributes_per_domain = optimize->GetAttributesPerDomain(); // attributes @@ -1048,7 +1028,7 @@ bool OptimizeParameterFile::read_flag_particles(TiXmlHandle* doc_handle, Optimiz //--------------------------------------------------------------------------- bool OptimizeParameterFile::read_flag_domains(TiXmlHandle* doc_handle, Optimize* optimize) { - optimize->SetDomainFlags(this->read_int_list(doc_handle, "fixed_domains")); + optimize->SetFixedDomains(this->read_int_list(doc_handle, "fixed_domains")); return true; } diff --git a/Libs/Optimize/OptimizeParameters.cpp b/Libs/Optimize/OptimizeParameters.cpp index 5882887846..1c85d695c7 100644 --- a/Libs/Optimize/OptimizeParameters.cpp +++ b/Libs/Optimize/OptimizeParameters.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -13,6 +14,7 @@ #include "Optimize.h" using namespace shapeworks; +using namespace shapeworks::particles; namespace Keys { const std::string number_of_particles = "number_of_particles"; @@ -275,6 +277,73 @@ std::string OptimizeParameters::get_output_prefix() { return output; } +//--------------------------------------------------------------------------- +std::vector>> OptimizeParameters::get_initial_points() { + int domains_per_shape = project_->get_number_of_domains_per_subject(); + + auto subjects = project_->get_subjects(); + std::vector>> domain_means; + + for (int d = 0; d < domains_per_shape; d++) { + std::vector> domain_sum; + int count = 0; + for (auto s : subjects) { + if (s->is_fixed()) { + count++; + // read the local points + auto filename = s->get_local_particle_filenames()[d]; + auto particles = read_particles_as_vector(filename); + if (domain_sum.size() == 0) { + domain_sum = particles; + } else { + for (int p = 0; p < particles.size(); p++) { + domain_sum[p] += particles[p]; + } + } + } + } + // now divide to find mean + for (int p = 0; p < domain_sum.size(); p++) { + domain_sum[p] /= count; + } + + domain_means.push_back(domain_sum); + } + + std::vector>> initial_points; + for (auto s : subjects) { + for (int d = 0; d < domains_per_shape; d++) { + if (s->is_fixed()) { + auto filename = s->get_local_particle_filenames()[d]; + auto particles = read_particles_as_vector(filename); + initial_points.push_back(particles); + } else { + // get alignment transform and invert it + auto transforms = s->get_groomed_transforms(); + + // create identify transform in case there are no groomed transforms + auto transform = vtkSmartPointer::New(); + if (d < transforms.size()) { + transform = ProjectUtils::convert_transform(transforms[d]); + transform->Inverse(); + } + + // transform each of the domain mean positions back to the local space of this new shape + std::vector> points; + for (int i = 0; i < domain_means[d].size(); i++) { + itk::Point point; + transform->TransformPoint(domain_means[d][i].GetDataPointer(), point.GetDataPointer()); + points.push_back(point); + } + + initial_points.push_back(points); + } + } + } + + return initial_points; +} + //--------------------------------------------------------------------------- int OptimizeParameters::get_geodesic_cache_multiplier() { return params_.get(Keys::geodesic_cache_multiplier, 0); } @@ -439,6 +508,24 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) { throw std::invalid_argument("No subjects to optimize"); } + if (project_->get_fixed_subjects_present()) { + int idx = 0; + std::vector fixed_domains; + + for (auto s : subjects) { + if (s->is_fixed()) { + for (int i = 0; i < domains_per_shape; i++) { + fixed_domains.push_back(idx++); + } + } else { + idx += domains_per_shape; + } + } + + optimize->SetFixedDomains(fixed_domains); + optimize->SetInitialPoints(get_initial_points()); + } + for (auto s : subjects) { if (abort_load_) { return false; @@ -466,13 +553,13 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) { int count = 0; for (const auto& subject : subjects) { for (int i = 0; i < domains_per_shape; i++) { // need one flag for each domain - if (is_subject_fixed(subject)) { + if (subject->is_fixed()) { domain_flags.push_back(count); } count++; } } - optimize->SetDomainFlags(domain_flags); + optimize->SetFixedDomains(domain_flags); } // add constraints @@ -544,8 +631,6 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) { constraint.clipMesh(mesh); } - /// HERE! - if (get_use_geodesics_to_landmarks()) { auto filenames = s->get_landmarks_filenames(); Eigen::VectorXd points; @@ -587,7 +672,7 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) { } } else { Image image(filename); - if (is_subject_fixed(s)) { + if (s->is_fixed()) { optimize->AddImage(nullptr); } else { optimize->AddImage(image); diff --git a/Libs/Optimize/OptimizeParameters.h b/Libs/Optimize/OptimizeParameters.h index ec775cb401..58efdcc158 100644 --- a/Libs/Optimize/OptimizeParameters.h +++ b/Libs/Optimize/OptimizeParameters.h @@ -2,6 +2,8 @@ #include +#include + #include namespace shapeworks { @@ -131,6 +133,8 @@ class OptimizeParameters { private: std::string get_output_prefix(); + std::vector>> get_initial_points(); + Parameters params_; ProjectHandle project_; diff --git a/Libs/Optimize/ParticleSystem.cpp b/Libs/Optimize/ParticleSystem.cpp index 694406458b..9642add3b6 100644 --- a/Libs/Optimize/ParticleSystem.cpp +++ b/Libs/Optimize/ParticleSystem.cpp @@ -41,7 +41,7 @@ void ParticleSystem::SetNumberOfDomains(unsigned int num) { m_Domains.resize(num); m_Transforms.resize(num); m_InverseTransforms.resize(num); - while (num >= m_PrefixTransforms.size()) { + while (num > m_PrefixTransforms.size()) { TransformType transform; transform.set_identity(); m_PrefixTransforms.push_back(transform); @@ -50,7 +50,7 @@ void ParticleSystem::SetNumberOfDomains(unsigned int num) { m_Positions.resize(num); m_IndexCounters.resize(num); m_Neighborhoods.resize(num); - while (num >= this->m_DomainFlags.size()) { + while (num > this->m_DomainFlags.size()) { m_DomainFlags.push_back(false); } this->Modified(); @@ -132,17 +132,13 @@ void ParticleSystem::SetNeighborhood(unsigned int i, NeighborhoodType* N) { this->InvokeEvent(e); } -const typename ParticleSystem::PointType& ParticleSystem::AddPosition(const PointType& p, unsigned int d) { +const ParticleSystem::PointType& ParticleSystem::AddPosition(const PointType& p, unsigned int d) { m_Positions[d]->operator[](m_IndexCounters[d]) = p; // Potentially modifies position! - if (m_DomainFlags[d] == false) { - // debugg - // std::cout << "d" << d << " before apply " << m_Positions[d]->operator[](m_IndexCounters[d]); + if (m_DomainFlags[d] == false) { // Not a fixed domain. Fixed domains won't load the image const auto idx = m_IndexCounters[d]; m_Domains[d]->ApplyConstraints(m_Positions[d]->operator[](idx), idx); - // debugg - // std::cout << " after apply " << m_Positions[d]->operator[](m_IndexCounters[d]) << std::endl; m_Neighborhoods[d]->AddPosition(m_Positions[d]->operator[](idx), idx); } @@ -161,19 +157,12 @@ const typename ParticleSystem::PointType& ParticleSystem::AddPosition(const Poin return m_Positions[d]->operator[](m_IndexCounters[d] - 1); } -const typename ParticleSystem::PointType& ParticleSystem::SetPosition(const PointType& p, unsigned long int k, - unsigned int d) { +const ParticleSystem::PointType& ParticleSystem::SetPosition(const PointType& p, unsigned long int k, unsigned int d) { if (m_FixedParticleFlags[d % m_DomainsPerShape][k] == false) { // Potentially modifies position! if (m_DomainFlags[d] == false) { m_Positions[d]->operator[](k) = p; - - // Debuggg - // std::cout << "SynchronizePositions Apply constraints " << m_Positions[d]->operator[](k); m_Domains[d]->ApplyConstraints(m_Positions[d]->operator[](k), k); - // Debuggg - // std::cout << " updated " << m_Positions[d]->operator[](k) << std::endl; - m_Neighborhoods[d]->SetPosition(m_Positions[d]->operator[](k), k); } } @@ -190,7 +179,7 @@ const typename ParticleSystem::PointType& ParticleSystem::SetPosition(const Poin void ParticleSystem::AddPositionList(const std::vector& p, unsigned int d) { // Traverse the list and add each point to the domain. - for (typename std::vector::const_iterator it = p.begin(); it != p.end(); it++) { + for (auto it = p.begin(); it != p.end(); it++) { this->AddPosition(*it, d); } } @@ -241,11 +230,10 @@ void ParticleSystem::AdvancedAllParticleSplitting(double epsilon, unsigned int d // Only runs if domains were successfully retrieved if (lists.size() > 0) { - // Initialize augmentend lagrangian variables (mus and possibly lambdas) for (size_t domain = dom_to_process; domain < num_doms; domain += domains_per_shape) { size_t num_particles = lists[0].size(); - std::vector zeros(num_particles*2, 0.0); + std::vector zeros(num_particles * 2, 0.0); this->GetDomain(domain)->GetConstraints()->InitializeLagrangianParameters(zeros); } diff --git a/Libs/Optimize/Sampler.cpp b/Libs/Optimize/Sampler.cpp index 1ffb3375fe..2210c10f90 100644 --- a/Libs/Optimize/Sampler.cpp +++ b/Libs/Optimize/Sampler.cpp @@ -7,7 +7,6 @@ #include "Libs/Optimize/Domain/ContourDomain.h" #include "Libs/Optimize/Utils/ObjectReader.h" - namespace shapeworks { Sampler::Sampler() { @@ -19,7 +18,6 @@ Sampler::Sampler() { m_Optimizer = OptimizerType::New(); m_PointsFiles.push_back(""); - m_MeshFiles.push_back(""); m_LinkingFunction = DualVectorFunction::New(); m_EnsembleEntropyFunction = LegacyCorrespondenceFunction::New(); @@ -51,6 +49,7 @@ Sampler::Sampler() { m_CorrespondenceMode = shapeworks::CorrespondenceMode::EnsembleEntropy; } +//--------------------------------------------------------------------------- void Sampler::AllocateDataCaches() { // Set up the various data caches that the optimization functions will use. m_Sigma1Cache = GenericContainerArray::New(); @@ -67,6 +66,7 @@ void Sampler::AllocateDataCaches() { m_ParticleSystem->RegisterObserver(m_MeanCurvatureCache); } +//--------------------------------------------------------------------------- void Sampler::AllocateDomainsAndNeighborhoods() { // Allocate all the necessary domains and neighborhoods. This must be done // *after* registering the attributes to the particle system since some of @@ -125,29 +125,38 @@ void Sampler::AllocateDomainsAndNeighborhoods() { } } +//--------------------------------------------------------------------------- void Sampler::ReadPointsFiles() { // If points file names have been specified, then read the initial points. for (unsigned int i = 0; i < m_PointsFiles.size(); i++) { if (m_PointsFiles[i] != "") { auto points = particles::read_particles_as_vector(m_PointsFiles[i]); - this->GetParticleSystem()->AddPositionList(points, i); + m_ParticleSystem->AddPositionList(points, i); } } // Push position information out to all observers (necessary to correctly // fill out the shape matrix). - this->GetParticleSystem()->SynchronizePositions(); + m_ParticleSystem->SynchronizePositions(); } +//--------------------------------------------------------------------------- +void Sampler::initialize_initial_positions() { + for (unsigned int i = 0; i < initial_points_.size(); i++) { + m_ParticleSystem->AddPositionList(initial_points_[i], i); + } +} + +//--------------------------------------------------------------------------- void Sampler::InitializeOptimizationFunctions() { // Set the minimum neighborhood radius and maximum sigma based on the // domain of the 1st input image. double maxradius = -1.0; double minimumNeighborhoodRadius = this->m_Spacing; - for (unsigned int d = 0; d < this->GetParticleSystem()->GetNumberOfDomains(); d++) { - if (!GetParticleSystem()->GetDomain(d)->IsDomainFixed()) { - double radius = GetParticleSystem()->GetDomain(d)->GetMaxDiameter(); + for (unsigned int d = 0; d < m_ParticleSystem->GetNumberOfDomains(); d++) { + if (!m_ParticleSystem->GetDomain(d)->IsDomainFixed()) { + double radius = m_ParticleSystem->GetDomain(d)->GetMaxDiameter(); maxradius = radius > maxradius ? radius : maxradius; } } @@ -157,7 +166,7 @@ void Sampler::InitializeOptimizationFunctions() { m_CurvatureGradientFunction->SetMinimumNeighborhoodRadius(minimumNeighborhoodRadius); m_CurvatureGradientFunction->SetMaximumNeighborhoodRadius(maxradius); - m_CurvatureGradientFunction->SetParticleSystem(this->GetParticleSystem()); + m_CurvatureGradientFunction->SetParticleSystem(m_ParticleSystem); m_CurvatureGradientFunction->SetDomainNumber(0); if (m_IsSharedBoundaryEnabled) { m_CurvatureGradientFunction->SetSharedBoundaryEnabled(true); @@ -172,8 +181,7 @@ void Sampler::InitializeOptimizationFunctions() { m_GeneralShapeGradMatrix->Initialize(); } -void Sampler::GenerateData() {} - +//--------------------------------------------------------------------------- void Sampler::Execute() { if (this->GetInitialized() == false) { this->AllocateDataCaches(); @@ -186,9 +194,10 @@ void Sampler::Execute() { this->AllocateDomainsAndNeighborhoods(); // Point the optimizer to the particle system. - this->GetOptimizer()->SetParticleSystem(this->GetParticleSystem()); + this->GetOptimizer()->SetParticleSystem(m_ParticleSystem); this->ReadTransforms(); this->ReadPointsFiles(); + initialize_initial_positions(); this->InitializeOptimizationFunctions(); this->SetInitialized(true); @@ -200,14 +209,55 @@ void Sampler::Execute() { this->GetOptimizer()->StartOptimization(); } +//--------------------------------------------------------------------------- +Sampler::CuttingPlaneList Sampler::ComputeCuttingPlanes() { + CuttingPlaneList planes; + for (size_t i = 0; i < m_CuttingPlanes.size(); i++) { + std::vector> domain_i_cps; + for (size_t j = 0; j < m_CuttingPlanes[i].size(); j++) { + std::pair cut_plane; + cut_plane.first = ComputePlaneNormal(m_CuttingPlanes[i][j].a.as_ref(), m_CuttingPlanes[i][j].b.as_ref(), + m_CuttingPlanes[i][j].c.as_ref()); + cut_plane.second = + Eigen::Vector3d(m_CuttingPlanes[i][j].a[0], m_CuttingPlanes[i][j].a[1], m_CuttingPlanes[i][j].a[2]); + domain_i_cps.push_back(cut_plane); + } + planes.push_back(domain_i_cps); + } + return planes; +} + +//--------------------------------------------------------------------------- +Eigen::Vector3d Sampler::ComputePlaneNormal(const vnl_vector& a, const vnl_vector& b, + const vnl_vector& c) { + // See http://mathworld.wolfram.com/Plane.html, for example + vnl_vector q; + q = vnl_cross_3d((b - a), (c - a)); + + if (q.magnitude() > 0.0) { + Eigen::Vector3d qp; + q = q / q.magnitude(); + qp(0) = q[0]; + qp(1) = q[1]; + qp(2) = q[2]; + return qp; + } else { + std::cerr << "Error in Sampler::ComputePlaneNormal" << std::endl; + std::cerr << "There was an issue with a cutting plane that was defined. It has yielded a 0,0,0 vector. Please " + "check the inputs." + << std::endl; + throw std::runtime_error("Error computing plane normal"); + } +} + void Sampler::ReadTransforms() { if (m_TransformFile != "") { ObjectReader reader; reader.SetFileName(m_TransformFile.c_str()); reader.Update(); - for (unsigned int i = 0; i < this->GetParticleSystem()->GetNumberOfDomains(); i++) - this->GetParticleSystem()->SetTransform(i, reader.GetOutput()[i]); + for (unsigned int i = 0; i < m_ParticleSystem->GetNumberOfDomains(); i++) + m_ParticleSystem->SetTransform(i, reader.GetOutput()[i]); } if (m_PrefixTransformFile != "") { @@ -215,8 +265,8 @@ void Sampler::ReadTransforms() { reader.SetFileName(m_PrefixTransformFile.c_str()); reader.Update(); - for (unsigned int i = 0; i < this->GetParticleSystem()->GetNumberOfDomains(); i++) - this->GetParticleSystem()->SetPrefixTransform(i, reader.GetOutput()[i]); + for (unsigned int i = 0; i < m_ParticleSystem->GetNumberOfDomains(); i++) + m_ParticleSystem->SetPrefixTransform(i, reader.GetOutput()[i]); } } @@ -264,17 +314,60 @@ void Sampler::SetFieldAttributes(const std::vector& s) { void Sampler::TransformCuttingPlanes(unsigned int i) { if (m_Initialized == true) { - TransformType T1 = this->GetParticleSystem()->GetTransform(i) * this->GetParticleSystem()->GetPrefixTransform(i); - for (unsigned int d = 0; d < this->GetParticleSystem()->GetNumberOfDomains(); d++) { - if (this->GetParticleSystem()->GetDomainFlag(d) == false) { - TransformType T2 = this->GetParticleSystem()->InvertTransform(this->GetParticleSystem()->GetTransform(d) * - this->GetParticleSystem()->GetPrefixTransform(d)); + TransformType T1 = m_ParticleSystem->GetTransform(i) * m_ParticleSystem->GetPrefixTransform(i); + for (unsigned int d = 0; d < m_ParticleSystem->GetNumberOfDomains(); d++) { + if (m_ParticleSystem->GetDomainFlag(d) == false) { + TransformType T2 = m_ParticleSystem->InvertTransform(m_ParticleSystem->GetTransform(d) * + m_ParticleSystem->GetPrefixTransform(d)); m_ParticleSystem->GetDomain(d)->GetConstraints()->transformPlanes(T2 * T1); } } } } +void Sampler::SetCorrespondenceMode(CorrespondenceMode mode) { + if (mode == shapeworks::CorrespondenceMode::MeanEnergy) { + m_LinkingFunction->SetFunctionB(m_EnsembleEntropyFunction); + m_EnsembleEntropyFunction->UseMeanEnergy(); + } else if (mode == shapeworks::CorrespondenceMode::EnsembleEntropy) { + m_LinkingFunction->SetFunctionB(m_EnsembleEntropyFunction); + m_EnsembleEntropyFunction->UseEntropy(); + } else if (mode == shapeworks::CorrespondenceMode::EnsembleRegressionEntropy) { + m_LinkingFunction->SetFunctionB(m_EnsembleRegressionEntropyFunction); + } else if (mode == shapeworks::CorrespondenceMode::EnsembleMixedEffectsEntropy) { + m_LinkingFunction->SetFunctionB(m_EnsembleMixedEffectsEntropyFunction); + } else if (mode == shapeworks::CorrespondenceMode::MeshBasedGeneralEntropy) { + m_LinkingFunction->SetFunctionB(m_CorrespondenceFunction); + m_CorrespondenceFunction->UseEntropy(); + } else if (mode == shapeworks::CorrespondenceMode::MeshBasedGeneralMeanEnergy) { + m_LinkingFunction->SetFunctionB(m_CorrespondenceFunction); + m_CorrespondenceFunction->UseMeanEnergy(); + } else if (mode == shapeworks::CorrespondenceMode::DisentagledEnsembleEntropy) { + m_LinkingFunction->SetFunctionB(m_DisentangledEnsembleEntropyFunction); + m_DisentangledEnsembleEntropyFunction->UseEntropy(); + } else if (mode == shapeworks::CorrespondenceMode::DisentangledEnsembleMeanEnergy) { + m_LinkingFunction->SetFunctionB(m_DisentangledEnsembleEntropyFunction); + m_DisentangledEnsembleEntropyFunction->UseMeanEnergy(); + } + + m_CorrespondenceMode = mode; +} + +void Sampler::SetAttributesPerDomain(const std::vector s) { + std::vector s1; + if (s.size() == 0) { + s1.resize(m_CorrespondenceFunction->GetDomainsPerShape()); + for (int i = 0; i < m_CorrespondenceFunction->GetDomainsPerShape(); i++) s1[i] = 0; + } else { + s1 = s; + } + + m_AttributesPerDomain = s1; + m_CorrespondenceFunction->SetAttributesPerDomain(s1); + m_GeneralShapeMatrix->SetAttributesPerDomain(s1); + m_GeneralShapeGradMatrix->SetAttributesPerDomain(s1); +} + void Sampler::SetCuttingPlane(unsigned int i, const vnl_vector_fixed& va, const vnl_vector_fixed& vb, const vnl_vector_fixed& vc) { @@ -330,7 +423,7 @@ bool Sampler::initialize_ffcs(size_t dom) { std::cout << "dom " << dom << " point count " << mesh->numPoints() << " faces " << mesh->numFaces() << std::endl; if (m_FFCs[dom].isSet()) { - this->m_DomainList[dom]->GetConstraints()->addFreeFormConstraint(mesh); + m_DomainList[dom]->GetConstraints()->addFreeFormConstraint(mesh); m_FFCs[dom].computeGradientFields(mesh); } diff --git a/Libs/Optimize/Sampler.h b/Libs/Optimize/Sampler.h index c5d759faca..63ab15ee86 100644 --- a/Libs/Optimize/Sampler.h +++ b/Libs/Optimize/Sampler.h @@ -7,13 +7,10 @@ #include "GradientDescentOptimizer.h" #include "Libs/Optimize/Container/GenericContainerArray.h" #include "Libs/Optimize/Container/MeanCurvatureContainer.h" -#include "Libs/Optimize/Domain/DomainType.h" -#include "Libs/Optimize/Domain/ImplicitSurfaceDomain.h" -#include "Libs/Optimize/Domain/MeshDomain.h" #include "Libs/Optimize/Domain/MeshWrapper.h" #include "Libs/Optimize/Function/CorrespondenceFunction.h" -#include "Libs/Optimize/Function/DisentangledCorrespondenceFunction.h" #include "Libs/Optimize/Function/CurvatureSamplingFunction.h" +#include "Libs/Optimize/Function/DisentangledCorrespondenceFunction.h" #include "Libs/Optimize/Function/DualVectorFunction.h" #include "Libs/Optimize/Function/LegacyCorrespondenceFunction.h" #include "Libs/Optimize/Function/SamplingFunction.h" @@ -21,7 +18,6 @@ #include "Libs/Optimize/Matrix/MixedEffectsShapeMatrix.h" #include "Libs/Optimize/Neighborhood/ParticleSurfaceNeighborhood.h" #include "ParticleSystem.h" -#include "TriMesh.h" #include "vnl/vnl_matrix_fixed.h" // Uncomment to visualize FFCs with scalar and vector fields @@ -63,30 +59,24 @@ class Sampler { double radius; }; - /** Returns the particle system used in the surface sampling. */ - itkGetObjectMacro(ParticleSystem, ParticleSystem); - - itkGetConstObjectMacro(ParticleSystem, ParticleSystem); - //! Constructor Sampler(); //! Destructor virtual ~Sampler(){}; - /** Returns a pointer to the gradient function used. */ - SamplingFunction* GetGradientFunction() { - return m_GradientFunction; - } + //! Returns the particle system + ParticleSystem* GetParticleSystem() { return m_ParticleSystem; } + const ParticleSystem* GetParticleSystem() const { return m_ParticleSystem.GetPointer(); } - CurvatureSamplingFunction* GetCurvatureGradientFunction() { - return m_CurvatureGradientFunction; - } + /** Returns a pointer to the gradient function used. */ + SamplingFunction* GetGradientFunction() { return m_GradientFunction; } - /** Return a pointer to the optimizer object. */ - itkGetObjectMacro(Optimizer, OptimizerType); + CurvatureSamplingFunction* GetCurvatureGradientFunction() { return m_CurvatureGradientFunction; } - itkGetConstObjectMacro(Optimizer, OptimizerType); + //! Return a pointer to the optimizer object + OptimizerType* GetOptimizer() { return m_Optimizer; } + const OptimizerType* GetOptimizer() const { return m_Optimizer.GetPointer(); } /**Optionally provide a filename for an initial point set.*/ void SetPointsFile(unsigned int i, const std::string& s) { @@ -98,18 +88,11 @@ class Sampler { void SetPointsFile(const std::string& s) { this->SetPointsFile(0, s); } - /**Optionally provide a filename for a mesh with geodesic distances.*/ - void SetMeshFile(unsigned int i, const std::string& s) { - if (m_MeshFiles.size() < i + 1) { - m_MeshFiles.resize(i + 1); - } - m_MeshFiles[i] = s; + //! Set initial particle positions (e.g. for fixed subjects) + void SetInitialPoints(std::vector>> initial_points) { + initial_points_ = initial_points; } - void SetMeshFile(const std::string& s) { this->SetMeshFile(0, s); } - - void SetMeshFiles(const std::vector& s) { m_MeshFiles = s; } - void AddImage(ImageType::Pointer image, double narrow_band, std::string name = ""); void ApplyConstraintsToZeroCrossing() { @@ -150,8 +133,8 @@ class Sampler { mode 0 = isotropic adaptivity mode 1 = no adaptivity */ - virtual void SetAdaptivityMode(int mode) { - //SW_LOG("SetAdaptivityMode: {}, pairwise_potential_type: {}", mode, m_pairwise_potential_type); + void SetAdaptivityMode(int mode) { + // SW_LOG("SetAdaptivityMode: {}, pairwise_potential_type: {}", mode, m_pairwise_potential_type); if (mode == 0) { m_LinkingFunction->SetFunctionA(this->GetCurvatureGradientFunction()); } else if (mode == 1) { @@ -176,33 +159,7 @@ class Sampler { bool GetSamplingOn() const { return m_LinkingFunction->GetAOn(); } /** This method sets the optimization function for correspondences between surfaces (domains). */ - virtual void SetCorrespondenceMode(shapeworks::CorrespondenceMode mode) { - if (mode == shapeworks::CorrespondenceMode::MeanEnergy) { - m_LinkingFunction->SetFunctionB(m_EnsembleEntropyFunction); - m_EnsembleEntropyFunction->UseMeanEnergy(); - } else if (mode == shapeworks::CorrespondenceMode::EnsembleEntropy) { - m_LinkingFunction->SetFunctionB(m_EnsembleEntropyFunction); - m_EnsembleEntropyFunction->UseEntropy(); - } else if (mode == shapeworks::CorrespondenceMode::EnsembleRegressionEntropy) { - m_LinkingFunction->SetFunctionB(m_EnsembleRegressionEntropyFunction); - } else if (mode == shapeworks::CorrespondenceMode::EnsembleMixedEffectsEntropy) { - m_LinkingFunction->SetFunctionB(m_EnsembleMixedEffectsEntropyFunction); - } else if (mode == shapeworks::CorrespondenceMode::MeshBasedGeneralEntropy) { - m_LinkingFunction->SetFunctionB(m_CorrespondenceFunction); - m_CorrespondenceFunction->UseEntropy(); - } else if (mode == shapeworks::CorrespondenceMode::MeshBasedGeneralMeanEnergy) { - m_LinkingFunction->SetFunctionB(m_CorrespondenceFunction); - m_CorrespondenceFunction->UseMeanEnergy(); - } else if (mode == shapeworks::CorrespondenceMode::DisentagledEnsembleEntropy) { - m_LinkingFunction->SetFunctionB(m_DisentangledEnsembleEntropyFunction); - m_DisentangledEnsembleEntropyFunction->UseEntropy(); - } else if (mode == shapeworks::CorrespondenceMode::DisentangledEnsembleMeanEnergy) { - m_LinkingFunction->SetFunctionB(m_DisentangledEnsembleEntropyFunction); - m_DisentangledEnsembleEntropyFunction->UseMeanEnergy(); - } - - m_CorrespondenceMode = mode; - } + void SetCorrespondenceMode(shapeworks::CorrespondenceMode mode); void RegisterGeneralShapeMatrices() { this->m_ParticleSystem->RegisterObserver(m_GeneralShapeMatrix); @@ -227,19 +184,7 @@ class Sampler { m_GeneralShapeGradMatrix->SetNormals(i, flag); } - void SetAttributesPerDomain(const std::vector s) { - std::vector s1; - if (s.size() == 0) { - s1.resize(m_CorrespondenceFunction->GetDomainsPerShape()); - for (int i = 0; i < m_CorrespondenceFunction->GetDomainsPerShape(); i++) s1[i] = 0; - } else - s1 = s; - - m_AttributesPerDomain = s1; - m_CorrespondenceFunction->SetAttributesPerDomain(s1); - m_GeneralShapeMatrix->SetAttributesPerDomain(s1); - m_GeneralShapeGradMatrix->SetAttributesPerDomain(s1); - } + void SetAttributesPerDomain(const std::vector s); LegacyShapeMatrix* GetShapeMatrix() { return m_LegacyShapeMatrix.GetPointer(); } @@ -250,7 +195,9 @@ class Sampler { LegacyCorrespondenceFunction* GetEnsembleEntropyFunction() { return m_EnsembleEntropyFunction.GetPointer(); } - DisentangledCorrespondenceFunction* GetDisentangledEnsembleEntropyFunction() { return m_DisentangledEnsembleEntropyFunction.GetPointer(); } + DisentangledCorrespondenceFunction* GetDisentangledEnsembleEntropyFunction() { + return m_DisentangledEnsembleEntropyFunction.GetPointer(); + } LegacyCorrespondenceFunction* GetEnsembleRegressionEntropyFunction() { return m_EnsembleRegressionEntropyFunction.GetPointer(); @@ -260,9 +207,7 @@ class Sampler { return m_EnsembleMixedEffectsEntropyFunction.GetPointer(); } - CorrespondenceFunction* GetMeshBasedGeneralEntropyGradientFunction() { - return m_CorrespondenceFunction.GetPointer(); - } + CorrespondenceFunction* GetMeshBasedGeneralEntropyGradientFunction() { return m_CorrespondenceFunction.GetPointer(); } const DualVectorFunction* GetLinkingFunction() const { return m_LinkingFunction.GetPointer(); } @@ -316,66 +261,35 @@ class Sampler { void ReadTransforms(); void ReadPointsFiles(); - virtual void AllocateDataCaches(); - virtual void AllocateDomainsAndNeighborhoods(); - virtual void InitializeOptimizationFunctions(); + void AllocateDataCaches(); + void AllocateDomainsAndNeighborhoods(); + void InitializeOptimizationFunctions(); + + void initialize_initial_positions(); /** */ - virtual void Initialize() { + void Initialize() { this->m_Initializing = true; this->Execute(); this->m_Initializing = false; } - virtual void ReInitialize(); - - virtual void Execute(); - - std::vector>> ComputeCuttingPlanes() { - std::vector>> planes; - for (size_t i = 0; i < m_CuttingPlanes.size(); i++) { - std::vector> domain_i_cps; - for (size_t j = 0; j < m_CuttingPlanes[i].size(); j++) { - std::pair cut_plane; - cut_plane.first = ComputePlaneNormal(m_CuttingPlanes[i][j].a.as_ref(), m_CuttingPlanes[i][j].b.as_ref(), - m_CuttingPlanes[i][j].c.as_ref()); - cut_plane.second = - Eigen::Vector3d(m_CuttingPlanes[i][j].a[0], m_CuttingPlanes[i][j].a[1], m_CuttingPlanes[i][j].a[2]); - domain_i_cps.push_back(cut_plane); - } - planes.push_back(domain_i_cps); - } - return planes; - } + void ReInitialize(); + + void Execute(); + + using CuttingPlaneList = std::vector>>; + + CuttingPlaneList ComputeCuttingPlanes(); Eigen::Vector3d ComputePlaneNormal(const vnl_vector& a, const vnl_vector& b, - const vnl_vector& c) { - // See http://mathworld.wolfram.com/Plane.html, for example - vnl_vector q; - q = vnl_cross_3d((b - a), (c - a)); - - if (q.magnitude() > 0.0) { - Eigen::Vector3d qp; - q = q / q.magnitude(); - qp(0) = q[0]; - qp(1) = q[1]; - qp(2) = q[2]; - return qp; - } else { - std::cerr << "Error in Libs/Optimize/ParticleSystem/Sampler.h::ComputePlaneNormal" << std::endl; - std::cerr << "There was an issue with a cutting plane that was defined. It has yielded a 0,0,0 vector. Please " - "check the inputs." - << std::endl; - throw std::runtime_error("Error computing plane normal"); - } - } + const vnl_vector& c); std::vector GetFFCs() { return m_FFCs; } void SetMeshFFCMode(bool mesh_ffc_mode) { m_meshFFCMode = mesh_ffc_mode; } - protected: - void GenerateData(); + private: bool GetInitialized() { return this->m_Initialized; } @@ -432,10 +346,6 @@ class Sampler { void operator=(const Sampler&); // purposely not implemented std::vector m_PointsFiles; - std::vector m_MeshFiles; - std::vector m_FeaMeshFiles; - std::vector m_FeaGradFiles; - std::vector m_FidsFiles; std::vector m_AttributesPerDomain; int m_DomainsPerShape; double m_Spacing{0}; @@ -452,6 +362,8 @@ class Sampler { std::vector fieldAttributes_; + std::vector>> initial_points_; + unsigned int m_verbosity; }; diff --git a/Libs/Particles/ParticleFile.cpp b/Libs/Particles/ParticleFile.cpp index db185c6a24..b4dc573565 100644 --- a/Libs/Particles/ParticleFile.cpp +++ b/Libs/Particles/ParticleFile.cpp @@ -6,6 +6,8 @@ #include #include +#include + #include namespace shapeworks::particles { @@ -54,7 +56,7 @@ static Eigen::VectorXd read_vtk_particles(std::string filename) { //--------------------------------------------------------------------------- Eigen::VectorXd read_particles(std::string filename) { - if (filename.substr(filename.size() - 4) == ".vtk") { + if (StringUtils::hasSuffix(filename, ".vtk")) { return read_vtk_particles(filename); } @@ -92,7 +94,7 @@ Eigen::VectorXd read_particles(std::string filename) { //--------------------------------------------------------------------------- void write_particles(std::string filename, const Eigen::VectorXd& points) { - if (filename.substr(filename.size() - 4) == ".vtk") { + if (StringUtils::hasSuffix(filename, ".vtk")) { write_vtk_particles(filename, points); return; } diff --git a/Libs/Particles/ParticleShapeStatistics.cpp b/Libs/Particles/ParticleShapeStatistics.cpp index b1fe17515e..c68601591b 100644 --- a/Libs/Particles/ParticleShapeStatistics.cpp +++ b/Libs/Particles/ParticleShapeStatistics.cpp @@ -391,9 +391,6 @@ int ParticleShapeStatistics::ReadPointFiles(const std::string& s) { } } - std::cerr << "group id size = " << m_groupIDs.size() << "\n"; - std::cerr << "numSamples = " << m_numSamples << "\n"; - // If there are no group IDs, make up some bogus ones if (m_groupIDs.size() != m_numSamples) { if (m_groupIDs.size() > 0) { diff --git a/Libs/Project/Project.cpp b/Libs/Project/Project.cpp index a84db08b63..f55c0b66e7 100644 --- a/Libs/Project/Project.cpp +++ b/Libs/Project/Project.cpp @@ -314,6 +314,17 @@ bool Project::get_particles_present() const { return particles_present_; } //--------------------------------------------------------------------------- bool Project::get_images_present() { return images_present_; } +//--------------------------------------------------------------------------- +bool Project::get_fixed_subjects_present() { + // return if any subjects are fixed + for (auto& subject : subjects_) { + if (subject->is_fixed()) { + return true; + } + } + return false; +} + //--------------------------------------------------------------------------- Parameters Project::get_parameters(const std::string& name, std::string domain_name) { Parameters params; diff --git a/Libs/Project/Project.h b/Libs/Project/Project.h index 36a1ac44eb..3a3f776ea7 100644 --- a/Libs/Project/Project.h +++ b/Libs/Project/Project.h @@ -91,6 +91,9 @@ class Project { //! Return if images are present (e.g. CT/MRI) bool get_images_present(); + //! Return if there are fixed subjects present + bool get_fixed_subjects_present(); + //! Get feature names std::vector get_feature_names(); diff --git a/Libs/Project/ProjectReader.cpp b/Libs/Project/ProjectReader.cpp index 68ee31c4aa..3f135f8100 100644 --- a/Libs/Project/ProjectReader.cpp +++ b/Libs/Project/ProjectReader.cpp @@ -62,6 +62,9 @@ void ProjectReader::load_subjects(StringMapList list) { if (contains(item, "name")) { name = item["name"]; } + if (contains(item, "fixed")) { + subject->set_fixed(Variant(item["fixed"])); + } if (name == "") { if (subject->get_original_filenames().size() != 0) { name = StringUtils::getBaseFilenameWithoutExtension(subject->get_original_filenames()[0]); diff --git a/Libs/Project/ProjectUtils.cpp b/Libs/Project/ProjectUtils.cpp index a6beef55b9..daee8aa890 100644 --- a/Libs/Project/ProjectUtils.cpp +++ b/Libs/Project/ProjectUtils.cpp @@ -183,6 +183,7 @@ StringMap ProjectUtils::get_value_map(std::vector prefixes, StringM //--------------------------------------------------------------------------- StringMap ProjectUtils::get_extra_columns(StringMap key_map) { StringList prefixes = {"name", + "fixed", SEGMENTATION_PREFIX, SHAPE_PREFIX, MESH_PREFIX, @@ -280,7 +281,7 @@ static void assign_keys(StringMap& j, std::vector prefixes, std::ve auto prefix = prefixes[0]; if (filenames.size() != domains.size()) { throw std::invalid_argument(prefix + " filenames and number of domains mismatch (" + - std::to_string(filenames.size()) + " vs " + std::to_string(domains.size()) + ")"); + std::to_string(filenames.size()) + " vs " + std::to_string(domains.size()) + ")"); } for (int i = 0; i < domains.size(); i++) { if (prefixes.size() == domains.size()) { @@ -300,7 +301,7 @@ static void assign_transforms(StringMap& j, std::string prefix, std::vectorget_display_name(); + if (project->get_fixed_subjects_present()) { + j["fixed"] = subject->is_fixed() ? "true" : "false"; + } auto original_prefixes = ProjectUtils::convert_domain_types(project->get_original_domain_types()); auto groomed_prefixes = ProjectUtils::convert_groomed_domain_types(project->get_groomed_domain_types()); diff --git a/Libs/Project/Subject.cpp b/Libs/Project/Subject.cpp index 186576cd3d..9742d09159 100644 --- a/Libs/Project/Subject.cpp +++ b/Libs/Project/Subject.cpp @@ -87,6 +87,12 @@ std::string Subject::get_display_name() { return display_name_; } //--------------------------------------------------------------------------- void Subject::set_display_name(std::string display_name) { display_name_ = display_name; } +//--------------------------------------------------------------------------- +bool Subject::is_fixed() { return fixed_; } + +//--------------------------------------------------------------------------- +void Subject::set_fixed(bool fixed) { fixed_ = fixed; } + //--------------------------------------------------------------------------- void Subject::set_local_particle_filenames(StringList filenames) { local_particle_filenames_ = filenames; } diff --git a/Libs/Project/Subject.h b/Libs/Project/Subject.h index dc6a25d6aa..07fe48db99 100644 --- a/Libs/Project/Subject.h +++ b/Libs/Project/Subject.h @@ -97,10 +97,16 @@ class Subject { //! Set the display name void set_display_name(std::string display_name); + //! Get if this subject is fixed or not + bool is_fixed(); + //! Set if this subject is fixed or not + void set_fixed(bool fixed); + private: int number_of_domains_ = 0; std::string display_name_; + bool fixed_ = false; StringList original_filenames_; StringList groomed_filenames_; StringList local_particle_filenames_; diff --git a/Libs/Project/Variant.cpp b/Libs/Project/Variant.cpp index f14d2b2a86..65a97fc7c4 100644 --- a/Libs/Project/Variant.cpp +++ b/Libs/Project/Variant.cpp @@ -11,8 +11,13 @@ Variant::operator std::string() const { return str_; } //--------------------------------------------------------------------------- Variant::operator bool() const { - // first try to read as the word 'true' or 'false', otherwise 0 or 1 + // first try to read as the word 'yes', 'no', 'true' or 'false', otherwise 0 or 1 std::string str_lower(str_); + if (str_ == "yes") { + return true; + } else if (str_ == "no") { + return false; + } std::transform(str_.begin(), str_.end(), str_lower.begin(), [](unsigned char c) { return std::tolower(c); }); bool t = (valid_ && (std::istringstream(str_lower) >> std::boolalpha >> t)) ? t : false; if (!t) { diff --git a/Studio/Analysis/AnalysisTool.cpp b/Studio/Analysis/AnalysisTool.cpp index 5c5df6de1c..38e4b6129f 100644 --- a/Studio/Analysis/AnalysisTool.cpp +++ b/Studio/Analysis/AnalysisTool.cpp @@ -444,6 +444,9 @@ bool AnalysisTool::compute_stats() { number_of_particles_ar.resize(dps); bool flag_get_num_part = false; for (auto& shape : session_->get_shapes()) { + if (shape->get_global_correspondence_points().size() == 0) { + continue; // skip any that don't have particles + } if (groups_enabled) { auto value = shape->get_subject()->get_group_value(group_set); if (value == left_group) { diff --git a/Studio/Data/DataTool.cpp b/Studio/Data/DataTool.cpp index 426c2d7887..bf09c85acd 100644 --- a/Studio/Data/DataTool.cpp +++ b/Studio/Data/DataTool.cpp @@ -13,9 +13,9 @@ #include #include #include +#include #include -#include "qt/QtWidgets/qmenu.h" #ifdef __APPLE__ static QString click_message = "⌘+click"; diff --git a/Studio/Data/Session.cpp b/Studio/Data/Session.cpp index 41482416ab..df433ad360 100644 --- a/Studio/Data/Session.cpp +++ b/Studio/Data/Session.cpp @@ -27,15 +27,15 @@ #include #include #include -#include #include #include +#include #include #include #include #include -#include "ExternalLibs/tinyxml/tinyxml.h" +#include "ExternalLibs/tinyxml/tinyxml.h" namespace shapeworks { @@ -437,7 +437,6 @@ bool Session::load_xl_project(QString filename) { break; } progress.setValue(progress.value() + 1); - } groups_available_ = project_->get_group_names().size() > 0; @@ -446,9 +445,7 @@ bool Session::load_xl_project(QString filename) { } //--------------------------------------------------------------------------- -void Session::set_project_path(QString relative_path) { - project_->set_project_path(relative_path.toStdString()); -} +void Session::set_project_path(QString relative_path) { project_->set_project_path(relative_path.toStdString()); } //--------------------------------------------------------------------------- std::shared_ptr Session::get_project() { return project_; } @@ -690,10 +687,10 @@ void Session::remove_shapes(QList list) { shapes_.erase(shapes_.begin() + i); } - project_->get_subjects(); renumber_shapes(); project_->update_subjects(); Q_EMIT data_changed(); + Q_EMIT update_display(); } //--------------------------------------------------------------------------- @@ -1164,7 +1161,6 @@ void Session::trigger_repaint() { Q_EMIT repaint(); } //--------------------------------------------------------------------------- void Session::trigger_reinsert_shapes() { Q_EMIT reinsert_shapes(); } - //--------------------------------------------------------------------------- void Session::set_display_mode(DisplayMode mode) { if (!is_loading()) { diff --git a/Studio/Interface/ShapeWorksStudioApp.cpp b/Studio/Interface/ShapeWorksStudioApp.cpp index 220609928b..0764dda38d 100644 --- a/Studio/Interface/ShapeWorksStudioApp.cpp +++ b/Studio/Interface/ShapeWorksStudioApp.cpp @@ -1216,9 +1216,11 @@ void ShapeWorksStudioApp::handle_reconstruction_complete() { //--------------------------------------------------------------------------- void ShapeWorksStudioApp::handle_groom_start() { - // clear out old points - session_->clear_particles(); - ui_->action_analysis_mode->setEnabled(false); + // clear out old points (unless fixed subjects) + if (!session_->get_project()->get_fixed_subjects_present()) { + session_->clear_particles(); + ui_->action_analysis_mode->setEnabled(false); + } } //--------------------------------------------------------------------------- diff --git a/Studio/Optimize/OptimizeTool.cpp b/Studio/Optimize/OptimizeTool.cpp index e9d0327209..0adc4c47b0 100644 --- a/Studio/Optimize/OptimizeTool.cpp +++ b/Studio/Optimize/OptimizeTool.cpp @@ -144,6 +144,8 @@ void OptimizeTool::handle_optimize_complete() { telemetry_.record_event("optimize", {{"duration_seconds", duration}, {"num_particles", QVariant::fromValue(session_->get_num_particles())}}); + session_->save_project(session_->get_filename()); + Q_EMIT optimize_complete(); update_run_button(); } diff --git a/Studio/Utils/StudioUtils.cpp b/Studio/Utils/StudioUtils.cpp index d0ddfaaf57..225902fb56 100644 --- a/Studio/Utils/StudioUtils.cpp +++ b/Studio/Utils/StudioUtils.cpp @@ -1,5 +1,11 @@ #include +#include +#include #include +#include +#include +#include +#include #include #include @@ -87,4 +93,46 @@ QString StudioUtils::get_platform_string() { #endif return platform; } + +//--------------------------------------------------------------------------- +void StudioUtils::add_viewport_border(vtkRenderer* renderer, double* color) { + // points start at upper right and proceed anti-clockwise + vtkNew points; + points->SetNumberOfPoints(4); + points->InsertPoint(0, 1, 1, 0); + points->InsertPoint(1, 0, 1, 0); + points->InsertPoint(2, 0, 0, 0); + points->InsertPoint(3, 1, 0, 0); + + vtkNew cells; + cells->Initialize(); + vtkNew lines; + + lines->GetPointIds()->SetNumberOfIds(5); + for (unsigned int i = 0; i < 4; ++i) { + lines->GetPointIds()->SetId(i, i); + } + lines->GetPointIds()->SetId(4, 0); + cells->InsertNextCell(lines); + + vtkNew poly; + poly->Initialize(); + poly->SetPoints(points); + poly->SetLines(cells); + + // use normalized viewport coordinates since they are independent of window size + vtkNew coordinate; + coordinate->SetCoordinateSystemToNormalizedViewport(); + + vtkNew mapper; + mapper->SetInputData(poly); + mapper->SetTransformCoordinate(coordinate); + + vtkNew actor; + actor->SetMapper(mapper); + actor->GetProperty()->SetColor(color); + actor->GetProperty()->SetLineWidth(6.0); + + renderer->AddViewProp(actor); +} } // namespace shapeworks diff --git a/Studio/Utils/StudioUtils.h b/Studio/Utils/StudioUtils.h index dea1db18d3..b72498ad1f 100644 --- a/Studio/Utils/StudioUtils.h +++ b/Studio/Utils/StudioUtils.h @@ -10,6 +10,7 @@ class QWidget; #include class vtkImageData; +class vtkRenderer; namespace shapeworks { @@ -27,8 +28,12 @@ class StudioUtils { //! reverse a poly data static vtkSmartPointer reverse_poly_data(vtkSmartPointer poly_data); + //! return platform string static QString get_platform_string(); + //! add a color border to a viewport + static void add_viewport_border(vtkRenderer* renderer, double* color); + }; } // namespace shapeworks diff --git a/Studio/Visualization/Viewer.cpp b/Studio/Visualization/Viewer.cpp index f737b81328..cb267eb113 100644 --- a/Studio/Visualization/Viewer.cpp +++ b/Studio/Visualization/Viewer.cpp @@ -638,8 +638,15 @@ void Viewer::display_shape(std::shared_ptr shape) { corner_annotation_->SetText(3, (annotations[3]).c_str()); corner_annotation_->GetTextProperty()->SetColor(0.50, 0.5, 0.5); + renderer_->RemoveAllViewProps(); + auto subject = shape->get_subject(); + if (subject && subject->is_fixed()) { + double color[4] = {0.0, 0.0, 1.0, 1.0}; + StudioUtils::add_viewport_border(renderer_, color); + } + number_of_domains_ = session_->get_domains_per_shape(); if (meshes_.valid()) { number_of_domains_ = std::max(number_of_domains_, meshes_.meshes().size()); @@ -704,8 +711,8 @@ void Viewer::display_shape(std::shared_ptr shape) { auto compare_poly_data = compare_meshes_.meshes()[i]->get_poly_data(); if (compare_settings.get_mean_shape_checked()) { - auto transform = - visualizer_->get_transform(shape_, compare_settings.get_display_mode(), visualizer_->get_alignment_domain(), i); + auto transform = visualizer_->get_transform(shape_, compare_settings.get_display_mode(), + visualizer_->get_alignment_domain(), i); transform->Inverse(); auto transform_filter = vtkSmartPointer::New(); @@ -1081,7 +1088,6 @@ void Viewer::insert_compare_meshes() { actor->SetUserTransform(identity); } else { actor->SetUserTransform(transform); - } mapper->SetInputData(poly_data);