......@@ -58,11 +58,16 @@ def create_new_runners(jobs, runner_group, worker_group):
created_one = False
running_pks = runner_group.running_job_pks
runnable_jobs = get_runnable_jobs(jobs, running_pks)
logger.debug(f"Have {len(runnable_jobs)} new runnable jobs")
while runnable_jobs:
logger.debug(f"Have {len(runnable_jobs)} new runnable jobs (out of "
runner_group.create_next_runner(runnable_jobs, worker_group)
except (ExceededMaxRunners, NoAvailableWorkers) as e:
except ExceededMaxRunners:"Exceeded max concurrent runners; waiting")
except NoAvailableWorkers:"Not enough idle workers to start any new runs")
created_one = True
......@@ -253,26 +253,25 @@ class RunnerGroup:
jobs = serial_jobs[:nidle_ranks] # TODO: try putting ALL serial jobs into one MPIEnsemble
assigned_workers = idle_workers
runner_class = MPIEnsembleRunner"Running {len(jobs)} serial jobs on {nidle_workers} workers "
msg = (f"Running {len(jobs)} serial jobs on {nidle_workers} workers "
f"with {nodes_per_worker} nodes-per-worker and {rpn} ranks per node")
elif largest_mpi_job and largest_mpi_job.num_nodes > nserial // rpw:
elif largest_mpi_job and largest_mpi_job.num_nodes > nserial // rpn:
jobs = [largest_mpi_job]
num_workers = ceil(largest_mpi_job.num_nodes / nodes_per_worker)
assigned_workers = idle_workers[:num_workers]
runner_class = MPIRunner"Running {largest_mpi_job.num_nodes}-node MPI job")
msg = (f"Running {largest_mpi_job.num_nodes}-node MPI job")
jobs = serial_jobs
nworkers = ceil(nserial/rpn/nodes_per_worker)
assigned_workers = idle_workers[:nworkers]
runner_class = MPIEnsembleRunner"Running {len(jobs)} serial jobs on {nworkers} workers "
msg = (f"Running {len(jobs)} serial jobs on {nworkers} workers "
f"totalling {nworkers*nodes_per_worker} nodes "
f"with {rpn} ranks per worker")
if not jobs:"Not enough idle workers to handle the runnable jobs")
raise NoAvailableWorkers
if not jobs: raise NoAvailableWorkers
runner = runner_class(jobs, assigned_workers)
from collections import namedtuple
import os
import random
from multiprocessing import Lock
import sys
import time
from uuid import UUID
from importlib.util import find_spec
from tests.BalsamTestCase import BalsamTestCase, cmdline
......@@ -246,11 +248,125 @@ class TestMPIEnsemble(BalsamTestCase):
self.assertTrue(all(j.state=='RUN_ERROR' for j in jobs['fail']))
class TestRunnerGroup:
class TestRunnerGroup(BalsamTestCase):
def setUp(self):
scheduler = Scheduler.scheduler_main
self.host_type = scheduler.host_type
if self.host_type == 'DEFAULT':
config = get_args('--consume-all --num-workers 1 --max-ranks-per-node 8'.split())
config = get_args('--consume-all')
self.worker_group = worker.WorkerGroup(config, host_type=self.host_type,
app_path = f"{sys.executable} {find_spec('tests.mock_mpi_app').origin}"
self.mpiapp = ApplicationDefinition() = "mock_mpi"
self.mpiapp.description = "print and sleep"
self.mpiapp.executable = app_path
app_path = f"{sys.executable} {find_spec('tests.mock_serial_app').origin}"
self.serialapp = ApplicationDefinition() = "mock_serial"
self.serialapp.description = "square a number"
self.serialapp.executable = app_path
def test_create_runners(self):
# Create sets of jobs intended to exercise each code path
# in a single call to launcher.create_new_runners()
'''sanity check launcher.create_new_runners()
Don't test implementation details here; just ensuring consistency'''
num_workers = len(self.worker_group)
num_nodes = sum(w.num_nodes for w in self.worker_group)
num_ranks = sum(w.num_nodes*w.max_ranks_per_node for w in
max_rpn = self.worker_group[0].max_ranks_per_node
num_serialjobs = random.randint(0, num_ranks+2)
num_mpijobs = random.randint(0, num_workers+2)
serialjobs = []
mpijobs = []
# Create a big shuffled assortment of jobs
runner_group = runners.RunnerGroup(Lock())
for i in range(num_serialjobs):
job = BalsamJob()
job.allowed_work_sites = settings.BALSAM_SITE = f"serial{i}"
job.application =
job.application_args = str(i)
job.state = 'PREPROCESSED'
for i in range(num_mpijobs):
job = BalsamJob()
job.allowed_work_sites = settings.BALSAM_SITE = f"mpi{i}"
job.application =
job.num_nodes = random.randint(1,num_nodes)
job.ranks_per_node = random.randint(2, max_rpn)
job.state = 'PREPROCESSED'
all_jobs = serialjobs + mpijobs
# None are running yet!
running_pks = runner_group.running_job_pks
self.assertListEqual(running_pks, [])
# Invoke create_new_runners once
# Some set of jobs will start running under control of the RunnerGroup
# Nondeterministic, due to random() used above, but we just want to
# check for consistency
create_new_runners(all_jobs, runner_group, self.worker_group)
# Get the list of running PKs from the RunnerGroup
# At least some jobs are running nwo
running_pks = runner_group.running_job_pks
self.assertGreater(len(running_pks), 0)
running_jobs = list(BalsamJob.objects.filter(pk__in=running_pks))
self.assertGreater(len(running_jobs), 0)
# Make sure that the aggregate runner PKs agree with the RunnerGroup
pks_from_runners = [UUID(pk) for runner in runner_group for pk in
self.assertListEqual(sorted(running_pks), sorted(pks_from_runners))
# Make sure that the busy workers are correctly marked not idle
busy_workers = [worker for runner in runner_group for worker in
self.assertTrue(all(w.idle == False for w in busy_workers))
# And the worker instances in each Runner are the same as the worker
# instances maintained in the calling code
busy_workers_ids = [id(w) for w in self.worker_group
if w in busy_workers]
sorted([id(w) for w in busy_workers]))
# Workers not busy are still idle
self.assertTrue(all(w.idle == True for w in self.worker_group
if w not in busy_workers))
# Now let all the jobs finish
# Update and remove runners with update_and_remove_finished()
def check_done():
return all(r.finished() for r in runner_group)
# Now there should be no runners, PKs, or busy workers left
self.assertListEqual(list(runner_group), [])
self.assertListEqual(runner_group.running_job_pks, [])
self.assertTrue(all(w.idle==True for w in self.worker_group))
# And all of the jobs that started running are now marked RUN_DONE
finished_jobs = list(BalsamJob.objects.filter(pk__in=running_pks))
self.assertTrue(all(j.state == 'RUN_DONE' for j in finished_jobs))
