Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# dna2vec

## This is a modified version of this for specific use withing Trace-Genomics.

[![Build Status](https://travis-ci.org/pnpnpn/dna2vec.svg?branch=master)](https://travis-ci.org/pnpnpn/dna2vec)

**Dna2vec** is an open-source library to train distributed representations
Expand Down
5 changes: 2 additions & 3 deletions attic_util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ def estimate_bytes(filenames):
return sum([os.stat(f).st_size for f in filenames])

def get_output_fileroot(dirpath, name, postfix):
return '{}/{}-{}-{}-{}'.format(
dirpath,
name,
return '{}-{}-{}-{}'.format(
os.path.join(dirpath, name),
arrow.utcnow().format('YYYYMMDD-HHmm'),
postfix,
random_str(3))
7 changes: 7 additions & 0 deletions configs/grinder_subset_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
inputs: ../ncbi_ground_reads_subset/*.fa
k-low: 3
k-high: 8
vec-dim: 100
epoch: 10
context: 10
out-dir: results/
10 changes: 7 additions & 3 deletions dna2vec/multi_k_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
import numpy as np

from gensim.models import word2vec
from gensim.models import word2vec, KeyedVectors
from gensim import matutils

class SingleKModel:
Expand All @@ -14,7 +14,7 @@ def __init__(self, model):

class MultiKModel:
def __init__(self, filepath):
self.aggregate = word2vec.Word2Vec.load_word2vec_format(filepath, binary=False)
self.aggregate = KeyedVectors.load_word2vec_format(filepath, binary=False)
self.logger = logbook.Logger(self.__class__.__name__)

vocab_lens = [len(vocab) for vocab in self.aggregate.vocab.keys()]
Expand All @@ -35,6 +35,10 @@ def model(self, k_len):
def vector(self, vocab):
return self.data[len(vocab)].model[vocab]

def most_similar(self, vocab, topn=10):
# Note this only works for returning k-mers of the same length.
return self.data[len(vocab)].model.most_similar(vocab, topn=topn)

def unitvec(self, vec):
return matutils.unitvec(vec)

Expand All @@ -56,4 +60,4 @@ def separate_out_model(self, k_len):
vec_str = ' '.join("%f" % val for val in self.aggregate[vocab])
print('{} {}'.format(vocab, vec_str), file=fptr)
fptr.flush()
return SingleKModel(word2vec.Word2Vec.load_word2vec_format(fptr.name, binary=False))
return SingleKModel(KeyedVectors.load_word2vec_format(fptr.name, binary=False))
67 changes: 48 additions & 19 deletions scripts/train_dna2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,59 @@
from dna2vec.generators import DisjointKmerFragmenter, SlidingKmerFragmenter

from gensim.models import word2vec
from gensim.models.callbacks import CallbackAny2Vec

class InvalidArgException(Exception):
pass

class EpochSaver(CallbackAny2Vec):
'''Callback to save model after each epoch.
Example taken from here: https://radimrehurek.com/gensim/models/callbacks.html#gensim.models.callbacks.CallbackAny2Vec
'''
...
def __init__(self, path_prefix):
self.path_prefix = path_prefix
self.epoch = 0

def on_epoch_end(self, model):
output_path = '{}_epoch{}.w2v'.format(self.path_prefix, self.epoch)
model.wv.save_word2vec_format(output_path, binary=False)
self.epoch += 1

class Learner:
def __init__(self, out_fileroot, context_halfsize, gensim_iters, vec_dim):
def __init__(self, out_fileroot,
context_halfsize,
epochs,
vec_dim,
workers,
epoch_saver
):
self.logger = logbook.Logger(self.__class__.__name__)
assert(word2vec.FAST_VERSION >= 0)
self.logger.info('word2vec.FAST_VERSION (should be >= 0): {}'.format(word2vec.FAST_VERSION))
self.model = None
self.out_fileroot = out_fileroot
self.context_halfsize = context_halfsize
self.gensim_iters = gensim_iters
self.epochs = epochs
self.use_skipgram = 1
self.vec_dim = vec_dim
self.epoch_saver = epoch_saver
self.workers = workers

self.logger.info('Context window half size: {}'.format(self.context_halfsize))
self.logger.info('Use skipgram: {}'.format(self.use_skipgram))
self.logger.info('gensim_iters: {}'.format(self.gensim_iters))
self.logger.info('epochs: {}'.format(self.epochs))
self.logger.info('vec_dim: {}'.format(self.vec_dim))
self.logger.info('workers: {}'.format(self.workers))

def train(self, kmer_seq_generator):
self.model = word2vec.Word2Vec(
sentences=kmer_seq_generator,
size=self.vec_dim,
window=self.context_halfsize,
min_count=5,
workers=4,
workers=self.workers,
sg=self.use_skipgram,
iter=self.gensim_iters)

# self.logger.info(model.vocab)
iter=self.epochs,
callbacks=[self.epoch_saver])

def write_vec(self):
out_filename = '{}.w2v'.format(self.out_fileroot)
Expand All @@ -69,7 +90,7 @@ def run_main(args, inputs, out_fileroot):
elif args.kmer_fragmenter == 'sliding':
kmer_fragmenter = SlidingKmerFragmenter(args.k_low, args.k_high)
else:
raise InvalidArgException('Invalid kmer fragmenter: {}'.format(args.kmer_fragmenter))
raise ValueError('Invalid kmer fragmenter: {}'.format(args.kmer_fragmenter))

logbook.info('kmer fragmenter: {}'.format(args.kmer_fragmenter))

Expand All @@ -82,8 +103,15 @@ def run_main(args, inputs, out_fileroot):
kmer_fragmenter,
histogram,
)

learner = Learner(out_fileroot, args.context, args.gensim_iters, args.vec_dim)
# This is the callback object that will save the model after each epoch (in theory).
epoch_saver = EpochSaver(out_fileroot)
# This is the model.
learner = Learner(out_fileroot,
args.context,
args.epochs,
args.vec_dim,
args.workers,
epoch_saver)
learner.train(kmer_seq_iterable)
learner.write_vec()

Expand All @@ -103,8 +131,8 @@ def main():
argp.add_argument('--k-high', help='k-mer end range (inclusive)', type=int, default=5)
argp.add_argument('--context', help='half size of context window (the total size is 2*c+1)', type=int, default=4)
argp.add_argument('--epochs', help='number of epochs', type=int, default=1)
argp.add_argument('--gensim-iters', help="gensim's internal iterations", type=int, default=1)
argp.add_argument('--out-dir', help="output directory", default='../dataset/dna2vec/results')
argp.add_argument('--workers', help='number of workers', type=int, default=4)
argp.add_argument('--out-dir', help="output directory", default='.')
argp.add_argument('--debug', help='', action='store_true')
args = argp.parse_args()

Expand Down Expand Up @@ -132,11 +160,12 @@ def main():
args.kmer_fragmenter))

out_txt_filename = '{}.txt'.format(out_fileroot)
with open(out_txt_filename, 'w') as summary_fptr:
with Tee(summary_fptr):
logbook.StreamHandler(sys.stdout, level=log_level).push_application()
redirect_logging()
run_main(args, inputs, out_fileroot)
print(out_txt_filename)
# with open(out_txt_filename, 'w') as summary_fptr:
# with Tee(summary_fptr):
logbook.TimedRotatingFileHandler(out_txt_filename, level=log_level).push_application()
redirect_logging()
run_main(args, inputs, out_fileroot)

if __name__ == '__main__':
main()