Commit f60d282a authored by Michael Salim's avatar Michael Salim
Browse files

* rewrote dag.spawn_child -- more robust and fewer calls to database

* user kill works even when a kill save() takes a long time in flight
  and some other processing happens first
* updated functional tests
* all tests passing on Theta, Cooley
parent 036cf45b
......@@ -61,6 +61,7 @@ follows::
'''
import django
import json
import os
import uuid
......@@ -72,7 +73,7 @@ __all__ = ['JOB_ID', 'TIMEOUT', 'ERROR',
os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings'
django.setup()
from balsam.service.models import BalsamJob
from balsam.service.models import BalsamJob, history_line
from django.conf import settings
current_job = None
......@@ -202,31 +203,36 @@ def spawn_child(clone=False, **kwargs):
if 'workflow' not in kwargs:
kwargs['workflow'] = current_job.workflow
if 'allowed_work_sites' not in kwargs:
kwargs['allowed_work_sites'] = settings.BALSAM_SITE
if clone:
child = BalsamJob()
new_pk = child.pk
child = BalsamJob()
new_pk = child.pk
exclude_fields = '_state version job_id working_directory'.split()
fields = [f for f in current_job.__dict__ if f not in exclude_fields]
for f in fields: child.__dict__[f] = current_job.__dict__[f]
exclude_fields = '_state version state_history job_id working_directory'.split()
fields = [f for f in current_job.__dict__ if f not in exclude_fields]
if clone:
for f in fields:
child.__dict__[f] = current_job.__dict__[f]
assert child.pk == new_pk
for k,v in kwargs.items():
if k in fields:
child.__dict__[k] = v
else:
raise ValueError(f"Invalid field {k}")
for k,v in kwargs.items():
if k in fields:
child.__dict__[k] = v
else:
raise ValueError(f"Invalid field {k}")
child.working_directory = '' # This is essential
child.db_write_client = None
child.save()
else:
child = add_job(**kwargs)
child.working_directory = '' # This is essential
child.db_write_client = None
add_dependency(current_job, child)
child.state_history = ''
child.update_state("CREATED", f"spawned by {current_job.cute_id}")
newparents = json.loads(current_job.parents)
newparents.append(str(current_job.job_id))
child.parents = json.dumps(newparents)
child.state = "CREATED"
child.state_history = history_line("CREATED", f"spawned by {current_job.cute_id}")
child.save()
return child
def kill(job, recursive=True):
......
......@@ -425,7 +425,8 @@ auto timeout retry: {self.auto_timeout_retry}
if new_state not in STATES:
raise InvalidStateError(f"{new_state} is not a job state in balsam.models")
self.refresh_from_db()
try: self.refresh_from_db()
except ObjectDoesNotExist: pass
if self.state == 'USER_KILLED': return
self.state_history += history_line(new_state, message)
......
......@@ -6,6 +6,7 @@ import getpass
import sys
import signal
import subprocess
import time
import tempfile
from importlib.util import find_spec
......@@ -54,6 +55,7 @@ def stop_launcher_processes():
print("\n".join(processes))
print("Sending SIGKILL")
sig_processes(processes, signal.SIGKILL)
time.sleep(3)
def run_launcher_until(function, args=(), period=1.0, timeout=60.0):
......@@ -937,13 +939,16 @@ class TestUserKill(BalsamTestCase):
def test_kill_during_execution_mpi(self):
'''Parallel MPIRunner job is properly terminated'''
from balsam.service.schedulers import Scheduler
from balsam.launcher import worker
from balsam.launcher.launcher import get_args
config = get_args('--consume-all'.split())
scheduler = Scheduler.scheduler_main
if scheduler.num_workers:
num_workers = scheduler.num_workers
if num_workers < 2:
self.skipTest("Need at least 2 workers to run this test")
else:
self.skipTest("Need environment with multiple workers to run this test")
group = worker.WorkerGroup(config, host_type=scheduler.host_type,
workers_str=scheduler.workers_str,
workers_file=scheduler.workers_file)
if len(group.workers) < 2:
self.skipTest("Need at least 2 workers to run this test")
killer_job = create_job(name="killer", app="killer")
slow_job = create_job(name="slow_job", app="slow", ranks_per_node=2,
......
......@@ -23,7 +23,8 @@ from balsam.launcher import runners
from balsam.launcher.launcher import get_args, create_runner
def ls_procs(keywords):
if type(keywords) == str: keywords = [keywords]
if type(keywords) == str:
keywords = keywords.split()
username = getpass.getuser()
......@@ -58,6 +59,7 @@ def stop_processes(name):
processes = ls_procs(name)
if processes:
sig_processes(processes, signal.SIGKILL)
time.sleep(3)
class TestMPIRunner(BalsamTestCase):
'''start, update_jobs, finished, error/timeout handling'''
......@@ -276,8 +278,7 @@ class TestMPIEnsemble(BalsamTestCase):
self.assertTrue(all(j.state=='RUN_ERROR' for j in jobs['fail']))
# Kill the sleeping jobs in case they do not terminate
stop_processes('mpi_ensemble')
stop_processes('mock_serial')
stop_processes('mpi_ensemble mock_serial')
class TestRunnerGroup(BalsamTestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment