Commit e4ea2070 authored by Michael Salim's avatar Michael Salim
Browse files

Python multiprocessing.Lock: manually serialized sqlite3 access

parent f3103f61
......@@ -10,3 +10,4 @@ env
......@@ -6,7 +6,7 @@ from datetime import datetime
from socket import gethostname
import uuid
from django.core.exceptions import ValidationError
from django.core.exceptions import ValidationError,ObjectDoesNotExist
from django.conf import settings
from django.db import models
from concurrency.fields import IntegerVersionField
......@@ -17,6 +17,7 @@ logger = logging.getLogger(__name__)
class InvalidStateError(ValidationError): pass
class InvalidParentsError(ValidationError): pass
class NoApplication(Exception): pass
TIME_FMT = '%m-%d-%Y %H:%M:%S'
......@@ -166,7 +167,7 @@ class BalsamJob(models.Model):
'Number of Compute Nodes',
help_text='The number of compute nodes requested for this job.',
processes_per_node = models.IntegerField(
ranks_per_node = models.IntegerField(
'Number of Processes per Node',
help_text='The number of MPI processes per node to schedule for this job.',
......@@ -272,7 +273,7 @@ actual_runtime: {self.runtime_str()}
num_nodes: {self.num_nodes}
threads per rank: {self.threads_per_rank}
threads per core: {self.threads_per_core}
processes_per_node: {self.processes_per_node}
ranks_per_node: {self.ranks_per_node}
scheduler_id: {self.scheduler_id}
application: {self.application if self.application else
......@@ -294,10 +295,22 @@ auto timeout retry: {self.auto_timeout_retry}
parent_ids = self.get_parents_by_id()
return BalsamJob.objects.filter(job_id__in=parent_ids)
def num_ranks(self):
return self.num_nodes * self.ranks_per_node
def cute_id(self):
return f"[{ str([:8] }]"
def app_cmd(self):
if self.application:
app = ApplicationDefinition.objects.get(name=job.application)
return f"{app.executable} {app.application_args}"
return self.direct_command
def get_children(self):
return BalsamJob.objects.filter(parents__icontains=str(
......@@ -319,9 +332,10 @@ auto timeout retry: {self.auto_timeout_retry}['parents'])
def get_application(self):
if not self.application:
return None
return ApplicationDefinition.objects.get(name=self.application)
if self.application:
return ApplicationDefinition.objects.get(name=self.application)
raise NoApplication
def parse_envstring(s):
......@@ -331,8 +345,12 @@ auto timeout retry: {self.auto_timeout_retry}
return {variable:value for (variable,value) in entries}
def get_envs(self, *, timeout=False, error=False):
envs = os.environ.copy()
app = self.get_application()
#envs = os.environ.copy()
envs = {}
app = self.get_application()
except NoApplication:
app = None
if app and app.environ_vars:
app_vars = self.parse_envstring(app.environ_vars)
......@@ -352,13 +370,13 @@ auto timeout retry: {self.auto_timeout_retry}
return envs
def update_state(self, new_state, message=''):
def update_state(self, new_state, message='',using=None):
if new_state not in STATES:
raise InvalidStateError(f"{new_state} is not a job state in balsam.models")
self.state_history += history_line(new_state, message)
self.state = new_state['state', 'state_history'])['state', 'state_history'],using=using)
def get_recent_state_str(self):
return self.state_history.split("\n")[-1].strip()
import os
class cd:
'''Context manager for changing cwd'''
def __init__(self, new_path):
......@@ -7,5 +9,7 @@ class cd:
self.saved_path = os.getcwd()
def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
print(exc_type, exc_value, traceback)
class BalsamLauncherError(Exception): pass
class BalsamRunnerError(Exception): pass
class ExceededMaxConcurrentRunners(BalsamRunnerError): pass
class ExceededMaxRunners(BalsamRunnerError): pass
class NoAvailableWorkers(BalsamRunnerError): pass
class BalsamTransitionError(Exception): pass
......@@ -53,6 +53,7 @@ def create_new_runners(jobs, runner_group, worker_group):
created_one = False
running_pks = runner_group.running_job_pks
runnable_jobs = get_runnable_jobs(jobs, running_pks)
logger.debug(f"Have {len(runnable_jobs)} runnable jobs")
while runnable_jobs:
runner_group.create_next_runner(runnable_jobs, worker_group)
......@@ -68,11 +69,10 @@ def main(args, transition_pool, runner_group, job_source):
delay_timer = delay()
while not scheduler.remaining_time_seconds() <= 0.0:
logger.debug("Begin launcher service loop")
logger.debug("\n************\nSERVICE LOOP\n************")
wait = True
for stat in transition_pool.get_statuses():'[{str([:8]}] transitioned to {stat.state}: {stat.msg}')
wait = False
......@@ -85,9 +85,11 @@ def main(args, transition_pool, runner_group, job_source):
for job in transitionable_jobs:
wait = False"Queued trans: {job.cute_id} in state {job.state}")
fxn = transitions.TRANSITIONS[job.state]"Queued transition: {job.cute_id} will undergo {fxn}")
any_finished = runner_group.update_and_remove_finished()
created = create_new_runners(, runner_group, worker_group)
if any_finished or created: wait = False
if wait: next(delay_timer)
......@@ -119,7 +121,8 @@ def get_args():
help="Continuously run jobs of specified workflow")
parser.add_argument('--num-workers', type=int, default=1,
help="Theta: defaults to # nodes. BGQ: the # of subblocks")
parser.add_argument('--serial-jobs-per-worker', type=int, default=4,
parser.add_argument('--nodes-per-worker', type=int, default=1)
parser.add_argument('--max-ranks-per-node', type=int, default=1,
help="For non-MPI jobs, how many to pack per worker")
parser.add_argument('--time-limit-minutes', type=int,
help="Provide a walltime limit if not already imposed")
......@@ -136,8 +139,8 @@ if __name__ == "__main__":
job_source = jobreader.JobReader.from_config(args)
runner_group = runners.RunnerGroup()
transition_pool = transitions.TransitionProcessPool()
runner_group = runners.RunnerGroup(transition_pool.lock)
worker_group = worker.WorkerGroup(args, host_type=scheduler.host_type,
class DEFAULTMPICommand(object):
'''Single node OpenMPI: ppn == num_ranks'''
def __init__(self):
self.mpi = 'mpirun'
self.nproc = '-n'
self.ppn = '-ppn'
self.env = '-env'
self.ppn = '-npernode'
self.env = '-x'
self.cpu_binding = None
self.threads_per_rank = None
self.threads_per_core = None
......@@ -11,28 +12,27 @@ class DEFAULTMPICommand(object):
def worker_str(self, workers):
return ""
def env_str(self, job):
if job.environ_vars:
return f"{self.env} {job.environ_vars}"
return ""
def env_str(self, envs):
envstrs = (f"{self.env} {var}={val}" for var,val in envs.items())
return " ".join(envstrs)
def threads(self, job):
def threads(self, thread_per_rank, thread_per_core):
result= ""
if self.cpu_binding:
result += f"{self.cpu_binding} "
if self.threads_per_rank:
result += f"{self.threads_per_rank} {job.threads_per_rank} "
result += f"{self.threads_per_rank} {thread_per_rank} "
if self.threads_per_core:
result += f"{self.threads_per_core} {job.threads_per_core} "
result += f"{self.threads_per_core} {thread_per_core} "
return result
def __call__(self, job, workers, nproc=None):
if nproc is None:
nproc = job.num_nodes * job.processes_per_node
def __call__(self, workers, *, app_cmd, num_ranks, ranks_per_node, envs,threads_per_rank=1,threads_per_core=1):
'''Build the mpirun/aprun/runjob command line string'''
workers = self.worker_str(workers)
envs = self.env_str(job)
result = (f"{self.mpi} {self.nproc} {nproc} {self.ppn} "
"{job.processes_per_node} {envs} {workers} {threads} ")
envs = self.env_str(envs)
thread_str = self.threads(threads_per_rank, threads_per_core)
result = (f"{self.mpi} {self.nproc} {num_ranks} {self.ppn} "
f"{num_ranks} {envs} {workers} {thread_str} {app_cmd}")
return result
......@@ -46,15 +46,6 @@ class BGQMPICommand(DEFAULTMPICommand):
self.threads_per_rank = None
self.threads_per_core = None
def env_str(self, job):
if not job.environ_vars:
return ""
envs = job.environ_vars.split(':')
result = ""
for env in envs:
result += f"{self.env} {env} "
return result
def worker_str(self, workers):
if len(workers) != 1:
raise BalsamRunnerException("BGQ requires exactly 1 worker (sub-block)")
......@@ -68,20 +59,11 @@ class CRAYMPICommand(DEFAULTMPICommand):
self.mpi = 'aprun'
self.nproc = '-n'
self.ppn = '-N'
self.env = '-e' # VAR1=val1:VAR2=val2
self.env = '-e'
self.cpu_binding = '-cc depth'
self.threads_per_rank = '-d'
self.threads_per_core = '-j'
def env_str(self, job):
if not job.environ_vars:
return ""
envs = job.environ_vars.split(':')
result = ""
for env in envs:
result += f"{self.env} {env} "
return result
def worker_str(self, workers):
if not workers:
return ""
from collections import namedtuple
import os
import sys
import logging
import django
os.environ['DJANGO_SETTINGS_MODULE'] = 'argobalsam.settings'
logger = logging.getLogger('balsamlauncher.mpi_ensemble')
from subprocess import Popen, STDOUT
from mpi4py import MPI
from import cd
from balsamlauncher.exceptions import *
......@@ -21,15 +27,20 @@ def read_jobs(fp):
for line in fp:
id, workdir, *command = line.split()
logger.debug(f"Read Job {id} CMD: {command} DIR: {workdir}")
logger.debug("Invalid jobline")
if id and command and os.path.isdir(workdir):
yield Job(id, workdir, command)
logger.debug("Invalid workdir")
def run(job):
basename = os.path.basename(job.workdir)
outname = f"{basename}.out"
logger.debug(f"Running job {}")
with cd(job.workdir) as _, open(outname, 'wb') as outf:
status_msg(, "RUNNING", msg="executing from mpi_ensemble")
......@@ -39,10 +50,10 @@ def run(job):
status_msg(, "FAILED", msg=str(e))
raise MPIEnsembleError from e
if retcode == 0: status_msg(, "RUN_FINISHED")
if retcode == 0: status_msg(, "RUN_DONE")
else: status_msg(, "RUN_ERROR", msg=f"process return code {retcode}")
def main(jobs_path):
......@@ -53,8 +64,10 @@ def main(jobs_path):
job_list = list(read_jobs(fp))
job_list = COMM.bcast(job_list, root=0)
logger.debug(f"Broadcasted job list. Total {len(job_list)} jobs to run")
for job in job_list[RANK::COMM.size]: run(job)
if __name__ == "__main__":
path = sys.argv[1]
logger.debug(f"Starting Reading jobs from {path}")
......@@ -57,7 +57,7 @@ class Runner:
mpi_cmd_class = getattr(mpi_commands, f"{host_type}MPICommand")
self.mpi_cmd = mpi_cmd_class() = job_list
self.jobs_by_pk = { : job for job in}
self.jobs_by_pk = {str( : job for job in}
self.process = None
self.monitor = None
self.outfile = None
......@@ -75,19 +75,10 @@ class Runner:
def finished(self):
return self.process.poll() is not None
def get_app_cmd(job):
if job.application:
app = ApplicationDefinition.objects.get(name=job.application)
return f"{app.executable} {app.application_args}"
return job.direct_command
def timeout(self):
with transaction.atomic():
for job in
if job.state == 'RUNNING': job.update_state('RUN_TIMEOUT')
for job in
if job.state == 'RUNNING': job.update_state('RUN_TIMEOUT')
class MPIRunner(Runner):
'''One subprocess, one job'''
......@@ -95,18 +86,24 @@ class MPIRunner(Runner):
super().__init__(job_list, worker_list)
if len( != 1:
raise BalsamRunnerException('MPIRunner must take exactly 1 job')
raise BalsamRunnerError('MPIRunner must take exactly 1 job')
job =[0]
app_cmd = self.get_app_cmd(job)
mpi_str = self.mpi_cmd(job, worker_list)
envs = job.get_envs() # dict
app_cmd = job.app_cmd
nranks = job.num_ranks
rpn = job.ranks_per_node
tpr = job.threads_per_rank
tpc = job.threads_per_core
mpi_str = self.mpi_cmd(worker_list, app_cmd=app_cmd, envs=envs,
num_ranks=nranks, ranks_per_node=rpn,
threads_per_rank=tpr, threads_per_core=tpc)
basename = os.path.basename(job.working_directory)
outname = os.path.join(job.working_directory, f"{basename}.out")
self.outfile = open(outname, 'w+b')
command = f"{mpi_str} {app_cmd}"
self.popen_args['args'] = shlex.split(command)
self.popen_args['args'] = shlex.split(mpi_str)
self.popen_args['cwd'] = job.working_directory
self.popen_args['stdout'] = self.outfile
self.popen_args['stderr'] = STDOUT
......@@ -117,12 +114,15 @@ class MPIRunner(Runner):
#job.refresh_from_db() # TODO: handle RecordModified
retcode = self.process.poll()
if retcode == None:
logger.debug(f"Job {job.cute_id} still running")
curstate = 'RUNNING'
msg = ''
elif retcode == 0:
curstate = 'RUN_FINISHED'
logger.debug(f"Job {job.cute_id} return code 0: done")
curstate = 'RUN_DONE'
msg = ''
logger.debug(f"Job {job.cute_id} return code!=0: error")
curstate = 'RUN_ERROR'
msg = str(retcode)
if job.state != curstate: job.update_state(curstate, msg) # TODO: handle RecordModified
......@@ -142,83 +142,120 @@ class MPIEnsembleRunner(Runner):
self.popen_args['stderr'] = STDOUT
self.popen_args['cwd'] = root_dir
# reads jobs from this temp file
with NamedTemporaryFile(prefix='mpi-ensemble', dir=root_dir,
delete=False, mode='w') as fp:
self.ensemble_filename =
ensemble_filename = os.path.abspath(
for job in
cmd = self.get_app_cmd(job)
cmd = job.app_cmd
fp.write(f"{} {job.working_directory} {cmd}\n")
nproc = sum(w.ranks_per_worker for w in worker_list)
mpi_str = self.mpi_cmd([0], worker_list, nproc=nproc)
rpn = worker_list[0].max_ranks_per_node
nranks = sum(w.num_nodes*rpn for w in worker_list)
envs =[0].get_envs() # TODO: different envs for each job
app_cmd = f"{sys.executable} {mpi_ensemble_exe} {ensemble_filename}"
mpi_str = self.mpi_cmd(worker_list, app_cmd=app_cmd, envs=envs,
num_ranks=nranks, ranks_per_node=rpn)
command = f"{mpi_str} {mpi_ensemble_exe} {self.ensemble_filename}"
self.popen_args['args'] = shlex.split(command)
self.popen_args['args'] = shlex.split(mpi_str)
logger.debug(f"MPI Ensemble Popen args: {self.popen_args['args']}")
def update_jobs(self):
'''Relies on stdout of'''
retcode = self.process.poll()
if retcode not in [None, 0]:
msg = " had nonzero return code:\n"
msg += "".join(self.monitor.available_lines())
raise RuntimeError(msg)
logger.debug("Checking mpi_ensemble stdout for status updates...")
for line in self.monitor.available_lines():
logger.debug(f"Monitor stdout line: {line.strip()}")
pk, state, *msg = line.split()
msg = ' '.join(msg)
if pk in self.jobs_by_pk and state in balsam.models.STATES:
job = self.jobs_by_pk[pk]
job.update_state(state, msg) # TODO: handle RecordModified exception
logger.debug(f"MPIEnsemble job {job.cute_id} updated to {state}")
raise BalsamRunnerException(f"Invalid status update: {status}")
logger.error(f"Invalid status update: {line.strip()}")
class RunnerGroup:
def __init__(self):
def __init__(self, lock):
self.runners = []
self.lock = lock
def __iter__(self):
return iter(self.runners)
def create_next_runner(runnable_jobs, workers):
def create_next_runner(self, runnable_jobs, workers):
'''Implements one particular strategy for choosing the next job, assuming
all jobs are either single-process or MPI-parallel. Will return the serial
ensemble job or single MPI job that occupies the largest possible number of
idle nodes'''
if len(self.runners) == MAX_CONCURRENT_RUNNERS:
if len(self.runners) == self.MAX_CONCURRENT_RUNNERS:"Cannot create another runner: at max")
raise ExceededMaxRunners(
f"Cannot have more than {MAX_CONCURRENT_RUNNERS} simultaneous runners"
f"Cannot have more than {self.MAX_CONCURRENT_RUNNERS} simultaneous runners"
idle_workers = [w for w in workers if w.idle]
nidle = len(idle_workers)
rpw = workers[0].ranks_per_worker
assert all(w.ranks_per_worker == rpw for w in idle_workers)"{nidle} idle workers; {len(runnable_jobs)} runnable jobs")
serial_jobs = [j for j in runnable_jobs
if j.num_nodes == 1 and j.processes_per_node == 1]
nidle_workers = len(idle_workers)
nodes_per_worker = workers[0].num_nodes
rpn = workers[0].max_ranks_per_node
assert all(w.num_nodes == nodes_per_worker for w in idle_workers)
assert all(w.max_ranks_per_node == rpn for w in idle_workers)"Creating next runner: {nidle_workers} idle workers with "
f"{nodes_per_worker} nodes per worker; {len(runnable_jobs)} runnable jobs")
nidle_nodes = nidle_workers * nodes_per_worker
nidle_ranks = nidle_nodes * rpn
serial_jobs = [j for j in runnable_jobs if j.num_ranks == 1]
nserial = len(serial_jobs)
logger.debug(f"{nserial} single-process jobs can run")
mpi_jobs = [j for j in runnable_jobs if 1 < j.num_nodes <= nidle or
(1==j.num_nodes<=nidle and j.processes_per_node > 1)]
mpi_jobs = [j for j in runnable_jobs if 1 < j.num_nodes <= nidle_nodes or
(1==j.num_nodes<=nidle_nodes and j.ranks_per_node > 1)]
largest_mpi_job = (max(mpi_jobs, key=lambda job: job.num_nodes)
if mpi_jobs else None)
if largest_mpi_job:
logger.debug(f"{len(mpi_jobs)} MPI jobs can run; largest takes "
f"{largest_mpi_job.num_nodes} nodes")
logger.debug("No MPI jobs can run")
# Try to fill all available nodes with serial ensemble runner
# If there are not enough serial jobs; run the larger of:
# largest MPI job that fits, or the remaining serial jobs
if nserial >= nidle*rpw:
jobs = serial_jobs[:nidle*rpw]
if nserial >= nidle_ranks:
jobs = serial_jobs[:nidle_ranks]
assigned_workers = idle_workers
runner_class = MPIEnsembleRunner"Running {len(jobs)} serial jobs on {nidle_workers} workers "
f"with {nodes_per_worker} nodes-per-worker and {rpn} ranks per node")
elif largest_mpi_job and largest_mpi_job.num_nodes > nserial // rpw:
jobs = [largest_mpi_job]
assigned_workers = idle_workers[:largest_mpi_job.num_nodes]
num_workers = ceil(largest_mpi_job.num_nodes / nodes_per_worker)
assigned_workers = idle_workers[:num_workers]
runner_class = MPIRunner"Running {largest_mpi_job.num_nodes}-node MPI job")
jobs = serial_jobs
assigned_workers = idle_workers[:ceil(float(nserial)/rpw)]
nworkers = ceil(nserial/rpn/nodes_per_worker)
assigned_workers = idle_workers[:nworkers]
runner_class = MPIEnsembleRunner"Running {len(jobs)} serial jobs on {nworkers} workers "
f"totalling {nworkers*nodes_per_worker} nodes "
f"with {rpn} ranks per worker")
if not jobs: raise NoAvailableWorkers
if not jobs:"Not enough idle workers to handle the runnable jobs")
raise NoAvailableWorkers
runner = runner_class(jobs, assigned_workers)
......@@ -228,12 +265,20 @@ class RunnerGroup:
def update_and_remove_finished(self):
# TODO: Benchmark performance overhead; does grouping into one
# transaction save significantly?
logger.debug(f"Checking status of {len(self.runners)} active runners")
any_finished = False
with transaction.atomic():
for runner in self.runners: runner.update_jobs()
for runner in self.runners: runner.update_jobs()
for runner in self.runners[:]:
if runner.finished():
for job in
if job.state not in 'RUN_DONE RUN_ERROR RUN_TIMEOUT'.split():
msg = (f"Job {job.cute_id} runner process done, but "
"failed to update job state.")
raise RuntimeError(msg)
any_finished = True
for worker in runner.worker_list:
......@@ -12,9 +12,9 @@ from django.core.exceptions import ObjectDoesNotExist
from django.conf import settings
from django import db
from common import transfer
from common import transfer
from balsamlauncher.exceptions import *
from balsam.models import BalsamJob
from balsam.models import BalsamJob, NoApplication
import logging
logger = logging.getLogger('balsamlauncher.transitions')
......@@ -26,21 +26,25 @@ StatusMsg = namedtuple('StatusMsg', ['pk', 'state', 'msg'])
JobMsg = namedtuple('JobMsg', ['job', 'transition_function'])
def main(job_queue, status_queue):