Commit f42b6739 authored by Michael Salim's avatar Michael Salim
Browse files

Init Postgres backend works; added "balsam which"; Can run sqlite/Postgres...

Init Postgres backend works; added "balsam which"; Can run sqlite/Postgres backends from "balsam dbserver"; Upped timeouts/queue depths for ZMQ Sqlite server
parent 91279287
import argparse import argparse
from importlib.util import find_spec from importlib.util import find_spec
import glob import glob
import getpass
import os import os
import sys import sys
import signal import signal
...@@ -18,26 +19,46 @@ PYTHON = sys.executable ...@@ -18,26 +19,46 @@ PYTHON = sys.executable
SQLITE_SERVER = find_spec('balsam.django_config.sqlite_server').origin SQLITE_SERVER = find_spec('balsam.django_config.sqlite_server').origin
DB_COMMANDS = { DB_COMMANDS = {
'sqlite3' : f'{PYTHON} {SQLITE_SERVER}', 'sqlite3' : f'{PYTHON} {SQLITE_SERVER}',
'postgres': f'', 'postgres': f'pg_ctl -D {{pg_db_path}} -w start',
'mysql' : f'', 'mysql' : f'',
} }
term_start = 0 term_start = 0
def run(cmd): def run(cmd):
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.DEVNULL, proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT) stderr=subprocess.STDOUT)
return proc return proc
def stop(proc): def stop(proc, serverinfo):
proc.terminate()
print("Balsam server shutdown...", flush=True) print("Balsam server shutdown...", flush=True)
try: retcode = proc.wait(timeout=30) if serverinfo['db_type'] == 'postgres':
except subprocess.TimeoutExpired: cmd = f'pg_ctl -D {{pg_db_path}} -w stop'.format(**serverinfo.data)
print("Warning: server did not quit gracefully") print(cmd)
proc.kill() proc = subprocess.Popen(cmd, shell=True)
time.sleep(2)
else:
proc.terminate()
try: retcode = proc.wait(timeout=30)
except subprocess.TimeoutExpired:
print("Warning: server did not quit gracefully")
proc.kill()
def wait(proc, serverinfo):
if serverinfo['db_type'] == 'sqlite3':
retcode = proc.wait(timeout=CHECK_PERIOD)
elif serverinfo['db_type'] == 'postgres':
time.sleep(CHECK_PERIOD)
user = getpass.getuser()
proc = subprocess.Popen('ps aux | grep {user} | grep postgres | '
'grep -v grep', shell=True, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
stdout, _ = proc.communicate()
lines = stdout.decode('utf-8').split('\n')
if len(lines) >= 1: raise subprocess.TimeoutExpired('cmd', CHECK_PERIOD)
def main(db_path): def main(db_path):
serverinfo = ServerInfo(db_path) serverinfo = ServerInfo(db_path)
...@@ -46,13 +67,14 @@ def main(db_path): ...@@ -46,13 +67,14 @@ def main(db_path):
db_cmd = f"BALSAM_DB_PATH={db_path} " + DB_COMMANDS[server_type].format(**serverinfo.data) db_cmd = f"BALSAM_DB_PATH={db_path} " + DB_COMMANDS[server_type].format(**serverinfo.data)
print(f"\nStarting balsam DB server daemon for DB at {db_path}") print(f"\nStarting balsam DB server daemon for DB at {db_path}")
print(db_cmd)
proc = run(db_cmd) proc = run(db_cmd)
# On SIGUSR1, stop immediately ("balsam server --stop" does this) # On SIGUSR1, stop immediately ("balsam server --stop" does this)
def handle_stop(signum, stack): def handle_stop(signum, stack):
stop(proc) stop(proc, serverinfo)
serverinfo.update({'address': None}) serverinfo.update({'address': None, 'host':None,'port':None})
sys.exit(0) sys.exit(0)
signal.signal(signal.SIGINT, handle_stop) signal.signal(signal.SIGINT, handle_stop)
...@@ -61,17 +83,18 @@ def main(db_path): ...@@ -61,17 +83,18 @@ def main(db_path):
while not term_start or time.time() - term_start < TERM_LINGER: while not term_start or time.time() - term_start < TERM_LINGER:
try: try:
retcode = proc.wait(timeout=CHECK_PERIOD) wait(proc, serverinfo)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
pass pass
else: else:
print("\nserver process stopped unexpectedly; restarting") print("\nserver process stopped unexpectedly; restarting")
serverinfo.reset_server_address() serverinfo.reset_server_address()
db_cmd = f"BALSAM_DB_PATH={db_path} " + DB_COMMANDS[server_type].format(**serverinfo.data) db_cmd = f"BALSAM_DB_PATH={db_path} " + DB_COMMANDS[server_type].format(**serverinfo.data)
print(db_cmd)
proc = run(db_cmd) proc = run(db_cmd)
stop(proc) stop(proc, serverinfo)
serverinfo.update({'address': None}) serverinfo.update({'address': None, 'host':None,'port':None})
if __name__ == "__main__": if __name__ == "__main__":
input_path = sys.argv[1] if len(sys.argv) == 2 else None input_path = sys.argv[1] if len(sys.argv) == 2 else None
......
...@@ -24,6 +24,13 @@ class ServerInfo: ...@@ -24,6 +24,13 @@ class ServerInfo:
' daemon did not have a clean shutdown.\n Use "balsam' ' daemon did not have a clean shutdown.\n Use "balsam'
' dbserver --reset <balsam_db_directory>" to reset the server file' ' dbserver --reset <balsam_db_directory>" to reset the server file'
) )
if self.data.get('host') and os.environ.get('IS_SERVER_DAEMON')=='True':
raise RuntimeError(f"A running server address is already posted at {self.path}\n"
' Use "balsam dbserver --stop" to shut it down.\n'
' If you are sure there is no running server process, the'
' daemon did not have a clean shutdown.\n Use "balsam'
' dbserver --reset <balsam_db_directory>" to reset the server file'
)
def get_free_port_and_address(self): def get_free_port_and_address(self):
hostname = socket.gethostname() hostname = socket.gethostname()
...@@ -50,13 +57,32 @@ class ServerInfo: ...@@ -50,13 +57,32 @@ class ServerInfo:
def get_postgres_info(self): def get_postgres_info(self):
hostname = socket.gethostname() hostname = socket.gethostname()
port = self.get_free_port() port = self.get_free_port()
info = dict(host=hostname, port=port) pg_db_path = os.path.join(self['balsamdb_path'], 'balsamdb')
info = dict(host=hostname, port=port, pg_db_path=pg_db_path)
return info return info
def update_sqlite3_config(self):
pass
def update_postgres_config(self):
conf_path = os.path.join(self['pg_db_path'], 'postgresql.conf')
config = open(conf_path).read()
with open(f"{conf_path}.new", 'w') as fp:
for line in config.split('\n'):
if line.startswith('port'):
port_line = f"port={self['port']} # auto-set by balsam db\n"
fp.write(port_line)
else:
fp.write(line + "\n")
os.rename(f"{conf_path}.new", conf_path)
def reset_server_address(self): def reset_server_address(self):
db = self['db_type'] db = self['db_type']
info = getattr(self, f'get_{db}_info')() info = getattr(self, f'get_{db}_info')()
self.update(info) self.update(info)
getattr(self, f'update_{db}_config')()
def update(self, update_dict): def update(self, update_dict):
self.refresh() self.refresh()
......
...@@ -58,7 +58,7 @@ def configure_db_backend(db_path): ...@@ -58,7 +58,7 @@ def configure_db_backend(db_path):
db = dict(ENGINE=ENGINES[db_type], NAME=db_name, db = dict(ENGINE=ENGINES[db_type], NAME=db_name,
OPTIONS=OPTIONS[db_type], USER=user, PASSWORD=password, OPTIONS=OPTIONS[db_type], USER=user, PASSWORD=password,
HOST=host, PORT=port) HOST=host, PORT=port, CONN_MAX_AGE=60)
DATABASES = {'default':db} DATABASES = {'default':db}
return DATABASES return DATABASES
...@@ -120,21 +120,21 @@ LOGGING = { ...@@ -120,21 +120,21 @@ LOGGING = {
'backupCount': LOG_BACKUP_COUNT, 'backupCount': LOG_BACKUP_COUNT,
'formatter': 'standard', 'formatter': 'standard',
}, },
'django': { #'django': {
'level': LOG_HANDLER_LEVEL, # 'level': LOG_HANDLER_LEVEL,
'class':'logging.handlers.RotatingFileHandler', # 'class':'logging.handlers.RotatingFileHandler',
'filename': os.path.join(LOGGING_DIRECTORY, 'django.log'), # 'filename': os.path.join(LOGGING_DIRECTORY, 'django.log'),
'maxBytes': LOG_FILE_SIZE_LIMIT, # 'maxBytes': LOG_FILE_SIZE_LIMIT,
'backupCount': LOG_BACKUP_COUNT, # 'backupCount': LOG_BACKUP_COUNT,
'formatter': 'standard', # 'formatter': 'standard',
}, #},
}, },
'loggers': { 'loggers': {
'django': { #'django': {
'handlers': ['django'], # 'handlers': ['django'],
'level': 'DEBUG', # 'level': 'DEBUG',
'propagate': True, # 'propagate': True,
}, #},
'balsam': { 'balsam': {
'handlers': ['default'], 'handlers': ['default'],
'level': 'DEBUG', 'level': 'DEBUG',
......
...@@ -10,7 +10,7 @@ from concurrency.exceptions import RecordModifiedError ...@@ -10,7 +10,7 @@ from concurrency.exceptions import RecordModifiedError
# These are ridiculously high to benchmark # These are ridiculously high to benchmark
# Should be more like 5-10 sec, 3-4 retry # Should be more like 5-10 sec, 3-4 retry
REQ_TIMEOUT = 60000 # 60 seconds REQ_TIMEOUT = 300000 # 5 minutes
REQ_RETRY = 56 REQ_RETRY = 56
...@@ -23,7 +23,7 @@ class Client: ...@@ -23,7 +23,7 @@ class Client:
self.first_message = True self.first_message = True
if self.serverAddr: if self.serverAddr:
try: try:
response = self.send_request('TEST_ALIVE', timeout=3000) response = self.send_request('TEST_ALIVE', timeout=30000)
except: except:
raise RuntimeError("Cannot reach server at {self.serverAddr}") raise RuntimeError("Cannot reach server at {self.serverAddr}")
else: else:
......
...@@ -34,7 +34,12 @@ class ZMQServer: ...@@ -34,7 +34,12 @@ class ZMQServer:
self.address = self.info['address'] self.address = self.info['address']
port = int(self.address.split(':')[2]) port = int(self.address.split(':')[2])
self.context = zmq.Context(1) self.context = zmq.Context(4)
self.context.setsockopt(zmq.BACKLOG, 32768)
self.context.setsockopt(zmq.SNDHWM, 32768)
self.context.setsockopt(zmq.RCVHWM, 32768)
self.context.setsockopt(zmq.SNDBUF, 1000000000)
self.context.setsockopt(zmq.RCVBUF, 1000000000)
self.socket = self.context.socket(zmq.REP) self.socket = self.context.socket(zmq.REP)
self.socket.bind(f'tcp://*:{port}') self.socket.bind(f'tcp://*:{port}')
logger.info(f"db_writer bound to socket @ {self.address}") logger.info(f"db_writer bound to socket @ {self.address}")
......
...@@ -4,7 +4,7 @@ import argparse ...@@ -4,7 +4,7 @@ import argparse
import sys import sys
from balsam.scripts.cli_commands import newapp,newjob,newdep,ls,modify,rm,qsub from balsam.scripts.cli_commands import newapp,newjob,newdep,ls,modify,rm,qsub
from balsam.scripts.cli_commands import kill,mkchild,launcher,service,make_dummies from balsam.scripts.cli_commands import kill,mkchild,launcher,service,make_dummies
from balsam.scripts.cli_commands import dbserver, init from balsam.scripts.cli_commands import dbserver, init, which
def main(): def main():
parser = make_parser() parser = make_parser()
...@@ -335,6 +335,11 @@ def make_parser(): ...@@ -335,6 +335,11 @@ def make_parser():
parser_dummy = subparsers.add_parser('make_dummies') parser_dummy = subparsers.add_parser('make_dummies')
parser_dummy.add_argument('num', type=int) parser_dummy.add_argument('num', type=int)
parser_dummy.set_defaults(func=make_dummies) parser_dummy.set_defaults(func=make_dummies)
# WHICH
# ---------
parser_which = subparsers.add_parser('which')
parser_which.set_defaults(func=which)
return parser return parser
......
...@@ -231,11 +231,11 @@ def rm(args): ...@@ -231,11 +231,11 @@ def rm(args):
# Are we removing jobs or apps? # Are we removing jobs or apps?
if objects_name.startswith('job'): cls = Job if objects_name.startswith('job'): cls = Job
elif objects_name.startswith('app'): cls = AppDef elif objects_name.startswith('app'): cls = AppDef
objects = cls.objects.all() objects = cls.objects
# Filter: all objects, by name-match (multiple), or by ID (unique)? # Filter: all objects, by name-match (multiple), or by ID (unique)?
if deleteall: if deleteall:
deletion_objs = objects deletion_objs = objects.all()
message = f"ALL {objects_name}" message = f"ALL {objects_name}"
elif name: elif name:
deletion_objs = objects.filter(name__icontains=name) deletion_objs = objects.filter(name__icontains=name)
...@@ -259,10 +259,8 @@ def rm(args): ...@@ -259,10 +259,8 @@ def rm(args):
return return
# Actually delete things here # Actually delete things here
for obj in deletion_objs: deletion_objs.delete()
msg = f"Deleted {objects_name[:-1]} {obj.cute_id}" print("Deleted.")
obj.delete()
print(msg)
def qsub(args): def qsub(args):
...@@ -371,7 +369,7 @@ def dbserver(args): ...@@ -371,7 +369,7 @@ def dbserver(args):
sys.exit(0) sys.exit(0)
else: else:
info = serverinfo.ServerInfo(args.reset) info = serverinfo.ServerInfo(args.reset)
info.update({'address': None}) info.update({'address': None, 'host':None, 'port':None})
print("Reset done") print("Reset done")
sys.exit(0) sys.exit(0)
...@@ -380,10 +378,10 @@ def dbserver(args): ...@@ -380,10 +378,10 @@ def dbserver(args):
if not server_pids: if not server_pids:
print(f"No db_daemon processes running under {getpass.getuser()}") print(f"No db_daemon processes running under {getpass.getuser()}")
else: else:
assert len(server_pids) == 1 assert len(server_pids) >= 1
pid = server_pids[0] for pid in server_pids:
print(f"Stopping db_daemon {pid}") print(f"Stopping db_daemon {pid}")
os.kill(pid, signal.SIGUSR1) os.kill(pid, signal.SIGUSR1)
else: else:
path = args.path path = args.path
if path: cmd = [sys.executable, fname, path] if path: cmd = [sys.executable, fname, path]
...@@ -415,6 +413,13 @@ def init(args): ...@@ -415,6 +413,13 @@ def init(args):
p.wait() p.wait()
def which(args):
os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings'
django.setup()
from django.conf import settings
import pprint
pprint.pprint(settings.DATABASES['default'])
def make_dummies(args): def make_dummies(args):
os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings' os.environ['DJANGO_SETTINGS_MODULE'] = 'balsam.django_config.settings'
django.setup() django.setup()
......
from getpass import getuser from getpass import getuser
import os import os
import sys import sys
from pprint import pprint
import time import time
import subprocess import subprocess
from balsam.django_config.serverinfo import ServerInfo from balsam.django_config.serverinfo import ServerInfo
...@@ -14,12 +15,21 @@ def postgres_init(serverInfo): ...@@ -14,12 +15,21 @@ def postgres_init(serverInfo):
p = subprocess.Popen(f'initdb -D {db_path} -U $USER', shell=True) p = subprocess.Popen(f'initdb -D {db_path} -U $USER', shell=True)
retcode = p.wait() retcode = p.wait()
if retcode != 0: raise RuntimeError("initdb failed") if retcode != 0: raise RuntimeError("initdb failed")
with open(os.path.join(db_path, 'postgresql.conf'), 'a') as fp:
fp.write("listen_addresses = '*' # appended from balsam init\n")
fp.write('port=0 # appended from balsam init\n')
fp.write('max_connections=128 # appended from balsam init\n')
fp.write('shared_buffers=2GB # appended from balsam init\n')
fp.write('synchronous_commit=off # appended from balsam init\n')
fp.write('wal_writer_delay=400ms # appended from balsam init\n')
with open(os.path.join(db_path, 'pg_hba.conf'), 'a') as fp:
fp.write(f"host all all 0.0.0.0/0 trust\n")
serverInfo.update({'user' : getuser()}) serverInfo.update({'user' : getuser()})
serverInfo.reset_server_address() serverInfo.reset_server_address()
port = serverInfo['port'] port = serverInfo['port']
with open(os.path.join(db_path, 'postgresql.conf'), 'a') as fp:
fp.write(f'port={port} # appended from balsam init\n')
serv_proc = subprocess.Popen(f'pg_ctl -D {db_path} -w start', shell=True) serv_proc = subprocess.Popen(f'pg_ctl -D {db_path} -w start', shell=True)
time.sleep(2) time.sleep(2)
...@@ -31,8 +41,8 @@ def postgres_post(serverInfo): ...@@ -31,8 +41,8 @@ def postgres_post(serverInfo):
db_path = serverInfo['balsamdb_path'] db_path = serverInfo['balsamdb_path']
db_path = os.path.join(db_path, 'balsamdb') db_path = os.path.join(db_path, 'balsamdb')
serv_proc = subprocess.Popen(f'pg_ctl -D {db_path} -w stop', shell=True) serv_proc = subprocess.Popen(f'pg_ctl -D {db_path} -w stop', shell=True)
serv_proc.wait() time.sleep(1)
serverInfo.update({'host':None, 'port':None})
def run_migrations(): def run_migrations():
import django import django
...@@ -45,16 +55,21 @@ def run_migrations(): ...@@ -45,16 +55,21 @@ def run_migrations():
print(f"DB settings:", settings.DATABASES['default']) print(f"DB settings:", settings.DATABASES['default'])
db_path = db.connection.settings_dict['NAME'] db_info = db.connection.settings_dict['NAME']
print(f"Setting up new balsam database: {db_path}") print(f"Setting up new balsam database:")
pprint(db_info, width=60)
call_command('makemigrations', interactive=False, verbosity=0) call_command('makemigrations', interactive=False, verbosity=0)
call_command('migrate', interactive=False, verbosity=0) call_command('migrate', interactive=False, verbosity=0)
new_path = settings.DATABASES['default']['NAME'] try:
if os.path.exists(new_path): from balsam.service.models import BalsamJob
print(f"Set up new DB at {new_path}") j = BalsamJob()
j.save()
j.delete()
except:
raise RuntimeError("BalsamJob table not properly created")
else: else:
raise RuntimeError(f"Failed to created DB at {new_path}") print("BalsamJob table created successfully")
if __name__ == "__main__": if __name__ == "__main__":
serverInfo = ServerInfo(sys.argv[1]) serverInfo = ServerInfo(sys.argv[1])
...@@ -70,3 +85,4 @@ if __name__ == "__main__": ...@@ -70,3 +85,4 @@ if __name__ == "__main__":
run_migrations() run_migrations()
if db_type == 'postgres': if db_type == 'postgres':
postgres_post(serverInfo) postgres_post(serverInfo)
print("OK")
...@@ -12,7 +12,6 @@ class BalsamTestCase(unittest.TestCase): ...@@ -12,7 +12,6 @@ class BalsamTestCase(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
test_db_path = os.environ['BALSAM_DB_PATH'] test_db_path = os.environ['BALSAM_DB_PATH']
assert test_db_path in db.connection.settings_dict['NAME']
assert 'test' in test_db_path assert 'test' in test_db_path
call_command('makemigrations',interactive=False,verbosity=0) call_command('makemigrations',interactive=False,verbosity=0)
...@@ -29,9 +28,6 @@ class BalsamTestCase(unittest.TestCase): ...@@ -29,9 +28,6 @@ class BalsamTestCase(unittest.TestCase):
pass # to be implemented by test cases pass # to be implemented by test cases
def tearDown(self): def tearDown(self):
test_db_path = os.environ['BALSAM_DB_PATH']
if not test_db_path in db.connection.settings_dict['NAME']:
raise RuntimeError("Test DB not configured")
call_command('flush',interactive=False,verbosity=0) call_command('flush',interactive=False,verbosity=0)
......
...@@ -19,7 +19,10 @@ class TestInsertion(BalsamTestCase): ...@@ -19,7 +19,10 @@ class TestInsertion(BalsamTestCase):
self.launcherInfo = util.launcher_info() self.launcherInfo = util.launcher_info()
max_workers = self.launcherInfo.num_workers max_workers = self.launcherInfo.num_workers
worker_counts = takewhile(lambda x: x<=max_workers, (2**i for i in range(20))) worker_counts = list(takewhile(lambda x: x<=max_workers, (2**i for i in range(20))))
if max_workers not in worker_counts:
worker_counts.append(max_workers)
worker_counts = list(reversed(worker_counts))
#ranks_per_node = [4, 8, 16, 32] #ranks_per_node = [4, 8, 16, 32]
ranks_per_node = [32] ranks_per_node = [32]
self.experiments = product(worker_counts, ranks_per_node) self.experiments = product(worker_counts, ranks_per_node)
......
# BENCHMARK: test_concurrent_mpi_insert
# Host: thetamom1
# COBALT_BLOCKNAME: 2810-2813,2816,3171,3178-3179,4253-4255,4318,4408-4409,4446,4579
# COBALT_PARTNAME: 2810-2813,2816,3171,3178-3179,4253-4255,4318,4408-4409,4446,4579
# COBALT_JOBID: 181696
# COBALT_PARTSIZE: 16
# COBALT_NODEFILE: /var/tmp/cobalt.181696
# COBALT_JOBSIZE: 16
# COBALT_BLOCKSIZE: 16
# Each rank simultaneously calls dag.add_job (num_ranks simultaneous insertions)
# measure total time for entire aprun (including all aprun/python overheads)
# db_writer is running on thetalogin6, aprun from thetamom1
# num_nodes ranks_per_node num_ranks total_time_sec
# --------------------------------------------------------------
16 32 512 62.640
16 16 256 38.550
16 8 128 25.280
16 4 64 19.750
8 32 256 45.760
8 16 128 26.060
8 8 64 18.790
8 4 32 15.000
4 32 128 34.560
4 16 64 21.290
4 8 32 16.400
4 4 16 13.780
2 32 64 28.300
2 16 32 19.300