Commit e1cd9f8d authored by Michael Salim's avatar Michael Salim
Browse files

* Removed multiprocessing Lock: django model saves() can happen

concurrently; they block by queueing up at the DB-writer, which is a ZMQ
proxy to the Sqlite database

* dag.spawn_child now clones jobs in a more sane way; fixed bug related
to job_id being sent over the wire as None

* improved logging

* Fixed deserialize bug when job_id is None (for new jobs)

* added test case for concurrent BalsamJob insertions to database from a
single mpi4py applicaiton, where all ranks are concurrently calling
save() with new jobs
parent 47bf5d63
......@@ -204,16 +204,22 @@ def spawn_child(clone=False, **kwargs):
kwargs['workflow'] = current_job.workflow
if clone:
child = BalsamJob.objects.get(pk=current_job.pk)
child.pk = None
child = BalsamJob()
new_pk = child.pk
exclude_fields = '_state version job_id working_directory'.split()
fields = [f for f in current_job.__dict__ if f not in exclude_fields]
for f in fields: child.__dict__[f] = current_job.__dict__[f]
assert child.pk == new_pk
for k,v in kwargs.items():
try:
getattr(child, k)
except AttributeError:
raise ValueError(f"Invalid field {k}")
if k in fields:
child.__dict__[k] = v
else:
setattr(child, k, v)
child.working_directory = '' # This is essential; awful BUG if not here
raise ValueError(f"Invalid field {k}")
child.working_directory = '' # This is essential
child.db_write_client = None
child.save()
else:
child = add_job(**kwargs)
......
......@@ -213,10 +213,10 @@ def on_exit(runner_group, transition_pool, job_source, writer_proc):
logger.debug("on_exit: send end message to transition threads")
transition_pool.end_and_wait()
logger.debug("on_exit: Launcher exit graceful\n\n")
client = db_writer.ZMQClient()
client.term_server()
logger.debug("on_exit: Launcher exit graceful\n\n")
sys.exit(0)
......
......@@ -45,16 +45,18 @@ logger = logging.getLogger('balsam.launcher.transitions')
# DB writes become a bottleneck, we have to go to a DB that supports better
# concurrency -- but SQLite makes it signifcantly easier for users to deploy
# Balsam, because it's built in and requires zero user configuration
class DummyLock:
def acquire(self): pass
def release(self): pass
if sys.platform.startswith('darwin'):
LockClass = multiprocessing.Lock
elif sys.platform.startswith('win32'):
LockClass = multiprocessing.Lock
else:
class DummyLock:
def acquire(self): pass
def release(self): pass
LockClass = DummyLock
LockClass = multiprocessing.Lock # TODO: replace with better solution!
#LockClass = multiprocessing.Lock
LockClass = DummyLock # With db_writer proxy; no need for lock!
logger.debug(f'Using lock: {LockClass}')
PREPROCESS_TIMEOUT_SECONDS = 300
......@@ -330,7 +332,8 @@ def preprocess(job, lock):
lock.acquire()
proc = subprocess.Popen(args, stdout=fp,
stderr=subprocess.STDOUT, env=envs,
cwd=job.working_directory)
cwd=job.working_directory,
)
retcode = proc.wait(timeout=PREPROCESS_TIMEOUT_SECONDS)
proc.communicate()
lock.release()
......@@ -411,7 +414,8 @@ def postprocess(job, lock, *, error_handling=False, timeout_handling=False):
lock.acquire()
proc = subprocess.Popen(args, stdout=fp,
stderr=subprocess.STDOUT, env=envs,
cwd=job.working_directory)
cwd=job.working_directory,
)
retcode = proc.wait(timeout=POSTPROCESS_TIMEOUT_SECONDS)
proc.communicate()
lock.release()
......@@ -421,7 +425,7 @@ def postprocess(job, lock, *, error_handling=False, timeout_handling=False):
raise BalsamTransitionError(message) from e
if retcode != 0:
tail = get_tail(out)
tail = get_tail(out, nlines=30)
message = f"{job.cute_id} postprocess returned {retcode}:\n{tail}"
raise BalsamTransitionError(message)
......
from io import StringIO
from traceback import print_exc
import json
import os
import logging
......@@ -40,6 +42,8 @@ class ZMQProxy:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(f'tcp://*:{PORT}')
logger.info(f"db_writer proxy listening at {self.address}")
logger.info(f"db_writer address written to {self.sock_file}")
return self.socket
def recv_request(self):
......@@ -68,25 +72,33 @@ class ZMQClient:
global SOCKFILE_PATH
SOCKFILE_PATH = settings.INSTALL_PATH
self.discover_zmq_proxy()
if self.zmq_server is not None:
logger.info(f"save() going to server @ {self.zmq_server}")
else:
logger.info(f"No db_writer detected; save() going directly to local db")
def discover_zmq_proxy(self):
path = os.path.join(SOCKFILE_PATH, SOCKFILE_NAME)
if os.path.exists(path):
self.zmq_server = open(path).read().strip()
logger.debug(f"client discover: {self.zmq_server}")
else:
logger.debug(f"client discover: no db_socket_file exists")
self.zmq_server = None
return
if 'tcp://' not in self.zmq_server:
logger.debug(f"client discover: invalid address")
self.zmq_server = None
return
logger.debug(f"client discover: sending request TEST_ALIVE")
response = self.send_request('TEST_ALIVE')
if response == 'ACK':
logger.info(f"save() going to server @ {self.zmq_server}")
else:
logger.info(f"save() going directly to local db")
if response != 'ACK':
self.zmq_server = None
logger.debug(f"client discover: no response; dead server")
else:
logger.debug(f"client discover: the server is alive!")
def send_request(self, msg):
context = zmq.Context()
......@@ -106,6 +118,7 @@ class ZMQClient:
force_update=force_update, using=using,
update_fields=update_fields)
logger.info(f"client: sending request for save of {job.cute_id}")
response = self.send_request(serial_data)
if response is None:
raise OperationalError("ZMQ DB write request timed out")
......@@ -117,16 +130,21 @@ class ZMQClient:
response = self.send_request('TERM')
def server_main():
logger.debug("hello from server_main")
parent_pid = os.getppid()
handler = lambda a,b: 0
signal.signal(signal.SIGINT, handler)
signal.signal(signal.SIGTERM, handler)
logger.debug("making zmq proxy class")
proxy = ZMQProxy()
try:
logger.info("db_writer starting up")
while True:
logger.info(f"proxy waiting for message")
message = proxy.recv_request()
logger.info(f"proxy received message")
if message is None:
if os.getppid() != parent_pid:
logger.info("db_writer detected parent PID died; quitting")
......@@ -140,7 +158,12 @@ def server_main():
break
else:
proxy.send_reply("ACK")
except:
buf = StringIO()
print_exc(file=buf)
logger.exception(f"db_writer Uncaught exception:\n%s", buf.getvalue())
finally:
logger.info("exiting server main; deleting sock_file now")
os.remove(os.path.join(SOCKFILE_PATH, SOCKFILE_NAME))
if __name__ == "__main__":
......
......@@ -273,9 +273,14 @@ class BalsamJob(models.Model):
if type(d['job_id']) is str:
d['job_id'] = uuid.UUID(d['job_id'])
else:
assert d['job_id'] is None
d['job_id'] = job.job_id
for field in SERIAL_FIELDS:
job.__dict__[field] = d[field]
assert type(job.job_id) == uuid.UUID
return job
......@@ -367,7 +372,7 @@ auto timeout retry: {self.auto_timeout_retry}
for i, parent in enumerate(parents_list):
pk = parent.pk if isinstance(parent,BalsamJob) else parent
if not BalsamJob.objects.filter(pk=pk).exists():
raise InvalidParentsError("Job PK {pk} is not in the BalsamJob DB")
raise InvalidParentsError(f"Job PK {pk} is not in the BalsamJob DB")
parents_list[i] = str(pk)
self.parents = json.dumps(parents_list)
self.save(update_fields=['parents'])
......@@ -492,7 +497,11 @@ auto timeout retry: {self.auto_timeout_retry}
def serialize(self, **kwargs):
d = self.to_dict()
d.update(kwargs)
d['job_id'] = str(self.job_id)
if type(self.job_id) == uuid.UUID:
d['job_id'] = str(self.job_id)
else:
assert self.job_id == d['job_id'] == None
serial_data = json.dumps(d)
return serial_data
......
from mpi4py import MPI
import balsam.launcher.dag as dag
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
job_name = f"hello{rank}"
dag.add_job(name=job_name, workflow="test", application="hello", num_nodes=1,
ranks_per_node=1)
print(f"Rank {rank} added job: success")
......@@ -92,9 +92,9 @@ class TestSingleJobTransitions(BalsamTestCase):
self.apps = {}
for name in aliases:
interpreter = sys.executable
exe_path = interpreter + " " + find_spec(f'tests.ft_apps.{name}').origin
pre_path = interpreter + " " + find_spec(f'tests.ft_apps.{name}_pre').origin
post_path = interpreter + " " + find_spec(f'tests.ft_apps.{name}_post').origin
exe_path = interpreter + " " + find_spec(f'tests.ft_apps.dag.{name}').origin
pre_path = interpreter + " " + find_spec(f'tests.ft_apps.dag.{name}_pre').origin
post_path = interpreter + " " + find_spec(f'tests.ft_apps.dag.{name}_post').origin
app = create_app(name=name, executable=exe_path, preproc=pre_path,
postproc=post_path)
self.apps[name] = app
......@@ -347,9 +347,9 @@ class TestDAG(BalsamTestCase):
self.apps = {}
for name in aliases:
interpreter = sys.executable
exe_path = interpreter + " " + find_spec(f'tests.ft_apps.{name}').origin
pre_path = interpreter + " " + find_spec(f'tests.ft_apps.{name}_pre').origin
post_path = interpreter + " " + find_spec(f'tests.ft_apps.{name}_post').origin
exe_path = interpreter + " " + find_spec(f'tests.ft_apps.dag.{name}').origin
pre_path = interpreter + " " + find_spec(f'tests.ft_apps.dag.{name}_pre').origin
post_path = interpreter + " " + find_spec(f'tests.ft_apps.dag.{name}_post').origin
app = create_app(name=name, executable=exe_path, preproc=pre_path,
postproc=post_path)
self.apps[name] = app
......@@ -864,11 +864,40 @@ class TestThreadPlacement(BalsamTestCase):
self.check_omp_exe_output(self.job1)
self.check_omp_exe_output(self.job2)
#class TestUserKill(BalsamTestCase):
# def setUp(self):
# self.app_path = find_spec("tests.ft_apps.c_apps").origin)
# self.app = create_app(name='omp')
#
# self.job0 = create_job(name='job0', app='omp', num_nodes=2, ranks_per_node=32, threads_per_rank=2)
# self.job1 = create_job(name='job1', app='omp', num_nodes=2, ranks_per_node=64, threads_per_rank=1)
# self.job2 = create_job(name='job2', app='omp', num_nodes=1, ranks_per_node=2, threads_per_rank=64, threads_per_core=2)
class TestConcurrentDB(BalsamTestCase):
def setUp(self):
from balsam.service.schedulers import Scheduler
scheduler = Scheduler.scheduler_main
if scheduler.num_workers:
self.num_nodes = scheduler.num_workers
else:
self.num_nodes = 1
hello_path = find_spec("tests.ft_apps.concurrent.hello").origin
insert_path = find_spec("tests.ft_apps.concurrent.mpi_insert").origin
interpreter = sys.executable
hello_path= f"{sys.executable} {hello_path}"
insert_path= f"{sys.executable} {insert_path}"
create_app(name="hello", executable=hello_path)
create_app(name="mpi4py-insert", executable=insert_path)
def test_many_write(self):
'''Many ranks can simultaneously add a job to the DB'''
job = create_job(name="mpi_insert", app='mpi4py-insert',
num_nodes=self.num_nodes, ranks_per_node=16)
num_ranks = job.num_ranks
success = run_launcher_until_state(job, 'JOB_FINISHED', period=2)
self.assertTrue(success)
created_jobs = BalsamJob.objects.filter(name__icontains='hello')
self.assertEqual(created_jobs.count(), num_ranks)
class TestUserKill(BalsamTestCase):
def setUp(self):
self.app_path = find_spec("tests.ft_apps.dynamic_kill.killer").origin
self.app_path = find_spec("tests.ft_apps.dynamic_kill.slow").origin
def test_kill_during_preprocess(self):
self.skipTest("implement me!")
def test_kill_during_execution(self):
self.skipTest("implement me!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment