-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
35 lines (28 loc) · 790 Bytes
/
train.py
File metadata and controls
35 lines (28 loc) · 790 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from anomalib import TaskType
from anomalib.data import MVTec
from anomalib.data.utils import ValSplitMode
from anomalib.engine import Engine
from anomalib.models import Padim
def train():
datamodule = MVTec(
root=r"F:\demo\mvtec_anomaly_detection",
category="bottle",
task=TaskType.CLASSIFICATION,
val_split_mode=ValSplitMode.SYNTHETIC, # synthetically generate validation data
image_size=(256, 256),
val_split_ratio=0.2,
train_batch_size=32,
eval_batch_size=32,
num_workers=4,
)
datamodule.setup()
# Model & engine
model = Padim()
engine = Engine()
# Train the model
engine.fit(
datamodule=datamodule,
model=model,
)
if __name__ == "__main__":
train()