-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
135 lines (109 loc) · 3.99 KB
/
utils.py
File metadata and controls
135 lines (109 loc) · 3.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
# -*- coding: utf-8 -*
import os
import pickle
import re
from typing import Tuple
import flax
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_datasets as tfds
from flax import serialization
from jaxtyping import AbstractDtype, Array, Float32, jaxtyped
from ml_collections import config_dict
from typeguard import typechecked as typechecker
RGB_DATASETS = ["cifar10", "cifar100", "imagenet", "imagenet_lt"]
MODELS = ["Custom", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ViT"]
FrozenDict = config_dict.FrozenConfigDict
class UInt8orFP32(AbstractDtype):
dtypes = ["uint8", "float32"]
def get_data(dataset: str, split: str) -> Tuple[np.ndarray, np.ndarray]:
tf_split = get_tf_split(split)
images, labels = tfds.as_numpy(
tfds.load(
dataset,
split=tf_split,
batch_size=-1,
as_supervised=True,
)
)
images = jnp.asarray(images)
labels = jax.nn.one_hot(x=labels, num_classes=np.unique(labels).shape[0])
return (images, labels)
def get_tf_split(split: str) -> str:
if split == "train":
tf_split = "train[:80%]"
elif split == "val":
tf_split = "train[80%:]"
else:
tf_split = split
return tf_split
def get_data_statistics(
dataset: str,
) -> Tuple[Float32[Array, "3"], Float32[Array, "3"]]:
"""Get means and stds of CIFAR-10, CIFAR-100, or the ImageNet training data."""
if dataset == "cifar10":
means = jnp.array([0.4914, 0.4822, 0.4465], dtype=jnp.float32)
stds = jnp.array([0.2023, 0.1994, 0.2010], dtype=jnp.float32)
elif dataset == "cifar100":
means = jnp.array([0.5071, 0.4865, 0.44092], dtype=jnp.float32)
stds = jnp.array([0.2673, 0.2564, 0.2761], dtype=jnp.float32)
elif dataset == "imagenet":
means = jnp.array([0.485, 0.456, 0.406], dtype=jnp.float32)
stds = jnp.array([0.229, 0.224, 0.225], dtype=jnp.float32)
else:
raise Exception(f"\nDataset statistics for {dataset} are not available.\n")
return means, stds
@jaxtyped
@typechecker
def normalize_images(
images: UInt8orFP32[Array, "#batchk h w c"],
data_config: FrozenDict,
) -> UInt8orFP32[Array, "#batchk h w c"]:
images = images / data_config.max_pixel_value
images -= data_config.means
images /= data_config.stds
return images
def load_metrics(metric_path):
"""Load pretrained parameters into memory."""
binary = find_binaries(metric_path)
metrics = pickle.loads(open(os.path.join(metric_path, binary), "rb").read())
return metrics
def save_params(out_path, params, epoch):
"""Encode parameters of network as bytes and save as binary file."""
if not os.path.exists(out_path):
os.makedirs(out_path, exist_ok=True)
bytes_output = serialization.to_bytes(params)
with open(os.path.join(out_path, f"pretrained_params_{epoch}.pkl"), "wb") as f:
pickle.dump(bytes_output, f)
def save_opt_state(out_path, opt_state, epoch):
"""Encode parameters of network as bytes and save as binary file."""
if not os.path.exists(out_path):
os.makedirs(out_path, exist_ok=True)
bytes_output = serialization.to_bytes(opt_state)
with open(os.path.join(out_path, f"opt_state_{epoch}.pkl"), "wb") as f:
pickle.dump(bytes_output, f)
def find_binaries(param_path):
"""Search for last checkpoint."""
param_binaries = sorted(
[
f
for _, _, files in os.walk(param_path)
for f in files
if re.search(r"(?=.*\d+)(?=.*pkl$)", f)
]
)
return param_binaries.pop()
def merge_params(pretrained_params, current_params):
return flax.core.FrozenDict(
{"encoder": pretrained_params["encoder"], "clf": current_params["clf"]}
)
def get_subset(y, hist):
subset = []
for k, freq in enumerate(hist):
subset.extend(
np.random.choice(np.where(y == k)[0], size=freq, replace=False).tolist()
)
subset = np.random.permutation(subset)
return subset