diff --git a/.bazelversion b/.bazelversion index 643916c03..815da58b7 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -7.3.1 +7.4.1 diff --git a/BUILD.bazel b/BUILD.bazel index 95acfb027..b68decb79 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -30,6 +30,7 @@ rust_binary( "@crates//:hyper-1.4.1", "@crates//:hyper-util", "@crates//:mimalloc", + "@crates//:mlua", "@crates//:opentelemetry", "@crates//:opentelemetry-prometheus", "@crates//:opentelemetry_sdk", diff --git a/Cargo.lock b/Cargo.lock index c9609016a..d3705e968 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -647,6 +647,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bstr" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -1714,6 +1724,48 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mlua" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f6ddbd668297c46be4bdea6c599dcc1f001a129586272d53170b7ac0a62961e" +dependencies = [ + "bstr", + "either", + "futures-util", + "mlua-sys", + "mlua_derive", + "num-traits", + "parking_lot", + "rustc-hash", +] + +[[package]] +name = "mlua-sys" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9eebac25c35a13285456c88ee2fde93d9aee8bcfdaf03f9d6d12be3391351ec" +dependencies = [ + "cc", + "cfg-if", + "pkg-config", +] + +[[package]] +name = "mlua_derive" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cfc5faa2e0d044b3f5f0879be2920e0a711c97744c42cf1c295cb183668933e" +dependencies = [ + "itertools", + "once_cell", + "proc-macro-error", + "proc-macro2", + "quote", + "regex", + "syn 2.0.79", +] + [[package]] name = "mock_instant" version = "0.5.1" @@ -1737,6 +1789,7 @@ dependencies = [ "hyper 1.4.1", "hyper-util", "mimalloc", + "mlua", "nativelink-config", "nativelink-error", "nativelink-metric", @@ -1781,6 +1834,7 @@ version = "0.5.3" dependencies = [ "fred", "hex", + "mlua", "nativelink-metric", "nativelink-proto", "prost", @@ -1942,6 +1996,7 @@ dependencies = [ "hyper-rustls", "lz4_flex", "memory-stats", + "mlua", "mock_instant", "nativelink-config", "nativelink-error", @@ -1984,6 +2039,7 @@ dependencies = [ "hyper 1.4.1", "hyper-util", "lru", + "mlua", "mock_instant", "nativelink-config", "nativelink-error", @@ -2295,6 +2351,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + [[package]] name = "powerfmt" version = "0.2.0" @@ -2595,6 +2657,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" + [[package]] name = "rustc_version" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index 33fe54582..9fb5f4047 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ futures = { version = "0.3.30", default-features = false } hyper = "1.4.1" hyper-util = "0.1.9" mimalloc = "0.1.43" +mlua = { version = "0.10.0", features = ["lua54", "async", "macros", "send"] } parking_lot = "0.12.3" rustls-pemfile = { version = "2.2.0", default-features = false } scopeguard = { version = "1.2.0", default-features = false } diff --git a/MODULE.bazel b/MODULE.bazel index c6eb65e76..73ae81d10 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -77,8 +77,20 @@ crate.from_cargo( "x86_64-unknown-linux-gnu", ], ) +crate.annotation( + build_script_env = { + "LUA_LIB": "external/lua~", + "LUA_LIB_NAME": "liblua", + }, + crate = "mlua-sys", + deps = ["@lua//:liblua"], +) use_repo(crate, "crates") +bazel_dep(name = "lua", version = "5.4.6") + +inject_repo(crate, "lua") + rust_analyzer = use_extension( "@rules_rust//tools/rust_analyzer:extension.bzl", "rust_analyzer_dependencies", diff --git a/nativelink-error/BUILD.bazel b/nativelink-error/BUILD.bazel index 08d8783de..4605e491e 100644 --- a/nativelink-error/BUILD.bazel +++ b/nativelink-error/BUILD.bazel @@ -16,6 +16,7 @@ rust_library( "//nativelink-proto", "@crates//:fred", "@crates//:hex", + "@crates//:mlua", "@crates//:prost", "@crates//:prost-types", "@crates//:serde", diff --git a/nativelink-error/Cargo.toml b/nativelink-error/Cargo.toml index bcbe76e9c..63f7fafed 100644 --- a/nativelink-error/Cargo.toml +++ b/nativelink-error/Cargo.toml @@ -14,6 +14,7 @@ fred = { version = "9.2.1", default-features = false, features = [ "enable-rustls-ring", ] } hex = { version = "0.4.3", default-features = false } +mlua = { version = "0.10.0", default-features = false } prost = { version = "0.13.3", default-features = false } prost-types = { version = "0.13.3", default-features = false } serde = { version = "1.0.210", default-features = false } diff --git a/nativelink-error/src/lib.rs b/nativelink-error/src/lib.rs index b87782c32..955c5e227 100644 --- a/nativelink-error/src/lib.rs +++ b/nativelink-error/src/lib.rs @@ -260,6 +260,23 @@ impl From for tonic::Status { } } +impl From for Error { + fn from(val: mlua::Error) -> Self { + match val { + mlua::Error::CallbackError { traceback, cause } => { + Self::new(Code::Internal, traceback).merge(std::sync::Arc::unwrap_or_clone(cause)) + } + _ => Self::new(Code::Internal, val.to_string()), + } + } +} + +impl From for mlua::Error { + fn from(val: Error) -> Self { + Self::external(val) + } +} + pub trait ResultExt { fn err_tip_with_code(self, tip_fn: F) -> Result where diff --git a/nativelink-store/BUILD.bazel b/nativelink-store/BUILD.bazel index 2c2db113c..58121ef27 100644 --- a/nativelink-store/BUILD.bazel +++ b/nativelink-store/BUILD.bazel @@ -60,6 +60,7 @@ rust_library( "@crates//:hyper-0.14.30", "@crates//:hyper-rustls", "@crates//:lz4_flex", + "@crates//:mlua", "@crates//:parking_lot", "@crates//:patricia_tree", "@crates//:prost", diff --git a/nativelink-store/Cargo.toml b/nativelink-store/Cargo.toml index fa3094417..653c8770f 100644 --- a/nativelink-store/Cargo.toml +++ b/nativelink-store/Cargo.toml @@ -47,6 +47,7 @@ hyper-rustls = { version = "0.24.2", default-features = false, features = [ "webpki-roots", ] } lz4_flex = { version = "0.11.3", default-features = false } +mlua = { version = "0.10.0", features = ["async"] } parking_lot = "0.12.3" prost = { version = "0.13.3", default-features = false } rand = { version = "0.8.5", default-features = false } diff --git a/nativelink-store/src/ref_store.rs b/nativelink-store/src/ref_store.rs index 6014fbbb1..80447ce89 100644 --- a/nativelink-store/src/ref_store.rs +++ b/nativelink-store/src/ref_store.rs @@ -149,3 +149,26 @@ impl StoreDriver for RefStore { } default_health_status_indicator!(RefStore); + +impl mlua::UserData for RefStore { + fn add_methods>(methods: &mut M) { + use futures::{stream, StreamExt}; + use nativelink_util::common::DigestInfo; + methods.add_async_method( + "get_many", + |_lua, this, (digests, window, func): (Vec, usize, mlua::Function)| async move { + let store = this.get_store()?; + let mut data_stream = stream::iter(digests) + .map(move |digest| async move { + let key = StoreKey::from(digest); + store.get_part_unchunked(key, 0, None).await + }) + .buffered(window); + while let Some(result) = data_stream.next().await { + func.call::<()>(result.map(|b| b.to_vec()).ok())?; + } + Ok(()) + }, + ); + } +} diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index ac17063f1..d6831093e 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -54,6 +54,7 @@ rust_library( "@crates//:hyper-1.4.1", "@crates//:hyper-util", "@crates//:lru", + "@crates//:mlua", "@crates//:mock_instant", "@crates//:parking_lot", "@crates//:pin-project", diff --git a/nativelink-util/Cargo.toml b/nativelink-util/Cargo.toml index fb635af60..2a42f6ce4 100644 --- a/nativelink-util/Cargo.toml +++ b/nativelink-util/Cargo.toml @@ -27,6 +27,7 @@ hex = { version = "0.4.3", default-features = false, features = ["std"] } hyper = "1.4.1" hyper-util = "0.1.9" lru = { version = "0.12.4", default-features = false } +mlua = { version = "0.10.0", features = ["macros"] } parking_lot = "0.12.3" pin-project-lite = "0.2.14" prost = { version = "0.13.3", default-features = false } diff --git a/nativelink-util/src/common.rs b/nativelink-util/src/common.rs index 0eed60341..b8e1bf472 100644 --- a/nativelink-util/src/common.rs +++ b/nativelink-util/src/common.rs @@ -33,7 +33,7 @@ use tracing::{event, Level}; pub use crate::fs; -#[derive(Default, Clone, Copy, Eq, PartialEq, Hash)] +#[derive(Default, Clone, Copy, Eq, PartialEq, Hash, mlua::FromLua)] #[repr(C)] pub struct DigestInfo { /// Raw hash in packed form. @@ -43,6 +43,8 @@ pub struct DigestInfo { size_bytes: u64, } +impl mlua::UserData for DigestInfo {} + impl MetricsComponent for DigestInfo { fn publish( &self, diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 7e2276a63..b26a9b89e 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -110,6 +110,9 @@ struct Args { /// Config file to use. #[clap(value_parser)] config_file: String, + /// Lua script to run. + #[clap(value_parser)] + lua_file: Option, } /// The root metrics collector struct. All metrics will be @@ -953,6 +956,28 @@ async fn inner_main( root_metrics.write().workers = worker_metrics; } + let lua = mlua::Lua::new_with(mlua::StdLib::ALL_SAFE, mlua::LuaOptions::new())?; + if let Some(script) = futures::executor::block_on(get_lua())? { + use futures::FutureExt; + let store_manager = store_manager.clone(); + let get_store = lua.create_function(move |_, name: String| { + Ok(Arc::into_inner(nativelink_store::ref_store::RefStore::new( + &nativelink_config::stores::RefSpec { name: name.clone() }, + Arc::downgrade(&store_manager), + )) + .unwrap()) + })?; + let digest = lua.create_function(|_, (hash, length): (String, usize)| { + Ok(nativelink_util::common::DigestInfo::try_new(&hash, length)?) + })?; + let globals = lua.globals(); + globals.set("get_store", get_store)?; + globals.set("digest", digest)?; + let fut = lua.load(script).call_async::<()>(()).map_err(Error::from); + let spawn_fut = spawn!("lua", fut.map(|_| std::process::exit(0))); + root_futures.push(Box::pin(spawn_fut.map_ok_or_else(|e| Err(e.into()), |v| v))); + } + if let Err(e) = try_join_all(root_futures).await { panic!("{e:?}"); }; @@ -969,6 +994,15 @@ async fn get_config() -> Result> { Ok(serde_json5::from_str(&json_contents)?) } +async fn get_lua() -> Result>, Box> { + let args = Args::parse(); + Ok(if let Some(script) = args.lua_file { + Some(std::fs::read(&script).err_tip(|| format!("Could not open Lua script {script}"))?) + } else { + None + }) +} + fn main() -> Result<(), Box> { init_tracing()?;