Commit 21af9c7c authored by Michael Salim's avatar Michael Salim
Browse files

Merge branch 'feature/db-init' into develop

parents 5250ce12 aabcfbd1
......@@ -19,3 +19,5 @@ experiments
docs/_build/*
docs/_static/*
*.egg-info
default_balsamdb
import argparse
from importlib.util import find_spec
import glob
import os
import sys
import signal
import subprocess
import time
os.environ['IS_SERVER_DAEMON']="True"
from balsam.django_config.settings import resolve_db_path
from serverinfo import ServerInfo
CHECK_PERIOD = 4
TERM_LINGER = 30
PYTHON = sys.executable
SQLITE_SERVER = find_spec('balsam.django_config.sqlite_server').origin
DB_COMMANDS = {
'sqlite3' : f'{PYTHON} {SQLITE_SERVER}',
'postgres': f'',
'mysql' : f'',
}
term_start = 0
def run(cmd):
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT)
return proc
def stop(proc):
proc.terminate()
print("Balsam server shutdown...", flush=True)
try: retcode = proc.wait(timeout=30)
except subprocess.TimeoutExpired:
print("Warning: server did not quit gracefully")
proc.kill()
def main(db_path):
serverinfo = ServerInfo(db_path)
serverinfo.reset_server_address()
server_type = serverinfo['db_type']
db_cmd = f"BALSAM_DB_PATH={db_path} " + DB_COMMANDS[server_type].format(**serverinfo.data)
print(f"\nStarting balsam DB server daemon for DB at {db_path}")
proc = run(db_cmd)
# On SIGUSR1, stop immediately ("balsam server --stop" does this)
def handle_stop(signum, stack):
stop(proc)
serverinfo.update({'address': None})
sys.exit(0)
signal.signal(signal.SIGINT, handle_stop)
signal.signal(signal.SIGTERM, handle_stop)
signal.signal(signal.SIGUSR1, handle_stop)
while not term_start or time.time() - term_start < TERM_LINGER:
try:
retcode = proc.wait(timeout=CHECK_PERIOD)
except subprocess.TimeoutExpired:
pass
else:
print("\nserver process stopped unexpectedly; restarting")
serverinfo.reset_server_address()
db_cmd = f"BALSAM_DB_PATH={db_path} " + DB_COMMANDS[server_type].format(**serverinfo.data)
proc = run(db_cmd)
stop(proc)
serverinfo.update({'address': None})
if __name__ == "__main__":
input_path = sys.argv[1] if len(sys.argv) == 2 else None
db_path = resolve_db_path(input_path)
main(db_path)
import json
import os
import socket
ADDRESS_FNAME = 'dbwriter_address'
class ServerInfo:
def __init__(self, balsam_db_path):
self.path = os.path.join(balsam_db_path, ADDRESS_FNAME)
self.data = {}
if not os.path.exists(self.path):
self.update(self.data)
else:
self.refresh()
if self.data.get('address') and os.environ.get('IS_SERVER_DAEMON')=='True':
raise RuntimeError(f"A running server address is already posted at {self.path}\n"
' Use "balsam dbserver --stop" to shut it down.\n'
' If you are sure there is no running server process, the'
' daemon did not have a clean shutdown.\n Use "balsam'
' dbserver --reset <balsam_db_directory>" to reset the server file'
)
def get_free_port_and_address(self):
hostname = socket.gethostname()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('', 0))
port = int(sock.getsockname()[1])
sock.close()
address = f'tcp://{hostname}:{port}'
return address
def get_sqlite3_info(self):
new_address = self.get_free_port_and_address()
info = dict(db_type='sqlite3', address=new_address)
return info
def reset_server_address(self):
db = self['db_type']
info = getattr(self, f'get_{db}_info')()
self.update(info)
def update(self, update_dict):
self.refresh()
self.data.update(update_dict)
with open(self.path, 'w') as fp:
fp.write(json.dumps(self.data))
def get(self, key, default=None):
if key in self.data:
return self.data[key]
else:
return default
def refresh(self):
if not os.path.exists(self.path): return
with open(self.path, 'r') as fp:
self.data = json.loads(fp.read())
def __getitem__(self, key):
if self.data is None: self.refresh()
return self.data[key]
def __setitem__(self, key, value):
self.update({key:value})
......@@ -10,18 +10,159 @@ For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.9/ref/settings/
"""
import os,logging
logger = logging.getLogger(__name__)
import os
import sys
from balsam.django_config import serverinfo
from balsam.user_settings import *
# ---------------
# DATABASE SETUP
# ---------------
def resolve_db_path(path=None):
if path:
assert os.path.exists(path)
elif os.environ.get('BALSAM_DB_PATH'):
path = os.environ['BALSAM_DB_PATH']
assert os.path.exists(path)
else:
path = default_db_path
return path
def configure_db_backend(db_path):
ENGINES = {
'sqlite3' : 'django.db.backends.sqlite3',
}
NAMES = {
'sqlite3' : 'db.sqlite3',
}
OPTIONS = {
'sqlite3' : {'timeout' : 5000},
}
info = serverinfo.ServerInfo(db_path)
db_type = info['db_type']
user = info.get('user', '')
password = info.get('password', '')
db_name = os.path.join(db_path, NAMES[db_type])
db = dict(ENGINE=ENGINES[db_type], NAME=db_name,
OPTIONS=OPTIONS[db_type], USER=user, PASSWORD=password)
DATABASES = {'default':db}
return DATABASES
CONCURRENCY_ENABLED = True
BALSAM_PATH = resolve_db_path()
DATABASES = configure_db_backend(BALSAM_PATH)
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
# --------------------
# SUBDIRECTORY SETUP
# --------------------
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONCURRENCY_ENABLED = True
LOGGING_DIRECTORY = os.path.join(BALSAM_PATH , 'log')
DATA_PATH = os.path.join(BALSAM_PATH ,'data')
BALSAM_WORK_DIRECTORY = DATA_PATH
for d in [
BALSAM_PATH ,
DATA_PATH,
LOGGING_DIRECTORY,
BALSAM_WORK_DIRECTORY,
]:
if not os.path.exists(d):
os.makedirs(d)
# ----------------
# LOGGING SETUP
# ----------------
HANDLER_FILE = os.path.join(LOGGING_DIRECTORY, LOG_FILENAME)
BALSAM_DB_CONFIG_LOG = os.path.join(LOGGING_DIRECTORY, "balsamdb-config.log")
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'standard': {
'format' : '%(asctime)s|%(process)d|%(levelname)8s|%(name)s:%(lineno)s] %(message)s',
'datefmt' : "%d-%b-%Y %H:%M:%S"
},
},
'handlers': {
'console': {
'class':'logging.StreamHandler',
'formatter': 'standard',
'level' : 'DEBUG'
},
'default': {
'level':LOG_HANDLER_LEVEL,
'class':'logging.handlers.RotatingFileHandler',
'filename': HANDLER_FILE,
'maxBytes': LOG_FILE_SIZE_LIMIT,
'backupCount': LOG_BACKUP_COUNT,
'formatter': 'standard',
},
'balsam-db-config': {
'level':LOG_HANDLER_LEVEL,
'class':'logging.handlers.RotatingFileHandler',
'filename': BALSAM_DB_CONFIG_LOG,
'maxBytes': LOG_FILE_SIZE_LIMIT,
'backupCount': LOG_BACKUP_COUNT,
'formatter': 'standard',
},
'django': {
'level': LOG_HANDLER_LEVEL,
'class':'logging.handlers.RotatingFileHandler',
'filename': os.path.join(LOGGING_DIRECTORY, 'django.log'),
'maxBytes': LOG_FILE_SIZE_LIMIT,
'backupCount': LOG_BACKUP_COUNT,
'formatter': 'standard',
},
},
'loggers': {
'django': {
'handlers': ['django'],
'level': 'DEBUG',
'propagate': True,
},
'balsam': {
'handlers': ['default'],
'level': 'DEBUG',
'propagate': True,
},
'balsam.django_config': {
'handlers': ['balsam-db-config'],
'level': 'DEBUG',
'propagate': False,
},
'balsam.service.models': {
'handlers': ['balsam-db-config'],
'level': 'DEBUG',
'propagate': False,
},
}
}
import logging
logger = logging.getLogger(__name__)
def log_uncaught_exceptions(exctype, value, tb,logger=logger):
logger.error(f"Uncaught Exception {exctype}: {value}",exc_info=(exctype,value,tb))
logger = logging.getLogger('console')
logger.error(f"Uncaught Exception {exctype}: {value}",exc_info=(exctype,value,tb))
sys.excepthook = log_uncaught_exceptions
# -----------------------
# SQLITE CLIENT SETUP
# ------------------------
is_server = os.environ.get('IS_BALSAM_SERVER')=='True'
is_daemon = os.environ.get('IS_SERVER_DAEMON')=='True'
using_sqlite = DATABASES['default']['ENGINE'].endswith('sqlite3')
SAVE_CLIENT = None
if using_sqlite and not (is_server or is_daemon):
from balsam.django_config import sqlite_client
SAVE_CLIENT = sqlite_client.Client(serverinfo.ServerInfo(BALSAM_PATH))
if SAVE_CLIENT.serverAddr is None:
SAVE_CLIENT = None
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = '=gyp#o9ac0@w3&-^@a)j&f#_n-o=k%z2=g5u@z5+klmh_*hebj'
......@@ -31,7 +172,6 @@ DEBUG = True
ALLOWED_HOSTS = []
# Application definition
INSTALLED_APPS = [
......
from io import StringIO
from traceback import print_exc
import json
import os
import uuid
import zmq
from django.db.utils import OperationalError
from concurrency.exceptions import RecordModifiedError
REQ_TIMEOUT = 10000 # 10 seconds
REQ_RETRY = 3
class Client:
def __init__(self, server_info):
import logging
self.logger = logging.getLogger(__name__)
self.server_info = server_info
self.serverAddr = self.server_info.get('address')
self.first_message = True
if self.serverAddr:
try:
response = self.send_request('TEST_ALIVE', timeout=300)
except:
raise RuntimeError("Cannot reach server at {self.serverAddr}")
else:
if response != 'ACK':
self.logger.exception(f"sqlite client cannot reach DB write server")
raise RuntimeError("Cannot reach server at {self.serverAddr}")
def send_request(self, msg, timeout=None):
if timeout is None:
timeout = REQ_TIMEOUT
if self.first_message:
self.first_message = False
self.logger.debug(f"Connected to DB write server at {self.serverAddr}")
context = zmq.Context(1)
poll = zmq.Poller()
for retry in range(REQ_RETRY):
client = context.socket(zmq.REQ)
client.connect(self.serverAddr)
poll.register(client, zmq.POLLIN)
client.send_string(msg)
socks = dict(poll.poll(timeout))
if socks.get(client) == zmq.POLLIN:
reply = client.recv()
client.close()
poll.unregister(client)
context.term()
self.logger.debug(f"received reply: {reply}")
return reply.decode('utf-8')
else:
self.logger.debug("No response from server, retrying...")
client.setsockopt(zmq.LINGER, 0)
client.close()
poll.unregister(client)
self.server_info.refresh()
self.serverAddr = self.server_info['address']
self.logger.debug(f"Connecting to DB write server at {self.serverAddr}")
context.term()
raise OperationalError(f"Sqlite client save request failed after "
f"{REQ_RETRY} retries: is the server down?")
def save(self, job, force_insert=False, force_update=False, using=None, update_fields=None):
serial_data = job.serialize(force_insert=force_insert,
force_update=force_update, using=using,
update_fields=update_fields)
self.logger.info(f"client: sending request for save of {job.cute_id}")
response = self.send_request(serial_data)
if response == 'ACK_RECORD_MODIFIED':
raise RecordModifiedError(target=job)
else:
assert response.startswith('ACK_SAVE')
job_id = uuid.UUID(response.split()[1])
if job.job_id is None:
job.job_id = job_id
else:
assert job.job_id == job_id
from io import StringIO
from traceback import print_exc
import json
import os
import logging
import time
import zmq
import signal
os.environ['IS_BALSAM_SERVER']="True"
os.environ['IS_SERVER_DAEMON']="False"
os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings'
import django
django.setup()
from balsam.service.models import BalsamJob
from balsam.django_config import serverinfo
from concurrency.exceptions import RecordModifiedError
logger = logging.getLogger('balsam.django_config.sqlite_server')
SERVER_PERIOD = 1000
TERM_LINGER = 3 # wait 3 sec after final save() to exit
terminate = False
class ZMQServer:
def __init__(self, db_path):
# connect to local sqlite DB thru ORM
self.BalsamJob = BalsamJob
self.info = serverinfo.ServerInfo(db_path)
self.address = self.info['address']
port = int(self.address.split(':')[2])
self.context = zmq.Context(1)
self.socket = self.context.socket(zmq.REP)
self.socket.bind(f'tcp://*:{port}')
logger.info(f"db_writer bound to socket @ {self.address}")
def recv_request(self):
events = self.socket.poll(timeout=SERVER_PERIOD)
if events:
message = self.socket.recv().decode('utf-8')
logger.debug(f'request: {message}')
else:
message = None
return message
def send_reply(self, msg):
self.socket.send_string(msg)
logger.debug(f"Sent reply {msg}")
def save(self, job_msg):
d = json.loads(job_msg)
job = self.BalsamJob.from_dict(d)
force_insert = d['force_insert']
force_update = d['force_update']
using = d['using']
update_fields = d['update_fields']
job.save(force_insert, force_update, using, update_fields)
logger.info(f"db_writer Saved {job.cute_id}")
return time.time(), job.pk
def server_main(db_path):
logger.debug("hello from server_main")
parent_pid = os.getppid()
global terminate
def handler(signum, stack):
global terminate
terminate = True
logger.debug("Got sigterm; will shut down soon")
signal.signal(signal.SIGINT, handler)
signal.signal(signal.SIGTERM, handler)
server = ZMQServer(db_path)
last_save_time = time.time()
while not terminate or time.time() - last_save_time < TERM_LINGER:
message = server.recv_request()
if terminate:
logger.debug(f"shut down in {TERM_LINGER - (time.time()-last_save_time)} seconds")
if message is None:
if os.getppid() != parent_pid:
logger.info("detected parent died; server quitting soon")
terminate = True
elif 'job_id' in message:
try:
last_save_time, job_id = server.save(message)
except RecordModifiedError:
server.send_reply("ACK_RECORD_MODIFIED")
logger.debug("sending ACK_RECORD_MODIFIED")
else:
server.send_reply(f"ACK_SAVE {job_id}")
logger.debug("sending ACK_SAVE")
else:
logger.debug("sending ACK")
server.send_reply("ACK")
if __name__ == "__main__":
db_path = os.environ['BALSAM_DB_PATH']
try:
server_main(db_path)
except:
buf = StringIO()
print_exc(file=buf)
logger.exception(f"db_writer Uncaught exception:\n%s", buf.getvalue())
finally:
logger.info("exiting server main")
......@@ -38,7 +38,6 @@ from balsam.launcher import transitions
from balsam.launcher import worker
from balsam.launcher import runners
from balsam.launcher.exceptions import *
from balsam.service import db_writer
ALMOST_RUNNABLE_STATES = ['READY','STAGED_IN']
RUNNABLE_STATES = ['PREPROCESSED', 'RESTART_READY']
......@@ -202,7 +201,7 @@ def main(args, transition_pool, runner_group, job_source):
logger.info("No jobs to process. Exiting main loop now.")
break
def on_exit(runner_group, transition_pool, job_source, writer_proc):
def on_exit(runner_group, transition_pool, job_source):
'''Exit cleanup'''
global HANDLING_EXIT
if HANDLING_EXIT: return
......@@ -215,8 +214,6 @@ def on_exit(runner_group, transition_pool, job_source, writer_proc):
logger.debug("on_exit: send end message to transition threads")
transition_pool.end_and_wait()
client = db_writer.ZMQClient()
client.term_server()
logger.debug("on_exit: Launcher exit graceful\n\n")
sys.exit(0)
......@@ -250,22 +247,10 @@ def detect_dead_runners(job_source):
logger.info(f'Picked up dead running job {job.cute_id}: marking RESTART_READY')
job.update_state('RESTART_READY', 'Detected dead runner')
def launch_db_writer_process():
import multiprocessing
INSTALL_PATH = settings.INSTALL_PATH
path = os.path.join(INSTALL_PATH, db_writer.SOCKFILE_NAME)
if os.path.exists(path):
os.remove(path)
writer_proc = multiprocessing.Process(target=db_writer.server_main)
writer_proc.daemon = True
writer_proc.start()
return writer_proc
if __name__ == "__main__":
args = get_args()
writer_proc = launch_db_writer_process()
job_source = jobreader.JobReader.from_config(args)
job_source.refresh_from_db()
transition_pool = transitions.TransitionProcessPool()
......@@ -276,10 +261,10 @@ if __name__ == "__main__":
detect_dead_runners(job_source)
handl = lambda a,b: on_exit(runner_group, transition_pool, job_source, writer_proc)
handl = lambda a,b: on_exit(runner_group, transition_pool, job_source)
signal.signal(signal.SIGINT, handl)
signal.signal(signal.SIGTERM, handl)
signal.signal(signal.SIGHUP, handl)
main(args, transition_pool, runner_group, job_source)
on_exit(runner_group, transition_pool, job_source, writer_proc)
on_exit(runner_group, transition_pool, job_source)
......@@ -48,15 +48,12 @@ logger = logging.getLogger('balsam.launcher.transitions')
class DummyLock:
def acquire(self): pass
def release(self): pass
if sys.platform.startswith('darwin'):
LockClass = multiprocessing.Lock
elif sys.platform.startswith('win32'):
LockClass = multiprocessing.Lock