From 75f486a6f8c34aa495a6e081ef644bdb5db25974 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 9 Dec 2024 13:55:24 +0000 Subject: [PATCH 1/3] Start of vfs --- Cargo.lock | 1 + crates/brush-dataset/src/brush_vfs.rs | 193 ++++++++++++++++++ crates/brush-dataset/src/formats/colmap.rs | 39 ++-- crates/brush-dataset/src/formats/mod.rs | 33 +-- .../brush-dataset/src/formats/nerfstudio.rs | 56 +++-- crates/brush-dataset/src/lib.rs | 2 +- crates/brush-dataset/src/zip.rs | 127 ------------ crates/brush-viewer/src/train_loop.rs | 11 +- crates/colmap-reader/Cargo.toml | 1 + crates/colmap-reader/src/lib.rs | 130 +++++++----- 10 files changed, 356 insertions(+), 237 deletions(-) create mode 100644 crates/brush-dataset/src/brush_vfs.rs delete mode 100644 crates/brush-dataset/src/zip.rs diff --git a/Cargo.lock b/Cargo.lock index 3b38afac..689d4054 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1575,6 +1575,7 @@ version = "0.1.0" dependencies = [ "byteorder", "glam", + "tokio", ] [[package]] diff --git a/crates/brush-dataset/src/brush_vfs.rs b/crates/brush-dataset/src/brush_vfs.rs new file mode 100644 index 00000000..7bc832f5 --- /dev/null +++ b/crates/brush-dataset/src/brush_vfs.rs @@ -0,0 +1,193 @@ +// +// This class helps working with an archive as a somewhat more regular filesystem. +// +// [1] really we want to just read directories. +// The reason is that picking directories isn't supported on +// rfd on wasm, nor is drag-and-dropping folders in egui. +use std::{ + collections::HashMap, + io::{Cursor, Read}, + path::{Path, PathBuf}, + sync::Arc, +}; + +use anyhow::Context; +use tokio::io::AsyncReadExt; +use tokio::{io::AsyncRead, sync::Mutex}; +use zip::{ + result::{ZipError, ZipResult}, + ZipArchive, +}; + +#[derive(Clone)] +pub struct ZipData { + data: Arc>, +} + +type ZipReader = Cursor; + +impl AsRef<[u8]> for ZipData { + fn as_ref(&self) -> &[u8] { + &self.data + } +} + +impl ZipData { + pub fn open_for_read(&self) -> ZipReader { + Cursor::new(self.clone()) + } +} + +impl From> for ZipData { + fn from(value: Vec) -> Self { + Self { + data: Arc::new(value), + } + } +} + +pub(crate) fn normalized_path(path: &Path) -> PathBuf { + Path::new(path) + .components() + .skip_while(|c| matches!(c, std::path::Component::CurDir)) + .collect::() +} + +#[derive(Clone)] +pub struct ConsumableReader { + // Option allows us to take ownership when first used + inner: Arc>>>, +} + +#[derive(Clone, Default)] +pub struct PathReader { + paths: HashMap, +} + +impl PathReader { + fn paths(&self) -> impl Iterator { + self.paths.keys() + } + + pub fn add(&mut self, path: PathBuf, reader: impl AsyncRead + Send + Unpin + 'static) { + self.paths.insert( + path, + ConsumableReader { + inner: Arc::new(Mutex::new(Some(Box::new(reader)))), + }, + ); + } + + async fn open(&mut self, path: &Path) -> anyhow::Result> { + let entry = self.paths.remove(path).context("File not found")?; + let reader = entry.inner.lock().await.take(); + reader.context("Missing reader") + } +} + +#[derive(Clone)] +pub enum BrushVfs { + Zip(ZipArchive>), + Manual(PathReader), + Directory(PathBuf, Vec), +} + +// TODO: This is all awfully ad-hoc. +impl BrushVfs { + pub async fn from_zip_reader(reader: impl AsyncRead + Unpin) -> ZipResult { + let mut bytes = vec![]; + let mut reader = reader; + reader.read_to_end(&mut bytes).await?; + + let zip_data = ZipData::from(bytes); + let archive = ZipArchive::new(zip_data.open_for_read())?; + Ok(BrushVfs::Zip(archive)) + } + + pub fn from_paths(paths: PathReader) -> Self { + BrushVfs::Manual(paths) + } + + pub async fn from_directory(dir: &Path) -> anyhow::Result { + let mut read = ::tokio::fs::read_dir(dir).await?; + let mut paths = vec![]; + while let Some(entry) = read.next_entry().await? { + paths.push(entry.path()); + } + Ok(BrushVfs::Directory(dir.to_path_buf(), paths)) + } + + pub(crate) fn file_names(&self) -> impl Iterator + '_ { + let iterator: Box> = match self { + BrushVfs::Zip(archive) => Box::new(archive.file_names()), + BrushVfs::Manual(map) => Box::new(map.paths().filter_map(|k| k.to_str())), + BrushVfs::Directory(_, paths) => Box::new(paths.iter().filter_map(|k| k.to_str())), + }; + // stupic macOS. + iterator.filter(|p| !p.contains("__MACOSX")) + } + + pub(crate) fn find_with_extension( + &self, + extension: &str, + contains: &str, + ) -> anyhow::Result { + let names: Vec<_> = self + .file_names() + .filter(|name| name.ends_with(extension)) + .collect(); + + if names.len() == 1 { + return Ok(Path::new(names[0]).to_owned()); + } + + let names: Vec<_> = names + .iter() + .filter(|name| name.contains(contains)) + .collect(); + + if names.len() == 1 { + return Ok(Path::new(names[0]).to_owned()); + } + + anyhow::bail!("Failed to find file ending in {extension} maybe containing {contains}."); + } + + pub(crate) async fn open_reader_at_path( + &mut self, + path: &Path, + ) -> anyhow::Result> { + match self { + BrushVfs::Zip(archive) => { + let name = archive + .file_names() + .find(|name| path == Path::new(name)) + .ok_or(ZipError::FileNotFound)?; + let name = name.to_owned(); + + let mut buffer = vec![]; + archive.by_name(&name)?.read_to_end(&mut buffer)?; + + Ok(Box::new(Cursor::new(buffer))) + } + BrushVfs::Manual(map) => map.open(path).await, + BrushVfs::Directory(path_buf, _) => { + let file = tokio::fs::File::open(path_buf).await?; + Ok(Box::new(file)) + } + } + } + + pub(crate) fn find_base_path(&self, search_path: &str) -> Option { + for file in self.file_names() { + let path = normalized_path(Path::new(file)); + if path.ends_with(search_path) { + return path + .ancestors() + .nth(Path::new(search_path).components().count()) + .map(|x| x.to_owned()); + } + } + None + } +} diff --git a/crates/brush-dataset/src/formats/colmap.rs b/crates/brush-dataset/src/formats/colmap.rs index 3d2c5a3b..fd15984b 100644 --- a/crates/brush-dataset/src/formats/colmap.rs +++ b/crates/brush-dataset/src/formats/colmap.rs @@ -1,7 +1,7 @@ use std::{future::Future, sync::Arc}; -use super::{DataStream, DatasetZip, LoadDatasetArgs}; -use crate::{splat_import::SplatMessage, stream_fut_parallel, Dataset}; +use super::{DataStream, LoadDatasetArgs}; +use crate::{brush_vfs::BrushVfs, splat_import::SplatMessage, stream_fut_parallel, Dataset}; use anyhow::{Context, Result}; use async_fn_stream::try_fn_stream; use brush_render::{ @@ -12,13 +12,15 @@ use brush_render::{ }; use brush_train::scene::SceneView; use glam::Vec3; +use tokio::io::AsyncReadExt; use tokio_stream::StreamExt; -fn read_views( - mut archive: DatasetZip, +async fn read_views( + archive: BrushVfs, load_args: &LoadDatasetArgs, ) -> Result>>> { log::info!("Loading colmap dataset"); + let mut archive = archive; let (is_binary, base_path) = if let Some(path) = archive.find_base_path("sparse/0/cameras.bin") { @@ -26,7 +28,7 @@ fn read_views( } else if let Some(path) = archive.find_base_path("sparse/0/cameras.txt") { (false, path) } else { - anyhow::bail!("No COLMAP data found (either text or binary.") + anyhow::bail!("No COLMAP data found (either text or binary.)") }; let (cam_path, img_path) = if is_binary { @@ -42,14 +44,14 @@ fn read_views( }; let cam_model_data = { - let mut cam_file = archive.file_at_path(&cam_path)?; - colmap_reader::read_cameras(&mut cam_file, is_binary)? + let mut cam_file = archive.open_reader_at_path(&cam_path).await?; + colmap_reader::read_cameras(&mut cam_file, is_binary).await? }; let img_infos = { - let img_file = archive.file_at_path(&img_path)?; - let mut buf_reader = std::io::BufReader::new(img_file); - colmap_reader::read_images(&mut buf_reader, is_binary)? + let img_file = archive.open_reader_at_path(&img_path).await?; + let mut buf_reader = tokio::io::BufReader::new(img_file); + colmap_reader::read_images(&mut buf_reader, is_binary).await? }; let mut img_info_list = img_infos.into_iter().collect::>(); @@ -82,7 +84,12 @@ fn read_views( let img_path = base_path.join(format!("images/{}", img_info.name)); - let img_bytes = archive.read_bytes_at_path(&img_path)?; + let mut img_bytes = vec![]; + archive + .open_reader_at_path(&img_path) + .await? + .read_to_end(&mut img_bytes) + .await?; let mut img = image::load_from_memory(&img_bytes)?; if let Some(max) = load_args.max_resolution { @@ -110,12 +117,12 @@ fn read_views( Ok(handles) } -pub(crate) fn load_dataset( - mut archive: DatasetZip, +pub(crate) async fn load_dataset( + mut archive: BrushVfs, load_args: &LoadDatasetArgs, device: &B::Device, ) -> Result<(DataStream>, DataStream)> { - let mut handles = read_views(archive.clone(), load_args)?; + let mut handles = read_views(archive.clone(), load_args).await?; if let Some(subsample) = load_args.subsample_frames { handles = handles.into_iter().step_by(subsample as usize).collect(); @@ -165,8 +172,8 @@ pub(crate) fn load_dataset( // Extract COLMAP sfm points. let points_data = { - let mut points_file = archive.file_at_path(&points_path)?; - colmap_reader::read_points3d(&mut points_file, is_binary) + let mut points_file = archive.open_reader_at_path(&points_path).await?; + colmap_reader::read_points3d(&mut points_file, is_binary).await }; // Ignore empty points data. diff --git a/crates/brush-dataset/src/formats/mod.rs b/crates/brush-dataset/src/formats/mod.rs index 7a95d7c9..d4de6ff2 100644 --- a/crates/brush-dataset/src/formats/mod.rs +++ b/crates/brush-dataset/src/formats/mod.rs @@ -1,11 +1,11 @@ use crate::{ + brush_vfs::BrushVfs, splat_import::{load_splat_from_ply, SplatMessage}, - zip::DatasetZip, Dataset, LoadDatasetArgs, }; use anyhow::Result; use brush_render::Backend; -use std::{io::Cursor, pin::Pin}; +use std::pin::Pin; use tokio_stream::Stream; pub mod colmap; @@ -14,32 +14,39 @@ pub mod nerfstudio; // A dynamic stream of datasets type DataStream = Pin> + Send + 'static>>; -pub fn load_dataset( - mut archive: DatasetZip, +pub async fn load_dataset( + mut vfs: BrushVfs, load_args: &LoadDatasetArgs, device: &B::Device, ) -> anyhow::Result<(DataStream>, DataStream)> { - let streams = nerfstudio::read_dataset(archive.clone(), load_args, device) - .or_else(|_| colmap::load_dataset::(archive.clone(), load_args, device)); + let stream = nerfstudio::read_dataset(vfs.clone(), load_args, device).await; - let Ok(streams) = streams else { - anyhow::bail!("Couldn't parse dataset as any format. Only some formats are supported.") + let stream = match stream { + Ok(s) => Ok(s), + Err(_) => colmap::load_dataset::(vfs.clone(), load_args, device).await, + }; + + let stream = match stream { + Ok(stream) => stream, + Err(e) => anyhow::bail!( + "Couldn't parse dataset as any format. Only some formats are supported. {e}" + ), }; // If there's an init.ply definitey override the init stream with that. - let init_path = archive.find_with_extension(".ply", "init"); + let init_path = vfs.find_with_extension(".ply", "init"); let init_stream = if let Ok(path) = init_path { - let ply_data = archive.read_bytes_at_path(&path)?; + let ply_data = vfs.open_reader_at_path(&path).await?; log::info!("Using {path:?} as initial point cloud."); Box::pin(load_splat_from_ply( - Cursor::new(ply_data), + ply_data, load_args.subsample_points, device.clone(), )) } else { - streams.0 + stream.0 }; - Ok((init_stream, streams.1)) + Ok((init_stream, stream.1)) } diff --git a/crates/brush-dataset/src/formats/nerfstudio.rs b/crates/brush-dataset/src/formats/nerfstudio.rs index 86a88efd..264d538d 100644 --- a/crates/brush-dataset/src/formats/nerfstudio.rs +++ b/crates/brush-dataset/src/formats/nerfstudio.rs @@ -1,5 +1,5 @@ -use super::DatasetZip; use super::LoadDatasetArgs; +use crate::brush_vfs::BrushVfs; use crate::splat_import::load_splat_from_ply; use crate::splat_import::SplatMessage; use crate::stream_fut_parallel; @@ -11,9 +11,9 @@ use brush_render::camera::{focal_to_fov, fov_to_focal, Camera}; use brush_render::Backend; use brush_train::scene::SceneView; use std::future::Future; -use std::io::Cursor; use std::path::PathBuf; use std::sync::Arc; +use tokio::io::AsyncReadExt; use tokio_stream::StreamExt; #[derive(serde::Deserialize, Clone)] @@ -95,7 +95,7 @@ struct FrameData { fn read_transforms_file( scene: JsonScene, transforms_path: PathBuf, - archive: DatasetZip, + vfs: BrushVfs, load_args: &LoadDatasetArgs, ) -> Result>>> { let iter = scene @@ -103,7 +103,7 @@ fn read_transforms_file( .into_iter() .take(load_args.max_frames.unwrap_or(usize::MAX)) .map(move |frame| { - let mut archive = archive.clone(); + let mut archive = vfs.clone(); let load_args = load_args.clone(); let transforms_path = transforms_path.clone(); @@ -126,7 +126,13 @@ fn read_transforms_file( if path.extension().is_none() { path = path.with_extension("png"); } - let img_buffer = archive.read_bytes_at_path(&path)?; + + let mut img_buffer = vec![]; + archive + .open_reader_at_path(&path) + .await? + .read_to_end(&mut img_buffer) + .await?; let comp_span = tracing::trace_span!("Decompress image").entered(); drop(comp_span); @@ -171,19 +177,26 @@ fn read_transforms_file( Ok(iter.collect()) } -pub fn read_dataset( - mut archive: DatasetZip, +pub async fn read_dataset( + mut vfs: BrushVfs, load_args: &LoadDatasetArgs, device: &B::Device, ) -> Result<(DataStream>, DataStream)> { log::info!("Loading nerfstudio dataset"); - let transforms_path = archive.find_with_extension(".json", "_train")?; - let train_scene: JsonScene = serde_json::from_reader(archive.file_at_path(&transforms_path)?)?; + let transforms_path = vfs.find_with_extension(".json", "_train")?; + + let mut buf = String::new(); + vfs.open_reader_at_path(&transforms_path) + .await? + .read_to_string(&mut buf) + .await?; + let train_scene: JsonScene = serde_json::from_str(&buf)?; + let mut train_handles = read_transforms_file( train_scene.clone(), transforms_path.clone(), - archive.clone(), + vfs.clone(), load_args, )?; @@ -195,20 +208,26 @@ pub fn read_dataset( } let load_args_clone = load_args.clone(); - let mut archive_clone = archive.clone(); let transforms_path_clone = transforms_path.clone(); + let mut data_clone = vfs.clone(); let dataset_stream = try_fn_stream(|emitter| async move { let mut train_views = vec![]; let mut eval_views = vec![]; - let eval_trans_path = archive_clone.find_with_extension(".json", "_val")?; + let eval_trans_path = data_clone.find_with_extension(".json", "_val")?; // If a seperate eval file is specified, read it. let val_stream = if eval_trans_path != transforms_path_clone { - let val_scene = serde_json::from_reader(archive_clone.file_at_path(&eval_trans_path)?)?; - read_transforms_file(val_scene, eval_trans_path, archive_clone, &load_args_clone).ok() + let mut json_str = String::new(); + data_clone + .open_reader_at_path(&eval_trans_path) + .await? + .read_to_string(&mut json_str) + .await?; + let val_scene = serde_json::from_str(&json_str)?; + read_transforms_file(val_scene, eval_trans_path, data_clone, &load_args_clone).ok() } else { None }; @@ -260,14 +279,11 @@ pub fn read_dataset( let splat_stream = try_fn_stream(|emitter| async move { if let Some(init) = train_scene.ply_file_path { let init_path = transforms_path.parent().unwrap().join(init); - let ply_data = archive.read_bytes_at_path(&init_path); + let ply_data = vfs.open_reader_at_path(&init_path).await; if let Ok(ply_data) = ply_data { - let splat_stream = load_splat_from_ply( - Cursor::new(ply_data), - load_args.subsample_points, - device.clone(), - ); + let splat_stream = + load_splat_from_ply(ply_data, load_args.subsample_points, device.clone()); let mut splat_stream = std::pin::pin!(splat_stream); diff --git a/crates/brush-dataset/src/lib.rs b/crates/brush-dataset/src/lib.rs index c0661d36..e56b1148 100644 --- a/crates/brush-dataset/src/lib.rs +++ b/crates/brush-dataset/src/lib.rs @@ -1,8 +1,8 @@ +pub mod brush_vfs; mod formats; pub mod scene_loader; pub mod splat_export; pub mod splat_import; -pub mod zip; pub use formats::load_dataset; diff --git a/crates/brush-dataset/src/zip.rs b/crates/brush-dataset/src/zip.rs deleted file mode 100644 index 22ce879e..00000000 --- a/crates/brush-dataset/src/zip.rs +++ /dev/null @@ -1,127 +0,0 @@ -// Currently, we make all datasets go through a zip file [1] -// This class helps working with an archive as a somewhat more regular filesystem. -// -// [1] really we want to just read directories. -// The reason is that picking directories isn't supported on -// rfd on wasm, nor is drag-and-dropping folders in egui. -use std::{ - io::{Cursor, Read}, - path::{Path, PathBuf}, - sync::Arc, -}; - -use zip::{ - read::ZipFile, - result::{ZipError, ZipResult}, - ZipArchive, -}; - -#[derive(Clone)] -pub struct ZipData { - data: Arc>, -} - -type ZipReader = Cursor; - -impl AsRef<[u8]> for ZipData { - fn as_ref(&self) -> &[u8] { - &self.data - } -} - -impl ZipData { - pub fn open_for_read(&self) -> ZipReader { - Cursor::new(self.clone()) - } -} - -impl From> for ZipData { - fn from(value: Vec) -> Self { - Self { - data: Arc::new(value), - } - } -} - -pub(crate) fn normalized_path(path: &Path) -> PathBuf { - Path::new(path) - .components() - .skip_while(|c| matches!(c, std::path::Component::CurDir)) - .collect::() -} - -#[derive(Clone)] -pub struct DatasetZip { - archive: ZipArchive>, -} - -// TODO: This is all awfully ad-hoc. -impl DatasetZip { - pub fn from_data(data: Vec) -> ZipResult { - let zip_data = ZipData::from(data); - let archive = ZipArchive::new(zip_data.open_for_read())?; - Ok(Self { archive }) - } - - pub(crate) fn file_names(&self) -> impl Iterator + '_ { - self.archive - .file_names() - // stupic macOS. - .filter(|p| !p.contains("__MACOSX")) - } - - pub(crate) fn find_with_extension( - &self, - extension: &str, - contains: &str, - ) -> anyhow::Result { - let names: Vec<_> = self - .file_names() - .filter(|name| name.ends_with(extension)) - .collect(); - - if names.len() == 1 { - return Ok(Path::new(names[0]).to_owned()); - } - - let names: Vec<_> = names - .iter() - .filter(|name| name.contains(contains)) - .collect(); - - if names.len() == 1 { - return Ok(Path::new(names[0]).to_owned()); - } - - anyhow::bail!("Failed to find file ending in {extension} maybe containing {contains}."); - } - - pub(crate) fn file_at_path<'a>(&'a mut self, path: &Path) -> Result, ZipError> { - let name = self - .archive - .file_names() - .find(|name| path == Path::new(name)) - .ok_or(ZipError::FileNotFound)?; - let name = name.to_owned(); - self.archive.by_name(&name) - } - - pub(crate) fn read_bytes_at_path(&mut self, path: &Path) -> anyhow::Result> { - let mut buffer = vec![]; - self.file_at_path(path)?.read_to_end(&mut buffer)?; - Ok(buffer) - } - - pub(crate) fn find_base_path(&self, search_path: &str) -> Option { - for file in self.archive.file_names() { - let path = normalized_path(Path::new(file)); - if path.ends_with(search_path) { - return path - .ancestors() - .nth(Path::new(search_path).components().count()) - .map(|x| x.to_owned()); - } - } - None - } -} diff --git a/crates/brush-viewer/src/train_loop.rs b/crates/brush-viewer/src/train_loop.rs index 6b359f3c..4260b90a 100644 --- a/crates/brush-viewer/src/train_loop.rs +++ b/crates/brush-viewer/src/train_loop.rs @@ -1,7 +1,7 @@ use async_fn_stream::try_fn_stream; use brush_dataset::{ - scene_loader::SceneLoader, zip::DatasetZip, Dataset, LoadDatasetArgs, LoadInitArgs, + brush_vfs::BrushVfs, scene_loader::SceneLoader, Dataset, LoadDatasetArgs, LoadInitArgs, }; use brush_render::gaussian_splats::{RandomSplatsConfig, Splats}; use brush_train::train::{SplatTrainer, TrainConfig}; @@ -9,7 +9,6 @@ use burn::module::AutodiffModule; use burn_jit::cubecl::Runtime; use burn_wgpu::{Wgpu, WgpuDevice, WgpuRuntime}; use rand::SeedableRng; -use tokio::io::AsyncReadExt; use tokio::{ io::AsyncRead, sync::mpsc::{error::TryRecvError, Receiver}, @@ -28,7 +27,7 @@ pub enum TrainMessage { } pub(crate) fn train_loop( - mut data: T, + data_reader: T, device: WgpuDevice, mut receiver: Receiver, load_data_args: LoadDatasetArgs, @@ -36,10 +35,8 @@ pub(crate) fn train_loop( config: TrainConfig, ) -> impl Stream> { try_fn_stream(|emitter| async move { - let mut bytes = vec![]; - data.read_to_end(&mut bytes).await?; // TODO: async zip ideally. - let zip_data = DatasetZip::from_data(bytes)?; + let zip_data = BrushVfs::from_zip_reader(data_reader).await?; let batch_size = 1; @@ -53,7 +50,7 @@ pub(crate) fn train_loop( let mut dataset = Dataset::empty(); let (mut splat_stream, mut data_stream) = - brush_dataset::load_dataset(zip_data.clone(), &load_data_args, &device)?; + brush_dataset::load_dataset(zip_data.clone(), &load_data_args, &device).await?; // Read initial splats if any. while let Some(message) = splat_stream.next().await { diff --git a/crates/colmap-reader/Cargo.toml b/crates/colmap-reader/Cargo.toml index 5082c356..f8e5b1c2 100644 --- a/crates/colmap-reader/Cargo.toml +++ b/crates/colmap-reader/Cargo.toml @@ -8,3 +8,4 @@ license.workspace = true [dependencies] glam.workspace = true byteorder.workspace = true +tokio.workspace = true diff --git a/crates/colmap-reader/src/lib.rs b/crates/colmap-reader/src/lib.rs index 361b2032..15a4c9c5 100644 --- a/crates/colmap-reader/src/lib.rs +++ b/crates/colmap-reader/src/lib.rs @@ -3,6 +3,9 @@ use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; use std::io::{self, BufRead, Read}; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncReadExt; +use tokio::io::{AsyncBufRead, AsyncRead}; // TODO: Really these should each hold their respective params but bit of an annoying refactor. We just need // basic params. @@ -156,12 +159,12 @@ fn parse(s: &str) -> io::Result { .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Parse error")) } -fn read_cameras_text(reader: R) -> io::Result> { +async fn read_cameras_text(reader: R) -> io::Result> { let mut cameras = HashMap::new(); - let mut buf_reader = io::BufReader::new(reader); + let mut buf_reader = tokio::io::BufReader::new(reader); let mut line = String::new(); - while buf_reader.read_line(&mut line)? > 0 { + while buf_reader.read_line(&mut line).await? > 0 { if line.starts_with('#') { line.clear(); continue; @@ -209,15 +212,17 @@ fn read_cameras_text(reader: R) -> io::Result> { Ok(cameras) } -fn read_cameras_binary(mut reader: R) -> io::Result> { +async fn read_cameras_binary( + mut reader: R, +) -> io::Result> { let mut cameras = HashMap::new(); - let num_cameras = reader.read_u64::()?; + let num_cameras = reader.read_u64_le().await?; for _ in 0..num_cameras { - let camera_id = reader.read_i32::()?; - let model_id = reader.read_i32::()?; - let width = reader.read_u64::()?; - let height = reader.read_u64::()?; + let camera_id = reader.read_i32_le().await?; + let model_id = reader.read_i32_le().await?; + let width = reader.read_u64_le().await?; + let height = reader.read_u64_le().await?; let model = CameraModel::from_id(model_id) .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid camera model"))?; @@ -225,7 +230,7 @@ fn read_cameras_binary(mut reader: R) -> io::Result()?); + params.push(reader.read_f64_le().await?); } cameras.insert( @@ -243,16 +248,16 @@ fn read_cameras_binary(mut reader: R) -> io::Result(mut reader: R) -> io::Result> { +async fn read_images_text(mut reader: R) -> io::Result> { let mut images = HashMap::new(); - let mut buf_reader = io::BufReader::new(reader); + let mut buf_reader = tokio::io::BufReader::new(reader); let mut line = String::new(); let mut img_data = true; loop { line.clear(); - if buf_reader.read_line(&mut line)? == 0 { + if buf_reader.read_line(&mut line).await? == 0 { break; } @@ -272,7 +277,7 @@ fn read_images_text(mut reader: R) -> io::Result> { let name = elems[9].to_string(); line.clear(); - buf_reader.read_line(&mut line)?; + buf_reader.read_line(&mut line).await?; let elems: Vec<&str> = line.split_whitespace().collect(); let mut xys = Vec::new(); let mut point3d_ids = Vec::new(); @@ -299,44 +304,46 @@ fn read_images_text(mut reader: R) -> io::Result> { Ok(images) } -fn read_images_binary(mut reader: R) -> io::Result> { +async fn read_images_binary( + mut reader: R, +) -> io::Result> { let mut images = HashMap::new(); - let num_images = reader.read_u64::()?; + let num_images = reader.read_u64_le().await?; for _ in 0..num_images { - let image_id = reader.read_i32::()?; + let image_id = reader.read_i32_le().await?; let [w, x, y, z] = [ - reader.read_f64::()? as f32, - reader.read_f64::()? as f32, - reader.read_f64::()? as f32, - reader.read_f64::()? as f32, + reader.read_f64_le().await? as f32, + reader.read_f64_le().await? as f32, + reader.read_f64_le().await? as f32, + reader.read_f64_le().await? as f32, ]; let quat = glam::quat(x, y, z, w); let tvec = glam::vec3( - reader.read_f64::()? as f32, - reader.read_f64::()? as f32, - reader.read_f64::()? as f32, + reader.read_f64_le().await? as f32, + reader.read_f64_le().await? as f32, + reader.read_f64_le().await? as f32, ); - let camera_id = reader.read_i32::()?; + let camera_id = reader.read_i32_le().await?; let mut name_bytes = Vec::new(); - reader.read_until(b'\0', &mut name_bytes)?; + reader.read_until(b'\0', &mut name_bytes).await?; let name = std::str::from_utf8(&name_bytes[..name_bytes.len() - 1]) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? .to_owned(); - let num_points2d = reader.read_u64::()?; + let num_points2d = reader.read_u64_le().await?; let mut xys = Vec::with_capacity(num_points2d as usize); let mut point3d_ids = Vec::with_capacity(num_points2d as usize); for _ in 0..num_points2d { xys.push(glam::Vec2::new( - reader.read_f64::()? as f32, - reader.read_f64::()? as f32, + reader.read_f64_le().await? as f32, + reader.read_f64_le().await? as f32, )); - point3d_ids.push(reader.read_i64::()?); + point3d_ids.push(reader.read_i64().await?); } images.insert( @@ -355,12 +362,14 @@ fn read_images_binary(mut reader: R) -> io::Result(mut reader: R) -> io::Result> { +async fn read_points3d_text( + mut reader: R, +) -> io::Result> { let mut points3d = HashMap::new(); - let mut buf_reader = io::BufReader::new(reader); + let mut buf_reader = tokio::io::BufReader::new(reader); let mut line = String::new(); - while buf_reader.read_line(&mut line)? > 0 { + while buf_reader.read_line(&mut line).await? > 0 { if line.starts_with('#') { line.clear(); continue; @@ -413,27 +422,33 @@ fn read_points3d_text(mut reader: R) -> io::Result(mut reader: R) -> io::Result> { +async fn read_points3d_binary( + mut reader: R, +) -> io::Result> { let mut points3d = HashMap::new(); - let num_points = reader.read_u64::()?; + let num_points = reader.read_u64_le().await?; for _ in 0..num_points { - let point3d_id = reader.read_i64::()?; + let point3d_id = reader.read_i64().await?; let xyz = glam::Vec3::new( - reader.read_f64::()? as f32, - reader.read_f64::()? as f32, - reader.read_f64::()? as f32, + reader.read_f64_le().await? as f32, + reader.read_f64_le().await? as f32, + reader.read_f64_le().await? as f32, ); - let rgb = [reader.read_u8()?, reader.read_u8()?, reader.read_u8()?]; - let error = reader.read_f64::()?; + let rgb = [ + reader.read_u8().await?, + reader.read_u8().await?, + reader.read_u8().await?, + ]; + let error = reader.read_f64_le().await?; - let track_length = reader.read_u64::()?; + let track_length = reader.read_u64_le().await?; let mut image_ids = Vec::with_capacity(track_length as usize); let mut point2d_idxs = Vec::with_capacity(track_length as usize); for _ in 0..track_length { - image_ids.push(reader.read_i32::()?); - point2d_idxs.push(reader.read_i32::()?); + image_ids.push(reader.read_i32_le().await?); + point2d_idxs.push(reader.read_i32_le().await?); } points3d.insert( @@ -451,26 +466,35 @@ fn read_points3d_binary(mut reader: R) -> io::Result(mut reader: R, binary: bool) -> io::Result> { +pub async fn read_cameras( + mut reader: R, + binary: bool, +) -> io::Result> { if binary { - read_cameras_binary(reader) + read_cameras_binary(reader).await } else { - read_cameras_text(reader) + read_cameras_text(reader).await } } -pub fn read_images(reader: R, binary: bool) -> io::Result> { +pub async fn read_images( + reader: R, + binary: bool, +) -> io::Result> { if binary { - read_images_binary(reader) + read_images_binary(reader).await } else { - read_images_text(reader) + read_images_text(reader).await } } -pub fn read_points3d(reader: R, binary: bool) -> io::Result> { +pub async fn read_points3d( + reader: R, + binary: bool, +) -> io::Result> { if binary { - read_points3d_binary(reader) + read_points3d_binary(reader).await } else { - read_points3d_text(reader) + read_points3d_text(reader).await } } From 483b66c7e893ecaf80db821d3fe56734a0bfa306 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 9 Dec 2024 16:01:30 +0000 Subject: [PATCH 2/3] Add basics for directory, cleanup, support in process_loop --- crates/brush-dataset/src/brush_vfs.rs | 109 ++++-------------- crates/brush-dataset/src/formats/colmap.rs | 57 ++++++--- crates/brush-dataset/src/formats/mod.rs | 11 +- .../brush-dataset/src/formats/nerfstudio.rs | 43 +++++-- crates/brush-viewer/src/train_loop.rs | 14 +-- crates/brush-viewer/src/viewer.rs | 39 ++++--- 6 files changed, 129 insertions(+), 144 deletions(-) diff --git a/crates/brush-dataset/src/brush_vfs.rs b/crates/brush-dataset/src/brush_vfs.rs index 7bc832f5..df39b406 100644 --- a/crates/brush-dataset/src/brush_vfs.rs +++ b/crates/brush-dataset/src/brush_vfs.rs @@ -7,7 +7,7 @@ use std::{ collections::HashMap, io::{Cursor, Read}, - path::{Path, PathBuf}, + path::{Component, Path, PathBuf}, sync::Arc, }; @@ -19,49 +19,28 @@ use zip::{ ZipArchive, }; +type DynRead = Box; + #[derive(Clone)] pub struct ZipData { data: Arc>, } -type ZipReader = Cursor; - impl AsRef<[u8]> for ZipData { fn as_ref(&self) -> &[u8] { &self.data } } -impl ZipData { - pub fn open_for_read(&self) -> ZipReader { - Cursor::new(self.clone()) - } -} - -impl From> for ZipData { - fn from(value: Vec) -> Self { - Self { - data: Arc::new(value), - } - } -} - pub(crate) fn normalized_path(path: &Path) -> PathBuf { - Path::new(path) - .components() - .skip_while(|c| matches!(c, std::path::Component::CurDir)) - .collect::() -} - -#[derive(Clone)] -pub struct ConsumableReader { - // Option allows us to take ownership when first used - inner: Arc>>>, + path.components() + .filter(|c| !matches!(c, Component::CurDir | Component::ParentDir)) + .collect() } #[derive(Clone, Default)] pub struct PathReader { - paths: HashMap, + paths: HashMap>>>, } impl PathReader { @@ -69,18 +48,16 @@ impl PathReader { self.paths.keys() } - pub fn add(&mut self, path: PathBuf, reader: impl AsyncRead + Send + Unpin + 'static) { + pub fn add(&mut self, path: &Path, reader: impl AsyncRead + Send + Unpin + 'static) { self.paths.insert( - path, - ConsumableReader { - inner: Arc::new(Mutex::new(Some(Box::new(reader)))), - }, + path.to_path_buf(), + Arc::new(Mutex::new(Some(Box::new(reader)))), ); } async fn open(&mut self, path: &Path) -> anyhow::Result> { let entry = self.paths.remove(path).context("File not found")?; - let reader = entry.inner.lock().await.take(); + let reader = entry.lock().await.take(); reader.context("Missing reader") } } @@ -99,8 +76,10 @@ impl BrushVfs { let mut reader = reader; reader.read_to_end(&mut bytes).await?; - let zip_data = ZipData::from(bytes); - let archive = ZipArchive::new(zip_data.open_for_read())?; + let zip_data = ZipData { + data: Arc::new(bytes), + }; + let archive = ZipArchive::new(Cursor::new(zip_data))?; Ok(BrushVfs::Zip(archive)) } @@ -117,46 +96,17 @@ impl BrushVfs { Ok(BrushVfs::Directory(dir.to_path_buf(), paths)) } - pub(crate) fn file_names(&self) -> impl Iterator + '_ { - let iterator: Box> = match self { - BrushVfs::Zip(archive) => Box::new(archive.file_names()), - BrushVfs::Manual(map) => Box::new(map.paths().filter_map(|k| k.to_str())), - BrushVfs::Directory(_, paths) => Box::new(paths.iter().filter_map(|k| k.to_str())), + pub fn file_names(&self) -> impl Iterator + '_ { + let iterator: Box> = match self { + BrushVfs::Zip(archive) => Box::new(archive.file_names().map(Path::new)), + BrushVfs::Manual(map) => Box::new(map.paths().map(|p| p.as_path())), + BrushVfs::Directory(_, paths) => Box::new(paths.iter().map(|p| p.as_path())), }; // stupic macOS. - iterator.filter(|p| !p.contains("__MACOSX")) + iterator.filter(|p| !p.starts_with("__MACOSX")) } - pub(crate) fn find_with_extension( - &self, - extension: &str, - contains: &str, - ) -> anyhow::Result { - let names: Vec<_> = self - .file_names() - .filter(|name| name.ends_with(extension)) - .collect(); - - if names.len() == 1 { - return Ok(Path::new(names[0]).to_owned()); - } - - let names: Vec<_> = names - .iter() - .filter(|name| name.contains(contains)) - .collect(); - - if names.len() == 1 { - return Ok(Path::new(names[0]).to_owned()); - } - - anyhow::bail!("Failed to find file ending in {extension} maybe containing {contains}."); - } - - pub(crate) async fn open_reader_at_path( - &mut self, - path: &Path, - ) -> anyhow::Result> { + pub async fn open_path(&mut self, path: &Path) -> anyhow::Result { match self { BrushVfs::Zip(archive) => { let name = archive @@ -164,10 +114,8 @@ impl BrushVfs { .find(|name| path == Path::new(name)) .ok_or(ZipError::FileNotFound)?; let name = name.to_owned(); - let mut buffer = vec![]; archive.by_name(&name)?.read_to_end(&mut buffer)?; - Ok(Box::new(Cursor::new(buffer))) } BrushVfs::Manual(map) => map.open(path).await, @@ -177,17 +125,4 @@ impl BrushVfs { } } } - - pub(crate) fn find_base_path(&self, search_path: &str) -> Option { - for file in self.file_names() { - let path = normalized_path(Path::new(file)); - if path.ends_with(search_path) { - return path - .ancestors() - .nth(Path::new(search_path).components().count()) - .map(|x| x.to_owned()); - } - } - None - } } diff --git a/crates/brush-dataset/src/formats/colmap.rs b/crates/brush-dataset/src/formats/colmap.rs index fd15984b..7d0c35a7 100644 --- a/crates/brush-dataset/src/formats/colmap.rs +++ b/crates/brush-dataset/src/formats/colmap.rs @@ -1,8 +1,16 @@ -use std::{future::Future, sync::Arc}; +use std::{ + future::Future, + path::{Path, PathBuf}, + sync::Arc, +}; use super::{DataStream, LoadDatasetArgs}; -use crate::{brush_vfs::BrushVfs, splat_import::SplatMessage, stream_fut_parallel, Dataset}; -use anyhow::{Context, Result}; +use crate::{ + brush_vfs::{normalized_path, BrushVfs}, + splat_import::SplatMessage, + stream_fut_parallel, Dataset, +}; +use anyhow::Result; use async_fn_stream::try_fn_stream; use brush_render::{ camera::{self, Camera}, @@ -15,6 +23,19 @@ use glam::Vec3; use tokio::io::AsyncReadExt; use tokio_stream::StreamExt; +fn find_base_path(archive: &BrushVfs, search_path: &str) -> Option { + for file in archive.file_names() { + let path = normalized_path(Path::new(file)); + if path.ends_with(search_path) { + return path + .ancestors() + .nth(Path::new(search_path).components().count()) + .map(|x| x.to_owned()); + } + } + None +} + async fn read_views( archive: BrushVfs, load_args: &LoadDatasetArgs, @@ -22,14 +43,14 @@ async fn read_views( log::info!("Loading colmap dataset"); let mut archive = archive; - let (is_binary, base_path) = if let Some(path) = archive.find_base_path("sparse/0/cameras.bin") - { - (true, path) - } else if let Some(path) = archive.find_base_path("sparse/0/cameras.txt") { - (false, path) - } else { - anyhow::bail!("No COLMAP data found (either text or binary.)") - }; + let (is_binary, base_path) = + if let Some(path) = find_base_path(&archive, "sparse/0/cameras.bin") { + (true, path) + } else if let Some(path) = find_base_path(&archive, "sparse/0/cameras.txt") { + (false, path) + } else { + anyhow::bail!("No COLMAP data found (either text or binary.)") + }; let (cam_path, img_path) = if is_binary { ( @@ -44,12 +65,12 @@ async fn read_views( }; let cam_model_data = { - let mut cam_file = archive.open_reader_at_path(&cam_path).await?; + let mut cam_file = archive.open_path(&cam_path).await?; colmap_reader::read_cameras(&mut cam_file, is_binary).await? }; let img_infos = { - let img_file = archive.open_reader_at_path(&img_path).await?; + let img_file = archive.open_path(&img_path).await?; let mut buf_reader = tokio::io::BufReader::new(img_file); colmap_reader::read_images(&mut buf_reader, is_binary).await? }; @@ -86,7 +107,7 @@ async fn read_views( let mut img_bytes = vec![]; archive - .open_reader_at_path(&img_path) + .open_path(&img_path) .await? .read_to_end(&mut img_bytes) .await?; @@ -105,7 +126,7 @@ async fn read_views( let camera = Camera::new(translation, quat, fovx, fovy, center_uv); let view = SceneView { - name: img_path.to_str().context("Invalid file name")?.to_owned(), + name: img_path.to_string_lossy().to_string(), camera, image: Arc::new(img), }; @@ -156,9 +177,9 @@ pub(crate) async fn load_dataset( let init_stream = try_fn_stream(|emitter| async move { let (is_binary, base_path) = - if let Some(path) = archive.find_base_path("sparse/0/cameras.bin") { + if let Some(path) = find_base_path(&archive, "sparse/0/cameras.bin") { (true, path) - } else if let Some(path) = archive.find_base_path("sparse/0/cameras.txt") { + } else if let Some(path) = find_base_path(&archive, "sparse/0/cameras.txt") { (false, path) } else { anyhow::bail!("No COLMAP data found (either text or binary.") @@ -172,7 +193,7 @@ pub(crate) async fn load_dataset( // Extract COLMAP sfm points. let points_data = { - let mut points_file = archive.open_reader_at_path(&points_path).await?; + let mut points_file = archive.open_path(&points_path).await?; colmap_reader::read_points3d(&mut points_file, is_binary).await }; diff --git a/crates/brush-dataset/src/formats/mod.rs b/crates/brush-dataset/src/formats/mod.rs index d4de6ff2..182f59a7 100644 --- a/crates/brush-dataset/src/formats/mod.rs +++ b/crates/brush-dataset/src/formats/mod.rs @@ -5,7 +5,7 @@ use crate::{ }; use anyhow::Result; use brush_render::Backend; -use std::pin::Pin; +use std::{path::Path, pin::Pin}; use tokio_stream::Stream; pub mod colmap; @@ -34,13 +34,10 @@ pub async fn load_dataset( }; // If there's an init.ply definitey override the init stream with that. - let init_path = vfs.find_with_extension(".ply", "init"); - - let init_stream = if let Ok(path) = init_path { - let ply_data = vfs.open_reader_at_path(&path).await?; - log::info!("Using {path:?} as initial point cloud."); + let init_stream = if let Ok(reader) = vfs.open_path(Path::new("init.ply")).await { + log::info!("Using init.ply as initial point cloud."); Box::pin(load_splat_from_ply( - ply_data, + reader, load_args.subsample_points, device.clone(), )) diff --git a/crates/brush-dataset/src/formats/nerfstudio.rs b/crates/brush-dataset/src/formats/nerfstudio.rs index 264d538d..02c0a4b0 100644 --- a/crates/brush-dataset/src/formats/nerfstudio.rs +++ b/crates/brush-dataset/src/formats/nerfstudio.rs @@ -129,7 +129,7 @@ fn read_transforms_file( let mut img_buffer = vec![]; archive - .open_reader_at_path(&path) + .open_path(&path) .await? .read_to_end(&mut img_buffer) .await?; @@ -184,10 +184,27 @@ pub async fn read_dataset( ) -> Result<(DataStream>, DataStream)> { log::info!("Loading nerfstudio dataset"); - let transforms_path = vfs.find_with_extension(".json", "_train")?; + let json_files: Vec<_> = vfs + .file_names() + .filter(|&n| n.extension().is_some_and(|p| p == "json")) + .map(|x| x.to_path_buf()) + .collect(); + + let transforms_path = if json_files.len() == 1 { + json_files.first().cloned().unwrap() + } else { + let train = json_files.iter().find(|x| { + x.file_name() + .is_some_and(|p| p.to_string_lossy().contains("_train")) + }); + let Some(train) = train else { + anyhow::bail!("No json file found."); + }; + train.clone() + }; let mut buf = String::new(); - vfs.open_reader_at_path(&transforms_path) + vfs.open_path(&transforms_path) .await? .read_to_string(&mut buf) .await?; @@ -209,25 +226,33 @@ pub async fn read_dataset( let load_args_clone = load_args.clone(); - let transforms_path_clone = transforms_path.clone(); let mut data_clone = vfs.clone(); let dataset_stream = try_fn_stream(|emitter| async move { let mut train_views = vec![]; let mut eval_views = vec![]; - let eval_trans_path = data_clone.find_with_extension(".json", "_val")?; + let eval_trans_path = json_files.iter().find(|x| { + x.file_name() + .is_some_and(|p| p.to_string_lossy().contains("_val")) + }); // If a seperate eval file is specified, read it. - let val_stream = if eval_trans_path != transforms_path_clone { + let val_stream = if let Some(eval_trans_path) = eval_trans_path { let mut json_str = String::new(); data_clone - .open_reader_at_path(&eval_trans_path) + .open_path(eval_trans_path) .await? .read_to_string(&mut json_str) .await?; let val_scene = serde_json::from_str(&json_str)?; - read_transforms_file(val_scene, eval_trans_path, data_clone, &load_args_clone).ok() + read_transforms_file( + val_scene, + eval_trans_path.clone(), + data_clone, + &load_args_clone, + ) + .ok() } else { None }; @@ -279,7 +304,7 @@ pub async fn read_dataset( let splat_stream = try_fn_stream(|emitter| async move { if let Some(init) = train_scene.ply_file_path { let init_path = transforms_path.parent().unwrap().join(init); - let ply_data = vfs.open_reader_at_path(&init_path).await; + let ply_data = vfs.open_path(&init_path).await; if let Ok(ply_data) = ply_data { let splat_stream = diff --git a/crates/brush-viewer/src/train_loop.rs b/crates/brush-viewer/src/train_loop.rs index 4260b90a..18302f6a 100644 --- a/crates/brush-viewer/src/train_loop.rs +++ b/crates/brush-viewer/src/train_loop.rs @@ -9,10 +9,7 @@ use burn::module::AutodiffModule; use burn_jit::cubecl::Runtime; use burn_wgpu::{Wgpu, WgpuDevice, WgpuRuntime}; use rand::SeedableRng; -use tokio::{ - io::AsyncRead, - sync::mpsc::{error::TryRecvError, Receiver}, -}; +use tokio::sync::mpsc::{error::TryRecvError, Receiver}; use tokio_stream::{Stream, StreamExt}; use web_time::Instant; @@ -26,8 +23,8 @@ pub enum TrainMessage { Eval { view_count: Option }, } -pub(crate) fn train_loop( - data_reader: T, +pub(crate) fn train_loop( + vfs: BrushVfs, device: WgpuDevice, mut receiver: Receiver, load_data_args: LoadDatasetArgs, @@ -35,9 +32,6 @@ pub(crate) fn train_loop( config: TrainConfig, ) -> impl Stream> { try_fn_stream(|emitter| async move { - // TODO: async zip ideally. - let zip_data = BrushVfs::from_zip_reader(data_reader).await?; - let batch_size = 1; // Maybe good if the seed would be configurable. @@ -50,7 +44,7 @@ pub(crate) fn train_loop( let mut dataset = Dataset::empty(); let (mut splat_stream, mut data_stream) = - brush_dataset::load_dataset(zip_data.clone(), &load_data_args, &device).await?; + brush_dataset::load_dataset(vfs.clone(), &load_data_args, &device).await?; // Read initial splats if any. while let Some(message) = splat_stream.next().await { diff --git a/crates/brush-viewer/src/viewer.rs b/crates/brush-viewer/src/viewer.rs index ff179c2f..4ed3127b 100644 --- a/crates/brush-viewer/src/viewer.rs +++ b/crates/brush-viewer/src/viewer.rs @@ -1,9 +1,11 @@ use core::f32; use std::ops::Range; +use std::path::Path; use std::{pin::Pin, sync::Arc}; use async_fn_stream::try_fn_stream; +use brush_dataset::brush_vfs::{BrushVfs, PathReader}; use brush_dataset::{self, splat_import, Dataset, LoadDatasetArgs, LoadInitArgs}; use brush_render::camera::Camera; use brush_render::gaussian_splats::Splats; @@ -147,19 +149,36 @@ fn process_loop( let mut data = BufReader::new(data); let mut peek = [0; 128]; data.read_exact(&mut peek).await?; - let data = std::io::Cursor::new(peek).chain(data); + let reader = std::io::Cursor::new(peek).chain(data); - log::info!("{:?}", String::from_utf8(peek.to_vec())); + let mut vfs = if peek.starts_with("ply".as_bytes()) { + let mut path_reader = PathReader::default(); + path_reader.add(Path::new("input.ply"), reader); + BrushVfs::from_paths(path_reader) + } else if peek.starts_with("PK".as_bytes()) { + BrushVfs::from_zip_reader(reader).await? + } else if peek.starts_with("".as_bytes()) { + anyhow::bail!("Failed to download data (are you trying to download from Google Drive? You might have to use the proxy.") + } else { + anyhow::bail!("only zip and ply files are supported."); + }; - if peek.starts_with("ply".as_bytes()) { - log::info!("Attempting to load data as .ply data"); + let names: Vec<_> = vfs.file_names().map(|x| x.to_path_buf()).collect(); + log::info!("Mounted VFS with {} files", names.len()); + + if names.len() == 1 && names[0].extension().is_some_and(|p| p == "ply") { + log::info!("Loading single ply file"); let _ = emitter .emit(ProcessMessage::StartLoading { training: false }) .await; let sub_sample = None; // Subsampling a trained ply doesn't really make sense. - let splat_stream = splat_import::load_splat_from_ply(data, sub_sample, device.clone()); + let splat_stream = splat_import::load_splat_from_ply( + vfs.open_path(&names[0]).await?, + sub_sample, + device.clone(), + ); let mut splat_stream = std::pin::pin!(splat_stream); @@ -178,15 +197,13 @@ fn process_loop( emitter .emit(ProcessMessage::DoneLoading { training: true }) .await; - } else if peek.starts_with("PK".as_bytes()) { - log::info!("Attempting to load data as .zip data"); - + } else { let _ = emitter .emit(ProcessMessage::StartLoading { training: true }) .await; let stream = train_loop::train_loop( - data, + vfs, device, train_receiver, load_data_args, @@ -197,10 +214,6 @@ fn process_loop( while let Some(message) = stream.next().await { emitter.emit(message?).await; } - } else if peek.starts_with("".as_bytes()) { - anyhow::bail!("Failed to download data (are you trying to download from Google Drive? You might have to use the proxy.") - } else { - anyhow::bail!("only zip and ply files are supported."); } Ok(()) From 42148736c6cf18ed7d4920045be00f2815223d4b Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 9 Dec 2024 17:14:16 +0000 Subject: [PATCH 3/3] Fixes for wasm --- crates/brush-dataset/src/brush_vfs.rs | 8 ++- crates/brush-viewer/src/data_source.rs | 0 crates/brush-viewer/src/viewer.rs | 67 ++++++++++++++++++-------- 3 files changed, 52 insertions(+), 23 deletions(-) create mode 100644 crates/brush-viewer/src/data_source.rs diff --git a/crates/brush-dataset/src/brush_vfs.rs b/crates/brush-dataset/src/brush_vfs.rs index df39b406..628b9c0f 100644 --- a/crates/brush-dataset/src/brush_vfs.rs +++ b/crates/brush-dataset/src/brush_vfs.rs @@ -12,8 +12,8 @@ use std::{ }; use anyhow::Context; -use tokio::io::AsyncReadExt; -use tokio::{io::AsyncRead, sync::Mutex}; +use tokio::{io::AsyncRead, io::AsyncReadExt, sync::Mutex}; + use zip::{ result::{ZipError, ZipResult}, ZipArchive, @@ -66,6 +66,7 @@ impl PathReader { pub enum BrushVfs { Zip(ZipArchive>), Manual(PathReader), + #[cfg(not(target_family = "wasm"))] Directory(PathBuf, Vec), } @@ -87,6 +88,7 @@ impl BrushVfs { BrushVfs::Manual(paths) } + #[cfg(not(target_family = "wasm"))] pub async fn from_directory(dir: &Path) -> anyhow::Result { let mut read = ::tokio::fs::read_dir(dir).await?; let mut paths = vec![]; @@ -100,6 +102,7 @@ impl BrushVfs { let iterator: Box> = match self { BrushVfs::Zip(archive) => Box::new(archive.file_names().map(Path::new)), BrushVfs::Manual(map) => Box::new(map.paths().map(|p| p.as_path())), + #[cfg(not(target_family = "wasm"))] BrushVfs::Directory(_, paths) => Box::new(paths.iter().map(|p| p.as_path())), }; // stupic macOS. @@ -119,6 +122,7 @@ impl BrushVfs { Ok(Box::new(Cursor::new(buffer))) } BrushVfs::Manual(map) => map.open(path).await, + #[cfg(not(target_family = "wasm"))] BrushVfs::Directory(path_buf, _) => { let file = tokio::fs::File::open(path_buf).await?; Ok(Box::new(file)) diff --git a/crates/brush-viewer/src/data_source.rs b/crates/brush-viewer/src/data_source.rs new file mode 100644 index 00000000..e69de29b diff --git a/crates/brush-viewer/src/viewer.rs b/crates/brush-viewer/src/viewer.rs index 4ed3127b..faf25c05 100644 --- a/crates/brush-viewer/src/viewer.rs +++ b/crates/brush-viewer/src/viewer.rs @@ -16,6 +16,9 @@ use burn_wgpu::{Wgpu, WgpuDevice}; use eframe::egui; use egui_tiles::{Container, Tile, TileId, Tiles}; use glam::{Affine3A, Quat, Vec3, Vec3A}; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::bytes::Bytes; +use tokio_util::io::StreamReader; use tokio_with_wasm::alias as tokio; use ::tokio::io::AsyncReadExt; @@ -145,7 +148,7 @@ fn process_loop( // Small hack to peek some bytes: Read them // and add them at the start again. - let data = source.read().await?; + let data = source.into_reader()?; let mut data = BufReader::new(data); let mut peek = [0; 128]; data.read_exact(&mut peek).await?; @@ -228,31 +231,53 @@ pub enum DataSource { Url(String), } -#[cfg(target_family = "wasm")] -type DataRead = Pin>; - -#[cfg(not(target_family = "wasm"))] type DataRead = Pin>; impl DataSource { - async fn read(&self) -> anyhow::Result { - match self { - DataSource::PickFile => { - let picked = rrfd::pick_file().await?; - let data = picked.read().await; - Ok(Box::pin(std::io::Cursor::new(data))) - } - DataSource::Url(url) => { - let mut url = url.to_owned(); - if !url.starts_with("http://") && !url.starts_with("https://") { - url = format!("https://{}", url); + fn into_reader(self) -> anyhow::Result { + let (send, rec) = ::tokio::sync::mpsc::channel(16); + + // Spawn the data reading. + tokio::spawn(async move { + let stream = try_fn_stream(|emitter| async move { + match self { + DataSource::PickFile => { + let picked = rrfd::pick_file() + .await + .map_err(|_| std::io::ErrorKind::NotFound)?; + let data = picked.read().await; + emitter.emit(Bytes::from_owner(data)).await; + } + DataSource::Url(url) => { + let mut url = url.to_owned(); + if !url.starts_with("http://") && !url.starts_with("https://") { + url = format!("https://{}", url); + } + let mut response = reqwest::get(url) + .await + .map_err(|_| std::io::ErrorKind::InvalidInput)? + .bytes_stream(); + + while let Some(bytes) = response.next().await { + let bytes = bytes.map_err(|_| std::io::ErrorKind::ConnectionAborted)?; + emitter.emit(bytes).await; + } + } + }; + anyhow::Result::<(), std::io::Error>::Ok(()) + }); + + let mut stream = std::pin::pin!(stream); + + while let Some(data) = stream.next().await { + if send.send(data).await.is_err() { + break; } - let response = reqwest::get(url).await?.bytes_stream(); - let mapped = response - .map(|e| e.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))); - Ok(Box::pin(tokio_util::io::StreamReader::new(mapped))) } - } + }); + + let reader = StreamReader::new(ReceiverStream::new(rec)); + Ok(reader) } }