Commit 3cedc3d2 authored by Michael Salim's avatar Michael Salim
Browse files

working on launcher

parent e40262db
class BalsamLauncherError(Exception): pass
class BalsamRunnerError(Exception): pass
class ExceededMaxConcurrentRunners(BalsamRunnerException): pass
class NoAvailableWorkers(BalsamRunnerException): pass
class BalsamTransitionError(Exception): pass
class TransitionNotFoundError(BalsamTransitionException, ValueError): pass
import balsam.models
from balsam.models import BalsamJob
class JobReader(dict):
'''Interface with BalsamJob DB & pull relevant jobs'''
def from_config(config):
if config.job_file: return FileJobReader(config.job_file)
else: return WFJobReader(config.wf_name)
def by_states(self):
'''dict of jobs keyed by state'''
result = defaultdict(list)
for job self.values():
return result
def jobs(self): return self.values()
def pks(self): return self.keys()
def refresh_from_db(self):
'''caller invokes this to read from DB'''
jobs = self._get_jobs()
jobs = self._filter(jobs)
job_dict = { : job for job in jobs}
def _get_jobs(self): raise NotImplementedError
def _filter(self, job_queryset):
jobs = job_queryset.exclude(state__in=balsam.models.END_STATES)
jobs = jobs.filter(allowed_work_sites__icontains=settings.BALSAM_SITE)
jobs = jobs.exclude(job_id__in=self.keys())
return [j for j in jobs if j.idle()]
class FileJobReader(JobReader):
'''Limit to job PKs specified in a file'''
def __init__(self, job_file):
self.job_file = job_file
self.pk_list = None
def _get_jobs(self):
if self.pk_list is None:
pk_strings = open(self.job_file).read().split()
self.pk_list = [uuid.UUID(pk) for pk in pk_strings]
jobs = BalsamJob.objects.filter(job_id__in=self.pk_list)
return jobs
class WFJobReader(JobReader):
'''Consume all jobs from DB, optionally matching a Workflow name'''
def __init__(self, wf_name):
self.wf_name = wf_name
def _get_jobs(self):
objects = BalsamJob.objects
wf = self.wf_name
jobs = objects.filter(workflow=wf) if wf else objects.all()
return jobs
......@@ -2,128 +2,42 @@
scheduling service and submits directly to a local job queue, or by the
Balsam service metascheduler'''
import argparse
from collections import defaultdict
import os
import multiprocessing
import queue
import time
from django.conf import settings
from django.db import transaction
import balsam.models
from balsam.models import BalsamJob
from balsam import scheduler
from balsam.launcher import jobreader
from balsam.launcher import transitions
from balsam.launcher import worker
from balsam.launcher.exceptions import *
START_TIME = time.time() + 10.0
class BalsamLauncherException(Exception): pass
class Worker:
def __init__(self, id, *, shape=None, block=None, corner=None,
ranks_per_worker=None): = id
self.shape = shape
self.block = block
self.corner = corner
self.ranks_per_worker = ranks_per_worker
self.idle = True
class WorkerGroup:
def __init__(self, config):
self.host_type = config.host_type
self.partition = config.partition
self.workers = []
self.setup = getattr(self, f"setup_{self.host_type}")
if self.host_type == 'DEFAULT':
self.num_workers = config.num_workers
self.num_workers = None
def setup_CRAY(self):
node_ids = []
ranges = self.partition.split(',')
for node_range in ranges:
lo, *hi = node_range.split('-')
lo = int(lo)
if hi:
hi = int(hi[0])
node_ids.extend(list(range(lo, hi+1)))
for id in node_ids:
def setup_BGQ(self):
# Boot blocks
# Get (block, corner, shape) args for each sub-block
def setup_DEFAULT(self):
for i in range(self.num_workers):
def get_idle_workers(self):
return [w for w in self.workers if w.idle]
signal.SIGINT: 'SIG_INT',
class JobRetriever:
'''Use the get_jobs method to pull valid jobs for this run'''
def __init__(self, config):
self.job_pk_list = None
self._job_file = config.job_file
self.wf_name = config.wf_name
self.host_type = config.host_type
def get_jobs(self):
if self._job_file:
jobs = self._jobs_from_file()
def delay(period=10.0):
nexttime = time.time() + period
while True:
now = time.time()
tosleep = nexttime - now
if tosleep <= 0:
nexttime = now + period
jobs = self._jobs_from_wf(wf=self.wf_name)
return self._filter(jobs)
nexttime = now + tosleep + period
def _filter(self, jobs):
jobs = jobs.exclude(state__in=balsam.models.END_STATES)
jobs = jobs.filter(allowed_work_sites__icontains=settings.BALSAM_SITE)
# Exclude jobs that are already in LauncherConfig pulled_jobs
# Otherwise, you'll be calling job.idle() and qstating too much
return [j for j in jobs if j.idle()]
def _jobs_from_file(self):
if self._job_pk_list is None:
pk_strings = open(self._job_file).read().split()
except IOError as e:
raise BalsamLauncherException(f"Can't read {self._job_file}") from e
self._job_pk_list = [uuid.UUID(pk) for pk in pk_strings]
except ValueError:
raise BalsamLauncherException(f"{self._job_file} contains bad UUID strings")
jobs = BalsamJob.objects.filter(job_id__in=self._job_file_pk_list)
except Exception as e:
raise BalsamLauncherException("Failed to query BalsamJobDB") from e
return jobs
def _jobs_from_wf(self, wf=''):
objects = BalsamJob.objects
jobs = objects.filter(workflow=wf) if wf else objects.all()
except Exception as e:
raise BalsamLauncherException(f"Failed to query BalsamJobDB for '{wf}'") from e
self._job_pk_list = [ for job in jobs]
return jobs
class LauncherConfig:
class HostEnvironment:
'''Set user- and environment-specific settings for this run'''
'BGQ' : 'vesta cetus mira'.split(),
......@@ -187,79 +101,68 @@ class LauncherConfig:
def sufficient_time(self, job):
return 60*job.wall_time_minutes < self.remaining_time_seconds()
def check_timeout(self, active_runners):
def check_timeout(self):
if self.remaining_time_seconds() < 1.0:
for runner in active_runners:
runner.timeout(SIGTIMEOUT, None)
return True
return False
class TransitionProcessPool:
'CREATED': check_parents,
'LAUNCHER_QUEUED': check_parents,
'AWAITING_PARENTS': check_parents,
'READY': stage_in,
'STAGED_IN': preprocess,
'RUN_DONE': postprocess,
'RUN_TIMEOUT': postprocess,
'RUN_ERROR': postprocess,
'POSTPROCESSED': stage_out
def __init__(self, num_transitions=None):
if not num_transitions:
num_transitions = settings.BALSAM_MAX_CONCURRENT_TRANSITIONS
self.job_queue = multiprocessing.Queue()
self.status_queue = multiprocessing.Queue()
self.procs = [
multiprocessing.Process( target=transitions.main,
args=(self.job_queue, self.status_queue))
for i in range(num_transitions)
for proc in self.procs:
def add_job(self, pk, transition_function):
m = transitions.JobMsg(pk, transition_function)
def get_statuses():
while not self.status_queue.empty():
yield self.status_queue.get_nowait()
except queue.Empty:
def stop_processes(self):
while not self.job_queue.empty():
except queue.Empty:
m = transitions.JobMsg('end', None)
for proc in self.procs:
def get_runnable_jobs(jobs, running_pks, host_env):
runnable_jobs = [job for job in
if not in running_pks and
job.state in RUNNABLE_STATES and
return runnable_jobs
def create_new_runners(jobs, runner_group, worker_group, host_env):
running_pks = runner_group.running_job_pks
runnable_jobs = get_runnable_jobs(jobs, running_pks, host_env)
while runnable_jobs:
runner_group.create_next_runner(runnable_jobs, worker_group)
except (ExceededMaxRunners, NoAvailableWorkers) as e:
running_pks = runner_group.running_job_pks
runnable_jobs = get_runnable_jobs(jobs, running_pks, host_env)
def main(args):
launcher_config = LauncherConfig(args)
job_retriever = JobRetriever(launcher_config)
workers = WorkerGroup(launcher_config)
transitions_pool = TransitionProcessPool()
while not launcher_config.check_timeout():
# keep a list of jobs I'm handling
# get_jobs() should only fetch new ones
# ping jobs I'm handling using job.service_ping
jobs = job_retriever.get_jobs()
host_env = HostEnvironment(args)
worker_group = worker.WorkerGroup(host_env)
jobsource = jobreader.JobReader.from_config(args)
transition_pool = transitions.TransitionProcessPool()
runner_group = runners.RunnerGroup()
delay_timer = delay()
# Main Launcher Service Loop
while not host_env.check_timeout():
wait = True
for stat in transitions_pool.get_statuses():
logger.debug(f'Transition: {} {stat.state}: {stat.msg}')
wait = False
transitionable_jobs = [
job for job in
if job not in transitions_pool
and job.state in transitions_pool.TRANSITIONS
for job in transitionable_jobs:
wait = False
any_finished = create_new_runners(, runner_group, worker_group, host_env
if any_finished: wait = False
if wait: next(delay_timer)
for runner in runner_group:
runner.timeout(SIGTIMEOUT, None)
# Maintain up to 50 active runners (1 runner tracks 1 subprocess-aprun)
# Add transitions to error_handle all the RUN_TIMEOUT jobs
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Start Balsam Job Launcher.")
......@@ -268,7 +171,7 @@ if __name__ == "__main__":
group.add_argument('--job-file', help="File of Balsam job IDs")
group.add_argument('--consume-all', action='store_true',
help="Continuously run all jobs from DB")
help="Continuously run jobs of specified workflow")
parser.add_argument('--num-workers', type=int, default=1,
......@@ -279,4 +182,6 @@ if __name__ == "__main__":
help="Override auto-detected walltime limit (runs
forever if no limit is detected or specified)")
args = parser.parse_args()
# TODO: intercept KeyboardInterrupt and all INT,TERM signals
# Cleanup actions; mark jobs as idle
'''A Runner is constructed with a list of jobs and a list of idle workers. It
creates and monitors the execution subprocess, updating job states in the DB as
necessary. RunnerGroup contains the list of Runner objects, logic for creating
the next Runner (i.e. assigning jobs to nodes), and the public interface'''
import functools
from math import ceil
import os
from pathlib import Path
import signal
......@@ -9,12 +15,13 @@ from tempfile import NamedTemporaryFile
from threading import Thread
from queue import Queue, Empty
from django.conf import settings
import balsam.models
from balsam.launcher.launcher import SIGNALS
from balsam.launcher import mpi_commands
from balsam.launcher import mpi_ensemble
class BalsamRunnerException(Exception): pass
from balsam.launcher.exceptions import *
class cd:
'''Context manager for changing cwd'''
......@@ -51,7 +58,10 @@ class MonitorStream(Thread):
class Runner:
'''Spawns ONE subprocess to run specified job(s) and monitor their execution'''
def __init__(self, job_list, worker_list, host_type):
def __init__(self, job_list, worker_list):
host_type = worker_list[0].host_type
assert all(w.host_type == host_type for w in worker_list)
self.worker_list = worker_list
mpi_cmd_class = getattr(mpi_commands, f"{host_type}MPICommand")
self.mpi_cmd = mpi_cmd_class() = job_list
......@@ -72,6 +82,9 @@ class Runner:
def update_jobs(self):
raise NotImplementedError
def finished(self):
return self.process.poll() is not None
def get_app_cmd(job):
if job.application:
......@@ -91,7 +104,7 @@ class Runner:
class MPIRunner(Runner):
'''One subprocess, one job'''
def __init__(self, job_list, worker_list, host_type):
def __init__(self, job_list, worker_list):
super().__init__(job_list, worker_list)
if len( != 1:
......@@ -114,6 +127,7 @@ class MPIRunner(Runner):
def update_jobs(self):
job =[0]
#job.refresh_from_db() # TODO: handle RecordModified
retcode = self.process.poll()
if retcode == None:
curstate = 'RUNNING'
......@@ -126,11 +140,12 @@ class MPIRunner(Runner):
msg = str(retcode)
if job.state != curstate:
job.update_state(curstate, msg) # TODO: handle RecordModified
class MPIEnsembleRunner(Runner):
'''One subprocess: an ensemble of serial jobs run in an mpi4py wrapper'''
def __init__(self, job_list, worker_list, host_type):
def __init__(self, job_list, worker_list):
mpi_ensemble_exe = os.path.abspath(mpi_ensemble.__file__)
......@@ -145,7 +160,7 @@ class MPIEnsembleRunner(Runner):
with NamedTemporaryFile(prefix='mpi-ensemble', dir=root_dir,
delete=False, mode='w') as fp:
self.ensemble_filename =
for job in self.job_list:
for job in
cmd = self.get_app_cmd(job)
fp.write(f"{} {job.working_directory} {cmd}\n")
......@@ -156,10 +171,83 @@ class MPIEnsembleRunner(Runner):
self.popen_args['args'] = shlex.split(command)
def update_jobs(self):
'''Relies on stdout of'''
for line in self.monitor.available_lines():
pk, state, *msg = line.split()
msg = ' '.join(msg)
if pk in self.jobs_by_pk and state in balsam.models.STATES:
self.jobs_by_pk[id].update_state(state, msg) # TODO: handle RecordModified exception
job = self.jobs_by_pk[pk]
job.update_state(state, msg) # TODO: handle RecordModified exception
raise BalsamRunnerException(f"Invalid status update: {status}")
for job in
class RunnerGroup:
def __init__(self):
self.runners = []
def create_next_runner(runnable_jobs, workers):
'''Implements one particular strategy for choosing the next job, assuming
all jobs are either single-process or MPI-parallel. Will return the serial
ensemble job or single MPI job that occupies the largest possible number of
idle nodes'''
if len(self.runners) == MAX_CONCURRENT_RUNNERS:
raise ExceededMaxRunners(
f"Cannot have more than {MAX_CONCURRENT_RUNNERS} simultaneous runners"
idle_workers = [w for w in workers if w.idle]
nidle = len(idle_workers)
rpw = workers[0].ranks_per_worker
assert all(w.ranks_per_worker == rpw for w in idle_workers)
serial_jobs = [j for j in runnable_jobs if j.num_nodes == 1 and
j.processes_per_node == 1]
nserial = len(serial_jobs)
mpi_jobs = [j for j in runnable_jobs if 1 < j.num_nodes <= nidle or
(1==j.num_nodes<=nidle and j.processes_per_node > 1)]
largest_mpi_job = (max(mpi_jobs, key=lambda job: job.num_nodes)
if mpi_jobs else None)
if nserial >= nidle*rpw:
jobs = serial_jobs[:nidle*rpw]
assigned_workers = idle_workers
runner_class = MPIEnsembleRunner
elif largest_mpi_job and largest_mpi_job.num_nodes > nserial // rpw:
jobs = [largest_mpi_job]
assigned_workers = idle_workers[:largest_mpi_job.num_nodes]
runner_class = MPIRunner
jobs = serial_jobs
assigned_workers = idle_workers[:ceil(float(nserial)/rpw)]
runner_class = MPIEnsembleRunner
if not jobs: raise NoAvailableWorkers
runner = runner_class(jobs, assigned_workers)
for worker in assigned_workers: worker.idle = False
def update_and_remove_finished(self):
# TODO: Benchmark performance overhead; does grouping into one
# transaction save significantly?
any_finished = False
with transaction.atomic():
for runner in self.runners[:]:
if runner.finished():
any_finished = True
for worker in runner.worker_list:
worker.idle = True
return any_finished
def running_job_pks(self):
active_runners = [r for r in self.runners if not r.finished()]
return [ for runner in active_runners for j in]
......@@ -3,11 +3,77 @@ from collections import namedtuple
import logging
from django.core.exceptions import ObjectDoesNotExist
from django.conf import settings
from common import transfer
from balsam.launcher.exceptions import *
logger = logging.getLogger(__name__)
class ProcessingError(Exception): pass
StatusMsg = namedtuple('Status', ['pk', 'state', 'msg'])
JobMsg = namedtuple('JobMsg', ['pk', 'transition_function'])
def main(job_queue, status_queue):
while True:
job, process_function = job_queue.get()
if job == 'end':
except BalsamTransitionError as e:
s = StatusMsg(, 'FAILED', str(e))
s = StatusMsg(, job.state, 'success')
class TransitionProcessPool:
def __init__(self):
self.job_queue = multiprocessing.Queue()
self.status_queue = multiprocessing.Queue()
self.transitions_pk_list = []
self.procs = [
multiprocessing.Process( target=main,
args=(self.job_queue, self.status_queue))
for i in range(NUM_PROC)
for proc in self.procs: proc.start()
def __contains__(self, job):
return in self.transitions_pk_list
def add_job(self, job):
if job in self: raise BalsamTransitionError("already in transition")
if job.state not in TRANSITIONS: raise TransitionNotFoundError
pk =
transition_function = TRANSITIONS[job.state]
m = JobMsg(pk, transition_function)
def get_statuses():
while not self.status_queue.empty():
stat = self.status_queue.get_nowait()
yield stat
except queue.Empty:
def stop_processes(self):