From 43c2bf4984487a850207d20df98f08bbc5355271 Mon Sep 17 00:00:00 2001
From: Eivind Jahren <ejah@equinor.com>
Date: Mon, 30 Sep 2024 14:01:50 +0200
Subject: [PATCH] Fix flaky kill_before_submit tests

---
 .../unit_tests/scheduler/test_lsf_driver.py   | 64 +++++++++---------
 .../unit_tests/scheduler/test_slurm_driver.py | 66 +++++++++----------
 2 files changed, 62 insertions(+), 68 deletions(-)

diff --git a/tests/ert/unit_tests/scheduler/test_lsf_driver.py b/tests/ert/unit_tests/scheduler/test_lsf_driver.py
index e730791eec3..6667fa854a2 100644
--- a/tests/ert/unit_tests/scheduler/test_lsf_driver.py
+++ b/tests/ert/unit_tests/scheduler/test_lsf_driver.py
@@ -6,7 +6,6 @@
 import re
 import stat
 import string
-import sys
 import time
 from contextlib import ExitStack as does_not_raise
 from pathlib import Path
@@ -38,7 +37,7 @@
     parse_bhist,
     parse_bjobs,
 )
-from tests.ert.utils import poll
+from tests.ert.utils import poll, wait_until
 
 from .conftest import mock_bin
 
@@ -1213,11 +1212,7 @@ def mock_poll_once_by_bhist(*args, **kwargs):
     assert driver._bhist_cache and job_id in driver._bhist_cache
 
 
-@pytest.mark.integration_test
-@pytest.mark.flaky(rerun=10)
-async def test_that_kill_before_submit_is_finished_works(
-    tmp_path, monkeypatch, caplog, pytestconfig
-):
+async def test_that_kill_before_submit_is_finished_works(tmp_path, monkeypatch, caplog):
     """This test asserts that it is possible to issue a kill command
     to a realization right after it has been submitted (as in driver.submit()).
 
@@ -1227,19 +1222,7 @@ async def test_that_kill_before_submit_is_finished_works(
     The design of the test alludes to much more flakyness than what is probable in reality,
     thus reruns are allowed to make this pass.
     """
-    os.chdir(tmp_path)
-
-    if pytestconfig.getoption("lsf"):
-        # Allow more time when tested on a real compute cluster to avoid false positives.
-        job_kill_window = 10
-        test_grace_time = 20
-    elif sys.platform.startswith("darwin"):
-        # Mitigate flakiness on low-power test nodes
-        job_kill_window = 5
-        test_grace_time = 10
-    else:
-        job_kill_window = 2
-        test_grace_time = 4
+    monkeypatch.chdir(tmp_path)
 
     bin_path = tmp_path / "bin"
     bin_path.mkdir()
@@ -1254,16 +1237,37 @@ async def test_that_kill_before_submit_is_finished_works(
     caplog.set_level(logging.DEBUG)
     driver = LsfDriver(bsub_cmd="slow_bsub")
 
+    job_path = bin_path / "job.sh"
+    job_path.write_text(
+        dedent(
+            f"""\
+            #!/bin/bash
+
+            do_stop=0
+
+            function handle()
+            {{
+                echo "killed" > {tmp_path}/was_killed
+                do_stop=1
+                exit
+            }}
+            trap handle SIGTERM
+            while [[ $do_stop == 0 ]]
+            do
+                sleep 0.1
+            done
+            """
+        ),
+        encoding="utf-8",
+    )
+    job_path.chmod(job_path.stat().st_mode | stat.S_IEXEC)
+
     # Allow submit and kill to be interleaved by asyncio by issuing
     # submit() in its own asyncio Task:
     asyncio.create_task(
         driver.submit(
-            # The sleep is the time window in which we can kill the job before
-            # the unwanted finish message appears on disk.
             0,
-            "sh",
-            "-c",
-            f"sleep {job_kill_window}; touch {tmp_path}/survived",
+            str(job_path),
         )
     )
     await asyncio.sleep(0.01)  # Allow submit task to start executing
@@ -1281,12 +1285,4 @@ async def finished(iens: int, returncode: int):
     await poll(driver, {0}, finished=finished)
     assert "ERROR" not in str(caplog.text)
 
-    # In case the return value of the killed job is correct but the submitted
-    # shell script is still running for whatever reason, a file called
-    # "survived" will appear on disk. Wait for it, and then ensure it is not
-    # there.
-    assert test_grace_time > job_kill_window, "Wrong test setup"
-    await asyncio.sleep(test_grace_time)
-    assert not Path(
-        "survived"
-    ).exists(), "The process children of the job should also have been killed"
+    wait_until((tmp_path / "was_killed").exists, timeout=10)
diff --git a/tests/ert/unit_tests/scheduler/test_slurm_driver.py b/tests/ert/unit_tests/scheduler/test_slurm_driver.py
index f239e634745..c2e01b5a580 100644
--- a/tests/ert/unit_tests/scheduler/test_slurm_driver.py
+++ b/tests/ert/unit_tests/scheduler/test_slurm_driver.py
@@ -4,16 +4,16 @@
 import random
 import stat
 import string
-import sys
 from contextlib import ExitStack as does_not_raise
 from pathlib import Path
+from textwrap import dedent
 
 import pytest
 from hypothesis import given
 from hypothesis import strategies as st
 
 from ert.scheduler import SlurmDriver
-from tests.ert.utils import poll
+from tests.ert.utils import poll, wait_until
 
 from .conftest import mock_bin
 
@@ -350,24 +350,8 @@ async def test_submit_with_num_cpu(pytestconfig, job_name):
     assert Path("test").read_text(encoding="utf-8") == "test\n"
 
 
-@pytest.mark.integration_test
-@pytest.mark.flaky(reruns=3)
-async def test_kill_before_submit_is_finished(
-    tmp_path, monkeypatch, caplog, pytestconfig
-):
-    os.chdir(tmp_path)
-
-    if pytestconfig.getoption("slurm"):
-        # Allow more time when tested on a real compute cluster to avoid false positives.
-        job_kill_window = 5
-        test_grace_time = 10
-    elif sys.platform.startswith("darwin"):
-        # Mitigate flakiness on low-power test nodes
-        job_kill_window = 5
-        test_grace_time = 10
-    else:
-        job_kill_window = 1
-        test_grace_time = 2
+async def test_kill_before_submit_is_finished(tmp_path, monkeypatch, caplog):
+    monkeypatch.chdir(tmp_path)
 
     bin_path = tmp_path / "bin"
     bin_path.mkdir()
@@ -382,16 +366,38 @@ async def test_kill_before_submit_is_finished(
     caplog.set_level(logging.DEBUG)
     driver = SlurmDriver(sbatch_cmd="slow_sbatch")
 
+    job_path = bin_path / "job.sh"
+    job_path.write_text(
+        dedent(
+            f"""\
+            #!/bin/bash
+
+            do_stop=0
+
+            function handle()
+            {{
+                echo "killed" > {tmp_path}/was_killed
+                do_stop=1
+                exit -1
+            }}
+            trap handle SIGTERM
+            trap handle SIGKILL
+            while [[ $do_stop == 0 ]]
+            do
+                sleep 0.1
+            done
+            """
+        ),
+        encoding="utf-8",
+    )
+    job_path.chmod(job_path.stat().st_mode | stat.S_IEXEC)
+
     # Allow submit and kill to be interleaved by asyncio by issuing
     # submit() in its own asyncio Task:
     asyncio.create_task(
         driver.submit(
-            # The sleep is the time window in which we can kill the job before
-            # the unwanted finish message appears on disk.
             0,
-            "sh",
-            "-c",
-            f"sleep {job_kill_window}; touch {tmp_path}/survived",
+            str(job_path),
         )
     )
     await asyncio.sleep(0.01)  # Allow submit task to start executing
@@ -404,12 +410,4 @@ async def finished(iens: int, returncode: int):
 
     await poll(driver, {0}, finished=finished)
 
-    # In case the return value of the killed job is correct but the submitted
-    # shell script is still running for whatever reason, a file called
-    # "survived" will appear on disk. Wait for it, and then ensure it is not
-    # there.
-    assert test_grace_time > job_kill_window, "Wrong test setup"
-    await asyncio.sleep(test_grace_time)
-    assert not Path(
-        "survived"
-    ).exists(), "The process children of the job should also have been killed"
+    wait_until((tmp_path / "was_killed").exists, timeout=10)