Skip to content

Commit

Permalink
AutoEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
YaeSakuraL committed Nov 22, 2022
1 parent 83be4cf commit 94509a7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
10 changes: 7 additions & 3 deletions AutoEncoder/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
from Bio import SeqIO
import numpy as np
from typing import Any
import glob

class Sequence(object):
def __init__(self, path):
self.train = self.tokenize(os.path.join(path, 'virus-2014.fasta'))
self.valid = self.tokenize(os.path.join(path, 'virus-2014.fasta'))
self.test = self.tokenize(os.path.join(path, 'virus-2014.fasta'))
# train_list=[]
# train_list.append(glob.glob(r'../dataset/virus/-2014/*.py'))
self.train = self.tokenize(os.path.join(path, 'virus/-2014/virus-2014.fasta'))
self.valid = self.tokenize(os.path.join(path, 'virus/2014-2015/virus2014-2015.fasta'))
self.test = self.tokenize(os.path.join(path, 'virus/2015-/virus2015-.fasta'))

def tokenize(self, path):
"""Tokenizes a fasta file."""
print(path)
assert os.path.exists(path)
# read the fasta file
records = list(SeqIO.parse(path, "fasta"))
Expand Down
29 changes: 21 additions & 8 deletions AutoEncoder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
import torch
import torch.nn as nn
import torch.onnx
from tensorboardX import SummaryWriter

import data
import model
import numpy as np

parser = argparse.ArgumentParser(description='PyTorch Transformer Model')
parser.add_argument('--data', type=str, default='./data/',
parser.add_argument('--data', type=str, default='../dataset',
help='location of the data sequence')
parser.add_argument('--model', type=str, default='Transformer',
help='type of network (RNN_TANH, RNN_RELU, LSTM, GRU, Transformer)')
parser.add_argument('--emsize', type=int, default=200,
help='size of word embeddings')
parser.add_argument('--nhid', type=int, default=200,
help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=2,
parser.add_argument('--nlayers', type=int, default=3,
help='number of layers')
parser.add_argument('--lr', type=float, default=20,
parser.add_argument('--lr', type=float, default=1e-5,
help='initial learning rate')
parser.add_argument('--clip', type=float, default=0.25,
help='gradient clipping')
Expand All @@ -42,7 +43,7 @@
help='use CUDA')
parser.add_argument('--mps', action='store_true', default=False,
help='enables macOS GPU training')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
help='report interval')
parser.add_argument('--save', type=str, default='model.pt',
help='path to save the final model')
Expand Down Expand Up @@ -105,6 +106,10 @@
train_data = sequence.train
val_data = sequence.valid
test_data = sequence.test
train_data=np.concatenate((train_data,val_data[0:100]))
val_data=val_data[101:-1]
train_data=np.concatenate((train_data,test_data[0:500]))
test_data=test_data[501:-1]
###############################################################################
# Build the model
###############################################################################
Expand Down Expand Up @@ -144,7 +149,7 @@ def repackage_hidden(h):
def get_batch(source, i):
seq_len = min(args.bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len]
target = source[i:i+seq_len]
return data.to(device), target.to(device)

def evaluate(data_source):
Expand All @@ -155,15 +160,15 @@ def evaluate(data_source):
total_length = np.sum([len(x) for x in data_source])
with torch.no_grad():
for i in range(0, data_source.size):
for j in range(j,data_source[i].shape[0]-1, args.bptt):
for j in range(0,data_source[i].shape[0]-1, args.bptt):
data, targets = get_batch(data_source[i], j)
output = model(data)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / total_length


def train():
def train(writer):
# Turn on training mode which enables dropout.
model.train()
total_loss = 0.
Expand All @@ -181,6 +186,7 @@ def train():

loss = criterion(output, targets)
loss.backward()


# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
Expand All @@ -197,9 +203,13 @@ def train():
epoch, batch, len(train_data), lr,
elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
total_loss = 0
global global_iteration
global_iteration+=1
writer.add_scalar("loss/train", loss.item(),global_iteration)
start_time = time.time()
if args.dry_run:
break



def export_onnx(path, batch_size, seq_len):
Expand All @@ -211,20 +221,23 @@ def export_onnx(path, batch_size, seq_len):


# Loop over epochs.
global_iteration=0
lr = args.lr
best_val_loss = None
writer = SummaryWriter(logdir="../checkpoints/lr_1e-5_5000bp_3layers_test")

# At any point you can hit Ctrl + C to break out of training early.
try:
for epoch in range(1, args.epochs+1):
epoch_start_time = time.time()
train()
train(writer)
val_loss = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
writer.add_scalar("loss/epoch", val_loss,epoch)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
with open(args.save, 'wb') as f:
Expand Down
4 changes: 3 additions & 1 deletion AutoEncoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,6 @@ def forward(self, src, has_mask=True):
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.decoder(output)
return F.log_softmax(output, dim=-2)
output = output.view(output.shape[0],4)
#output = F.log_softmax(output, dim=-2)
return F.softmax(output)

0 comments on commit 94509a7

Please sign in to comment.