-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlocal_train.py
More file actions
89 lines (64 loc) · 3.12 KB
/
local_train.py
File metadata and controls
89 lines (64 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
from pathlib import Path
import json
import os.path
import shutil
import socket
import torch
from tqdm import trange, tqdm
import sys
APP_PATH = Path('federatedhealth_mlm_job') / 'app'
from federatedhealth.nlp_models import XLMRobertaModel
from federatedhealth.config import load_config
def main():
parser = argparse.ArgumentParser(description="Train the local version of the XLMRoberta model for federated health")
#parser.add_argument('model_path', help="Directory to read the model from", type=Path)
parser.add_argument('--app-dir',
help="Directory where the NVFLARE app resides.",
type=Path,
default=APP_PATH)
parser.add_argument('--site',
help="The identifier for the site, will mainly be used for organizing output.",
default=socket.gethostname())
parser.add_argument('--workspace-dir',
help="Directory to save training output to",
type=Path,
default=Path("local_training"))
args = parser.parse_args()
workspace_dir = args.workspace_dir / args.site
workspace_dir.mkdir(exist_ok=True, parents=True)
server_config_path = args.app_dir / 'config' / 'config_fed_server.json'
client_config_path = args.app_dir / 'config' / 'config_fed_client.json'
# Extract server arguments
with open(server_config_path) as fp:
server_fed_config = json.load(fp)
max_epochs = server_fed_config["num_rounds"]
# Extract client arguments
with open(client_config_path) as fp:
client_fed_config = json.load(fp)
model = XLMRobertaModel()
training_config = args.app_dir / 'config' / 'train_config.json'
config = load_config(override_path=str(training_config))
training_data_path = config.data_config.training_data
dev_data_path = config.data_config.dev_data
test_data_path = config.data_config.test_data
model.initialize(workspace_dir, training_data_path, dev_data_path, test_data_path, training_override=config.training_args)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
best_perplexity = float("inf")
best_model_path = None
latest_model_path = workspace_dir / "latest_model_epoch0.pt"
torch.save(model.state_dict(), latest_model_path)
for epoch in trange(max_epochs, desc="Epoch"):
model.train()
for inner_epoch in range(model.aggregation_epochs):
for batch_data in tqdm(model.train_dataloader, desc='Batch'):
model.fit_batch(batch_data)
eval_loss, perplexity = model.local_valid()
if perplexity < best_perplexity:
best_model_path = workspace_dir / f"best_model_epoch-{epoch+1}.pt"
torch.save(model.state_dict(), best_model_path)
latest_model_path = workspace_dir / f"latest_model_epoch-{epoch+1}.pt"
torch.save(model.state_dict(), latest_model_path)
if __name__ == '__main__':
main()