-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathclassification.py
More file actions
51 lines (40 loc) · 1.21 KB
/
classification.py
File metadata and controls
51 lines (40 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import equimo.models as em
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import tensorflow as tf
from loguru import logger
from hackathon import classification
from hackathon.config import CONFIG
from hackathon.utils import (
cast_floating_to,
)
# Ensure TF doesn't consume all GPU memory
tf.config.experimental.set_visible_devices([], "GPU")
def main():
key = jr.PRNGKey(CONFIG["seed"])
# Initialize student model
logger.info("Initializing student...")
key_backbone, key_decoder = jr.split(key, 2)
# Classification benchmarks
dataset = "imagenette"
model = em.reduceformer_backbone_b1(
in_channels=3,
dropout=CONFIG["student_dpr"],
drop_path_rate=CONFIG["student_dpr"],
num_classes=CONFIG[f"num_classes_{dataset}"],
key=key_backbone,
)
model = cast_floating_to(model, jnp.float32)
model_params, model_static = eqx.partition(model, eqx.is_array)
classification.evaluate(
dataset=dataset,
params=model_params,
static=model_static,
key=key,
config=CONFIG,
seed=CONFIG["seed"],
)
logger.info("All evaluations completed!")
if __name__ == "__main__":
main()