Skip to content

Commit

Permalink
Merge pull request #5596 from freddy77/leak_fd
Browse files Browse the repository at this point in the history
Remove possible file descriptor leak if safe_close_and_exec fails
  • Loading branch information
robhoes authored May 22, 2024
2 parents 568323b + 5b86ed8 commit 09d0a3e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 45 deletions.
7 changes: 6 additions & 1 deletion ocaml/forkexecd/lib/forkhelpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,13 @@ let safe_close_and_exec ?tracing ?env stdin stdout stderr
let fds_to_close = ref [] in

let add_fd_to_close_list fd = fds_to_close := fd :: !fds_to_close in
(* let remove_fd_from_close_list fd = fds_to_close := List.filter (fun fd' -> fd' <> fd) !fds_to_close in *)
let remove_fd_from_close_list fd =
fds_to_close := List.filter (fun fd' -> fd' <> fd) !fds_to_close
in
let close_fds () = List.iter (fun fd -> Unix.close fd) !fds_to_close in

add_fd_to_close_list sock ;

finally
(fun () ->
let maybe_add_id_to_fd_map id_to_fd_map (uuid, fd, v) =
Expand Down Expand Up @@ -290,6 +294,7 @@ let safe_close_and_exec ?tracing ?env stdin stdout stderr
Fecomms.write_raw_rpc ?tracing sock Fe.Exec ;
match Fecomms.read_raw_rpc ?tracing sock with
| Ok (Fe.Execed pid) ->
remove_fd_from_close_list sock ;
(sock, pid)
| Ok status ->
let msg =
Expand Down
113 changes: 69 additions & 44 deletions ocaml/forkexecd/test/fe_test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ let min_fds = 7

let max_fds = 1024 - 13 (* fe daemon has a bunch for its own use *)

let fail x =
Xapi_stdext_unix.Unixext.write_string_to_file "/tmp/fe-test.log" x ;
Printf.fprintf stderr "%s\n" x ;
assert false

let fail fmt = Format.ksprintf fail fmt

let all_combinations fds =
let y =
{
Expand Down Expand Up @@ -68,8 +75,26 @@ let shuffle x =
done ;
Array.to_list arr

let fds_fold f init =
let path = "/proc/self/fd" in
(* get rid of the fd used to read the directory *)
Array.fold_right
(fun fd_num acc ->
try
let link = Unix.readlink (Filename.concat path fd_num) in
f fd_num link acc
with _ -> acc
)
(Sys.readdir path) init

let fd_list () = fds_fold (fun fd_num link l -> (fd_num, link) :: l) []

let fd_count () = fds_fold (fun _ _ n -> n + 1) 0

let irrelevant_strings = ["irrelevant"; "not"; "important"]

let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ())

let one fds x =
(*Printf.fprintf stderr "named_fds = %d\n" x.named_fds;
Printf.fprintf stderr "extra = %d\n" x.extra;*)
Expand All @@ -82,7 +107,6 @@ let one fds x =
let number_of_extra = x.extra in
let other_names = make_names number_of_extra in

let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let table =
(fun x -> List.combine x (List.map (fun _ -> fd) x)) (names @ other_names)
in
Expand All @@ -107,7 +131,6 @@ let one fds x =

let test_delay () =
let start = Unix.gettimeofday () in
let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let args = ["sleep"] in
(* Need to have fractional part because some internal usage split integer
and fractional and do computation.
Expand All @@ -117,7 +140,7 @@ let test_delay () =
let timeout = 1.7 in
try
Forkhelpers.execute_command_get_output ~timeout exe args |> ignore ;
failwith "Failed to timeout"
fail "Failed to timeout"
with
| Forkhelpers.Subprocess_timeout ->
let elapsed = Unix.gettimeofday () -. start in
Expand All @@ -127,39 +150,25 @@ let test_delay () =
if elapsed > timeout +. 0.2 then
failwith "Excessive time elapsed"
| e ->
failwith
(Printf.sprintf "Failed with unexpected exception: %s"
(Printexc.to_string e)
)
fail "Failed with unexpected exception: %s" (Printexc.to_string e)

let test_notimeout () =
let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let args = ["sleep"] in
try
Forkhelpers.execute_command_get_output exe args |> ignore ;
()
with e ->
failwith
(Printf.sprintf "Failed with unexpected exception: %s"
(Printexc.to_string e)
)

let fail x =
Xapi_stdext_unix.Unixext.write_string_to_file "/tmp/fe-test.log" x ;
Printf.fprintf stderr "%s\n" x ;
assert false
with e -> fail "Failed with unexpected exception: %s" (Printexc.to_string e)

let expect expected s =
if s <> expected ^ "\n" then
fail (Printf.sprintf "output %s expected %s" s expected)
fail "output %s expected %s" s expected

let test_exitcode () =
let run_expect cmd expected =
try Forkhelpers.execute_command_get_output cmd [] |> ignore
with Forkhelpers.Spawn_internal_error (_, _, Unix.WEXITED n) ->
if n <> expected then
fail
(Printf.sprintf "%s exited with code %d, expected %d" cmd n expected)
fail "%s exited with code %d, expected %d" cmd n expected
in
run_expect "/bin/false" 1 ;
run_expect "/bin/xe-fe-test-no-command" 127 ;
Expand All @@ -168,7 +177,6 @@ let test_exitcode () =
Printf.printf "\nCompleted exitcode tests\n"

let test_output () =
let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let expected_out = "output string" in
let expected_err = "error string" in
let args = ["echo"; expected_out; expected_err] in
Expand All @@ -178,7 +186,6 @@ let test_output () =
print_endline "Completed output tests"

let test_input () =
let exe = Printf.sprintf "/proc/%d/exe" (Unix.getpid ()) in
let input = "input string" in
let args = ["replay"] in
let out, _ =
Expand All @@ -187,6 +194,38 @@ let test_input () =
expect input out ;
print_endline "Completed input tests"

(* This test tests a failure inside Forkhelpers.safe_close_and_exec.
Although the exact way of this reproduction is never supposed to
happen in the real world, an internal failure could happen for instance
if forkexecd daemon is restarted for a moment, so make sure we are
able to detect and handle these cases *)
let test_internal_failure_error () =
let initial_fd_count = fd_count () in
let leak_fd_detect () =
let current_fd_count = fd_count () in
if current_fd_count <> initial_fd_count then
fail "File descriptor leak detected initially %d files, now %d"
initial_fd_count current_fd_count
in
(* this weird function will open and close "num" file descriptors
and returns the last (now closed) of them, mainly to get an invalid
file descriptor with some closed one before *)
let rec waste_fds num =
let fd = Unix.openfile "/dev/null" [Unix.O_WRONLY] 0o0 in
let ret = if num = 0 then fd else waste_fds (num - 1) in
Unix.close fd ; ret
in
let fd = waste_fds 20 in
let args = ["sleep"] in
try
Forkhelpers.safe_close_and_exec None (Some fd) None [] exe args |> ignore ;
fail "Expected an exception"
with
| Fd_send_recv.Unix_error _ ->
leak_fd_detect ()
| e ->
fail "Failed with unexpected exception: %s" (Printexc.to_string e)

let master fds =
Printf.printf "\nPerforming timeout tests\n%!" ;
test_delay () ;
Expand All @@ -196,6 +235,8 @@ let master fds =
Printf.printf "\nPerforming input/output tests\n%!" ;
test_output () ;
test_input () ;
Printf.printf "\nPerforming internal failure test\n%!" ;
test_internal_failure_error () ;
let combinations = shuffle (all_combinations fds) in
Printf.printf "Starting %d tests\n%!" (List.length combinations) ;
let i = ref 0 in
Expand All @@ -215,28 +256,14 @@ let master fds =

let slave = function
| [] ->
failwith "Error, at least one fd expected"
fail "Error, at least one fd expected"
| total_fds :: rest ->
let total_fds = int_of_string total_fds in
let fds =
List.filter (fun x -> not (List.mem x irrelevant_strings)) rest
in
(* Check that these fds are present *)
let pid = Unix.getpid () in
let path = Printf.sprintf "/proc/%d/fd" pid in
let raw =
List.filter (* get rid of the fd used to read the directory *)
(fun x ->
try
ignore (Unix.readlink (Filename.concat path x)) ;
true
with _ -> false
)
(Array.to_list (Sys.readdir path))
in
let pairs =
List.map (fun x -> (x, Unix.readlink (Filename.concat path x))) raw
in
let pairs = fd_list () in
(* Filter any of stdin,stdout,stderr which have been mapped to /dev/null *)
let filtered =
List.filter
Expand All @@ -257,18 +284,16 @@ let slave = function
List.iter
(fun fd ->
if not (List.mem fd (List.map fst filtered)) then
fail (Printf.sprintf "fd %s not in /proc/%d/fd [ %s ]" fd pid ls)
fail "fd %s not in /proc/self/fd [ %s ]" fd ls
)
fds ;
(* Check that we have the expected number *)
(*
Printf.fprintf stderr "%s %d\n" total_fds (List.length present - 1)
*)
if total_fds <> List.length filtered then
fail
(Printf.sprintf "Expected %d fds; /proc/%d/fd has %d: %s" total_fds
pid (List.length filtered) ls
)
fail "Expected %d fds; /proc/self/fd has %d: %s" total_fds
(List.length filtered) ls

let sleep () = Unix.sleep 3 ; Printf.printf "Ok\n"

Expand Down

0 comments on commit 09d0a3e

Please sign in to comment.