Commit 47bf5d63 authored by Michael Salim's avatar Michael Salim
Browse files

refactor: ZMQ-DBWriter logic moved out of models

parent e45ef83f
......@@ -38,6 +38,7 @@ from balsam.launcher import transitions
from balsam.launcher import worker
from balsam.launcher import runners
from balsam.launcher.exceptions import *
from balsam.service import db_writer
ALMOST_RUNNABLE_STATES = ['READY','STAGED_IN']
RUNNABLE_STATES = ['PREPROCESSED', 'RESTART_READY']
......@@ -213,7 +214,9 @@ def on_exit(runner_group, transition_pool, job_source, writer_proc):
logger.debug("on_exit: send end message to transition threads")
transition_pool.end_and_wait()
logger.debug("on_exit: Launcher exit graceful\n\n")
writer_proc.terminate()
client = db_writer.ZMQClient()
client.term_server()
sys.exit(0)
......@@ -248,13 +251,12 @@ def detect_dead_runners(job_source):
def launch_db_writer_process():
import multiprocessing
from balsam.service import db_writer
INSTALL_PATH = settings.INSTALL_PATH
path = os.path.join(INSTALL_PATH, 'db_writer_socket')
path = os.path.join(INSTALL_PATH, db_writer.SOCKFILE_NAME)
if os.path.exists(path):
os.remove(path)
writer_proc = multiprocessing.Process(target=db_writer.main)
writer_proc = multiprocessing.Process(target=db_writer.server_main)
writer_proc.daemon = True
writer_proc.start()
return writer_proc
......
......@@ -3,48 +3,145 @@ import os
import logging
import zmq
from socket import gethostname
import signal
import django
from django.conf import settings
os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings'
django.setup()
from balsam.service.models import BalsamJob
INSTALL_PATH = settings.INSTALL_PATH
logger = logging.getLogger('balsam.service')
def setup():
hostname = gethostname()
port = "5556"
with open(os.path.join(INSTALL_PATH, 'db_writer_socket'), 'w') as fp:
fp.write(f'tcp://{hostname}:{port}')
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f'tcp://*:{port}')
return socket
def save_job(job_msg):
d = json.loads(job_msg)
force_insert = d['force_insert']
force_update = d['force_update']
using = d['using']
update_fields = d['update_fields']
job = BalsamJob.from_dict(d)
job._save_direct(force_insert, force_update, using, update_fields)
logger.info(f"db_writer Saved {job.cute_id}")
def main():
socket = setup()
from django.db.utils import OperationalError
logger = logging.getLogger('balsam.service.db_writer')
SOCKFILE_PATH = None
SOCKFILE_NAME = 'db_writer_socket'
PORT = "5556"
SERVER_PERIOD = 1000
CLIENT_TIMEOUT = 10000 # 10 seconds
class ZMQProxy:
def __init__(self):
import django
os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings'
django.setup()
from balsam.service.models import BalsamJob
self.BalsamJob = BalsamJob
global SOCKFILE_PATH
SOCKFILE_PATH = settings.INSTALL_PATH
self.setup()
def setup(self):
hostname = gethostname()
self.address = f'tcp://{hostname}:{PORT}'
self.sock_file = os.path.join(SOCKFILE_PATH, SOCKFILE_NAME)
with open(self.sock_file, 'w') as fp:
fp.write(self.address)
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(f'tcp://*:{PORT}')
return self.socket
def recv_request(self):
events = self.socket.poll(timeout=SERVER_PERIOD)
if events:
message = self.socket.recv().decode('utf-8')
else:
message = None
return message
def send_reply(self, msg):
self.socket.send_string(msg)
def _django_save(self, job_msg):
d = json.loads(job_msg)
job = self.BalsamJob.from_dict(d)
force_insert = d['force_insert']
force_update = d['force_update']
using = d['using']
update_fields = d['update_fields']
job._save_direct(force_insert, force_update, using, update_fields)
logger.info(f"db_writer Saved {job.cute_id}")
class ZMQClient:
def __init__(self):
global SOCKFILE_PATH
SOCKFILE_PATH = settings.INSTALL_PATH
self.discover_zmq_proxy()
def discover_zmq_proxy(self):
path = os.path.join(SOCKFILE_PATH, SOCKFILE_NAME)
if os.path.exists(path):
self.zmq_server = open(path).read().strip()
else:
self.zmq_server = None
return
if 'tcp://' not in self.zmq_server:
self.zmq_server = None
return
response = self.send_request('TEST_ALIVE')
if response == 'ACK':
logger.info(f"save() going to server @ {self.zmq_server}")
else:
logger.info(f"save() going directly to local db")
self.zmq_server = None
def send_request(self, msg):
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.LINGER, 2000)
socket.connect(self.zmq_server)
socket.send_string(msg)
response = socket.poll(timeout=CLIENT_TIMEOUT)
if response and response > 0:
return socket.recv().decode('utf-8')
else:
return None
def save(self, job, force_insert=False, force_update=False, using=None, update_fields=None):
serial_data = job.serialize(force_insert=force_insert,
force_update=force_update, using=using,
update_fields=update_fields)
response = self.send_request(serial_data)
if response is None:
raise OperationalError("ZMQ DB write request timed out")
else:
assert response == 'ACK_SAVE'
def term_server(self):
if self.zmq_server:
response = self.send_request('TERM')
def server_main():
parent_pid = os.getppid()
handler = lambda a,b: 0
signal.signal(signal.SIGINT, handler)
signal.signal(signal.SIGTERM, handler)
proxy = ZMQProxy()
try:
while True:
message = socket.recv().decode('utf-8')
if 'job_id' in message:
save_job(message)
socket.send_string("ACK")
message = proxy.recv_request()
if message is None:
if os.getppid() != parent_pid:
logger.info("db_writer detected parent PID died; quitting")
break
elif 'job_id' in message:
proxy._django_save(message)
proxy.send_reply("ACK_SAVE")
elif 'TERM' in message:
logger.info("db_writer got TERM message; quitting")
proxy.send_reply("ACK_TERM")
break
else:
proxy.send_reply("ACK")
finally:
os.remove(os.path.join(INSTALL_PATH, 'db_writer_socket'))
os.remove(os.path.join(SOCKFILE_PATH, SOCKFILE_NAME))
if __name__ == "__main__":
main()
server_main()
......@@ -3,19 +3,16 @@ import json
import logging
import sys
from datetime import datetime
from socket import gethostname
import uuid
import time
import zmq
from django.core.exceptions import ValidationError,ObjectDoesNotExist
from django.conf import settings
from django.db import models
from django.db.utils import OperationalError
from concurrency.fields import IntegerVersionField
from concurrency.exceptions import RecordModifiedError
from balsam.common import Serializer
from balsam.service import db_writer
logger = logging.getLogger('balsam.service')
......@@ -245,7 +242,7 @@ class BalsamJob(models.Model):
help_text="Chronological record of the job's states",
default=history_line)
zmq_server_addr = None
db_write_client = None
def _save_direct(self, force_insert=False, force_update=False, using=None,
update_fields=None):
......@@ -256,55 +253,16 @@ class BalsamJob(models.Model):
update_fields = None
models.Model.save(self, force_insert, force_update, using, update_fields)
def _save_zmq_writer(self, force_insert=False, force_update=False, using=None, update_fields=None):
SERIAL_FIELDS = [f for f in self.__dict__ if f not in ['_state', 'version']]
d = {field : self.__dict__[field] for field in SERIAL_FIELDS}
d['job_id'] = str(self.job_id)
d['force_insert'] = force_insert
d['force_update'] = force_update
d['using'] = using
d['update_fields'] = update_fields
message = json.dumps(d)
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(BalsamJob.zmq_server_addr)
socket.send_string(message)
response = socket.poll(timeout=10000)
if response:
ack = socket.recv()
assert ack.decode('utf-8') == 'ACK'
else: raise OperationalError("ZMQ DB write request timedout")
def check_zmq_write_server(self):
path = settings.INSTALL_PATH
path = os.path.join(path, 'db_writer_socket')
if os.path.exists(path):
zmq_server_addr = open(path).read().strip()
else:
return ''
if 'tcp://' not in zmq_server_addr: return ''
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(zmq_server_addr)
socket.send_string('TEST_ALIVE')
response = socket.poll(timeout=10000)
if response:
logger.info(f"model save() going to server @ {zmq_server_addr}")
return zmq_server_addr
else:
logger.info("model save() direct to disk")
return ''
def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
if BalsamJob.zmq_server_addr is None:
BalsamJob.zmq_server_addr = self.check_zmq_write_server()
if BalsamJob.zmq_server_addr == '':
if BalsamJob.db_write_client is None:
BalsamJob.db_write_client = db_writer.ZMQClient()
if BalsamJob.db_write_client.zmq_server is None:
logger.info(f"direct save of {self.cute_id}")
self._save_direct(force_insert, force_update, using, update_fields)
else:
self._save_zmq_writer(force_insert, force_update, using, update_fields)
logger.info(f"sending request for save of {self.cute_id}")
BalsamJob.db_write_client.save(self, force_insert, force_update, using, update_fields)
@staticmethod
def from_dict(d):
......@@ -526,13 +484,26 @@ auto timeout retry: {self.auto_timeout_retry}
self.save(update_fields=['working_directory'])
return path
def serialize(self):
pass
def to_dict(self):
SERIAL_FIELDS = [f for f in self.__dict__ if f not in ['_state', 'version']]
d = {field : self.__dict__[field] for field in SERIAL_FIELDS}
return d
def serialize(self, **kwargs):
d = self.to_dict()
d.update(kwargs)
d['job_id'] = str(self.job_id)
serial_data = json.dumps(d)
return serial_data
@classmethod
def deserialize(cls, serial_data):
pass
if type(serial_data) is bytes:
serial_data = serial_data.decode('utf-8')
if type(serial_data) is str:
serial_data = json.loads(serial_data)
job = BalsamJob.from_dict(serial_data)
return job
class ApplicationDefinition(models.Model):
''' application definition, each DB entry is a task that can be run
......
......@@ -797,7 +797,7 @@ class TestDAG(BalsamTestCase):
class TestThreadPlacement(BalsamTestCase):
def setUp(self):
self.app_path = os.path.dirname(find_spec("tests.c_apps").origin)
self.app_path = os.path.dirname(find_spec("tests.ft_apps.c_apps").origin)
self.app = create_app(name='omp')
self.job0 = create_job(name='job0', app='omp', num_nodes=2, ranks_per_node=32, threads_per_rank=2)
......@@ -864,11 +864,11 @@ class TestThreadPlacement(BalsamTestCase):
self.check_omp_exe_output(self.job1)
self.check_omp_exe_output(self.job2)
class TestUserKill(BalsamTestCase):
def setUp(self):
self.app_path = find_spec("tests.ft_apps.c_apps").origin)
self.app = create_app(name='omp')
self.job0 = create_job(name='job0', app='omp', num_nodes=2, ranks_per_node=32, threads_per_rank=2)
self.job1 = create_job(name='job1', app='omp', num_nodes=2, ranks_per_node=64, threads_per_rank=1)
self.job2 = create_job(name='job2', app='omp', num_nodes=1, ranks_per_node=2, threads_per_rank=64, threads_per_core=2)
#class TestUserKill(BalsamTestCase):
# def setUp(self):
# self.app_path = find_spec("tests.ft_apps.c_apps").origin)
# self.app = create_app(name='omp')
#
# self.job0 = create_job(name='job0', app='omp', num_nodes=2, ranks_per_node=32, threads_per_rank=2)
# self.job1 = create_job(name='job1', app='omp', num_nodes=2, ranks_per_node=64, threads_per_rank=1)
# self.job2 = create_job(name='job2', app='omp', num_nodes=1, ranks_per_node=2, threads_per_rank=64, threads_per_core=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment