Commit 3cedc3d2 authored by Michael Salim's avatar Michael Salim
Browse files

working on launcher

parent e40262db
class BalsamLauncherError(Exception): pass
class BalsamRunnerError(Exception): pass
class ExceededMaxConcurrentRunners(BalsamRunnerException): pass
class NoAvailableWorkers(BalsamRunnerException): pass
class BalsamTransitionError(Exception): pass
class TransitionNotFoundError(BalsamTransitionException, ValueError): pass
import balsam.models
from balsam.models import BalsamJob
class JobReader(dict):
'''Interface with BalsamJob DB & pull relevant jobs'''
@staticmethod
def from_config(config):
'''Constructor'''
if config.job_file: return FileJobReader(config.job_file)
else: return WFJobReader(config.wf_name)
@property
def by_states(self):
'''dict of jobs keyed by state'''
result = defaultdict(list)
for job self.values():
result[job.state].append(job)
return result
@property
def jobs(self): return self.values()
@property
def pks(self): return self.keys()
def refresh_from_db(self):
'''caller invokes this to read from DB'''
jobs = self._get_jobs()
jobs = self._filter(jobs)
job_dict = {job.pk : job for job in jobs}
self.update(job_dict)
def _get_jobs(self): raise NotImplementedError
def _filter(self, job_queryset):
jobs = job_queryset.exclude(state__in=balsam.models.END_STATES)
jobs = jobs.filter(allowed_work_sites__icontains=settings.BALSAM_SITE)
jobs = jobs.exclude(job_id__in=self.keys())
return [j for j in jobs if j.idle()]
class FileJobReader(JobReader):
'''Limit to job PKs specified in a file'''
def __init__(self, job_file):
super().__init__()
self.job_file = job_file
self.pk_list = None
def _get_jobs(self):
if self.pk_list is None:
pk_strings = open(self.job_file).read().split()
self.pk_list = [uuid.UUID(pk) for pk in pk_strings]
jobs = BalsamJob.objects.filter(job_id__in=self.pk_list)
return jobs
class WFJobReader(JobReader):
'''Consume all jobs from DB, optionally matching a Workflow name'''
def __init__(self, wf_name):
super().__init__()
self.wf_name = wf_name
def _get_jobs(self):
objects = BalsamJob.objects
wf = self.wf_name
jobs = objects.filter(workflow=wf) if wf else objects.all()
return jobs
......@@ -2,128 +2,42 @@
scheduling service and submits directly to a local job queue, or by the
Balsam service metascheduler'''
import argparse
from collections import defaultdict
import os
import multiprocessing
import queue
import time
from django.conf import settings
from django.db import transaction
import balsam.models
from balsam.models import BalsamJob
from balsam import scheduler
from balsam.launcher import jobreader
from balsam.launcher import transitions
from balsam.launcher import worker
from balsam.launcher.exceptions import *
START_TIME = time.time() + 10.0
class BalsamLauncherException(Exception): pass
class Worker:
def __init__(self, id, *, shape=None, block=None, corner=None,
ranks_per_worker=None):
self.id = id
self.shape = shape
self.block = block
self.corner = corner
self.ranks_per_worker = ranks_per_worker
self.idle = True
class WorkerGroup:
def __init__(self, config):
self.host_type = config.host_type
self.partition = config.partition
self.workers = []
self.setup = getattr(self, f"setup_{self.host_type}")
if self.host_type == 'DEFAULT':
self.num_workers = config.num_workers
else:
self.num_workers = None
self.setup()
def setup_CRAY(self):
node_ids = []
ranges = self.partition.split(',')
for node_range in ranges:
lo, *hi = node_range.split('-')
lo = int(lo)
if hi:
hi = int(hi[0])
node_ids.extend(list(range(lo, hi+1)))
else:
node_ids.append(lo)
for id in node_ids:
self.workers.append(Worker(id))
def setup_BGQ(self):
# Boot blocks
# Get (block, corner, shape) args for each sub-block
pass
def setup_DEFAULT(self):
for i in range(self.num_workers):
self.workers.apppend(Worker(i))
def get_idle_workers(self):
return [w for w in self.workers if w.idle]
SIGTIMEOUT = 'TIMEOUT!'
SIGTIMEOUT = 'TIMEOUT'
SIGNALS = {
signal.SIGINT: 'SIG_INT',
signal.SIGTERM: 'SIG_TERM',
}
class JobRetriever:
'''Use the get_jobs method to pull valid jobs for this run'''
def __init__(self, config):
self.job_pk_list = None
self._job_file = config.job_file
self.wf_name = config.wf_name
self.host_type = config.host_type
def get_jobs(self):
if self._job_file:
jobs = self._jobs_from_file()
def delay(period=10.0):
nexttime = time.time() + period
while True:
now = time.time()
tosleep = nexttime - now
if tosleep <= 0:
nexttime = now + period
else:
jobs = self._jobs_from_wf(wf=self.wf_name)
return self._filter(jobs)
time.sleep(tosleep)
nexttime = now + tosleep + period
yield
def _filter(self, jobs):
jobs = jobs.exclude(state__in=balsam.models.END_STATES)
jobs = jobs.filter(allowed_work_sites__icontains=settings.BALSAM_SITE)
# Exclude jobs that are already in LauncherConfig pulled_jobs
# Otherwise, you'll be calling job.idle() and qstating too much
return [j for j in jobs if j.idle()]
def _jobs_from_file(self):
if self._job_pk_list is None:
try:
pk_strings = open(self._job_file).read().split()
except IOError as e:
raise BalsamLauncherException(f"Can't read {self._job_file}") from e
try:
self._job_pk_list = [uuid.UUID(pk) for pk in pk_strings]
except ValueError:
raise BalsamLauncherException(f"{self._job_file} contains bad UUID strings")
try:
jobs = BalsamJob.objects.filter(job_id__in=self._job_file_pk_list)
except Exception as e:
raise BalsamLauncherException("Failed to query BalsamJobDB") from e
else:
return jobs
def _jobs_from_wf(self, wf=''):
objects = BalsamJob.objects
try:
jobs = objects.filter(workflow=wf) if wf else objects.all()
except Exception as e:
raise BalsamLauncherException(f"Failed to query BalsamJobDB for '{wf}'") from e
else:
self._job_pk_list = [job.pk for job in jobs]
return jobs
class LauncherConfig:
class HostEnvironment:
'''Set user- and environment-specific settings for this run'''
RECOGNIZED_HOSTS = {
'BGQ' : 'vesta cetus mira'.split(),
......@@ -187,79 +101,68 @@ class LauncherConfig:
def sufficient_time(self, job):
return 60*job.wall_time_minutes < self.remaining_time_seconds()
def check_timeout(self, active_runners):
def check_timeout(self):
if self.remaining_time_seconds() < 1.0:
for runner in active_runners:
runner.timeout(SIGTIMEOUT, None)
return True
return False
class TransitionProcessPool:
TRANSITIONS = {
'CREATED': check_parents,
'LAUNCHER_QUEUED': check_parents,
'AWAITING_PARENTS': check_parents,
'READY': stage_in,
'STAGED_IN': preprocess,
'RUN_DONE': postprocess,
'RUN_TIMEOUT': postprocess,
'RUN_ERROR': postprocess,
'POSTPROCESSED': stage_out
}
def __init__(self, num_transitions=None):
if not num_transitions:
num_transitions = settings.BALSAM_MAX_CONCURRENT_TRANSITIONS
self.job_queue = multiprocessing.Queue()
self.status_queue = multiprocessing.Queue()
self.procs = [
multiprocessing.Process( target=transitions.main,
args=(self.job_queue, self.status_queue))
for i in range(num_transitions)
]
for proc in self.procs:
proc.start()
def add_job(self, pk, transition_function):
m = transitions.JobMsg(pk, transition_function)
self.job_queue.put(m)
def get_statuses():
while not self.status_queue.empty():
try:
yield self.status_queue.get_nowait()
except queue.Empty:
break
def stop_processes(self):
while not self.job_queue.empty():
try:
self.job_queue.get_nowait()
except queue.Empty:
break
m = transitions.JobMsg('end', None)
for proc in self.procs:
self.job_queue.put(m)
def get_runnable_jobs(jobs, running_pks, host_env):
runnable_jobs = [job for job in jobsource.jobs
if job.pk not in running_pks and
job.state in RUNNABLE_STATES and
host_env.sufficient_time(job)]
return runnable_jobs
def create_new_runners(jobs, runner_group, worker_group, host_env):
running_pks = runner_group.running_job_pks
runnable_jobs = get_runnable_jobs(jobs, running_pks, host_env)
while runnable_jobs:
try:
runner_group.create_next_runner(runnable_jobs, worker_group)
except (ExceededMaxRunners, NoAvailableWorkers) as e:
break
else:
running_pks = runner_group.running_job_pks
runnable_jobs = get_runnable_jobs(jobs, running_pks, host_env)
def main(args):
launcher_config = LauncherConfig(args)
job_retriever = JobRetriever(launcher_config)
workers = WorkerGroup(launcher_config)
transitions_pool = TransitionProcessPool()
while not launcher_config.check_timeout():
# keep a list of jobs I'm handling
# get_jobs() should only fetch new ones
# ping jobs I'm handling using job.service_ping
jobs = job_retriever.get_jobs()
host_env = HostEnvironment(args)
worker_group = worker.WorkerGroup(host_env)
jobsource = jobreader.JobReader.from_config(args)
transition_pool = transitions.TransitionProcessPool()
runner_group = runners.RunnerGroup()
delay_timer = delay()
# Main Launcher Service Loop
while not host_env.check_timeout():
wait = True
for stat in transitions_pool.get_statuses():
logger.debug(f'Transition: {stat.pk} {stat.state}: {stat.msg}')
wait = False
jobsource.refresh_from_db()
transitionable_jobs = [
job for job in jobsource.jobs
if job not in transitions_pool
and job.state in transitions_pool.TRANSITIONS
]
for job in transitionable_jobs:
transitions_pool.add_job(job)
wait = False
runner_group.update_and_remove_finished()
any_finished = create_new_runners(
jobsource.jobs, runner_group, worker_group, host_env
)
if any_finished: wait = False
if wait: next(delay_timer)
transitions_pool.stop_processes()
for runner in runner_group:
runner.timeout(SIGTIMEOUT, None)
# Maintain up to 50 active runners (1 runner tracks 1 subprocess-aprun)
# Add transitions to error_handle all the RUN_TIMEOUT jobs
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Start Balsam Job Launcher.")
......@@ -268,7 +171,7 @@ if __name__ == "__main__":
group.add_argument('--job-file', help="File of Balsam job IDs")
group.add_argument('--consume-all', action='store_true',
help="Continuously run all jobs from DB")
group.add_argument('--consume-wf',
group.add_argument('--wf-name',
help="Continuously run jobs of specified workflow")
parser.add_argument('--num-workers', type=int, default=1,
......@@ -279,4 +182,6 @@ if __name__ == "__main__":
help="Override auto-detected walltime limit (runs
forever if no limit is detected or specified)")
args = parser.parse_args()
# TODO: intercept KeyboardInterrupt and all INT,TERM signals
# Cleanup actions; mark jobs as idle
main(args)
'''A Runner is constructed with a list of jobs and a list of idle workers. It
creates and monitors the execution subprocess, updating job states in the DB as
necessary. RunnerGroup contains the list of Runner objects, logic for creating
the next Runner (i.e. assigning jobs to nodes), and the public interface'''
import functools
from math import ceil
import os
from pathlib import Path
import signal
......@@ -9,12 +15,13 @@ from tempfile import NamedTemporaryFile
from threading import Thread
from queue import Queue, Empty
from django.conf import settings
import balsam.models
from balsam.launcher.launcher import SIGNALS
from balsam.launcher import mpi_commands
from balsam.launcher import mpi_ensemble
class BalsamRunnerException(Exception): pass
from balsam.launcher.exceptions import *
class cd:
'''Context manager for changing cwd'''
......@@ -51,7 +58,10 @@ class MonitorStream(Thread):
class Runner:
'''Spawns ONE subprocess to run specified job(s) and monitor their execution'''
def __init__(self, job_list, worker_list, host_type):
def __init__(self, job_list, worker_list):
host_type = worker_list[0].host_type
assert all(w.host_type == host_type for w in worker_list)
self.worker_list = worker_list
mpi_cmd_class = getattr(mpi_commands, f"{host_type}MPICommand")
self.mpi_cmd = mpi_cmd_class()
self.jobs = job_list
......@@ -72,6 +82,9 @@ class Runner:
def update_jobs(self):
raise NotImplementedError
def finished(self):
return self.process.poll() is not None
@staticmethod
def get_app_cmd(job):
if job.application:
......@@ -91,7 +104,7 @@ class Runner:
class MPIRunner(Runner):
'''One subprocess, one job'''
def __init__(self, job_list, worker_list, host_type):
def __init__(self, job_list, worker_list):
super().__init__(job_list, worker_list)
if len(self.jobs) != 1:
......@@ -114,6 +127,7 @@ class MPIRunner(Runner):
def update_jobs(self):
job = self.jobs[0]
#job.refresh_from_db() # TODO: handle RecordModified
retcode = self.process.poll()
if retcode == None:
curstate = 'RUNNING'
......@@ -126,11 +140,12 @@ class MPIRunner(Runner):
msg = str(retcode)
if job.state != curstate:
job.update_state(curstate, msg) # TODO: handle RecordModified
job.service_ping()
class MPIEnsembleRunner(Runner):
'''One subprocess: an ensemble of serial jobs run in an mpi4py wrapper'''
def __init__(self, job_list, worker_list, host_type):
def __init__(self, job_list, worker_list):
mpi_ensemble_exe = os.path.abspath(mpi_ensemble.__file__)
......@@ -145,7 +160,7 @@ class MPIEnsembleRunner(Runner):
with NamedTemporaryFile(prefix='mpi-ensemble', dir=root_dir,
delete=False, mode='w') as fp:
self.ensemble_filename = fp.name
for job in self.job_list:
for job in self.jobs:
cmd = self.get_app_cmd(job)
fp.write(f"{job.pk} {job.working_directory} {cmd}\n")
......@@ -156,10 +171,83 @@ class MPIEnsembleRunner(Runner):
self.popen_args['args'] = shlex.split(command)
def update_jobs(self):
'''Relies on stdout of mpi_ensemble.py'''
for line in self.monitor.available_lines():
pk, state, *msg = line.split()
msg = ' '.join(msg)
if pk in self.jobs_by_pk and state in balsam.models.STATES:
self.jobs_by_pk[id].update_state(state, msg) # TODO: handle RecordModified exception
job = self.jobs_by_pk[pk]
job.update_state(state, msg) # TODO: handle RecordModified exception
else:
raise BalsamRunnerException(f"Invalid status update: {status}")
for job in self.jobs:
job.service_ping()
class RunnerGroup:
MAX_CONCURRENT_RUNNERS = settings.BALSAM_MAX_CONCURRENT_RUNNERS
def __init__(self):
self.runners = []
def create_next_runner(runnable_jobs, workers):
'''Implements one particular strategy for choosing the next job, assuming
all jobs are either single-process or MPI-parallel. Will return the serial
ensemble job or single MPI job that occupies the largest possible number of
idle nodes'''
if len(self.runners) == MAX_CONCURRENT_RUNNERS:
raise ExceededMaxRunners(
f"Cannot have more than {MAX_CONCURRENT_RUNNERS} simultaneous runners"
)
idle_workers = [w for w in workers if w.idle]
nidle = len(idle_workers)
rpw = workers[0].ranks_per_worker
assert all(w.ranks_per_worker == rpw for w in idle_workers)
serial_jobs = [j for j in runnable_jobs if j.num_nodes == 1 and
j.processes_per_node == 1]
nserial = len(serial_jobs)
mpi_jobs = [j for j in runnable_jobs if 1 < j.num_nodes <= nidle or
(1==j.num_nodes<=nidle and j.processes_per_node > 1)]
largest_mpi_job = (max(mpi_jobs, key=lambda job: job.num_nodes)
if mpi_jobs else None)
if nserial >= nidle*rpw:
jobs = serial_jobs[:nidle*rpw]
assigned_workers = idle_workers
runner_class = MPIEnsembleRunner
elif largest_mpi_job and largest_mpi_job.num_nodes > nserial // rpw:
jobs = [largest_mpi_job]
assigned_workers = idle_workers[:largest_mpi_job.num_nodes]
runner_class = MPIRunner
else:
jobs = serial_jobs
assigned_workers = idle_workers[:ceil(float(nserial)/rpw)]
runner_class = MPIEnsembleRunner
if not jobs: raise NoAvailableWorkers
runner = runner_class(jobs, assigned_workers)
self.runners.append(runner)
for worker in assigned_workers: worker.idle = False
def update_and_remove_finished(self):
# TODO: Benchmark performance overhead; does grouping into one
# transaction save significantly?
any_finished = False
with transaction.atomic():
for runner in self.runners[:]:
runner.update_jobs()
if runner.finished():
any_finished = True
self.runners.remove(runner)
for worker in runner.worker_list:
worker.idle = True
return any_finished
@property
def running_job_pks(self):
active_runners = [r for r in self.runners if not r.finished()]
return [j.pk for runner in active_runners for j in runner.jobs]
......@@ -3,11 +3,77 @@ from collections import namedtuple
import logging
from django.core.exceptions import ObjectDoesNotExist
from django.conf import settings
from common import transfer
from balsam.launcher.exceptions import *
logger = logging.getLogger(__name__)
class ProcessingError(Exception): pass
StatusMsg = namedtuple('Status', ['pk', 'state', 'msg'])
JobMsg = namedtuple('JobMsg', ['pk', 'transition_function'])
def main(job_queue, status_queue):
while True:
job, process_function = job_queue.get()
if job == 'end':
return
try:
process_function(job)
except BalsamTransitionError as e:
s = StatusMsg(job.pk, 'FAILED', str(e))
status_queue.put(s)
else:
s = StatusMsg(job.pk, job.state, 'success')
status_queue.put(s)
class TransitionProcessPool:
NUM_PROC = settings.BALSAM_MAX_CONCURRENT_TRANSITIONS
def __init__(self):
self.job_queue = multiprocessing.Queue()
self.status_queue = multiprocessing.Queue()
self.transitions_pk_list = []
self.procs = [
multiprocessing.Process( target=main,
args=(self.job_queue, self.status_queue))
for i in range(NUM_PROC)
]
for proc in self.procs: proc.start()
def __contains__(self, job):
return job.pk in self.transitions_pk_list
def add_job(self, job):
if job in self: raise BalsamTransitionError("already in transition")
if job.state not in TRANSITIONS: raise TransitionNotFoundError
pk = job.pk
transition_function = TRANSITIONS[job.state]
m = JobMsg(pk, transition_function)
self.job_queue.put(m)
self.transitions_pk_list.append(pk)
def get_statuses():
while not self.status_queue.empty():
try:
stat = self.status_queue.get_nowait()
self.transitions_pk_list.remove(stat.pk)
yield stat
except queue.Empty:
break
def stop_processes(self):