Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lua embedding #1468

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7.3.1
7.4.1
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ rust_binary(
"@crates//:hyper-1.4.1",
"@crates//:hyper-util",
"@crates//:mimalloc",
"@crates//:mlua",
"@crates//:opentelemetry",
"@crates//:opentelemetry-prometheus",
"@crates//:opentelemetry_sdk",
Expand Down
68 changes: 68 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ futures = { version = "0.3.30", default-features = false }
hyper = "1.4.1"
hyper-util = "0.1.9"
mimalloc = "0.1.43"
mlua = { version = "0.10.0", features = ["lua54", "async", "macros", "send"] }
parking_lot = "0.12.3"
rustls-pemfile = { version = "2.2.0", default-features = false }
scopeguard = { version = "1.2.0", default-features = false }
Expand Down
12 changes: 12 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,20 @@ crate.from_cargo(
"x86_64-unknown-linux-gnu",
],
)
crate.annotation(
build_script_env = {
"LUA_LIB": "external/lua~",
"LUA_LIB_NAME": "liblua",
},
crate = "mlua-sys",
deps = ["@lua//:liblua"],
)
use_repo(crate, "crates")

bazel_dep(name = "lua", version = "5.4.6")

inject_repo(crate, "lua")

rust_analyzer = use_extension(
"@rules_rust//tools/rust_analyzer:extension.bzl",
"rust_analyzer_dependencies",
Expand Down
1 change: 1 addition & 0 deletions nativelink-error/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ rust_library(
"//nativelink-proto",
"@crates//:fred",
"@crates//:hex",
"@crates//:mlua",
"@crates//:prost",
"@crates//:prost-types",
"@crates//:serde",
Expand Down
1 change: 1 addition & 0 deletions nativelink-error/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ fred = { version = "9.2.1", default-features = false, features = [
"enable-rustls-ring",
] }
hex = { version = "0.4.3", default-features = false }
mlua = { version = "0.10.0", default-features = false }
prost = { version = "0.13.3", default-features = false }
prost-types = { version = "0.13.3", default-features = false }
serde = { version = "1.0.210", default-features = false }
Expand Down
17 changes: 17 additions & 0 deletions nativelink-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,23 @@ impl From<Error> for tonic::Status {
}
}

impl From<mlua::Error> for Error {
fn from(val: mlua::Error) -> Self {
match val {
mlua::Error::CallbackError { traceback, cause } => {
Self::new(Code::Internal, traceback).merge(std::sync::Arc::unwrap_or_clone(cause))
}
_ => Self::new(Code::Internal, val.to_string()),
}
}
}

impl From<Error> for mlua::Error {
fn from(val: Error) -> Self {
Self::external(val)
}
}

pub trait ResultExt<T> {
fn err_tip_with_code<F, S>(self, tip_fn: F) -> Result<T, Error>
where
Expand Down
1 change: 1 addition & 0 deletions nativelink-store/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ rust_library(
"@crates//:hyper-0.14.30",
"@crates//:hyper-rustls",
"@crates//:lz4_flex",
"@crates//:mlua",
"@crates//:parking_lot",
"@crates//:patricia_tree",
"@crates//:prost",
Expand Down
1 change: 1 addition & 0 deletions nativelink-store/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ hyper-rustls = { version = "0.24.2", default-features = false, features = [
"webpki-roots",
] }
lz4_flex = { version = "0.11.3", default-features = false }
mlua = { version = "0.10.0", features = ["async"] }
parking_lot = "0.12.3"
prost = { version = "0.13.3", default-features = false }
rand = { version = "0.8.5", default-features = false }
Expand Down
23 changes: 23 additions & 0 deletions nativelink-store/src/ref_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,26 @@ impl StoreDriver for RefStore {
}

default_health_status_indicator!(RefStore);

impl mlua::UserData for RefStore {
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
use futures::{stream, StreamExt};
use nativelink_util::common::DigestInfo;
methods.add_async_method(
"get_many",
|_lua, this, (digests, window, func): (Vec<DigestInfo>, usize, mlua::Function)| async move {
let store = this.get_store()?;
let mut data_stream = stream::iter(digests)
.map(move |digest| async move {
let key = StoreKey::from(digest);
store.get_part_unchunked(key, 0, None).await
})
.buffered(window);
while let Some(result) = data_stream.next().await {
func.call::<()>(result.map(|b| b.to_vec()).ok())?;
}
Ok(())
},
);
}
}
1 change: 1 addition & 0 deletions nativelink-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ rust_library(
"@crates//:hyper-1.4.1",
"@crates//:hyper-util",
"@crates//:lru",
"@crates//:mlua",
"@crates//:mock_instant",
"@crates//:parking_lot",
"@crates//:pin-project",
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ hex = { version = "0.4.3", default-features = false, features = ["std"] }
hyper = "1.4.1"
hyper-util = "0.1.9"
lru = { version = "0.12.4", default-features = false }
mlua = { version = "0.10.0", features = ["macros"] }
parking_lot = "0.12.3"
pin-project-lite = "0.2.14"
prost = { version = "0.13.3", default-features = false }
Expand Down
4 changes: 3 additions & 1 deletion nativelink-util/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use tracing::{event, Level};

pub use crate::fs;

#[derive(Default, Clone, Copy, Eq, PartialEq, Hash)]
#[derive(Default, Clone, Copy, Eq, PartialEq, Hash, mlua::FromLua)]
#[repr(C)]
pub struct DigestInfo {
/// Raw hash in packed form.
Expand All @@ -43,6 +43,8 @@ pub struct DigestInfo {
size_bytes: u64,
}

impl mlua::UserData for DigestInfo {}

impl MetricsComponent for DigestInfo {
fn publish(
&self,
Expand Down
34 changes: 34 additions & 0 deletions src/bin/nativelink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ struct Args {
/// Config file to use.
#[clap(value_parser)]
config_file: String,
/// Lua script to run.
#[clap(value_parser)]
lua_file: Option<String>,
}

/// The root metrics collector struct. All metrics will be
Expand Down Expand Up @@ -953,6 +956,28 @@ async fn inner_main(
root_metrics.write().workers = worker_metrics;
}

let lua = mlua::Lua::new_with(mlua::StdLib::ALL_SAFE, mlua::LuaOptions::new())?;
if let Some(script) = futures::executor::block_on(get_lua())? {
use futures::FutureExt;
let store_manager = store_manager.clone();
let get_store = lua.create_function(move |_, name: String| {
Ok(Arc::into_inner(nativelink_store::ref_store::RefStore::new(
&nativelink_config::stores::RefSpec { name: name.clone() },
Arc::downgrade(&store_manager),
))
.unwrap())
})?;
let digest = lua.create_function(|_, (hash, length): (String, usize)| {
Ok(nativelink_util::common::DigestInfo::try_new(&hash, length)?)
})?;
let globals = lua.globals();
globals.set("get_store", get_store)?;
globals.set("digest", digest)?;
let fut = lua.load(script).call_async::<()>(()).map_err(Error::from);
let spawn_fut = spawn!("lua", fut.map(|_| std::process::exit(0)));
root_futures.push(Box::pin(spawn_fut.map_ok_or_else(|e| Err(e.into()), |v| v)));
}

if let Err(e) = try_join_all(root_futures).await {
panic!("{e:?}");
};
Expand All @@ -969,6 +994,15 @@ async fn get_config() -> Result<CasConfig, Box<dyn std::error::Error>> {
Ok(serde_json5::from_str(&json_contents)?)
}

async fn get_lua() -> Result<Option<Vec<u8>>, Box<dyn std::error::Error>> {
let args = Args::parse();
Ok(if let Some(script) = args.lua_file {
Some(std::fs::read(&script).err_tip(|| format!("Could not open Lua script {script}"))?)
} else {
None
})
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
init_tracing()?;

Expand Down
Loading