-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
103 lines (82 loc) · 2.85 KB
/
main.py
File metadata and controls
103 lines (82 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import json
import multiprocessing
import torch
import torch.nn as nn
import hparams
import dataset
import lightning as pl
from models.LangID import LangID
from torch.utils.data import DataLoader
hp = hparams.get_hparams()
num_frames = hp.clip_duration * hp.sample_rate
min_num_frames = hp.min_clip_duration * hp.sample_rate
class LitLangID(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = LangID(num_lang=hp.num_lang, sample_rate=hp.sample_rate)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log("train_loss", loss, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
correct = (y_hat.argmax(dim=1) == y).type(torch.float).sum().item()
self.log("val_loss", loss, on_step=False, on_epoch=True)
self.log("val_acc", correct / y.size(0), on_step=False, on_epoch=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=hp.lr)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=1, gamma=hp.lr_decay
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
torch.set_float32_matmul_precision(hp.float32_matmul_precision)
ds = dataset.load_dataset(
path=hp.data_path,
sample_rate=hp.sample_rate,
test_size=hp.test_size,
num_frames=num_frames,
min_num_frames=min_num_frames,
)
cpu_count = multiprocessing.cpu_count()
train_loader = DataLoader(
ds["train"], batch_size=hp.batch_size, num_workers=cpu_count, shuffle=True
)
test_loader = DataLoader(ds["test"], batch_size=hp.batch_size, num_workers=cpu_count)
model = LitLangID()
checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=1,
)
trainer = pl.Trainer(
max_epochs=hp.epochs,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
precision = hp.precision
)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)
best_model_path = checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")
best_model = LitLangID.load_from_checkpoint(best_model_path).model
onnx_model = nn.Sequential(best_model,nn.Softmax(dim=1))
onnx_model.to(device="cpu")
onnx_model.eval()
dummy_input = torch.randn(1, num_frames)
onnx_program = torch.onnx.export(onnx_model,
dummy_input,
os.path.join(hp.save_path, "lang_id.onnx"),
export_params=True,
opset_version=20,
verify=True
)
with open(os.path.join(hp.save_path, "lang_id.json"), "w") as f:
json.dump(ds["lang_id"], f)