diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..02a2378 --- /dev/null +++ b/.flake8 @@ -0,0 +1,15 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203,E501,W503,E402,F401,F403,F841,B006,B007,B008,B009,C416,E262 +exclude = + .git, + __pycache__, + .venv, + venv, + build, + dist, + *.egg-info, + .mypy_cache, + .pytest_cache +per-file-ignores = + __init__.py:F401,F403,E402 diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 42b5c4a..b6c2183 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -13,4 +13,4 @@ contact_links: about: Browse code examples and tutorials - name: π Existing Issues url: https://github.com/DeepKnowledge1/AnomaVision/issues?q=is%3Aissue - about: Search existing issues to see if your question was already answered \ No newline at end of file + about: Search existing issues to see if your question was already answered diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 40187b5..c351ade 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,9 +13,44 @@ concurrency: cancel-in-progress: true jobs: + # NEW JOB: Code Quality Enforcement + code-quality: + name: Code Quality Checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Install dependencies with dev extras + run: uv sync --extra dev --extra cpu --locked + + - name: Check formatting with Black + run: uv run black --check --diff . + + # - name: Check import sorting with isort + # run: uv run --with isort isort --check-only --diff . + + - name: Lint with flake8 + run: | + uv run flake8 \ + --max-line-length=88 \ + --extend-ignore=E203,E501,W503,E402,F401,F403,F841,B006,B007,B008,B009,C416,E262 \ + anomavision/ anodet/ apps/ tests/ + tests: name: Tests (Python ${{ matrix.python-version }}) runs-on: ubuntu-latest + # Run tests only after code quality passes + needs: code-quality strategy: matrix: python-version: ["3.10", "3.11", "3.12"] @@ -44,6 +79,8 @@ jobs: cuda-matrix-check: name: Verify CUDA Resolvers runs-on: ubuntu-latest + # Run CUDA checks only after code quality passes + needs: code-quality steps: - uses: actions/checkout@v4 diff --git a/.gitignore b/.gitignore index da17a86..bc45113 100644 --- a/.gitignore +++ b/.gitignore @@ -359,3 +359,4 @@ images/ !docs/images/archti.png !docs/AnomaVision_vs_Anomalib.pdf poetry publish --build --username __toke.txt +fix_flake.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d4cde7..b31c9f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,45 +1,49 @@ repos: - # Black: auto-format Python code (fixes E501 line too long) + # Black: auto-format Python code - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 26.3.1 hooks: - id: black - args: [--line-length=88] + args: [--line-length=88, --target-version=py310] - # isort: automatically sort and group imports (fixes E402) + # isort: automatically sort and group imports - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 8.0.1 # FIXED: was indented too far hooks: - id: isort - args: [--profile=black] + args: + - --profile=black + - --line-length=88 + - --force-single-line-imports - # # flake8: linting for style & errors - # - repo: https://github.com/pycqa/flake8 - # rev: 7.0.0 - # hooks: - # - id: flake8 - # additional_dependencies: [ - # flake8-bugbear, - # flake8-comprehensions - # ] - - # # mypy: static type checking - # - repo: https://github.com/pre-commit/mirrors-mypy - # rev: v1.8.0 - # hooks: - # - id: mypy + # flake8: ONLY block on critical errors (F811 redefinitions) + # All other warnings are ignored to allow legacy code to pass + - repo: https://github.com/pycqa/flake8 + rev: 7.3.0 + hooks: + - id: flake8 + args: + - --max-line-length=88 + - --extend-ignore=E203,E501,W503,E402,F401,F403,F841,B006,B007,B008,B009,C416,E262 + additional_dependencies: + - flake8-bugbear + - flake8-comprehensions # Basic cleanup hooks - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace + # Run pytest LAST - repo: local hooks: - id: pytest name: Run pytest - # entry: poetry run pytest + entry: uv run pytest language: system pass_filenames: false + +# Stop on first failure (don't run pytest if flake8 fails) +fail_fast: true diff --git a/anodet/__init__.py b/anodet/__init__.py index 31cec34..72b2596 100644 --- a/anodet/__init__.py +++ b/anodet/__init__.py @@ -278,7 +278,7 @@ except ImportError as e: print(f"β FATAL: Failed to import from anomavision: {e}") raise ImportError( - f"Failed to import from 'anomavision' package. " + "Failed to import from 'anomavision' package. " f"Please ensure AnomaVision is properly installed: {e}" ) from e diff --git a/anomavision/__init__.py b/anomavision/__init__.py index fcdb0cd..ed87261 100644 --- a/anomavision/__init__.py +++ b/anomavision/__init__.py @@ -12,7 +12,6 @@ from .datasets.mvtec_dataset import MVTecDataset from .feature_extraction import ResnetEmbeddingsExtractor from .padim import Padim - from .sampling_methods.kcenter_greedy import kCenterGreedy from .test import optimal_threshold, visualize_eval_data, visualize_eval_pair from .utils import get_logger # Export for users diff --git a/anomavision/cli.py b/anomavision/cli.py index f37d949..babe599 100644 --- a/anomavision/cli.py +++ b/anomavision/cli.py @@ -49,6 +49,7 @@ def create_parser() -> argparse.ArgumentParser: try: from anomavision import __version__ + version_str = f"AnomaVision {__version__}" except ImportError: version_str = "AnomaVision" @@ -87,8 +88,10 @@ def create_parser() -> argparse.ArgumentParser: # works immediately with no changes needed here. # ============================================================ + def _add_train_parser(subparsers) -> None: from anomavision.train import create_parser as _cp + subparsers.add_parser( "train", help="Train a new anomaly detection model", @@ -99,6 +102,7 @@ def _add_train_parser(subparsers) -> None: def _add_export_parser(subparsers) -> None: from anomavision.export import create_parser as _cp + subparsers.add_parser( "export", help="Export trained model to different formats", @@ -109,6 +113,7 @@ def _add_export_parser(subparsers) -> None: def _add_detect_parser(subparsers) -> None: from anomavision.detect import create_parser as _cp + subparsers.add_parser( "detect", help="Run inference on images", @@ -119,6 +124,7 @@ def _add_detect_parser(subparsers) -> None: def _add_eval_parser(subparsers) -> None: from anomavision.eval import create_parser as _cp + subparsers.add_parser( "eval", help="Evaluate model performance", @@ -132,23 +138,28 @@ def _add_eval_parser(subparsers) -> None: # No sys.argv manipulation. No double-parsing. # ============================================================ + def _dispatch_train(args: argparse.Namespace) -> None: from anomavision import train + train.main(args) def _dispatch_export(args: argparse.Namespace) -> None: from anomavision import export + export.main(args) def _dispatch_detect(args: argparse.Namespace) -> None: from anomavision import detect + detect.main(args) def _dispatch_eval(args: argparse.Namespace) -> None: from anomavision import eval as eval_module # 'eval' shadows the Python builtin + eval_module.main(args) @@ -156,6 +167,7 @@ def _dispatch_eval(args: argparse.Namespace) -> None: # Entry point # ============================================================ + def main() -> None: parser = create_parser() args = parser.parse_args() diff --git a/anomavision/datasets/MQTTSource.py b/anomavision/datasets/MQTTSource.py index 865576a..9a54f49 100644 --- a/anomavision/datasets/MQTTSource.py +++ b/anomavision/datasets/MQTTSource.py @@ -1,13 +1,13 @@ -from anomavision.datasets.StreamSource import StreamSource - -from typing import Optional import threading import time -from queue import Queue, Empty, Full +from queue import Empty, Full, Queue +from typing import Optional import cv2 import numpy as np +from anomavision.datasets.StreamSource import StreamSource + try: import paho.mqtt.client as mqtt except ImportError as e: diff --git a/anomavision/datasets/StreamDataset.py b/anomavision/datasets/StreamDataset.py index 7dd9e22..fef913b 100644 --- a/anomavision/datasets/StreamDataset.py +++ b/anomavision/datasets/StreamDataset.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -6,10 +6,7 @@ from torch.utils.data import IterableDataset from anomavision.datasets.StreamSource import StreamSource -from anomavision.utils import ( - create_image_transform, - create_mask_transform, -) +from anomavision.utils import create_image_transform, create_mask_transform class StreamDataset(IterableDataset): diff --git a/anomavision/datasets/StreamSource.py b/anomavision/datasets/StreamSource.py index ee11828..3b0780f 100644 --- a/anomavision/datasets/StreamSource.py +++ b/anomavision/datasets/StreamSource.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod - # ========== BASE STRATEGY INTERFACE ========== + class StreamSource(ABC): @abstractmethod def connect(self) -> None: @@ -27,4 +27,3 @@ def disconnect(self) -> None: def is_connected(self) -> bool: """Check if source is connected.""" pass - diff --git a/anomavision/datasets/StreamSourceFactory.py b/anomavision/datasets/StreamSourceFactory.py index c7a0127..50e11b9 100644 --- a/anomavision/datasets/StreamSourceFactory.py +++ b/anomavision/datasets/StreamSourceFactory.py @@ -1,11 +1,11 @@ -from typing import Dict, Any - -from anomavision.datasets.StreamDataset import StreamDataset -from anomavision.datasets.StreamSource import StreamSource +import contextlib +from typing import Any, Dict from torch.utils.data import IterableDataset from anomavision.datasets.MQTTSource import MQTTSource +from anomavision.datasets.StreamDataset import StreamDataset +from anomavision.datasets.StreamSource import StreamSource from anomavision.datasets.TCPsource import TCPSource from anomavision.datasets.VideoSource import VideoSource from anomavision.datasets.WebcamSource import WebcamSource @@ -65,24 +65,3 @@ def create(config: Dict[str, Any]) -> StreamSource: ) raise ValueError(f"Unknown StreamSource type: {source_type}") - - -# Optional: convenience ctor on StreamDataset -class StreamDataset(IterableDataset): - # your existing __init__ here... - - @classmethod - def from_config( - cls, - source_config: Dict[str, Any], - **dataset_kwargs: Any, - ) -> "StreamDataset": - """ - Build StreamDataset from a source config dict + dataset kwargs. - - Example: - source_cfg = {"type": "webcam", "camera_id": 0} - ds = StreamDataset.from_config(source_cfg, max_frames=100) - """ - source = StreamSourceFactory.create(source_config) - return cls(source=source, **dataset_kwargs) diff --git a/anomavision/datasets/TCPsource.py b/anomavision/datasets/TCPsource.py index 5f0191b..6f1b49c 100644 --- a/anomavision/datasets/TCPsource.py +++ b/anomavision/datasets/TCPsource.py @@ -1,14 +1,14 @@ -from anomavision.datasets.StreamSource import StreamSource - -from typing import Optional import socket import threading import time -from queue import Queue, Empty, Full +from queue import Empty, Full, Queue +from typing import Optional import cv2 import numpy as np +from anomavision.datasets.StreamSource import StreamSource + class TCPSource(StreamSource): """ @@ -132,7 +132,10 @@ def _receiver_loop(self) -> None: if payload_len <= 0: continue - if self.max_message_size is not None and payload_len > self.max_message_size: + if ( + self.max_message_size is not None + and payload_len > self.max_message_size + ): # Drain and skip (best-effort) _ = self._recv_exact(payload_len) continue @@ -238,4 +241,8 @@ def disconnect(self) -> None: def is_connected(self) -> bool: with self._lock: - return self._connected and not self._stop_event.is_set() and self.socket is not None + return ( + self._connected + and not self._stop_event.is_set() + and self.socket is not None + ) diff --git a/anomavision/datasets/VideoSource.py b/anomavision/datasets/VideoSource.py index c3a26ce..968557d 100644 --- a/anomavision/datasets/VideoSource.py +++ b/anomavision/datasets/VideoSource.py @@ -1,12 +1,13 @@ -from anomavision.datasets.StreamSource import StreamSource - -import cv2 -import threading import queue +import threading import time -import numpy as np from typing import Optional +import cv2 +import numpy as np + +from anomavision.datasets.StreamSource import StreamSource + class VideoSource(StreamSource): """ diff --git a/anomavision/datasets/WebcamSource.py b/anomavision/datasets/WebcamSource.py index aa001a0..b75d82f 100644 --- a/anomavision/datasets/WebcamSource.py +++ b/anomavision/datasets/WebcamSource.py @@ -1,12 +1,13 @@ -from anomavision.datasets.StreamSource import StreamSource - -import cv2 -import threading import queue +import threading import time -import numpy as np from typing import Optional +import cv2 +import numpy as np + +from anomavision.datasets.StreamSource import StreamSource + class WebcamSource(StreamSource): """ diff --git a/anomavision/detect.py b/anomavision/detect.py index 02dae24..58d0dab 100644 --- a/anomavision/detect.py +++ b/anomavision/detect.py @@ -21,10 +21,9 @@ from torch.utils.data import DataLoader import anomavision +from anomavision.config import _shape, load_config from anomavision.datasets.StreamDataset import StreamDataset from anomavision.datasets.StreamSourceFactory import StreamSourceFactory - -from anomavision.config import _shape, load_config from anomavision.general import Profiler, determine_device, increment_path from anomavision.inference.model.wrapper import ModelWrapper from anomavision.inference.modelType import ModelType @@ -210,7 +209,11 @@ def run_inference(args): # Parse visualization color try: - viz_color = tuple(map(int, config.viz_color.split(","))) if config.viz_color else (128, 0, 128) + viz_color = ( + tuple(map(int, config.viz_color.split(","))) + if config.viz_color + else (128, 0, 128) + ) if len(viz_color) != 3: raise ValueError except (ValueError, AttributeError): @@ -224,11 +227,15 @@ def run_inference(args): crop_size = _shape(config.crop_size) normalize = config.get("normalize", True) - logger.info("Image processing: resize=%s, crop=%s, norm=%s", resize, crop_size, normalize) + logger.info( + "Image processing: resize=%s, crop=%s, norm=%s", resize, crop_size, normalize + ) # Validation if not config.get("img_path") and not stream_mode: - raise ValueError("img_path is required (via --img_path or config) when stream_mode is False") + raise ValueError( + "img_path is required (via --img_path or config) when stream_mode is False" + ) if not config.get("model"): raise ValueError("model is required (via --model or config)") @@ -247,7 +254,7 @@ def run_inference(args): "scores": [], "classifications": [], # We only store images/maps if needed for downstream tasks to avoid memory issues - "images": [] if not stream_mode else None + "images": [] if not stream_mode else None, } total_start_time = time.time() @@ -271,7 +278,13 @@ def run_inference(args): # --- Model Loading Phase --- with profilers["model_loading"]: - model_path = os.path.join(MODEL_DATA_PATH, config.algorithm, config.class_name, config.run_name, config.model) + model_path = os.path.join( + MODEL_DATA_PATH, + config.algorithm, + config.class_name, + config.run_name, + config.model, + ) logger.info(f"Loading model: {model_path}") if not os.path.exists(model_path): @@ -291,7 +304,11 @@ def run_inference(args): run_name = config.run_name viz_output_dir = config.get("viz_output_dir", "./visualizations/") RESULTS_PATH = increment_path( - Path(viz_output_dir) / config.algorithm / config.class_name / model_type.value.upper() / run_name, + Path(viz_output_dir) + / config.algorithm + / config.class_name + / model_type.value.upper() + / run_name, exist_ok=config.get("overwrite", False), mkdir=True, ) @@ -383,18 +400,24 @@ def run_inference(args): # 2. Post-processing with profilers["postprocessing"]: try: - score_maps = adaptive_gaussian_blur(score_maps, kernel_size=33, sigma=4) + score_maps = adaptive_gaussian_blur( + score_maps, kernel_size=33, sigma=4 + ) # Classify if config.thresh is not None: - is_anomaly = anomavision.classification(image_scores, config.thresh) + is_anomaly = anomavision.classification( + image_scores, config.thresh + ) else: is_anomaly = np.zeros_like(image_scores) # Accumulate Results (Offline only) if not stream_mode: results_accumulator["scores"].extend(image_scores.tolist()) - results_accumulator["classifications"].extend(is_anomaly.tolist()) + results_accumulator["classifications"].extend( + is_anomaly.tolist() + ) results_accumulator["images"].extend(images) except Exception as e: @@ -410,7 +433,13 @@ def run_inference(args): boundary_images = ( anomavision.visualization.framed_boundary_images( test_images, - anomavision.classification(score_maps, config.thresh)if config.thresh else np.zeros_like(score_maps), + ( + anomavision.classification( + score_maps, config.thresh + ) + if config.thresh + else np.zeros_like(score_maps) + ), is_anomaly, padding=config.get("viz_padding", 40), ) @@ -424,7 +453,11 @@ def run_inference(args): highlighted_images = anomavision.visualization.highlighted_images( [images[i] for i in range(len(images))], # Dummy mask if threshold not set - anomavision.classification(score_maps, config.thresh) if config.thresh else np.zeros_like(score_maps), + ( + anomavision.classification(score_maps, config.thresh) + if config.thresh + else np.zeros_like(score_maps) + ), color=viz_color, ) @@ -434,7 +467,10 @@ def run_inference(args): if config.save_visualizations and RESULTS_PATH: try: fig, axs = plt.subplots(1, 4, figsize=(16, 8)) - fig.suptitle(f"Result - Batch {batch_idx} Img {img_id}", fontsize=14) + fig.suptitle( + f"Result - Batch {batch_idx} Img {img_id}", + fontsize=14, + ) axs[0].imshow(images[img_id]) axs[0].set_title("Original") @@ -452,7 +488,10 @@ def run_inference(args): axs[3].set_title("Highlighted") axs[3].axis("off") - save_path = os.path.join(RESULTS_PATH, f"batch_{batch_idx}_img_{img_id}.png") + save_path = os.path.join( + RESULTS_PATH, + f"batch_{batch_idx}_img_{img_id}.png", + ) plt.savefig(save_path, dpi=100, bbox_inches="tight") plt.close(fig) except Exception as e: @@ -465,10 +504,12 @@ def run_inference(args): logger.info("Closing model...") model.close() if stream_mode: - # Clean up stream source - try: - test_dataset.close() - except: pass + # Clean up stream source + try: + test_dataset.close() + except Exception: + + pass # --- Metrics & Summary --- total_pipeline_time = time.time() - total_start_time @@ -482,12 +523,24 @@ def run_inference(args): logger.info("=" * 60) logger.info("ANOMAVISION PERFORMANCE SUMMARY") logger.info("=" * 60) - logger.info(f"Setup time: {profilers['setup'].accumulated_time * 1000:.2f} ms") - logger.info(f"Model loading time: {profilers['model_loading'].accumulated_time * 1000:.2f} ms") - logger.info(f"Data loading time: {profilers['data_loading'].accumulated_time * 1000:.2f} ms") - logger.info(f"Inference time: {profilers['inference'].accumulated_time * 1000:.2f} ms") - logger.info(f"Postprocessing time: {profilers['postprocessing'].accumulated_time * 1000:.2f} ms") - logger.info(f"Visualization time: {profilers['visualization'].accumulated_time * 1000:.2f} ms") + logger.info( + f"Setup time: {profilers['setup'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Model loading time: {profilers['model_loading'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Data loading time: {profilers['data_loading'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Inference time: {profilers['inference'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Postprocessing time: {profilers['postprocessing'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Visualization time: {profilers['visualization'].accumulated_time * 1000:.2f} ms" + ) logger.info(f"Total pipeline time: {total_pipeline_time * 1000:.2f} ms") logger.info("=" * 60) @@ -501,16 +554,18 @@ def run_inference(args): logger.info(f"Average inference time: {avg_ms:.2f} ms/batch") if batch_count > 0: - batch_size = config.get('batch_size', 1) or 1 + batch_size = config.get("batch_size", 1) or 1 throughput = fps * (final_count / batch_count) if batch_count else 0 - logger.info(f"Throughput: {throughput:.1f} images/sec (batch size: {batch_size})") + logger.info( + f"Throughput: {throughput:.1f} images/sec (batch size: {batch_size})" + ) logger.info("=" * 60) metrics = { "fps": fps, "avg_inference_ms": avg_ms, "total_time_s": total_pipeline_time, - "total_images": final_count + "total_images": final_count, } return metrics, results_accumulator diff --git a/anomavision/eval.py b/anomavision/eval.py index e0dbc2a..1283099 100644 --- a/anomavision/eval.py +++ b/anomavision/eval.py @@ -8,8 +8,8 @@ import numpy as np import torch from easydict import EasyDict as edict +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score from torch.utils.data import DataLoader -from sklearn.metrics import roc_auc_score, precision_recall_curve, auc import anomavision from anomavision.config import load_config @@ -18,12 +18,12 @@ from anomavision.inference.modelType import ModelType from anomavision.utils import ( adaptive_gaussian_blur, + compute_metrics, + find_best_threshold_f1, + find_optimal_threshold, get_logger, merge_config, setup_logging, - find_best_threshold_f1, - compute_metrics, - find_optimal_threshold ) @@ -156,7 +156,9 @@ def evaluate_model_with_wrapper( logger.info(f"Starting evaluation on {len(test_dataloader.dataset)} images") try: - for batch_idx, (batch, images, image_targets, mask_targets) in enumerate(test_dataloader): + for batch_idx, (batch, images, image_targets, mask_targets) in enumerate( + test_dataloader + ): batch = batch.to(device_str) with evaluation_profiler: @@ -170,7 +172,9 @@ def evaluate_model_with_wrapper( # Collect results all_images.extend(images) all_image_classifications_target.extend( - image_targets.numpy() if hasattr(image_targets, "numpy") else image_targets + image_targets.numpy() + if hasattr(image_targets, "numpy") + else image_targets ) all_masks_target.extend( mask_targets.numpy() if hasattr(mask_targets, "numpy") else mask_targets @@ -192,7 +196,11 @@ def evaluate_model_with_wrapper( return ( np.array(all_images), np.array(all_image_classifications_target), - np.squeeze(np.array(all_masks_target), axis=1) if len(all_masks_target) > 0 else np.array([]), + ( + np.squeeze(np.array(all_masks_target), axis=1) + if len(all_masks_target) > 0 + else np.array([]) + ), np.array(all_image_scores), np.array(all_score_maps), ) @@ -227,7 +235,9 @@ def run_evaluation(args): # Setup Phase with profilers["setup"]: - DATASET_PATH = os.path.realpath(config.dataset_path) if config.dataset_path else None + DATASET_PATH = ( + os.path.realpath(config.dataset_path) if config.dataset_path else None + ) MODEL_DATA_PATH = os.path.realpath(config.model_data_path) device_str = determine_device(config.device) @@ -239,7 +249,13 @@ def run_evaluation(args): # Load Model Phase with profilers["model_loading"]: - model_path = os.path.join(MODEL_DATA_PATH, config.algorithm, config.class_name, config.run_name, config.model) + model_path = os.path.join( + MODEL_DATA_PATH, + config.algorithm, + config.class_name, + config.run_name, + config.model, + ) logger.info(f"Loading model: {model_path}") if not os.path.exists(model_path): @@ -273,7 +289,7 @@ def run_evaluation(args): batch_size=batch_size, num_workers=config.num_workers if config.num_workers else 0, pin_memory=config.pin_memory and device_str == "cuda", - shuffle=False + shuffle=False, ) except Exception as e: logger.error(f"Failed to create dataloader: {e}") @@ -308,8 +324,8 @@ def run_evaluation(args): # Add timing metrics total_images = len(test_dataset) - metrics['inference_fps'] = profilers["evaluation"].get_fps(total_images) - metrics['inference_time_total_s'] = profilers["evaluation"].accumulated_time + metrics["inference_fps"] = profilers["evaluation"].get_fps(total_images) + metrics["inference_time_total_s"] = profilers["evaluation"].accumulated_time # Visualization Phase if config.enable_visualization: @@ -344,11 +360,21 @@ def run_evaluation(args): logger.info("=" * 60) logger.info("ANOMAVISION EVALUATION PERFORMANCE SUMMARY") logger.info("=" * 60) - logger.info(f"Setup time: {profilers['setup'].accumulated_time * 1000:.2f} ms") - logger.info(f"Model loading time: {profilers['model_loading'].accumulated_time * 1000:.2f} ms") - logger.info(f"Data loading time: {profilers['data_loading'].accumulated_time * 1000:.2f} ms") - logger.info(f"Evaluation time: {profilers['evaluation'].accumulated_time * 1000:.2f} ms") - logger.info(f"Visualization time: {profilers['visualization'].accumulated_time * 1000:.2f} ms") + logger.info( + f"Setup time: {profilers['setup'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Model loading time: {profilers['model_loading'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Data loading time: {profilers['data_loading'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Evaluation time: {profilers['evaluation'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Visualization time: {profilers['visualization'].accumulated_time * 1000:.2f} ms" + ) # logger.info("=" * 60) # 2. PERFORMANCE METRICS @@ -362,7 +388,9 @@ def run_evaluation(args): if len(test_dataloader) > 0: images_per_batch = total_images / len(test_dataloader) - logger.info(f"Evaluation throughput: {evaluation_fps * images_per_batch:.1f} images/sec (batch size: {batch_size})") + logger.info( + f"Evaluation throughput: {evaluation_fps * images_per_batch:.1f} images/sec (batch size: {batch_size})" + ) # logger.info("=" * 60) # 3. EVALUATION SUMMARY @@ -373,7 +401,9 @@ def run_evaluation(args): logger.info(f"Total images evaluated: {total_images}") logger.info(f"Model type: {model_type.value.upper() if model_type else 'UNKNOWN'}") logger.info(f"Device: {device_str}") - logger.info(f"Image processing: resize={config.resize}, crop_size={config.crop_size}, normalize={config.normalize}") + logger.info( + f"Image processing: resize={config.resize}, crop_size={config.crop_size}, normalize={config.normalize}" + ) # logger.info("=" * 60) logger.info("=" * 60) @@ -396,7 +426,7 @@ def run_evaluation(args): "labels": labels, "masks": masks, "scores": scores, - "maps": maps + "maps": maps, } return metrics, raw_results diff --git a/anomavision/export.py b/anomavision/export.py index 0d15609..21e6da9 100644 --- a/anomavision/export.py +++ b/anomavision/export.py @@ -722,8 +722,19 @@ def main(args=None): setup_logging(enabled=True, log_level=config.log_level, log_to_file=True) logger = get_logger("anomavision.export") - model_path = Path(config.model_data_path) / config.algorithm / config.class_name / config.run_name / config.model - output_dir = Path(config.model_data_path) / config.algorithm / config.class_name / config.run_name + model_path = ( + Path(config.model_data_path) + / config.algorithm + / config.class_name + / config.run_name + / config.model + ) + output_dir = ( + Path(config.model_data_path) + / config.algorithm + / config.class_name + / config.run_name + ) model_stem = Path(config.model).stem # Generate output names diff --git a/anomavision/general.py b/anomavision/general.py index 4c36a83..3791d29 100644 --- a/anomavision/general.py +++ b/anomavision/general.py @@ -1,6 +1,8 @@ +import contextlib import os import subprocess import sys +import time from pathlib import Path import matplotlib.pyplot as plt @@ -52,12 +54,6 @@ def save_visualization(images, filename, output_dir): plt.imsave(filepath, images) -import contextlib -import time - -import torch - - class Profiler(contextlib.ContextDecorator): """ AnomaVision Performance Profiler for accurate timing measurements. diff --git a/anomavision/padim.py b/anomavision/padim.py index 4dc22bb..6615b36 100644 --- a/anomavision/padim.py +++ b/anomavision/padim.py @@ -137,7 +137,7 @@ def forward( layer_hook=self.layer_hook, layer_indices=self.layer_indices, ) - embedding_vectors= embedding_vectors.to(dtype=x.dtype) + embedding_vectors = embedding_vectors.to(dtype=x.dtype) patch_scores = self.mahalanobisDistance( features=embedding_vectors, width=w, height=h, export=export, chunk=256 ) # (B, w, h) diff --git a/anomavision/sampling_methods/kcenter_greedy.py b/anomavision/sampling_methods/kcenter_greedy.py index 964252d..f5d60e0 100644 --- a/anomavision/sampling_methods/kcenter_greedy.py +++ b/anomavision/sampling_methods/kcenter_greedy.py @@ -89,7 +89,8 @@ def select_batch_(self, model, already_selected, N, **kwargs): self.features = model.transform(self.X) print("Calculating distances...") self.update_distances(already_selected, only_new=False, reset_dist=True) - except: + except Exception: + print("Using flat_X as features.") self.update_distances(already_selected, only_new=True, reset_dist=False) diff --git a/anomavision/train.py b/anomavision/train.py index f0dbeb9..b786a89 100644 --- a/anomavision/train.py +++ b/anomavision/train.py @@ -20,7 +20,9 @@ def create_parser(add_help: bool = True) -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Train PaDiM (args OR config).", add_help=add_help) + parser = argparse.ArgumentParser( + description="Train PaDiM (args OR config).", add_help=add_help + ) # meta parser.add_argument( "--config", type=str, default="config.yml", help="Path to config.yml/.json" @@ -176,13 +178,16 @@ def run_training(args): config.normalize, ) if config.normalize: - logger.info( - "Normalization: mean=%s, std=%s", config.norm_mean, config.norm_std - ) + logger.info("Normalization: mean=%s, std=%s", config.norm_mean, config.norm_std) # Resolve output run dir once run_dir = increment_path( - Path(config.model_data_path) / config.algorithm / config.class_name / config.run_name, exist_ok=True, mkdir=True + Path(config.model_data_path) + / config.algorithm + / config.class_name + / config.run_name, + exist_ok=True, + mkdir=True, ) # === Dataset === @@ -199,9 +204,11 @@ def run_training(args): if not os.path.isdir(root): # Fallback check: maybe dataset_path ALREADY points to the class folder? # This makes it more robust for different input styles - potential_root = os.path.join(os.path.realpath(config.dataset_path), "train", "good") + potential_root = os.path.join( + os.path.realpath(config.dataset_path), "train", "good" + ) if os.path.isdir(potential_root): - root = potential_root + root = potential_root else: logger.error('Expected folder "%s" does not exist.', root) raise FileNotFoundError(f"Dataset root not found: {root}") @@ -263,13 +270,11 @@ def run_training(args): # snapshot the effective configuration save_args_to_yaml(config, str(Path(run_dir) / "config.yml")) - logger.info( - "saved: model=%s, config=%s", model_path, Path(run_dir) / "config.yml" - ) + logger.info("saved: model=%s, config=%s", model_path, Path(run_dir) / "config.yml") logger.info("=== Training done in %.2fs ===", time.perf_counter() - t0) # Return objects for external usage (e.g. MLOps pipeline) - return padim, config, run_dir, {'train': dl} + return padim, config, run_dir, {"train": dl} def main(args=None): @@ -282,6 +287,7 @@ def main(args=None): except Exception: get_logger(__name__).exception("Fatal error during training.") sys.exit(1) - + + if __name__ == "__main__": main() diff --git a/anomavision/utils.py b/anomavision/utils.py index d6d0bc7..34d0a66 100644 --- a/anomavision/utils.py +++ b/anomavision/utils.py @@ -12,11 +12,9 @@ import torch import yaml from PIL import Image +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score, roc_curve from torchvision import transforms as T -from sklearn.metrics import roc_curve -from sklearn.metrics import roc_auc_score, precision_recall_curve, auc - # Default standard transforms - kept for backward compatibility standard_image_transform = T.Compose( [ @@ -347,8 +345,6 @@ def split_tensor_and_run_function( return output_tensor - - def setup_logging( log_level: str = "INFO", log_to_file: bool = False, @@ -790,6 +786,7 @@ def adaptive_gaussian_blur(input_array, kernel_size=33, sigma=4): except ImportError: raise ImportError("SciPy is required when PyTorch is not available") + def find_best_threshold_f1(labels, scores): precision, recall, thresholds = precision_recall_curve(labels, scores) @@ -798,7 +795,6 @@ def find_best_threshold_f1(labels, scores): return thresholds[best_idx], f1[best_idx] -from sklearn.metrics import roc_curve def find_best_threshold_roc(labels, scores): fpr, tpr, thresholds = roc_curve(labels, scores) @@ -825,6 +821,7 @@ def find_best_threshold_accuracy(labels, scores): return best_thresh, best_acc + def compute_metrics(labels, scores, thresh=None): """ Calculate standard anomaly detection metrics. @@ -833,26 +830,26 @@ def compute_metrics(labels, scores, thresh=None): # AUROC try: - metrics['auc_score'] = float(roc_auc_score(labels, scores)) + metrics["auc_score"] = float(roc_auc_score(labels, scores)) except ValueError: - metrics['auc_score'] = 0.0 + metrics["auc_score"] = 0.0 # PR-AUC try: precision, recall, _ = precision_recall_curve(labels, scores) - metrics['pr_auc'] = float(auc(recall, precision)) + metrics["pr_auc"] = float(auc(recall, precision)) except ValueError: - metrics['pr_auc'] = 0.0 + metrics["pr_auc"] = 0.0 # Statistics - metrics['mean_anomaly_score'] = float(np.mean(scores)) - metrics['std_anomaly_score'] = float(np.std(scores)) + metrics["mean_anomaly_score"] = float(np.mean(scores)) + metrics["std_anomaly_score"] = float(np.std(scores)) # Accuracy (if thresh provided) if thresh is not None: predictions = (scores > thresh).astype(int) - metrics['accuracy'] = float(np.mean(predictions == labels)) - metrics['thresh'] = thresh + metrics["accuracy"] = float(np.mean(predictions == labels)) + metrics["thresh"] = thresh return metrics diff --git a/app.py b/app.py index b038c07..9b86be1 100644 --- a/app.py +++ b/app.py @@ -86,7 +86,9 @@ async def process_image(file: UploadFile = File(...)): np_image = np.array(image) # Preprocess and run inference - batch = to_batch([np_image], anomavision.standard_image_transform, torch.device("cpu")) + batch = to_batch( + [np_image], anomavision.standard_image_transform, torch.device("cpu") + ) image_scores, score_maps = model.predict(batch) # Postprocess diff --git a/apps/anomavision_gui_tkinter.py b/apps/anomavision_gui_tkinter.py index f0cd67d..8e6d501 100644 --- a/apps/anomavision_gui_tkinter.py +++ b/apps/anomavision_gui_tkinter.py @@ -37,6 +37,7 @@ except Exception: _DND_AVAILABLE = False + # ----- Your package imports (existing project modules) ----- import anomavision from anomavision.config import load_config @@ -45,7 +46,7 @@ from anomavision.inference.modelType import ModelType from anomavision.padim import Padim from anomavision.utils import get_logger -from PIL import Image, ImageTk + # ----------------------------------------------------------------------------- # Global fast paths and helpers # ----------------------------------------------------------------------------- @@ -208,7 +209,7 @@ def make_logo(size=44, bg="#38BDF8", fg="#001225"): text = "AV" # Try common fonts; fall back to default font = None - for name in ("arial.ttf", "SegoeUI.ttf", "DejaVuSans.ttf"): + for name in ("arial.tt", "SegoeUI.tt", "DejaVuSans.tt"): try: font = ImageFont.truetype(name, int(size * 0.44)) break @@ -764,8 +765,8 @@ def _build_style(self): self.style.configure( "Green.Horizontal.TProgressbar", troughcolor=self.colors["panel"], - background="#10B981", # β green bar - thickness=14 + background="#10B981", # β green bar + thickness=14, ) def toggle_theme(self): @@ -836,10 +837,10 @@ def _build_header(self): # --- Create logo image BEFORE using it (safe fallback if PIL font missing) --- - - try: - icon_image = Image.open("av.png") # use PNG for clean scaling; ICO also works + icon_image = Image.open( + "av.png" + ) # use PNG for clean scaling; ICO also works icon_image = icon_image.resize((44, 44), Image.LANCZOS) self.logo_tk = ImageTk.PhotoImage(icon_image) except Exception: @@ -1214,11 +1215,13 @@ def row(lbl): variable=self.train_progress_var, maximum=100, mode="determinate", - style="Green.Horizontal.TProgressbar" + style="Green.Horizontal.TProgressbar", ) self.train_progress.pack(fill=tk.X, padx=14, pady=(0, 12)) self.train_progress.pack_forget() - self.train_percent_label = tk.Label(cfg, text="0%", bg=self.colors["panel"], fg=self.colors["fg"]) + self.train_percent_label = tk.Label( + cfg, text="0%", bg=self.colors["panel"], fg=self.colors["fg"] + ) self.train_percent_label.pack(anchor="w", padx=14) def update_training_progress(self, value): diff --git a/apps/api/fastapi_app.py b/apps/api/fastapi_app.py index a9a00a4..cf27ef7 100644 --- a/apps/api/fastapi_app.py +++ b/apps/api/fastapi_app.py @@ -29,7 +29,9 @@ RESIZE_SIZE = (224, 224) # You can override these via environment variables -MODEL_DATA_PATH = os.getenv("ANOMAVISION_MODEL_DATA_PATH", "distributions/padim/bottle/anomav_exp") +MODEL_DATA_PATH = os.getenv( + "ANOMAVISION_MODEL_DATA_PATH", "distributions/padim/bottle/anomav_exp" +) MODEL_FILE = os.getenv("ANOMAVISION_MODEL_FILE", "model.onnx") DEVICE = os.getenv("ANOMAVISION_DEVICE", "auto") # "auto"|"cpu"|"cuda" @@ -81,6 +83,7 @@ async def cleanup(): model_type = None print("Model cleanup completed.") + @asynccontextmanager async def lifespan(app: FastAPI): await load_model() diff --git a/apps/ui/gradio_app.py b/apps/ui/gradio_app.py index 715fb23..a51c3fc 100644 --- a/apps/ui/gradio_app.py +++ b/apps/ui/gradio_app.py @@ -33,6 +33,7 @@ from anomavision.general import determine_device from anomavision.inference.model.wrapper import ModelWrapper from anomavision.inference.modelType import ModelType + ANOMAVISION_AVAILABLE = True except ImportError: ANOMAVISION_AVAILABLE = False @@ -41,14 +42,16 @@ # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # Config # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ -MODEL_DATA_PATH = os.getenv("ANOMAVISION_MODEL_DATA_PATH", "distributions/padim/bottle/anomav_exp") -MODEL_FILE = os.getenv("ANOMAVISION_MODEL_FILE", "model.onnx") -DEVICE_ENV = os.getenv("ANOMAVISION_DEVICE", "auto") +MODEL_DATA_PATH = os.getenv( + "ANOMAVISION_MODEL_DATA_PATH", "distributions/padim/bottle/anomav_exp" +) +MODEL_FILE = os.getenv("ANOMAVISION_MODEL_FILE", "model.onnx") +DEVICE_ENV = os.getenv("ANOMAVISION_DEVICE", "auto") THRESHOLD_DEFAULT = float(os.getenv("ANOMAVISION_THRESHOLD", "13.0")) -VIZ_PADDING = int(os.getenv("ANOMAVISION_VIZ_PADDING", "40")) -VIZ_ALPHA = float(os.getenv("ANOMAVISION_VIZ_ALPHA", "0.5")) -VIZ_COLOR = tuple(map(int, os.getenv("ANOMAVISION_VIZ_COLOR", "128,0,128").split(","))) -SAMPLE_DIR = os.getenv("SAMPLE_IMAGES_DIR", "D:/01-DATA/sample_images") +VIZ_PADDING = int(os.getenv("ANOMAVISION_VIZ_PADDING", "40")) +VIZ_ALPHA = float(os.getenv("ANOMAVISION_VIZ_ALPHA", "0.5")) +VIZ_COLOR = tuple(map(int, os.getenv("ANOMAVISION_VIZ_COLOR", "128,0,128").split(","))) +SAMPLE_DIR = os.getenv("SAMPLE_IMAGES_DIR", "D:/01-DATA/sample_images") # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # Model β loaded once at startup @@ -72,11 +75,13 @@ def _load_model() -> str: try: _device_str = determine_device(DEVICE_ENV) _model_type = ModelType.from_extension(model_path) - _model = ModelWrapper(model_path, _device_str) + _model = ModelWrapper(model_path, _device_str) # Optional warmup try: - dummy = torch.zeros((1, 3, 224, 224), dtype=torch.float32, device=_device_str) + dummy = torch.zeros( + (1, 3, 224, 224), dtype=torch.float32, device=_device_str + ) _model.warmup(batch=dummy, runs=1) except Exception: pass @@ -108,7 +113,9 @@ def _demo_predict(image_np: np.ndarray): def _real_predict(image_np: np.ndarray, threshold: float): """Run anomavision inference and return (score, score_map_np, boundary_np, heatmap_np, highlighted_np).""" device = torch.device(_device_str) - batch = anomavision.to_batch([image_np], anomavision.standard_image_transform, device) + batch = anomavision.to_batch( + [image_np], anomavision.standard_image_transform, device + ) if _device_str == "cuda": batch = batch.half() @@ -116,8 +123,8 @@ def _real_predict(image_np: np.ndarray, threshold: float): with torch.no_grad(): image_scores, score_maps = _model.predict(batch) - score_map_cls = anomavision.classification(score_maps, threshold) - image_cls = anomavision.classification(image_scores, threshold) + score_map_cls = anomavision.classification(score_maps, threshold) + image_cls = anomavision.classification(image_scores, threshold) test_images = np.array([image_np]) @@ -187,7 +194,7 @@ def _collect_samples() -> list: for p in sorted(base.rglob("*")): if p.suffix.lower() in SUPPORTED_EXT: - rel = p.relative_to(base) + rel = p.relative_to(base) parts = rel.parts if len(parts) >= 3: label = f"{parts[0]}/{parts[1]}" @@ -235,22 +242,24 @@ def run_inference( if image is None: return "β Please upload or select an image.", None, None, None, None - resize = (int(resize_w), int(resize_h)) + resize = (int(resize_w), int(resize_h)) image_np = _pil_to_np(image) t0 = time.time() if _model is not None and ANOMAVISION_AVAILABLE: try: - score, score_map_np, boundary_np, heatmap_np, highlighted_np = _real_predict( - image_np, threshold + score, score_map_np, boundary_np, heatmap_np, highlighted_np = ( + _real_predict(image_np, threshold) ) is_anomaly = score >= threshold - original_pil = image.resize(resize, Image.BILINEAR) - heatmap_pil = _np_to_pil(heatmap_np, resize) if include_viz else None - boundary_pil = _np_to_pil(boundary_np, resize) if include_viz else None - highlighted_pil = _np_to_pil(highlighted_np, resize) if include_viz else None + original_pil = image.resize(resize, Image.BILINEAR) + heatmap_pil = _np_to_pil(heatmap_np, resize) if include_viz else None + boundary_pil = _np_to_pil(boundary_np, resize) if include_viz else None + highlighted_pil = ( + _np_to_pil(highlighted_np, resize) if include_viz else None + ) except Exception as e: return f"β οΈ Inference error: {e}", None, None, None, None @@ -263,38 +272,45 @@ def run_inference( if include_viz: import matplotlib.cm as cm + heatmap_norm = heatmap_raw / heatmap_raw.max() - cmap = cm.get_cmap("jet") + cmap = cm.get_cmap("jet") heatmap_rgba = (cmap(heatmap_norm) * 255).astype(np.uint8) - heatmap_rgb = heatmap_rgba[:, :, :3] - blend = (0.5 * image_np + 0.5 * heatmap_rgb).astype(np.uint8) - heatmap_pil = _np_to_pil(heatmap_rgb, resize) - boundary_pil = _np_to_pil(blend, resize) - highlighted_pil = _np_to_pil(image_np, resize) + heatmap_rgb = heatmap_rgba[:, :, :3] + blend = (0.5 * image_np + 0.5 * heatmap_rgb).astype(np.uint8) + heatmap_pil = _np_to_pil(heatmap_rgb, resize) + boundary_pil = _np_to_pil(blend, resize) + highlighted_pil = _np_to_pil(image_np, resize) else: heatmap_pil = boundary_pil = highlighted_pil = None elapsed = time.time() - t0 - label = "π¨ ANOMALY DETECTED" if is_anomaly else "β NORMAL" - status = f"Model: {Path(MODEL_FILE).stem} | Score: {score:.4f} | {label}" - detail = f"Threshold: {threshold:.2f} | Inference time: {elapsed:.2f}s" + label = "π¨ ANOMALY DETECTED" if is_anomaly else "β NORMAL" + status = f"Model: {Path(MODEL_FILE).stem} | Score: {score:.4f} | {label}" + detail = f"Threshold: {threshold:.2f} | Inference time: {elapsed:.2f}s" - return f"{status}\n{detail}", original_pil, heatmap_pil, boundary_pil, highlighted_pil + return ( + f"{status}\n{detail}", + original_pil, + heatmap_pil, + boundary_pil, + highlighted_pil, + ) # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # CSS β Clean Light Theme with Indigo/Violet Accents # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ -_ACCENT = "#6366f1" # indigo -_ACCENT_H = "#4f46e5" # indigo hover -_ACCENT2 = "#ef4444" # red for anomaly alerts -_ACCENT3 = "#22c55e" # green for normal result -_BG = "#f5f6fa" # off-white page background -_SURFACE = "#ffffff" # card surface -_SURFACE2 = "#f0f1f8" # slightly tinted input background -_BORDER = "#e2e4f0" # soft lavender border -_TEXT = "#1e1b4b" # deep indigo text -_MUTED = "#7c82a8" # muted blue-grey +_ACCENT = "#6366f1" # indigo +_ACCENT_H = "#4f46e5" # indigo hover +_ACCENT2 = "#ef4444" # red for anomaly alerts +_ACCENT3 = "#22c55e" # green for normal result +_BG = "#f5f6fa" # off-white page background +_SURFACE = "#ffffff" # card surface +_SURFACE2 = "#f0f1f8" # slightly tinted input background +_BORDER = "#e2e4f0" # soft lavender border +_TEXT = "#1e1b4b" # deep indigo text +_MUTED = "#7c82a8" # muted blue-grey custom_css = f""" @import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap'); @@ -716,7 +732,8 @@ def run_inference( with gr.Blocks(title="AnomaVision β Industrial Anomaly Detection") as demo: # ββ Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ - gr.HTML(f""" + gr.HTML( + f"""