Commit 0541d3ae authored by Murali Emani's avatar Murali Emani
Browse files

adding Uno codes

parent f70af482
Data/
models/
# UnoMT in Pytorch
Multi-tasking (drug response, cell line classification, etc.) Uno Implemented in PyTorch.
https://github.com/xduan7/UnoPytorch
## Todos
* More labels for the network like drug labels;
* Dataloader hanging problem when num_workers set to more than 0;
* Better pre-processing for drug descriptor integer features;
* Network regularization with weight decay and/or dropout;
* Hyper-parameter searching;
## Prerequisites
```
Python 3.6.4
PyTorch 0.4.1
SciPy 1.1.0
pandas 0.23.4
Scikit-Learn 0.19.1
urllib3 1.23
joblib 0.12.2
```
The default network structure is shown below:
<img src="./images/default_network.jpg" width="100%">
An example of the program output for training on NCI60 and valdiation on all other data sources is shown below:
```
python unoMT_baseline_pytorch.py --resp_val_start_epoch 2 --epochs 5
Importing candle utils for pytorch
Created unoMT benchmark
Configuration file: ./unoMT_default_model.txt
{'autoencoder_init': True,
'cl_clf_layer_dim': 256,
'cl_clf_lr': 0.008,
'cl_clf_num_layers': 2,
'cl_clf_opt': 'SGD',
'disjoint_cells': True,
'disjoint_drugs': False,
'drop': 0.1,
'drug_feature_usage': 'both',
'drug_latent_dim': 1024,
'drug_layer_dim': 4096,
'drug_num_layers': 2,
'drug_qed_activation': 'sigmoid',
'drug_qed_layer_dim': 1024,
'drug_qed_loss_func': 'mse',
'drug_qed_lr': 0.01,
'drug_qed_num_layers': 2,
'drug_qed_opt': 'SGD',
'drug_target_layer_dim': 1024,
'drug_target_lr': 0.002,
'drug_target_num_layers': 2,
'drug_target_opt': 'SGD',
'dscptr_nan_threshold': 0.0,
'dscptr_scaling': 'std',
'early_stop_patience': 5,
'epochs': 1000,
'gene_latent_dim': 512,
'gene_layer_dim': 1024,
'gene_num_layers': 2,
'grth_scaling': 'none',
'l2_regularization': 1e-05,
'lr_decay_factor': 0.98,
'max_num_batches': 1000,
'qed_scaling': 'none',
'resp_activation': 'none',
'resp_layer_dim': 2048,
'resp_loss_func': 'mse',
'resp_lr': 1e-05,
'resp_num_blocks': 4,
'resp_num_layers': 2,
'resp_num_layers_per_block': 2,
'resp_opt': 'SGD',
'resp_val_start_epoch': 0,
'rnaseq_feature_usage': 'combat',
'rnaseq_scaling': 'std',
'rng_seed': 0,
'save_path': 'save/unoMT',
'solr_root': '',
'timeout': 3600,
'train_sources': 'NCI60',
'trn_batch_size': 32,
'val_batch_size': 256,
'val_sources': ['NCI60', 'CTRP', 'GDSC', 'CCLE', 'gCSI'],
'val_split': 0.2}
Params:
{'autoencoder_init': True,
'cl_clf_layer_dim': 256,
'cl_clf_lr': 0.008,
'cl_clf_num_layers': 2,
'cl_clf_opt': 'SGD',
'datatype': <class 'numpy.float32'>,
'disjoint_cells': True,
'disjoint_drugs': False,
'drop': 0.1,
'drug_feature_usage': 'both',
'drug_latent_dim': 1024,
'drug_layer_dim': 4096,
'drug_num_layers': 2,
'drug_qed_activation': 'sigmoid',
'drug_qed_layer_dim': 1024,
'drug_qed_loss_func': 'mse',
'drug_qed_lr': 0.01,
'drug_qed_num_layers': 2,
'drug_qed_opt': 'SGD',
'drug_target_layer_dim': 1024,
'drug_target_lr': 0.002,
'drug_target_num_layers': 2,
'drug_target_opt': 'SGD',
'dscptr_nan_threshold': 0.0,
'dscptr_scaling': 'std',
'early_stop_patience': 5,
'epochs': 5,
'experiment_id': 'EXP000',
'gene_latent_dim': 512,
'gene_layer_dim': 1024,
'gene_num_layers': 2,
'gpus': [],
'grth_scaling': 'none',
'l2_regularization': 1e-05,
'logfile': None,
'lr_decay_factor': 0.98,
'max_num_batches': 1000,
'multi_gpu': False,
'no_cuda': False,
'output_dir': '/home/jamal/Code/ECP/CANDLE/Benchmarks/Pilot1/UnoMT/Output/EXP000/RUN000',
'qed_scaling': 'none',
'resp_activation': 'none',
'resp_layer_dim': 2048,
'resp_loss_func': 'mse',
'resp_lr': 1e-05,
'resp_num_blocks': 4,
'resp_num_layers': 2,
'resp_num_layers_per_block': 2,
'resp_opt': 'SGD',
'resp_val_start_epoch': 2,
'rnaseq_feature_usage': 'combat',
'rnaseq_scaling': 'std',
'rng_seed': 0,
'run_id': 'RUN000',
'save_path': 'save/unoMT',
'shuffle': False,
'solr_root': '',
'timeout': 3600,
'train_bool': True,
'train_sources': 'NCI60',
'trn_batch_size': 32,
'val_batch_size': 256,
'val_sources': ['NCI60', 'CTRP', 'GDSC', 'CCLE', 'gCSI'],
'val_split': 0.2,
'verbose': None}
Parameters initialized
Failed to split NCI60 cells in stratified way. Splitting randomly ...
Failed to split NCI60 cells in stratified way. Splitting randomly ...
Failed to split CCLE cells in stratified way. Splitting randomly ...
Failed to split CCLE drugs stratified on growth and correlation. Splitting solely on avg growth ...
Failed to split gCSI drugs stratified on growth and correlation. Splitting solely on avg growth ...
RespNet(
(_RespNet__gene_encoder): Sequential(
(dense_0): Linear(in_features=942, out_features=1024, bias=True)
(relu_0): ReLU()
(dense_1): Linear(in_features=1024, out_features=1024, bias=True)
(relu_1): ReLU()
(dense_2): Linear(in_features=1024, out_features=512, bias=True)
)
(_RespNet__drug_encoder): Sequential(
(dense_0): Linear(in_features=4688, out_features=4096, bias=True)
(relu_0): ReLU()
(dense_1): Linear(in_features=4096, out_features=4096, bias=True)
(relu_1): ReLU()
(dense_2): Linear(in_features=4096, out_features=1024, bias=True)
)
(_RespNet__resp_net): Sequential(
(dense_0): Linear(in_features=1537, out_features=2048, bias=True)
(activation_0): ReLU()
(residual_block_0): ResBlock(
(block): Sequential(
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
(res_dropout_0): Dropout(p=0.1)
(res_relu_0): ReLU()
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
(res_dropout_1): Dropout(p=0.1)
)
(activation): ReLU()
)
(residual_block_1): ResBlock(
(block): Sequential(
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
(res_dropout_0): Dropout(p=0.1)
(res_relu_0): ReLU()
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
(res_dropout_1): Dropout(p=0.1)
)
(activation): ReLU()
)
(residual_block_2): ResBlock(
(block): Sequential(
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
(res_dropout_0): Dropout(p=0.1)
(res_relu_0): ReLU()
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
(res_dropout_1): Dropout(p=0.1)
)
(activation): ReLU()
)
(residual_block_3): ResBlock(
(block): Sequential(
(res_dense_0): Linear(in_features=2048, out_features=2048, bias=True)
(res_dropout_0): Dropout(p=0.1)
(res_relu_0): ReLU()
(res_dense_1): Linear(in_features=2048, out_features=2048, bias=True)
(res_dropout_1): Dropout(p=0.1)
)
(activation): ReLU()
)
(dense_1): Linear(in_features=2048, out_features=2048, bias=True)
(dropout_1): Dropout(p=0.1)
(res_relu_1): ReLU()
(dense_2): Linear(in_features=2048, out_features=2048, bias=True)
(dropout_2): Dropout(p=0.1)
(res_relu_2): ReLU()
(dense_out): Linear(in_features=2048, out_features=1, bias=True)
)
)
Data sizes:
Train:
Data set: NCI60 Size: 882873
Validation:
Data set: NCI60 Size: 260286
Data set: CTRP Size: 1040021
Data set: GDSC Size: 235812
Data set: CCLE Size: 17510
Data set: gCSI Size: 10323
================================================================================
Training Epoch 1:
Drug Weighted QED Regression Loss: 0.022274
Drug Response Regression Loss: 1881.89
Epoch Running Time: 13.2 Seconds.
================================================================================
Training Epoch 2:
Drug Weighted QED Regression Loss: 0.019416
Drug Response Regression Loss: 1348.13
Epoch Running Time: 12.9 Seconds.
================================================================================
Training Epoch 3:
Drug Weighted QED Regression Loss: 0.015868
Drug Response Regression Loss: 1123.27
Cell Line Classification:
Category Accuracy: 99.01%;
Site Accuracy: 94.11%;
Type Accuracy: 94.18%
Drug Target Family Classification Accuracy: 44.44%
Drug Weighted QED Regression
MSE: 0.018845 MAE: 0.111807 R2: +0.45
Drug Response Regression:
NCI60 MSE: 973.04 MAE: 22.18 R2: +0.69
CTRP MSE: 2404.64 MAE: 34.04 R2: +0.32
GDSC MSE: 2717.81 MAE: 36.53 R2: +0.19
CCLE MSE: 2518.47 MAE: 36.60 R2: +0.38
gCSI MSE: 2752.33 MAE: 36.97 R2: +0.35
Epoch Running Time: 54.6 Seconds.
================================================================================
Training Epoch 4:
Drug Weighted QED Regression Loss: 0.014096
Drug Response Regression Loss: 933.27
Cell Line Classification:
Category Accuracy: 99.34%;
Site Accuracy: 96.12%;
Type Accuracy: 96.18%
Drug Target Family Classification Accuracy: 44.44%
Drug Weighted QED Regression
MSE: 0.018467 MAE: 0.110287 R2: +0.46
Drug Response Regression:
NCI60 MSE: 844.51 MAE: 20.41 R2: +0.73
CTRP MSE: 2314.19 MAE: 33.76 R2: +0.35
GDSC MSE: 2747.73 MAE: 36.65 R2: +0.18
CCLE MSE: 2482.03 MAE: 35.89 R2: +0.39
gCSI MSE: 2665.35 MAE: 36.27 R2: +0.37
Epoch Running Time: 54.9 Seconds.
================================================================================
Training Epoch 5:
Drug Weighted QED Regression Loss: 0.013514
Drug Response Regression Loss: 846.06
Cell Line Classification:
Category Accuracy: 99.38%;
Site Accuracy: 95.89%;
Type Accuracy: 95.30%
Drug Target Family Classification Accuracy: 44.44%
Drug Weighted QED Regression
MSE: 0.017026 MAE: 0.106697 R2: +0.50
Drug Response Regression:
NCI60 MSE: 835.82 MAE: 21.33 R2: +0.74
CTRP MSE: 2653.04 MAE: 37.98 R2: +0.25
GDSC MSE: 2892.86 MAE: 39.76 R2: +0.13
CCLE MSE: 2412.75 MAE: 36.82 R2: +0.41
gCSI MSE: 2888.99 MAE: 38.70 R2: +0.32
Epoch Running Time: 55.5 Seconds.
Program Running Time: 191.1 Seconds.
================================================================================
Overall Validation Results:
Best Results from Different Models (Epochs):
Cell Line Categories Best Accuracy: 99.375% (Epoch = 5)
Cell Line Sites Best Accuracy: 96.118% (Epoch = 4)
Cell Line Types Best Accuracy: 96.184% (Epoch = 4)
Drug Target Family Best Accuracy: 44.444% (Epoch = 3)
Drug Weighted QED Best R2 Score: +0.5034 (Epoch = 5, MSE = 0.017026, MAE = 0.106697)
NCI60 Best R2 Score: +0.7369 (Epoch = 5, MSE = 835.82, MAE = 21.33)
CTRP Best R2 Score: +0.3469 (Epoch = 4, MSE = 2314.19, MAE = 33.76)
GDSC Best R2 Score: +0.1852 (Epoch = 3, MSE = 2717.81, MAE = 36.53)
CCLE Best R2 Score: +0.4094 (Epoch = 5, MSE = 2412.75, MAE = 36.82)
gCSI Best R2 Score: +0.3693 (Epoch = 4, MSE = 2665.35, MAE = 36.27)
Best Results from the Same Model (Epoch = 5):
Cell Line Categories Accuracy: 99.375%
Cell Line Sites Accuracy: 95.888%
Cell Line Types Accuracy: 95.296%
Drug Target Family Accuracy: 44.444%
Drug Weighted QED R2 Score: +0.5034 (MSE = 0.017026, MAE = 0.106697)
NCI60 R2 Score: +0.7369 (MSE = 835.82, MAE = 21.33)
CTRP R2 Score: +0.2513 (MSE = 2653.04, MAE = 37.98)
GDSC R2 Score: +0.1327 (MSE = 2892.86, MAE = 39.76)
CCLE R2 Score: +0.4094 (MSE = 2412.75, MAE = 36.82)
gCSI R2 Score: +0.3164 (MSE = 2888.99, MAE = 38.70)
```
For default hyper parameters, the transfer learning matrix results are shown below:
<p align="center">
<img src="./images/default_results.jpg" width="80%">
</p>
Note that the green cells represents R2 score of higher than 0.1, red cells are R2 scores lower than -0.1 and yellows are for all the values in between.
from __future__ import absolute_import
#__version__ = '0.0.0'
#import from data_utils
from data_utils import load_csv_data
from data_utils import load_Xy_one_hot_data2
from data_utils import load_Xy_data_noheader
#import from file_utils
from file_utils import get_file
#import from default_utils
from default_utils import ArgumentStruct
from default_utils import Benchmark
from default_utils import str2bool
from default_utils import initialize_parameters
from default_utils import fetch_file
from default_utils import verify_path
from default_utils import keras_default_config
from default_utils import set_up_logger
from generic_utils import Progbar
# import from viz_utils
from viz_utils import plot_history
from viz_utils import plot_scatter
# import benchmark-dependent utils
import sys
if 'torch' in sys.modules:
print ('Importing candle utils for pytorch')
from pytorch_utils import set_seed
from pytorch_utils import build_optimizer
from pytorch_utils import build_activation
from pytorch_utils import get_function
from pytorch_utils import initialize
from pytorch_utils import xent
from pytorch_utils import mse
from pytorch_utils import set_parallelism_threads # for compatibility
else:
raise Exception('No backend has been specified.')
This diff is collapsed.
This diff is collapsed.
from __future__ import absolute_import
from __future__ import print_function
import tarfile
import os
import sys
import shutil
import hashlib
from six.moves.urllib.request import urlopen
from six.moves.urllib.error import URLError, HTTPError
from generic_utils import Progbar
# Under Python 2, 'urlretrieve' relies on FancyURLopener from legacy
# urllib module, known to have issues with proxy management
if sys.version_info[0] == 2:
def urlretrieve(url, filename, reporthook=None, data=None):
def chunk_read(response, chunk_size=8192, reporthook=None):
total_size = response.info().get('Content-Length').strip()
total_size = int(total_size)
count = 0
while 1:
chunk = response.read(chunk_size)
count += 1
if not chunk:
reporthook(count, total_size, total_size)
break
if reporthook:
reporthook(count, chunk_size, total_size)
yield chunk
response = urlopen(url, data)
with open(filename, 'wb') as fd:
for chunk in chunk_read(response, reporthook=reporthook):
fd.write(chunk)
else:
from six.moves.urllib.request import urlretrieve
def get_file(fname, origin, untar=False,
md5_hash=None, cache_subdir='common'):
""" Downloads a file from a URL if it not already in the cache.
Passing the MD5 hash will verify the file after download as well
as if it is already present in the cache.
Parameters
----------
fname : string
name of the file
origin : string
original URL of the file
untar : boolean
whether the file should be decompressed
md5_hash : string
MD5 hash of the file for verification
cache_subdir : string
directory being used as the cache
Returns
----------
Path to the downloaded file
"""
file_path = os.path.dirname(os.path.realpath(__file__))
datadir_base = os.path.expanduser(os.path.join(file_path, '..', 'Data'))
datadir = os.path.join(datadir_base, cache_subdir)
if not os.path.exists(datadir):
os.makedirs(datadir)
#if untar:
# fnamesplit = fname.split('.tar.gz')
# untar_fpath = os.path.join(datadir, fnamesplit[0])
if fname.endswith('.tar.gz'):
fnamesplit = fname.split('.tar.gz')
untar_fpath = os.path.join(datadir, fnamesplit[0])
untar = True
elif fname.endswith('.tgz'):
fnamesplit = fname.split('.tgz')
untar_fpath = os.path.join(datadir, fnamesplit[0])
untar = True
fpath = os.path.join(datadir, fname)
download = False
if os.path.exists(fpath):
# file found; verify integrity if a hash was provided
if md5_hash is not None:
if not validate_file(fpath, md5_hash):
print('A local file was found, but it seems to be '
'incomplete or outdated.')
download = True
else:
download = True
if download:
print('Downloading data from', origin)
global progbar
progbar = None
def dl_progress(count, block_size, total_size):
global progbar
if progbar is None:
progbar = Progbar(total_size)
else:
progbar.update(count * block_size)
error_msg = 'URL fetch failure on {}: {} -- {}'
try:
try:
urlretrieve(origin, fpath, dl_progress)
except URLError as e:
raise Exception(error_msg.format(origin, e.errno, e.reason))
except HTTPError as e:
raise Exception(error_msg.format(origin, e.code, e.msg))
except (Exception, KeyboardInterrupt) as e:
if os.path.exists(fpath):
os.remove(fpath)
raise
progbar = None
print()
if untar:
if not os.path.exists(untar_fpath):
print('Untarring file...')
tfile = tarfile.open(fpath, 'r:gz')
try:
tfile.extractall(path=datadir)
except (Exception, KeyboardInterrupt) as e:
if os.path.exists(untar_fpath):
if os.path.isfile(untar_fpath):
os.remove(untar_fpath)
else:
shutil.rmtree(untar_fpath)
raise
tfile.close()
return untar_fpath
print()
return fpath
def validate_file(fpath, md5_hash):
""" Validates a file against a MD5 hash
Parameters
----------
fpath : string
path to the file being validated
md5_hash : string
the MD5 hash being validated against
Returns
----------
boolean
Whether the file is valid
"""
hasher = hashlib.md5()
with open(fpath, 'rb') as f:
buf = f.read()
hasher.update(buf)
if str(hasher.hexdigest()) == str(md5_hash):
return True
else:
return False
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
import time
import sys
import os
import six
import marshal
import types as python_types
import logging
def get_from_module(identifier, module_params, module_name,
instantiate=False, kwargs=None):
if isinstance(identifier, six.string_types):
res = module_params.get(identifier)
if not res:
raise Exception('Invalid ' + str(module_name) + ': ' +
str(identifier))
if instantiate and not kwargs:
return res()
elif instantiate and kwargs:
return res(**kwargs)
else:
return res
elif type(identifier) is dict:
name = identifier.pop('name')
res = module_params.get(name)
if res:
return res(**identifier)
else:
raise Exception('Invalid ' + str(module_name) + ': ' +
str(identifier))
return identifier
def make_tuple(*args):
return args
def func_dump(func):
""" Serialize user defined function. """
code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
defaults = func.__defaults__
if func.__closure__:
closure = tuple(c.cell_contents for c in func.__closure__)
else:
closure = None
return code, defaults, closure
def func_load(code, defaults=None, closure=None, globs=None):
""" Deserialize user defined function. """
if isinstance(code, (tuple, list)): # unpack previous dump
code, defaults, closure = code
code = marshal.loads(code.encode('raw_unicode_escape'))
if closure is not None:
closure = func_reconstruct_closure(closure)
if globs is None:
globs = globals()
return python_types.FunctionType(code, globs, name=code.co_name, argdefs=defaults, closure=closure)
def func_reconstruct_closure(values):
""" Deserialization helper that reconstructs a closure. """
nums = range(len(values))
src = ["def func(arg):"]
src += [" _%d = arg[%d]" % (n, n) for n in nums]
src += [" return lambda:(%s)" % ','.join(["_%d" % n for n in nums]), ""]
src = '\n'.join(src)
try:
exec(src, globals())
except:
raise SyntaxError(src)
return func(values).__closure__
class Progbar(object):
def __init__(self, target, width=30, verbose=1, interval=0.01):
"""