Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement loading snapshots from stream #1462

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/neural-graphics-primitives/testbed.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ class Testbed {
);
void visualize_nerf_cameras(ImDrawList* list, const mat4& world2proj);
fs::path find_network_config(const fs::path& network_config_path);
nlohmann::json load_network_config(std::istream& stream, bool is_compressed);
nlohmann::json load_network_config(const fs::path& network_config_path);
void reload_network_from_file(const fs::path& path = "");
void reload_network_from_json(const nlohmann::json& json, const std::string& config_base_path=""); // config_base_path is needed so that if the passed in json uses the 'parent' feature, we know where to look... be sure to use a filename, or if a directory, end with a trailing slash
Expand Down Expand Up @@ -484,7 +485,9 @@ class Testbed {
vec2 fov_xy() const ;
void set_fov_xy(const vec2& val);
void save_snapshot(const fs::path& path, bool include_optimizer_state, bool compress);
void load_snapshot(nlohmann::json config);
void load_snapshot(const fs::path& path);
void load_snapshot(std::istream& stream, bool is_compressed = true);
CameraKeyframe copy_camera_to_keyframe() const;
void set_camera_from_keyframe(const CameraKeyframe& k);
void set_camera_from_time(float t);
Expand Down
2 changes: 1 addition & 1 deletion src/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ int main_func(const std::vector<std::string>& arguments) {
}

if (snapshot_flag) {
testbed.load_snapshot(get(snapshot_flag));
testbed.load_snapshot(static_cast<fs::path>(get(snapshot_flag)));
} else if (network_config_flag) {
testbed.reload_network_from_file(get(network_config_flag));
}
Expand Down
2 changes: 1 addition & 1 deletion src/python_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ PYBIND11_MODULE(pyngp, m) {
.def("n_params", &Testbed::n_params, "Number of trainable parameters")
.def("n_encoding_params", &Testbed::n_encoding_params, "Number of trainable parameters in the encoding")
.def("save_snapshot", &Testbed::save_snapshot, py::arg("path"), py::arg("include_optimizer_state")=false, py::arg("compress")=true, "Save a snapshot of the currently trained model. Optionally compressed (only when saving '.ingp' files).")
.def("load_snapshot", &Testbed::load_snapshot, py::arg("path"), "Load a previously saved snapshot")
.def("load_snapshot", py::overload_cast<const fs::path&>(&Testbed::load_snapshot), py::arg("path"), "Load a previously saved snapshot")
.def("load_camera_path", &Testbed::load_camera_path, py::arg("path"), "Load a camera path")
.def("load_file", &Testbed::load_file, py::arg("path"), "Load a file and automatically determine how to handle it. Can be a snapshot, dataset, network config, or camera path.")
.def_property("loop_animation", &Testbed::loop_animation, &Testbed::set_loop_animation)
Expand Down
49 changes: 37 additions & 12 deletions src/testbed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,14 @@ fs::path Testbed::find_network_config(const fs::path& network_config_path) {
return network_config_path;
}

json Testbed::load_network_config(std::istream& stream, bool is_compressed) {
if (is_compressed) {
zstr::istream zstream{stream};
return json::from_msgpack(zstream);
}
return json::from_msgpack(stream);
}

json Testbed::load_network_config(const fs::path& network_config_path) {
bool is_snapshot = equals_case_insensitive(network_config_path.extension(), "msgpack") || equals_case_insensitive(network_config_path.extension(), "ingp");
if (network_config_path.empty() || !network_config_path.exists()) {
Expand Down Expand Up @@ -1543,7 +1551,7 @@ void Testbed::imgui() {
ImGui::SameLine();
if (ImGui::Button("Load")) {
try {
load_snapshot(m_imgui.snapshot_path);
load_snapshot(static_cast<fs::path>(m_imgui.snapshot_path));
} catch (const std::exception& e) {
imgui_error_string = fmt::format("Failed to load snapshot: {}", e.what());
ImGui::OpenPopup("Error");
Expand Down Expand Up @@ -2339,14 +2347,14 @@ void Testbed::SecondWindow::draw(GLuint texture) {
}

void Testbed::init_opengl_shaders() {
static const char* shader_vert = R"(#version 140
static const char* shader_vert = R"glsl(#version 140
out vec2 UVs;
void main() {
UVs = vec2((gl_VertexID << 1) & 2, gl_VertexID & 2);
gl_Position = vec4(UVs * 2.0 - 1.0, 0.0, 1.0);
})";
})glsl";

static const char* shader_frag = R"(#version 140
static const char* shader_frag = R"glsl(#version 140
in vec2 UVs;
out vec4 frag_color;
uniform sampler2D rgba_texture;
Expand Down Expand Up @@ -2386,7 +2394,7 @@ void Testbed::init_opengl_shaders() {
//Uncomment the following line of code to visualize debug the depth buffer for debugging.
// frag_color = vec4(vec3(texture(depth_texture, tex_coords.xy).r), 1.0);
gl_FragDepth = texture(depth_texture, tex_coords.xy).r;
})";
})glsl";

GLuint vert = glCreateShader(GL_VERTEX_SHADER);
glShaderSource(vert, 1, &shader_vert, NULL);
Expand Down Expand Up @@ -4746,12 +4754,7 @@ void Testbed::save_snapshot(const fs::path& path, bool include_optimizer_state,
tlog::success() << "Saved snapshot '" << path.str() << "'";
}

void Testbed::load_snapshot(const fs::path& path) {
auto config = load_network_config(path);
if (!config.contains("snapshot")) {
throw std::runtime_error{fmt::format("File '{}' does not contain a snapshot.", path.str())};
}

void Testbed::load_snapshot(nlohmann::json config) {
const auto& snapshot = config["snapshot"];
if (snapshot.value("version", 0) < SNAPSHOT_FORMAT_VERSION) {
throw std::runtime_error{"Snapshot uses an old format and can not be loaded."};
Expand Down Expand Up @@ -4841,7 +4844,6 @@ void Testbed::load_snapshot(const fs::path& path) {
m_render_aabb = snapshot.value("render_aabb", m_render_aabb);
if (snapshot.contains("up_dir")) from_json(snapshot.at("up_dir"), m_up_dir);

m_network_config_path = path;
m_network_config = std::move(config);

reset_network(false);
Expand All @@ -4868,6 +4870,29 @@ void Testbed::load_snapshot(const fs::path& path) {
set_all_devices_dirty();
}

void Testbed::load_snapshot(const fs::path& path) {
auto config = load_network_config(path);
if (!config.contains("snapshot")) {
throw std::runtime_error{fmt::format("File '{}' does not contain a snapshot.", path.str())};
}

load_snapshot(std::move(config));

m_network_config_path = path;
}

void Testbed::load_snapshot(std::istream& stream, bool is_compressed) {
auto config = load_network_config(stream, is_compressed);
if (!config.contains("snapshot")) {
throw std::runtime_error{"Given stream does not contain a snapshot."};
}

load_snapshot(std::move(config));

// Network config path is unknown.
m_network_config_path = "";
}

Testbed::CudaDevice::CudaDevice(int id, bool is_primary) : m_id{id}, m_is_primary{is_primary} {
auto guard = device_guard();
m_stream = std::make_unique<StreamAndEvent>();
Expand Down
Loading