Commit 75f61897 authored by Michael Salim's avatar Michael Salim
Browse files

User kill now works properly for both serial and mpi jobs

Added test cases
parent e1cd9f8d
......@@ -11,7 +11,7 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings'
django.setup()
logger = logging.getLogger('balsam.launcher.mpi_ensemble')
from subprocess import Popen, STDOUT
from subprocess import Popen, STDOUT, TimeoutExpired
from mpi4py import MPI
......@@ -61,9 +61,21 @@ def read_jobs(fp):
else:
logger.debug("Invalid workdir")
def poll_execution_or_killed(job, proc, period=10):
retcode = None
while retcode is None:
try:
retcode = proc.wait(timeout=period)
except TimeoutExpired:
job.refresh_from_db()
if job.state == 'USER_KILLED':
logger.debug(f"{job.cute_id} USER_KILLED; terminating it now")
proc.terminate()
return "USER_KILLED"
else:
return retcode
def run(job):
job_from_db = BalsamJob.objects.get(pk=job.id)
if job_from_db.state == 'USER_KILLED':
......@@ -85,7 +97,8 @@ def run(job):
signal.signal(signal.SIGINT, handler)
signal.signal(signal.SIGTERM, handler)
retcode = proc.wait()
retcode = poll_execution_or_killed(job_from_db, proc)
except Exception as e:
logger.exception(f"mpi_ensemble rank {RANK} job {job.id}: exception during Popen")
status_msg(job.id, "FAILED", msg=str(e))
......@@ -94,6 +107,8 @@ def run(job):
if retcode == 0:
logger.debug(f"mpi_ensemble rank {RANK}: job returned 0")
status_msg(job.id, "RUN_DONE")
elif retcode == "USER_KILLED":
status_msg(job.id, "USER_KILLED", msg="mpi_ensemble aborting job due to user request")
else:
outf.flush()
tail = get_tail(outf.name).replace('\n', '\\n')
......
......@@ -253,7 +253,7 @@ class MPIEnsembleRunner(Runner):
logger.info(f"MPIEnsemble {job.cute_id} updated to {state}: {msg}")
except (ValueError, KeyError, InvalidStateError) as e:
if 'resources: utime' not in line:
logger.error(f"Invalid statusMsg from mpi_ensemble: {line.strip()}")
logger.error(f"Unexpected statusMsg from mpi_ensemble: {line.strip()}")
retcode = None
if timeout:
......
......@@ -142,9 +142,9 @@ def server_main():
try:
logger.info("db_writer starting up")
while True:
logger.info(f"proxy waiting for message")
#logger.info(f"proxy waiting for message")
message = proxy.recv_request()
logger.info(f"proxy received message")
#logger.info(f"proxy received message")
if message is None:
if os.getppid() != parent_pid:
logger.info("db_writer detected parent PID died; quitting")
......
......@@ -425,6 +425,9 @@ auto timeout retry: {self.auto_timeout_retry}
if new_state not in STATES:
raise InvalidStateError(f"{new_state} is not a job state in balsam.models")
self.refresh_from_db()
if self.state == 'USER_KILLED': return
self.state_history += history_line(new_state, message)
self.state = new_state
try:
......
import balsam.launcher.dag as dag
import time
start = time.time()
while dag.BalsamJob.objects.filter(name="slow_job").count() == 0:
time.sleep(2)
if time.time() - start > 40:
raise RuntimeError("the slow job never started")
slow_job = dag.BalsamJob.objects.get(name='slow_job')
dag.kill(slow_job)
from mpi4py import MPI
import sys
import time
rank = MPI.COMM_WORLD.Get_rank()
if rank == 0: print("Sleeping for a long time...")
if len(sys.argv) == 1:
delay = 10
else:
delay = int(sys.argv[1])
time.sleep(20)
if 'parallel' in sys.argv:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
else:
rank = 0
print(f"Rank {rank} Sleeping for a long time...")
sys.stdout.flush()
time.sleep(delay)
if rank == 0: print("Done")
......@@ -891,13 +891,72 @@ class TestConcurrentDB(BalsamTestCase):
created_jobs = BalsamJob.objects.filter(name__icontains='hello')
self.assertEqual(created_jobs.count(), num_ranks)
class TestUserKill(BalsamTestCase):
def setUp(self):
self.app_path = find_spec("tests.ft_apps.dynamic_kill.killer").origin
self.app_path = find_spec("tests.ft_apps.dynamic_kill.slow").origin
killer_name = find_spec("tests.ft_apps.dynamic_kill.killer").origin
slow_name = find_spec("tests.ft_apps.dynamic_kill.slow").origin
interpreter = sys.executable
self.killer_name = f"{interpreter} {killer_name}"
self.slow_name = f"{interpreter} {slow_name}"
create_app(name="killer", executable=self.killer_name)
create_app(name="slow", executable=self.slow_name)
def test_kill_during_preprocess(self):
self.skipTest("implement me!")
'''Job killed while pre-processing is properly marked'''
killer_job = create_job(name="killer", app="killer")
slow_job = create_job(name="slow_job", app="slow", preproc=self.slow_name, args="30")
success = run_launcher_until_state(killer_job, 'JOB_FINISHED')
self.assertTrue(success)
slow_job.refresh_from_db()
self.assertEqual(slow_job.state, "USER_KILLED")
preproc_out = slow_job.read_file_in_workdir('preprocess.log')
self.assertIn("Sleeping for a long time", preproc_out)
self.assertNotIn("RUNNING", slow_job.state_history)
self.assertIn("STAGED_IN", slow_job.state_history)
def test_kill_during_execution_serial(self):
'''Serial job running in mpi_ensemble is properly terminated'''
killer_job = create_job(name="killer", app="killer")
slow_job = create_job(name="slow_job", app="slow", args="30")
success = run_launcher_until_state(killer_job, 'JOB_FINISHED')
self.assertTrue(success)
slow_job.refresh_from_db()
self.assertEqual(slow_job.state, "USER_KILLED")
stdout = slow_job.read_file_in_workdir('slow_job.out')
self.assertIn("Sleeping for a long time", stdout)
self.assertIn("RUNNING", slow_job.state_history)
self.assertIn("USER_KILLED", slow_job.state_history)
self.assertNotIn("RUN_DONE", slow_job.state_history)
def test_kill_during_execution_mpi(self):
'''Parallel MPIRunner job is properly terminated'''
from balsam.service.schedulers import Scheduler
scheduler = Scheduler.scheduler_main
if scheduler.num_workers:
num_workers = scheduler.num_workers
if num_workers < 2:
self.skipTest("Need at least 2 workers to run this test")
else:
self.skipTest("Need environment with multiple workers to run this test")
killer_job = create_job(name="killer", app="killer")
slow_job = create_job(name="slow_job", app="slow", ranks_per_node=2,
args="30 parallel")
success = run_launcher_until_state(killer_job, 'JOB_FINISHED')
self.assertTrue(success)
def test_kill_during_execution(self):
self.skipTest("implement me!")
slow_job.refresh_from_db()
self.assertEqual(slow_job.state, "USER_KILLED")
stdout = slow_job.read_file_in_workdir('slow_job.out')
self.assertIn("Rank 0 Sleeping for a long time", stdout)
self.assertIn("Rank 1 Sleeping for a long time", stdout)
self.assertIn("RUNNING", slow_job.state_history)
self.assertIn("USER_KILLED", slow_job.state_history)
self.assertNotIn("RUN_DONE", slow_job.state_history)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment