-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
61 lines (45 loc) · 1.45 KB
/
utils.py
File metadata and controls
61 lines (45 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import random
import numpy as np
import torch
import importlib
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def lerp(a, b, t):
return a + t * (b - a)
def size_of_model(model: torch.nn.Module):
return sum(p.numel() for p in model.parameters())
def seed_all(seed):
"""
provide the seed for reproducibility
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Compute the average dictionary
def compute_average_dict(dicts):
# Initialize a dictionary to hold sums and counts
average = {}
count = len(dicts)
# Iterate through all dictionaries and accumulate values
for d in dicts:
for key, value in d.items():
average[key] = average.get(key, 0) + value / count
# Compute the average for each key
return average
def loss_logging(writer, loss_dict, step, prefix="loss", do_print=False):
"""
Log the losses to TensorBoard
"""
for key, value in loss_dict.items():
if isinstance(value, torch.Tensor):
value = value.item()
writer.add_scalar(f"{prefix}/{key}", value, step)
if do_print:
print(f"{key}: {value:.6f}")