Commit 8931606d authored by Michael Salim's avatar Michael Salim
Browse files

prototype zmq server which receives all write messages and serializes writing to sqlite DB

parent 546f2309
......@@ -200,7 +200,7 @@ def main(args, transition_pool, runner_group, job_source):
logger.info("No jobs to process. Exiting main loop now.")
break
def on_exit(runner_group, transition_pool, job_source):
def on_exit(runner_group, transition_pool, job_source, writer_proc):
'''Exit cleanup'''
global HANDLING_EXIT
if HANDLING_EXIT: return
......@@ -213,6 +213,7 @@ def on_exit(runner_group, transition_pool, job_source):
logger.debug("on_exit: send end message to transition threads")
transition_pool.end_and_wait()
logger.debug("on_exit: Launcher exit graceful\n\n")
writer_proc.terminate()
exit(0)
......@@ -245,9 +246,19 @@ def detect_dead_runners(job_source):
logger.info(f'Picked up dead running job {job.cute_id}: marking RESTART_READY')
job.update_state('RESTART_READY', 'Detected dead runner')
def launch_db_writer_process():
from importlib.util import find_spec
import subprocess
db_writer = find_spec("balsam.service.db_writer").origin
writer_proc = subprocess.Popen([sys.executable, db_writer],
stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
return writer_proc
if __name__ == "__main__":
args = get_args()
writer_proc = launch_db_writer_process()
job_source = jobreader.JobReader.from_config(args)
job_source.refresh_from_db()
transition_pool = transitions.TransitionProcessPool()
......@@ -258,10 +269,10 @@ if __name__ == "__main__":
detect_dead_runners(job_source)
handl = lambda a,b: on_exit(runner_group, transition_pool, job_source)
handl = lambda a,b: on_exit(runner_group, transition_pool, job_source, writer_proc)
signal.signal(signal.SIGINT, handl)
signal.signal(signal.SIGTERM, handl)
signal.signal(signal.SIGHUP, handl)
main(args, transition_pool, runner_group, job_source)
on_exit(runner_group, transition_pool, job_source)
on_exit(runner_group, transition_pool, job_source, writer_proc)
import json
import os
import logging
import zmq
from socket import gethostname
import django
from django.conf import settings
os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings'
django.setup()
from balsam.service.models import BalsamJob
INSTALL_PATH = settings.INSTALL_PATH
logger = logging.getLogger('balsam.service')
def setup():
hostname = gethostname()
port = "5556"
with open(os.path.join(INSTALL_PATH, 'db_writer_socket'), 'w') as fp:
fp.write(f'tcp://{hostname}:{port}')
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f'tcp://*:{port}')
return socket
def save_job(job_msg):
d = json.loads(job_msg)
force_insert = d['force_insert']
force_update = d['force_update']
using = d['using']
update_fields = d['update_fields']
job = BalsamJob.from_dict(d)
job._save_direct(force_insert, force_update, using, update_fields)
logger.info(f"db_writer Saved {job.cute_id}")
def main():
socket = setup()
try:
while True:
message = socket.recv().decode('utf-8')
if 'job_id' in message:
save_job(message)
socket.send_string("ACK")
finally:
os.remove(os.path.join(INSTALL_PATH, 'db_writer_socket'))
if __name__ == "__main__":
main()
......@@ -6,6 +6,7 @@ from datetime import datetime
from socket import gethostname
import uuid
import time
import zmq
from django.core.exceptions import ValidationError,ObjectDoesNotExist
from django.conf import settings
......@@ -16,7 +17,7 @@ from concurrency.exceptions import RecordModifiedError
from balsam.common import Serializer
logger = logging.getLogger(__name__)
logger = logging.getLogger('balsam.service')
class InvalidStateError(ValidationError): pass
class InvalidParentsError(ValidationError): pass
......@@ -243,30 +244,82 @@ class BalsamJob(models.Model):
'Job State History',
help_text="Chronological record of the job's states",
default=history_line)
def save(self, force_insert=False, force_update=False, using=None,
zmq_server_addr = None
def _save_direct(self, force_insert=False, force_update=False, using=None,
update_fields=None):
'''Override default Django save to ensure version always updated'''
if update_fields is not None:
update_fields.append('version')
if self._state.adding:
update_fields = None
models.Model.save(self, force_insert, force_update, using, update_fields)
def _save_zmq_writer(self, force_insert=False, force_update=False, using=None, update_fields=None):
SERIAL_FIELDS = [f for f in self.__dict__ if f not in ['_state', 'version']]
d = {field : self.__dict__[field] for field in SERIAL_FIELDS}
d['job_id'] = str(self.job_id)
d['force_insert'] = force_insert
d['force_update'] = force_update
d['using'] = using
d['update_fields'] = update_fields
message = json.dumps(d)
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(BalsamJob.zmq_server_addr)
socket.send_string(message)
response = socket.poll(timeout=10000)
if response:
ack = socket.recv()
assert ack.decode('utf-8') == 'ACK'
else: raise OperationalError("ZMQ DB write request timedout")
def check_zmq_write_server(self):
path = settings.INSTALL_PATH
path = os.path.join(path, 'db_writer_socket')
if os.path.exists(path):
zmq_server_addr = open(path).read().strip()
else:
return ''
if 'tcp://' not in zmq_server_addr: return ''
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(zmq_server_addr)
socket.send_string('TEST_ALIVE')
response = socket.poll(timeout=10000)
if response:
logger.info(f"model save() going to server @ {zmq_server_addr}")
return zmq_server_addr
else:
logger.info("model save() direct to disk")
return ''
def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
if BalsamJob.zmq_server_addr is None:
BalsamJob.zmq_server_addr = self.check_zmq_write_server()
if BalsamJob.zmq_server_addr == '':
self._save_direct(force_insert, force_update, using, update_fields)
else:
self._save_zmq_writer(force_insert, force_update, using, update_fields)
@staticmethod
def from_dict(d):
job = BalsamJob()
SERIAL_FIELDS = [f for f in job.__dict__ if f not in
'_state version force_insert force_update using update_fields'.split()
]
if type(d['job_id']) is str:
d['job_id'] = uuid.UUID(d['job_id'])
for field in SERIAL_FIELDS:
job.__dict__[field] = d[field]
return job
# Work around sqlite3 DB locked error
while True:
try: models.Model.save(self, force_insert, force_update, using, update_fields)
except OperationalError:
try:
time.sleep(5)
newjob = BalsamJob.objects.get(pk=self.pk)
if newjob.version == self.version: break
except ObjectDoesNotExist: pass
except RecordModifiedError:
newjob = BalsamJob.objects.get(pk=self.pk)
logger.error(f'RecordModifiedError when saving {self.cute_id}')
logger.error(f'Trying to save:\n{str(self)}\nIn DB:\n{str(newjob)}\n')
raise
else: break
def __str__(self):
return f'''
......
......@@ -54,7 +54,7 @@ setup(
packages=find_packages(exclude=['docs','__pycache__','data','experiments','log',]),
install_requires=['django', 'django-concurrency'],
install_requires=['django', 'django-concurrency', 'pyzmq'],
include_package_data=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment