diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..02a2378 --- /dev/null +++ b/.flake8 @@ -0,0 +1,15 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203,E501,W503,E402,F401,F403,F841,B006,B007,B008,B009,C416,E262 +exclude = + .git, + __pycache__, + .venv, + venv, + build, + dist, + *.egg-info, + .mypy_cache, + .pytest_cache +per-file-ignores = + __init__.py:F401,F403,E402 diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 42b5c4a..b6c2183 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -13,4 +13,4 @@ contact_links: about: Browse code examples and tutorials - name: πŸ” Existing Issues url: https://github.com/DeepKnowledge1/AnomaVision/issues?q=is%3Aissue - about: Search existing issues to see if your question was already answered \ No newline at end of file + about: Search existing issues to see if your question was already answered diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 40187b5..c351ade 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,9 +13,44 @@ concurrency: cancel-in-progress: true jobs: + # NEW JOB: Code Quality Enforcement + code-quality: + name: Code Quality Checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Install dependencies with dev extras + run: uv sync --extra dev --extra cpu --locked + + - name: Check formatting with Black + run: uv run black --check --diff . + + # - name: Check import sorting with isort + # run: uv run --with isort isort --check-only --diff . + + - name: Lint with flake8 + run: | + uv run flake8 \ + --max-line-length=88 \ + --extend-ignore=E203,E501,W503,E402,F401,F403,F841,B006,B007,B008,B009,C416,E262 \ + anomavision/ anodet/ apps/ tests/ + tests: name: Tests (Python ${{ matrix.python-version }}) runs-on: ubuntu-latest + # Run tests only after code quality passes + needs: code-quality strategy: matrix: python-version: ["3.10", "3.11", "3.12"] @@ -44,6 +79,8 @@ jobs: cuda-matrix-check: name: Verify CUDA Resolvers runs-on: ubuntu-latest + # Run CUDA checks only after code quality passes + needs: code-quality steps: - uses: actions/checkout@v4 diff --git a/.gitignore b/.gitignore index da17a86..bc45113 100644 --- a/.gitignore +++ b/.gitignore @@ -359,3 +359,4 @@ images/ !docs/images/archti.png !docs/AnomaVision_vs_Anomalib.pdf poetry publish --build --username __toke.txt +fix_flake.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d4cde7..b31c9f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,45 +1,49 @@ repos: - # Black: auto-format Python code (fixes E501 line too long) + # Black: auto-format Python code - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 26.3.1 hooks: - id: black - args: [--line-length=88] + args: [--line-length=88, --target-version=py310] - # isort: automatically sort and group imports (fixes E402) + # isort: automatically sort and group imports - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 8.0.1 # FIXED: was indented too far hooks: - id: isort - args: [--profile=black] + args: + - --profile=black + - --line-length=88 + - --force-single-line-imports - # # flake8: linting for style & errors - # - repo: https://github.com/pycqa/flake8 - # rev: 7.0.0 - # hooks: - # - id: flake8 - # additional_dependencies: [ - # flake8-bugbear, - # flake8-comprehensions - # ] - - # # mypy: static type checking - # - repo: https://github.com/pre-commit/mirrors-mypy - # rev: v1.8.0 - # hooks: - # - id: mypy + # flake8: ONLY block on critical errors (F811 redefinitions) + # All other warnings are ignored to allow legacy code to pass + - repo: https://github.com/pycqa/flake8 + rev: 7.3.0 + hooks: + - id: flake8 + args: + - --max-line-length=88 + - --extend-ignore=E203,E501,W503,E402,F401,F403,F841,B006,B007,B008,B009,C416,E262 + additional_dependencies: + - flake8-bugbear + - flake8-comprehensions # Basic cleanup hooks - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace + # Run pytest LAST - repo: local hooks: - id: pytest name: Run pytest - # entry: poetry run pytest + entry: uv run pytest language: system pass_filenames: false + +# Stop on first failure (don't run pytest if flake8 fails) +fail_fast: true diff --git a/anodet/__init__.py b/anodet/__init__.py index 31cec34..72b2596 100644 --- a/anodet/__init__.py +++ b/anodet/__init__.py @@ -278,7 +278,7 @@ except ImportError as e: print(f"❌ FATAL: Failed to import from anomavision: {e}") raise ImportError( - f"Failed to import from 'anomavision' package. " + "Failed to import from 'anomavision' package. " f"Please ensure AnomaVision is properly installed: {e}" ) from e diff --git a/anomavision/__init__.py b/anomavision/__init__.py index fcdb0cd..ed87261 100644 --- a/anomavision/__init__.py +++ b/anomavision/__init__.py @@ -12,7 +12,6 @@ from .datasets.mvtec_dataset import MVTecDataset from .feature_extraction import ResnetEmbeddingsExtractor from .padim import Padim - from .sampling_methods.kcenter_greedy import kCenterGreedy from .test import optimal_threshold, visualize_eval_data, visualize_eval_pair from .utils import get_logger # Export for users diff --git a/anomavision/cli.py b/anomavision/cli.py index f37d949..babe599 100644 --- a/anomavision/cli.py +++ b/anomavision/cli.py @@ -49,6 +49,7 @@ def create_parser() -> argparse.ArgumentParser: try: from anomavision import __version__ + version_str = f"AnomaVision {__version__}" except ImportError: version_str = "AnomaVision" @@ -87,8 +88,10 @@ def create_parser() -> argparse.ArgumentParser: # works immediately with no changes needed here. # ============================================================ + def _add_train_parser(subparsers) -> None: from anomavision.train import create_parser as _cp + subparsers.add_parser( "train", help="Train a new anomaly detection model", @@ -99,6 +102,7 @@ def _add_train_parser(subparsers) -> None: def _add_export_parser(subparsers) -> None: from anomavision.export import create_parser as _cp + subparsers.add_parser( "export", help="Export trained model to different formats", @@ -109,6 +113,7 @@ def _add_export_parser(subparsers) -> None: def _add_detect_parser(subparsers) -> None: from anomavision.detect import create_parser as _cp + subparsers.add_parser( "detect", help="Run inference on images", @@ -119,6 +124,7 @@ def _add_detect_parser(subparsers) -> None: def _add_eval_parser(subparsers) -> None: from anomavision.eval import create_parser as _cp + subparsers.add_parser( "eval", help="Evaluate model performance", @@ -132,23 +138,28 @@ def _add_eval_parser(subparsers) -> None: # No sys.argv manipulation. No double-parsing. # ============================================================ + def _dispatch_train(args: argparse.Namespace) -> None: from anomavision import train + train.main(args) def _dispatch_export(args: argparse.Namespace) -> None: from anomavision import export + export.main(args) def _dispatch_detect(args: argparse.Namespace) -> None: from anomavision import detect + detect.main(args) def _dispatch_eval(args: argparse.Namespace) -> None: from anomavision import eval as eval_module # 'eval' shadows the Python builtin + eval_module.main(args) @@ -156,6 +167,7 @@ def _dispatch_eval(args: argparse.Namespace) -> None: # Entry point # ============================================================ + def main() -> None: parser = create_parser() args = parser.parse_args() diff --git a/anomavision/datasets/MQTTSource.py b/anomavision/datasets/MQTTSource.py index 865576a..9a54f49 100644 --- a/anomavision/datasets/MQTTSource.py +++ b/anomavision/datasets/MQTTSource.py @@ -1,13 +1,13 @@ -from anomavision.datasets.StreamSource import StreamSource - -from typing import Optional import threading import time -from queue import Queue, Empty, Full +from queue import Empty, Full, Queue +from typing import Optional import cv2 import numpy as np +from anomavision.datasets.StreamSource import StreamSource + try: import paho.mqtt.client as mqtt except ImportError as e: diff --git a/anomavision/datasets/StreamDataset.py b/anomavision/datasets/StreamDataset.py index 7dd9e22..fef913b 100644 --- a/anomavision/datasets/StreamDataset.py +++ b/anomavision/datasets/StreamDataset.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -6,10 +6,7 @@ from torch.utils.data import IterableDataset from anomavision.datasets.StreamSource import StreamSource -from anomavision.utils import ( - create_image_transform, - create_mask_transform, -) +from anomavision.utils import create_image_transform, create_mask_transform class StreamDataset(IterableDataset): diff --git a/anomavision/datasets/StreamSource.py b/anomavision/datasets/StreamSource.py index ee11828..3b0780f 100644 --- a/anomavision/datasets/StreamSource.py +++ b/anomavision/datasets/StreamSource.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod - # ========== BASE STRATEGY INTERFACE ========== + class StreamSource(ABC): @abstractmethod def connect(self) -> None: @@ -27,4 +27,3 @@ def disconnect(self) -> None: def is_connected(self) -> bool: """Check if source is connected.""" pass - diff --git a/anomavision/datasets/StreamSourceFactory.py b/anomavision/datasets/StreamSourceFactory.py index c7a0127..50e11b9 100644 --- a/anomavision/datasets/StreamSourceFactory.py +++ b/anomavision/datasets/StreamSourceFactory.py @@ -1,11 +1,11 @@ -from typing import Dict, Any - -from anomavision.datasets.StreamDataset import StreamDataset -from anomavision.datasets.StreamSource import StreamSource +import contextlib +from typing import Any, Dict from torch.utils.data import IterableDataset from anomavision.datasets.MQTTSource import MQTTSource +from anomavision.datasets.StreamDataset import StreamDataset +from anomavision.datasets.StreamSource import StreamSource from anomavision.datasets.TCPsource import TCPSource from anomavision.datasets.VideoSource import VideoSource from anomavision.datasets.WebcamSource import WebcamSource @@ -65,24 +65,3 @@ def create(config: Dict[str, Any]) -> StreamSource: ) raise ValueError(f"Unknown StreamSource type: {source_type}") - - -# Optional: convenience ctor on StreamDataset -class StreamDataset(IterableDataset): - # your existing __init__ here... - - @classmethod - def from_config( - cls, - source_config: Dict[str, Any], - **dataset_kwargs: Any, - ) -> "StreamDataset": - """ - Build StreamDataset from a source config dict + dataset kwargs. - - Example: - source_cfg = {"type": "webcam", "camera_id": 0} - ds = StreamDataset.from_config(source_cfg, max_frames=100) - """ - source = StreamSourceFactory.create(source_config) - return cls(source=source, **dataset_kwargs) diff --git a/anomavision/datasets/TCPsource.py b/anomavision/datasets/TCPsource.py index 5f0191b..6f1b49c 100644 --- a/anomavision/datasets/TCPsource.py +++ b/anomavision/datasets/TCPsource.py @@ -1,14 +1,14 @@ -from anomavision.datasets.StreamSource import StreamSource - -from typing import Optional import socket import threading import time -from queue import Queue, Empty, Full +from queue import Empty, Full, Queue +from typing import Optional import cv2 import numpy as np +from anomavision.datasets.StreamSource import StreamSource + class TCPSource(StreamSource): """ @@ -132,7 +132,10 @@ def _receiver_loop(self) -> None: if payload_len <= 0: continue - if self.max_message_size is not None and payload_len > self.max_message_size: + if ( + self.max_message_size is not None + and payload_len > self.max_message_size + ): # Drain and skip (best-effort) _ = self._recv_exact(payload_len) continue @@ -238,4 +241,8 @@ def disconnect(self) -> None: def is_connected(self) -> bool: with self._lock: - return self._connected and not self._stop_event.is_set() and self.socket is not None + return ( + self._connected + and not self._stop_event.is_set() + and self.socket is not None + ) diff --git a/anomavision/datasets/VideoSource.py b/anomavision/datasets/VideoSource.py index c3a26ce..968557d 100644 --- a/anomavision/datasets/VideoSource.py +++ b/anomavision/datasets/VideoSource.py @@ -1,12 +1,13 @@ -from anomavision.datasets.StreamSource import StreamSource - -import cv2 -import threading import queue +import threading import time -import numpy as np from typing import Optional +import cv2 +import numpy as np + +from anomavision.datasets.StreamSource import StreamSource + class VideoSource(StreamSource): """ diff --git a/anomavision/datasets/WebcamSource.py b/anomavision/datasets/WebcamSource.py index aa001a0..b75d82f 100644 --- a/anomavision/datasets/WebcamSource.py +++ b/anomavision/datasets/WebcamSource.py @@ -1,12 +1,13 @@ -from anomavision.datasets.StreamSource import StreamSource - -import cv2 -import threading import queue +import threading import time -import numpy as np from typing import Optional +import cv2 +import numpy as np + +from anomavision.datasets.StreamSource import StreamSource + class WebcamSource(StreamSource): """ diff --git a/anomavision/detect.py b/anomavision/detect.py index 02dae24..58d0dab 100644 --- a/anomavision/detect.py +++ b/anomavision/detect.py @@ -21,10 +21,9 @@ from torch.utils.data import DataLoader import anomavision +from anomavision.config import _shape, load_config from anomavision.datasets.StreamDataset import StreamDataset from anomavision.datasets.StreamSourceFactory import StreamSourceFactory - -from anomavision.config import _shape, load_config from anomavision.general import Profiler, determine_device, increment_path from anomavision.inference.model.wrapper import ModelWrapper from anomavision.inference.modelType import ModelType @@ -210,7 +209,11 @@ def run_inference(args): # Parse visualization color try: - viz_color = tuple(map(int, config.viz_color.split(","))) if config.viz_color else (128, 0, 128) + viz_color = ( + tuple(map(int, config.viz_color.split(","))) + if config.viz_color + else (128, 0, 128) + ) if len(viz_color) != 3: raise ValueError except (ValueError, AttributeError): @@ -224,11 +227,15 @@ def run_inference(args): crop_size = _shape(config.crop_size) normalize = config.get("normalize", True) - logger.info("Image processing: resize=%s, crop=%s, norm=%s", resize, crop_size, normalize) + logger.info( + "Image processing: resize=%s, crop=%s, norm=%s", resize, crop_size, normalize + ) # Validation if not config.get("img_path") and not stream_mode: - raise ValueError("img_path is required (via --img_path or config) when stream_mode is False") + raise ValueError( + "img_path is required (via --img_path or config) when stream_mode is False" + ) if not config.get("model"): raise ValueError("model is required (via --model or config)") @@ -247,7 +254,7 @@ def run_inference(args): "scores": [], "classifications": [], # We only store images/maps if needed for downstream tasks to avoid memory issues - "images": [] if not stream_mode else None + "images": [] if not stream_mode else None, } total_start_time = time.time() @@ -271,7 +278,13 @@ def run_inference(args): # --- Model Loading Phase --- with profilers["model_loading"]: - model_path = os.path.join(MODEL_DATA_PATH, config.algorithm, config.class_name, config.run_name, config.model) + model_path = os.path.join( + MODEL_DATA_PATH, + config.algorithm, + config.class_name, + config.run_name, + config.model, + ) logger.info(f"Loading model: {model_path}") if not os.path.exists(model_path): @@ -291,7 +304,11 @@ def run_inference(args): run_name = config.run_name viz_output_dir = config.get("viz_output_dir", "./visualizations/") RESULTS_PATH = increment_path( - Path(viz_output_dir) / config.algorithm / config.class_name / model_type.value.upper() / run_name, + Path(viz_output_dir) + / config.algorithm + / config.class_name + / model_type.value.upper() + / run_name, exist_ok=config.get("overwrite", False), mkdir=True, ) @@ -383,18 +400,24 @@ def run_inference(args): # 2. Post-processing with profilers["postprocessing"]: try: - score_maps = adaptive_gaussian_blur(score_maps, kernel_size=33, sigma=4) + score_maps = adaptive_gaussian_blur( + score_maps, kernel_size=33, sigma=4 + ) # Classify if config.thresh is not None: - is_anomaly = anomavision.classification(image_scores, config.thresh) + is_anomaly = anomavision.classification( + image_scores, config.thresh + ) else: is_anomaly = np.zeros_like(image_scores) # Accumulate Results (Offline only) if not stream_mode: results_accumulator["scores"].extend(image_scores.tolist()) - results_accumulator["classifications"].extend(is_anomaly.tolist()) + results_accumulator["classifications"].extend( + is_anomaly.tolist() + ) results_accumulator["images"].extend(images) except Exception as e: @@ -410,7 +433,13 @@ def run_inference(args): boundary_images = ( anomavision.visualization.framed_boundary_images( test_images, - anomavision.classification(score_maps, config.thresh)if config.thresh else np.zeros_like(score_maps), + ( + anomavision.classification( + score_maps, config.thresh + ) + if config.thresh + else np.zeros_like(score_maps) + ), is_anomaly, padding=config.get("viz_padding", 40), ) @@ -424,7 +453,11 @@ def run_inference(args): highlighted_images = anomavision.visualization.highlighted_images( [images[i] for i in range(len(images))], # Dummy mask if threshold not set - anomavision.classification(score_maps, config.thresh) if config.thresh else np.zeros_like(score_maps), + ( + anomavision.classification(score_maps, config.thresh) + if config.thresh + else np.zeros_like(score_maps) + ), color=viz_color, ) @@ -434,7 +467,10 @@ def run_inference(args): if config.save_visualizations and RESULTS_PATH: try: fig, axs = plt.subplots(1, 4, figsize=(16, 8)) - fig.suptitle(f"Result - Batch {batch_idx} Img {img_id}", fontsize=14) + fig.suptitle( + f"Result - Batch {batch_idx} Img {img_id}", + fontsize=14, + ) axs[0].imshow(images[img_id]) axs[0].set_title("Original") @@ -452,7 +488,10 @@ def run_inference(args): axs[3].set_title("Highlighted") axs[3].axis("off") - save_path = os.path.join(RESULTS_PATH, f"batch_{batch_idx}_img_{img_id}.png") + save_path = os.path.join( + RESULTS_PATH, + f"batch_{batch_idx}_img_{img_id}.png", + ) plt.savefig(save_path, dpi=100, bbox_inches="tight") plt.close(fig) except Exception as e: @@ -465,10 +504,12 @@ def run_inference(args): logger.info("Closing model...") model.close() if stream_mode: - # Clean up stream source - try: - test_dataset.close() - except: pass + # Clean up stream source + try: + test_dataset.close() + except Exception: + + pass # --- Metrics & Summary --- total_pipeline_time = time.time() - total_start_time @@ -482,12 +523,24 @@ def run_inference(args): logger.info("=" * 60) logger.info("ANOMAVISION PERFORMANCE SUMMARY") logger.info("=" * 60) - logger.info(f"Setup time: {profilers['setup'].accumulated_time * 1000:.2f} ms") - logger.info(f"Model loading time: {profilers['model_loading'].accumulated_time * 1000:.2f} ms") - logger.info(f"Data loading time: {profilers['data_loading'].accumulated_time * 1000:.2f} ms") - logger.info(f"Inference time: {profilers['inference'].accumulated_time * 1000:.2f} ms") - logger.info(f"Postprocessing time: {profilers['postprocessing'].accumulated_time * 1000:.2f} ms") - logger.info(f"Visualization time: {profilers['visualization'].accumulated_time * 1000:.2f} ms") + logger.info( + f"Setup time: {profilers['setup'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Model loading time: {profilers['model_loading'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Data loading time: {profilers['data_loading'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Inference time: {profilers['inference'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Postprocessing time: {profilers['postprocessing'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Visualization time: {profilers['visualization'].accumulated_time * 1000:.2f} ms" + ) logger.info(f"Total pipeline time: {total_pipeline_time * 1000:.2f} ms") logger.info("=" * 60) @@ -501,16 +554,18 @@ def run_inference(args): logger.info(f"Average inference time: {avg_ms:.2f} ms/batch") if batch_count > 0: - batch_size = config.get('batch_size', 1) or 1 + batch_size = config.get("batch_size", 1) or 1 throughput = fps * (final_count / batch_count) if batch_count else 0 - logger.info(f"Throughput: {throughput:.1f} images/sec (batch size: {batch_size})") + logger.info( + f"Throughput: {throughput:.1f} images/sec (batch size: {batch_size})" + ) logger.info("=" * 60) metrics = { "fps": fps, "avg_inference_ms": avg_ms, "total_time_s": total_pipeline_time, - "total_images": final_count + "total_images": final_count, } return metrics, results_accumulator diff --git a/anomavision/eval.py b/anomavision/eval.py index e0dbc2a..1283099 100644 --- a/anomavision/eval.py +++ b/anomavision/eval.py @@ -8,8 +8,8 @@ import numpy as np import torch from easydict import EasyDict as edict +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score from torch.utils.data import DataLoader -from sklearn.metrics import roc_auc_score, precision_recall_curve, auc import anomavision from anomavision.config import load_config @@ -18,12 +18,12 @@ from anomavision.inference.modelType import ModelType from anomavision.utils import ( adaptive_gaussian_blur, + compute_metrics, + find_best_threshold_f1, + find_optimal_threshold, get_logger, merge_config, setup_logging, - find_best_threshold_f1, - compute_metrics, - find_optimal_threshold ) @@ -156,7 +156,9 @@ def evaluate_model_with_wrapper( logger.info(f"Starting evaluation on {len(test_dataloader.dataset)} images") try: - for batch_idx, (batch, images, image_targets, mask_targets) in enumerate(test_dataloader): + for batch_idx, (batch, images, image_targets, mask_targets) in enumerate( + test_dataloader + ): batch = batch.to(device_str) with evaluation_profiler: @@ -170,7 +172,9 @@ def evaluate_model_with_wrapper( # Collect results all_images.extend(images) all_image_classifications_target.extend( - image_targets.numpy() if hasattr(image_targets, "numpy") else image_targets + image_targets.numpy() + if hasattr(image_targets, "numpy") + else image_targets ) all_masks_target.extend( mask_targets.numpy() if hasattr(mask_targets, "numpy") else mask_targets @@ -192,7 +196,11 @@ def evaluate_model_with_wrapper( return ( np.array(all_images), np.array(all_image_classifications_target), - np.squeeze(np.array(all_masks_target), axis=1) if len(all_masks_target) > 0 else np.array([]), + ( + np.squeeze(np.array(all_masks_target), axis=1) + if len(all_masks_target) > 0 + else np.array([]) + ), np.array(all_image_scores), np.array(all_score_maps), ) @@ -227,7 +235,9 @@ def run_evaluation(args): # Setup Phase with profilers["setup"]: - DATASET_PATH = os.path.realpath(config.dataset_path) if config.dataset_path else None + DATASET_PATH = ( + os.path.realpath(config.dataset_path) if config.dataset_path else None + ) MODEL_DATA_PATH = os.path.realpath(config.model_data_path) device_str = determine_device(config.device) @@ -239,7 +249,13 @@ def run_evaluation(args): # Load Model Phase with profilers["model_loading"]: - model_path = os.path.join(MODEL_DATA_PATH, config.algorithm, config.class_name, config.run_name, config.model) + model_path = os.path.join( + MODEL_DATA_PATH, + config.algorithm, + config.class_name, + config.run_name, + config.model, + ) logger.info(f"Loading model: {model_path}") if not os.path.exists(model_path): @@ -273,7 +289,7 @@ def run_evaluation(args): batch_size=batch_size, num_workers=config.num_workers if config.num_workers else 0, pin_memory=config.pin_memory and device_str == "cuda", - shuffle=False + shuffle=False, ) except Exception as e: logger.error(f"Failed to create dataloader: {e}") @@ -308,8 +324,8 @@ def run_evaluation(args): # Add timing metrics total_images = len(test_dataset) - metrics['inference_fps'] = profilers["evaluation"].get_fps(total_images) - metrics['inference_time_total_s'] = profilers["evaluation"].accumulated_time + metrics["inference_fps"] = profilers["evaluation"].get_fps(total_images) + metrics["inference_time_total_s"] = profilers["evaluation"].accumulated_time # Visualization Phase if config.enable_visualization: @@ -344,11 +360,21 @@ def run_evaluation(args): logger.info("=" * 60) logger.info("ANOMAVISION EVALUATION PERFORMANCE SUMMARY") logger.info("=" * 60) - logger.info(f"Setup time: {profilers['setup'].accumulated_time * 1000:.2f} ms") - logger.info(f"Model loading time: {profilers['model_loading'].accumulated_time * 1000:.2f} ms") - logger.info(f"Data loading time: {profilers['data_loading'].accumulated_time * 1000:.2f} ms") - logger.info(f"Evaluation time: {profilers['evaluation'].accumulated_time * 1000:.2f} ms") - logger.info(f"Visualization time: {profilers['visualization'].accumulated_time * 1000:.2f} ms") + logger.info( + f"Setup time: {profilers['setup'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Model loading time: {profilers['model_loading'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Data loading time: {profilers['data_loading'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Evaluation time: {profilers['evaluation'].accumulated_time * 1000:.2f} ms" + ) + logger.info( + f"Visualization time: {profilers['visualization'].accumulated_time * 1000:.2f} ms" + ) # logger.info("=" * 60) # 2. PERFORMANCE METRICS @@ -362,7 +388,9 @@ def run_evaluation(args): if len(test_dataloader) > 0: images_per_batch = total_images / len(test_dataloader) - logger.info(f"Evaluation throughput: {evaluation_fps * images_per_batch:.1f} images/sec (batch size: {batch_size})") + logger.info( + f"Evaluation throughput: {evaluation_fps * images_per_batch:.1f} images/sec (batch size: {batch_size})" + ) # logger.info("=" * 60) # 3. EVALUATION SUMMARY @@ -373,7 +401,9 @@ def run_evaluation(args): logger.info(f"Total images evaluated: {total_images}") logger.info(f"Model type: {model_type.value.upper() if model_type else 'UNKNOWN'}") logger.info(f"Device: {device_str}") - logger.info(f"Image processing: resize={config.resize}, crop_size={config.crop_size}, normalize={config.normalize}") + logger.info( + f"Image processing: resize={config.resize}, crop_size={config.crop_size}, normalize={config.normalize}" + ) # logger.info("=" * 60) logger.info("=" * 60) @@ -396,7 +426,7 @@ def run_evaluation(args): "labels": labels, "masks": masks, "scores": scores, - "maps": maps + "maps": maps, } return metrics, raw_results diff --git a/anomavision/export.py b/anomavision/export.py index 0d15609..21e6da9 100644 --- a/anomavision/export.py +++ b/anomavision/export.py @@ -722,8 +722,19 @@ def main(args=None): setup_logging(enabled=True, log_level=config.log_level, log_to_file=True) logger = get_logger("anomavision.export") - model_path = Path(config.model_data_path) / config.algorithm / config.class_name / config.run_name / config.model - output_dir = Path(config.model_data_path) / config.algorithm / config.class_name / config.run_name + model_path = ( + Path(config.model_data_path) + / config.algorithm + / config.class_name + / config.run_name + / config.model + ) + output_dir = ( + Path(config.model_data_path) + / config.algorithm + / config.class_name + / config.run_name + ) model_stem = Path(config.model).stem # Generate output names diff --git a/anomavision/general.py b/anomavision/general.py index 4c36a83..3791d29 100644 --- a/anomavision/general.py +++ b/anomavision/general.py @@ -1,6 +1,8 @@ +import contextlib import os import subprocess import sys +import time from pathlib import Path import matplotlib.pyplot as plt @@ -52,12 +54,6 @@ def save_visualization(images, filename, output_dir): plt.imsave(filepath, images) -import contextlib -import time - -import torch - - class Profiler(contextlib.ContextDecorator): """ AnomaVision Performance Profiler for accurate timing measurements. diff --git a/anomavision/padim.py b/anomavision/padim.py index 4dc22bb..6615b36 100644 --- a/anomavision/padim.py +++ b/anomavision/padim.py @@ -137,7 +137,7 @@ def forward( layer_hook=self.layer_hook, layer_indices=self.layer_indices, ) - embedding_vectors= embedding_vectors.to(dtype=x.dtype) + embedding_vectors = embedding_vectors.to(dtype=x.dtype) patch_scores = self.mahalanobisDistance( features=embedding_vectors, width=w, height=h, export=export, chunk=256 ) # (B, w, h) diff --git a/anomavision/sampling_methods/kcenter_greedy.py b/anomavision/sampling_methods/kcenter_greedy.py index 964252d..f5d60e0 100644 --- a/anomavision/sampling_methods/kcenter_greedy.py +++ b/anomavision/sampling_methods/kcenter_greedy.py @@ -89,7 +89,8 @@ def select_batch_(self, model, already_selected, N, **kwargs): self.features = model.transform(self.X) print("Calculating distances...") self.update_distances(already_selected, only_new=False, reset_dist=True) - except: + except Exception: + print("Using flat_X as features.") self.update_distances(already_selected, only_new=True, reset_dist=False) diff --git a/anomavision/train.py b/anomavision/train.py index f0dbeb9..b786a89 100644 --- a/anomavision/train.py +++ b/anomavision/train.py @@ -20,7 +20,9 @@ def create_parser(add_help: bool = True) -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Train PaDiM (args OR config).", add_help=add_help) + parser = argparse.ArgumentParser( + description="Train PaDiM (args OR config).", add_help=add_help + ) # meta parser.add_argument( "--config", type=str, default="config.yml", help="Path to config.yml/.json" @@ -176,13 +178,16 @@ def run_training(args): config.normalize, ) if config.normalize: - logger.info( - "Normalization: mean=%s, std=%s", config.norm_mean, config.norm_std - ) + logger.info("Normalization: mean=%s, std=%s", config.norm_mean, config.norm_std) # Resolve output run dir once run_dir = increment_path( - Path(config.model_data_path) / config.algorithm / config.class_name / config.run_name, exist_ok=True, mkdir=True + Path(config.model_data_path) + / config.algorithm + / config.class_name + / config.run_name, + exist_ok=True, + mkdir=True, ) # === Dataset === @@ -199,9 +204,11 @@ def run_training(args): if not os.path.isdir(root): # Fallback check: maybe dataset_path ALREADY points to the class folder? # This makes it more robust for different input styles - potential_root = os.path.join(os.path.realpath(config.dataset_path), "train", "good") + potential_root = os.path.join( + os.path.realpath(config.dataset_path), "train", "good" + ) if os.path.isdir(potential_root): - root = potential_root + root = potential_root else: logger.error('Expected folder "%s" does not exist.', root) raise FileNotFoundError(f"Dataset root not found: {root}") @@ -263,13 +270,11 @@ def run_training(args): # snapshot the effective configuration save_args_to_yaml(config, str(Path(run_dir) / "config.yml")) - logger.info( - "saved: model=%s, config=%s", model_path, Path(run_dir) / "config.yml" - ) + logger.info("saved: model=%s, config=%s", model_path, Path(run_dir) / "config.yml") logger.info("=== Training done in %.2fs ===", time.perf_counter() - t0) # Return objects for external usage (e.g. MLOps pipeline) - return padim, config, run_dir, {'train': dl} + return padim, config, run_dir, {"train": dl} def main(args=None): @@ -282,6 +287,7 @@ def main(args=None): except Exception: get_logger(__name__).exception("Fatal error during training.") sys.exit(1) - + + if __name__ == "__main__": main() diff --git a/anomavision/utils.py b/anomavision/utils.py index d6d0bc7..34d0a66 100644 --- a/anomavision/utils.py +++ b/anomavision/utils.py @@ -12,11 +12,9 @@ import torch import yaml from PIL import Image +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score, roc_curve from torchvision import transforms as T -from sklearn.metrics import roc_curve -from sklearn.metrics import roc_auc_score, precision_recall_curve, auc - # Default standard transforms - kept for backward compatibility standard_image_transform = T.Compose( [ @@ -347,8 +345,6 @@ def split_tensor_and_run_function( return output_tensor - - def setup_logging( log_level: str = "INFO", log_to_file: bool = False, @@ -790,6 +786,7 @@ def adaptive_gaussian_blur(input_array, kernel_size=33, sigma=4): except ImportError: raise ImportError("SciPy is required when PyTorch is not available") + def find_best_threshold_f1(labels, scores): precision, recall, thresholds = precision_recall_curve(labels, scores) @@ -798,7 +795,6 @@ def find_best_threshold_f1(labels, scores): return thresholds[best_idx], f1[best_idx] -from sklearn.metrics import roc_curve def find_best_threshold_roc(labels, scores): fpr, tpr, thresholds = roc_curve(labels, scores) @@ -825,6 +821,7 @@ def find_best_threshold_accuracy(labels, scores): return best_thresh, best_acc + def compute_metrics(labels, scores, thresh=None): """ Calculate standard anomaly detection metrics. @@ -833,26 +830,26 @@ def compute_metrics(labels, scores, thresh=None): # AUROC try: - metrics['auc_score'] = float(roc_auc_score(labels, scores)) + metrics["auc_score"] = float(roc_auc_score(labels, scores)) except ValueError: - metrics['auc_score'] = 0.0 + metrics["auc_score"] = 0.0 # PR-AUC try: precision, recall, _ = precision_recall_curve(labels, scores) - metrics['pr_auc'] = float(auc(recall, precision)) + metrics["pr_auc"] = float(auc(recall, precision)) except ValueError: - metrics['pr_auc'] = 0.0 + metrics["pr_auc"] = 0.0 # Statistics - metrics['mean_anomaly_score'] = float(np.mean(scores)) - metrics['std_anomaly_score'] = float(np.std(scores)) + metrics["mean_anomaly_score"] = float(np.mean(scores)) + metrics["std_anomaly_score"] = float(np.std(scores)) # Accuracy (if thresh provided) if thresh is not None: predictions = (scores > thresh).astype(int) - metrics['accuracy'] = float(np.mean(predictions == labels)) - metrics['thresh'] = thresh + metrics["accuracy"] = float(np.mean(predictions == labels)) + metrics["thresh"] = thresh return metrics diff --git a/app.py b/app.py index b038c07..9b86be1 100644 --- a/app.py +++ b/app.py @@ -86,7 +86,9 @@ async def process_image(file: UploadFile = File(...)): np_image = np.array(image) # Preprocess and run inference - batch = to_batch([np_image], anomavision.standard_image_transform, torch.device("cpu")) + batch = to_batch( + [np_image], anomavision.standard_image_transform, torch.device("cpu") + ) image_scores, score_maps = model.predict(batch) # Postprocess diff --git a/apps/anomavision_gui_tkinter.py b/apps/anomavision_gui_tkinter.py index f0cd67d..8e6d501 100644 --- a/apps/anomavision_gui_tkinter.py +++ b/apps/anomavision_gui_tkinter.py @@ -37,6 +37,7 @@ except Exception: _DND_AVAILABLE = False + # ----- Your package imports (existing project modules) ----- import anomavision from anomavision.config import load_config @@ -45,7 +46,7 @@ from anomavision.inference.modelType import ModelType from anomavision.padim import Padim from anomavision.utils import get_logger -from PIL import Image, ImageTk + # ----------------------------------------------------------------------------- # Global fast paths and helpers # ----------------------------------------------------------------------------- @@ -208,7 +209,7 @@ def make_logo(size=44, bg="#38BDF8", fg="#001225"): text = "AV" # Try common fonts; fall back to default font = None - for name in ("arial.ttf", "SegoeUI.ttf", "DejaVuSans.ttf"): + for name in ("arial.tt", "SegoeUI.tt", "DejaVuSans.tt"): try: font = ImageFont.truetype(name, int(size * 0.44)) break @@ -764,8 +765,8 @@ def _build_style(self): self.style.configure( "Green.Horizontal.TProgressbar", troughcolor=self.colors["panel"], - background="#10B981", # βœ… green bar - thickness=14 + background="#10B981", # βœ… green bar + thickness=14, ) def toggle_theme(self): @@ -836,10 +837,10 @@ def _build_header(self): # --- Create logo image BEFORE using it (safe fallback if PIL font missing) --- - - try: - icon_image = Image.open("av.png") # use PNG for clean scaling; ICO also works + icon_image = Image.open( + "av.png" + ) # use PNG for clean scaling; ICO also works icon_image = icon_image.resize((44, 44), Image.LANCZOS) self.logo_tk = ImageTk.PhotoImage(icon_image) except Exception: @@ -1214,11 +1215,13 @@ def row(lbl): variable=self.train_progress_var, maximum=100, mode="determinate", - style="Green.Horizontal.TProgressbar" + style="Green.Horizontal.TProgressbar", ) self.train_progress.pack(fill=tk.X, padx=14, pady=(0, 12)) self.train_progress.pack_forget() - self.train_percent_label = tk.Label(cfg, text="0%", bg=self.colors["panel"], fg=self.colors["fg"]) + self.train_percent_label = tk.Label( + cfg, text="0%", bg=self.colors["panel"], fg=self.colors["fg"] + ) self.train_percent_label.pack(anchor="w", padx=14) def update_training_progress(self, value): diff --git a/apps/api/fastapi_app.py b/apps/api/fastapi_app.py index a9a00a4..cf27ef7 100644 --- a/apps/api/fastapi_app.py +++ b/apps/api/fastapi_app.py @@ -29,7 +29,9 @@ RESIZE_SIZE = (224, 224) # You can override these via environment variables -MODEL_DATA_PATH = os.getenv("ANOMAVISION_MODEL_DATA_PATH", "distributions/padim/bottle/anomav_exp") +MODEL_DATA_PATH = os.getenv( + "ANOMAVISION_MODEL_DATA_PATH", "distributions/padim/bottle/anomav_exp" +) MODEL_FILE = os.getenv("ANOMAVISION_MODEL_FILE", "model.onnx") DEVICE = os.getenv("ANOMAVISION_DEVICE", "auto") # "auto"|"cpu"|"cuda" @@ -81,6 +83,7 @@ async def cleanup(): model_type = None print("Model cleanup completed.") + @asynccontextmanager async def lifespan(app: FastAPI): await load_model() diff --git a/apps/ui/gradio_app.py b/apps/ui/gradio_app.py index 715fb23..a51c3fc 100644 --- a/apps/ui/gradio_app.py +++ b/apps/ui/gradio_app.py @@ -33,6 +33,7 @@ from anomavision.general import determine_device from anomavision.inference.model.wrapper import ModelWrapper from anomavision.inference.modelType import ModelType + ANOMAVISION_AVAILABLE = True except ImportError: ANOMAVISION_AVAILABLE = False @@ -41,14 +42,16 @@ # ───────────────────────────────────────────────────────────────────────────── # Config # ───────────────────────────────────────────────────────────────────────────── -MODEL_DATA_PATH = os.getenv("ANOMAVISION_MODEL_DATA_PATH", "distributions/padim/bottle/anomav_exp") -MODEL_FILE = os.getenv("ANOMAVISION_MODEL_FILE", "model.onnx") -DEVICE_ENV = os.getenv("ANOMAVISION_DEVICE", "auto") +MODEL_DATA_PATH = os.getenv( + "ANOMAVISION_MODEL_DATA_PATH", "distributions/padim/bottle/anomav_exp" +) +MODEL_FILE = os.getenv("ANOMAVISION_MODEL_FILE", "model.onnx") +DEVICE_ENV = os.getenv("ANOMAVISION_DEVICE", "auto") THRESHOLD_DEFAULT = float(os.getenv("ANOMAVISION_THRESHOLD", "13.0")) -VIZ_PADDING = int(os.getenv("ANOMAVISION_VIZ_PADDING", "40")) -VIZ_ALPHA = float(os.getenv("ANOMAVISION_VIZ_ALPHA", "0.5")) -VIZ_COLOR = tuple(map(int, os.getenv("ANOMAVISION_VIZ_COLOR", "128,0,128").split(","))) -SAMPLE_DIR = os.getenv("SAMPLE_IMAGES_DIR", "D:/01-DATA/sample_images") +VIZ_PADDING = int(os.getenv("ANOMAVISION_VIZ_PADDING", "40")) +VIZ_ALPHA = float(os.getenv("ANOMAVISION_VIZ_ALPHA", "0.5")) +VIZ_COLOR = tuple(map(int, os.getenv("ANOMAVISION_VIZ_COLOR", "128,0,128").split(","))) +SAMPLE_DIR = os.getenv("SAMPLE_IMAGES_DIR", "D:/01-DATA/sample_images") # ───────────────────────────────────────────────────────────────────────────── # Model β€” loaded once at startup @@ -72,11 +75,13 @@ def _load_model() -> str: try: _device_str = determine_device(DEVICE_ENV) _model_type = ModelType.from_extension(model_path) - _model = ModelWrapper(model_path, _device_str) + _model = ModelWrapper(model_path, _device_str) # Optional warmup try: - dummy = torch.zeros((1, 3, 224, 224), dtype=torch.float32, device=_device_str) + dummy = torch.zeros( + (1, 3, 224, 224), dtype=torch.float32, device=_device_str + ) _model.warmup(batch=dummy, runs=1) except Exception: pass @@ -108,7 +113,9 @@ def _demo_predict(image_np: np.ndarray): def _real_predict(image_np: np.ndarray, threshold: float): """Run anomavision inference and return (score, score_map_np, boundary_np, heatmap_np, highlighted_np).""" device = torch.device(_device_str) - batch = anomavision.to_batch([image_np], anomavision.standard_image_transform, device) + batch = anomavision.to_batch( + [image_np], anomavision.standard_image_transform, device + ) if _device_str == "cuda": batch = batch.half() @@ -116,8 +123,8 @@ def _real_predict(image_np: np.ndarray, threshold: float): with torch.no_grad(): image_scores, score_maps = _model.predict(batch) - score_map_cls = anomavision.classification(score_maps, threshold) - image_cls = anomavision.classification(image_scores, threshold) + score_map_cls = anomavision.classification(score_maps, threshold) + image_cls = anomavision.classification(image_scores, threshold) test_images = np.array([image_np]) @@ -187,7 +194,7 @@ def _collect_samples() -> list: for p in sorted(base.rglob("*")): if p.suffix.lower() in SUPPORTED_EXT: - rel = p.relative_to(base) + rel = p.relative_to(base) parts = rel.parts if len(parts) >= 3: label = f"{parts[0]}/{parts[1]}" @@ -235,22 +242,24 @@ def run_inference( if image is None: return "❌ Please upload or select an image.", None, None, None, None - resize = (int(resize_w), int(resize_h)) + resize = (int(resize_w), int(resize_h)) image_np = _pil_to_np(image) t0 = time.time() if _model is not None and ANOMAVISION_AVAILABLE: try: - score, score_map_np, boundary_np, heatmap_np, highlighted_np = _real_predict( - image_np, threshold + score, score_map_np, boundary_np, heatmap_np, highlighted_np = ( + _real_predict(image_np, threshold) ) is_anomaly = score >= threshold - original_pil = image.resize(resize, Image.BILINEAR) - heatmap_pil = _np_to_pil(heatmap_np, resize) if include_viz else None - boundary_pil = _np_to_pil(boundary_np, resize) if include_viz else None - highlighted_pil = _np_to_pil(highlighted_np, resize) if include_viz else None + original_pil = image.resize(resize, Image.BILINEAR) + heatmap_pil = _np_to_pil(heatmap_np, resize) if include_viz else None + boundary_pil = _np_to_pil(boundary_np, resize) if include_viz else None + highlighted_pil = ( + _np_to_pil(highlighted_np, resize) if include_viz else None + ) except Exception as e: return f"⚠️ Inference error: {e}", None, None, None, None @@ -263,38 +272,45 @@ def run_inference( if include_viz: import matplotlib.cm as cm + heatmap_norm = heatmap_raw / heatmap_raw.max() - cmap = cm.get_cmap("jet") + cmap = cm.get_cmap("jet") heatmap_rgba = (cmap(heatmap_norm) * 255).astype(np.uint8) - heatmap_rgb = heatmap_rgba[:, :, :3] - blend = (0.5 * image_np + 0.5 * heatmap_rgb).astype(np.uint8) - heatmap_pil = _np_to_pil(heatmap_rgb, resize) - boundary_pil = _np_to_pil(blend, resize) - highlighted_pil = _np_to_pil(image_np, resize) + heatmap_rgb = heatmap_rgba[:, :, :3] + blend = (0.5 * image_np + 0.5 * heatmap_rgb).astype(np.uint8) + heatmap_pil = _np_to_pil(heatmap_rgb, resize) + boundary_pil = _np_to_pil(blend, resize) + highlighted_pil = _np_to_pil(image_np, resize) else: heatmap_pil = boundary_pil = highlighted_pil = None elapsed = time.time() - t0 - label = "🚨 ANOMALY DETECTED" if is_anomaly else "βœ… NORMAL" - status = f"Model: {Path(MODEL_FILE).stem} | Score: {score:.4f} | {label}" - detail = f"Threshold: {threshold:.2f} | Inference time: {elapsed:.2f}s" + label = "🚨 ANOMALY DETECTED" if is_anomaly else "βœ… NORMAL" + status = f"Model: {Path(MODEL_FILE).stem} | Score: {score:.4f} | {label}" + detail = f"Threshold: {threshold:.2f} | Inference time: {elapsed:.2f}s" - return f"{status}\n{detail}", original_pil, heatmap_pil, boundary_pil, highlighted_pil + return ( + f"{status}\n{detail}", + original_pil, + heatmap_pil, + boundary_pil, + highlighted_pil, + ) # ───────────────────────────────────────────────────────────────────────────── # CSS β€” Clean Light Theme with Indigo/Violet Accents # ───────────────────────────────────────────────────────────────────────────── -_ACCENT = "#6366f1" # indigo -_ACCENT_H = "#4f46e5" # indigo hover -_ACCENT2 = "#ef4444" # red for anomaly alerts -_ACCENT3 = "#22c55e" # green for normal result -_BG = "#f5f6fa" # off-white page background -_SURFACE = "#ffffff" # card surface -_SURFACE2 = "#f0f1f8" # slightly tinted input background -_BORDER = "#e2e4f0" # soft lavender border -_TEXT = "#1e1b4b" # deep indigo text -_MUTED = "#7c82a8" # muted blue-grey +_ACCENT = "#6366f1" # indigo +_ACCENT_H = "#4f46e5" # indigo hover +_ACCENT2 = "#ef4444" # red for anomaly alerts +_ACCENT3 = "#22c55e" # green for normal result +_BG = "#f5f6fa" # off-white page background +_SURFACE = "#ffffff" # card surface +_SURFACE2 = "#f0f1f8" # slightly tinted input background +_BORDER = "#e2e4f0" # soft lavender border +_TEXT = "#1e1b4b" # deep indigo text +_MUTED = "#7c82a8" # muted blue-grey custom_css = f""" @import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap'); @@ -716,7 +732,8 @@ def run_inference( with gr.Blocks(title="AnomaVision β€” Industrial Anomaly Detection") as demo: # ── Header ────────────────────────────────────────────────────────────── - gr.HTML(f""" + gr.HTML( + f"""
@@ -744,7 +761,8 @@ def run_inference(
- """) + """ + ) # ── Tabs ───────────────────────────────────────────────────────────────── with gr.Tabs(): @@ -757,7 +775,9 @@ def run_inference( # ── Left column: controls ──────────────────────────────────── with gr.Column(scale=1, min_width=300): - gr.HTML('
Input
') + gr.HTML( + '
Input
' + ) input_img = gr.Image( type="pil", @@ -774,23 +794,47 @@ def run_inference( scale=1, ) category_dd = gr.Dropdown( - choices=["bottle", "cable", "carpet", "grid", - "hazelnut", "leather", "metal_nut", - "pill", "screw", "tile", "toothbrush", - "transistor", "wood", "zipper", "other"], + choices=[ + "bottle", + "cable", + "carpet", + "grid", + "hazelnut", + "leather", + "metal_nut", + "pill", + "screw", + "tile", + "toothbrush", + "transistor", + "wood", + "zipper", + "other", + ], value="bottle", label="Category", scale=1, ) threshold = gr.Slider( - 0.1, 50.0, THRESHOLD_DEFAULT, - step=0.1, label="Threshold" + 0.1, 50.0, THRESHOLD_DEFAULT, step=0.1, label="Threshold" ) with gr.Row(): - resize_w = gr.Number(value=224, label="Width", minimum=32, maximum=2048, precision=0) - resize_h = gr.Number(value=224, label="Height", minimum=32, maximum=2048, precision=0) + resize_w = gr.Number( + value=224, + label="Width", + minimum=32, + maximum=2048, + precision=0, + ) + resize_h = gr.Number( + value=224, + label="Height", + minimum=32, + maximum=2048, + precision=0, + ) viz_check = gr.Checkbox(value=True, label="Generate Visualizations") @@ -805,7 +849,7 @@ def run_inference( '
' 'Sample Images ' '' - '(click to select)
' + "(click to select)" ) _gallery_items = _sample_gallery_images() @@ -835,7 +879,9 @@ def run_inference( # ── Right column: results ───────────────────────────────────── with gr.Column(scale=2): - gr.HTML('
Results
') + gr.HTML( + '
Results
' + ) result_text = gr.Textbox( label="", @@ -846,10 +892,10 @@ def run_inference( ) with gr.Row(): - out_original = gr.Image(label="Original", type="pil") - out_heatmap = gr.Image(label="Anomaly Heatmap", type="pil") - out_overlay = gr.Image(label="Overlay", type="pil") - out_mask = gr.Image(label="Predicted Mask", type="pil") + out_original = gr.Image(label="Original", type="pil") + out_heatmap = gr.Image(label="Anomaly Heatmap", type="pil") + out_overlay = gr.Image(label="Overlay", type="pil") + out_mask = gr.Image(label="Predicted Mask", type="pil") # ── Event wiring ───────────────────────────────────────────────── @@ -862,6 +908,7 @@ def run_inference( # Sample gallery click β†’ load image into input_img if sample_gallery is not None: + def on_sample_select(evt: gr.SelectData) -> Image.Image: """Load the clicked sample image into the input component.""" if evt.index >= len(SAMPLES): @@ -877,7 +924,8 @@ def on_sample_select(evt: gr.SelectData) -> Image.Image: # ── Tab 2: Draw Defects ─────────────────────────────────────────────── with gr.Tab("🎨 Draw Defects"): - gr.HTML(""" + gr.HTML( + """
Synthetic Defect Testing
@@ -893,28 +941,39 @@ def on_sample_select(evt: gr.SelectData) -> Image.Image: ✦ Requires Gradio β‰₯ 4.x for the sketch editor

- """) + """ + ) with gr.Row(): with gr.Column(): sketch_img = gr.ImageEditor( type="pil", label="Draw Defects Here", - brush=gr.Brush(colors=["#ff0000", "#ffff00", "#ffffff"], default_size=8), + brush=gr.Brush( + colors=["#ff0000", "#ffff00", "#ffffff"], default_size=8 + ), + ) + sketch_threshold = gr.Slider( + 0.1, 50.0, THRESHOLD_DEFAULT, step=0.1, label="Threshold" ) - sketch_threshold = gr.Slider(0.1, 50.0, THRESHOLD_DEFAULT, step=0.1, label="Threshold") sketch_btn = gr.Button("πŸ” Analyze Drawn Image", variant="primary") with gr.Column(): - sketch_result = gr.Textbox(label="Result", lines=2) - sketch_heat = gr.Image(label="Heatmap", type="pil") - sketch_overlay = gr.Image(label="Overlay", type="pil") + sketch_result = gr.Textbox(label="Result", lines=2) + sketch_heat = gr.Image(label="Heatmap", type="pil") + sketch_overlay = gr.Image(label="Overlay", type="pil") def run_sketch(editor_val, thr): if editor_val is None: return "Please draw on the image first.", None, None - img = editor_val.get("composite") if isinstance(editor_val, dict) else editor_val + img = ( + editor_val.get("composite") + if isinstance(editor_val, dict) + else editor_val + ) if img is None: return "Please draw on the image first.", None, None - status, orig, heat, boundary, _ = run_inference(img, thr, 224, 224, True) + status, orig, heat, boundary, _ = run_inference( + img, thr, 224, 224, True + ) return status, heat, boundary sketch_btn.click( @@ -925,7 +984,8 @@ def run_sketch(editor_val, thr): # ── Tab 3: Compare Models ───────────────────────────────────────────── with gr.Tab("βš–οΈ Compare Models"): - gr.HTML(""" + gr.HTML( + """
βš–οΈ
@@ -941,7 +1001,8 @@ def run_sketch(editor_val, thr): font-size:0.72rem;font-weight:700;color:#6366f1;letter-spacing:0.1em; text-transform:uppercase;">Coming Soon
- """) + """ + ) if __name__ == "__main__": @@ -956,4 +1017,3 @@ def run_sketch(editor_val, thr): ), css=custom_css, ) - diff --git a/apps/ui/run_app.py b/apps/ui/run_app.py index 8231b59..0801a7e 100644 --- a/apps/ui/run_app.py +++ b/apps/ui/run_app.py @@ -2,18 +2,21 @@ AnomaVision - Hugging Face Spaces Entry Point Starts both FastAPI backend and Gradio frontend in a single process """ + import multiprocessing -import time -import sys import os +import sys +import time # Add the apps directory to Python path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + def run_fastapi(): """Run FastAPI server in background process""" try: import uvicorn + # Import from apps.api package from apps.api.fastapi_app import app @@ -23,16 +26,17 @@ def run_fastapi(): host="0.0.0.0", port=8000, log_level="info", - access_log=False # Reduce log noise + access_log=False, # Reduce log noise ) except Exception as e: print(f"❌ FastAPI failed to start: {e}") sys.exit(1) + def run_gradio(): """Run Gradio interface in main process""" - import requests import gradio as gr + import requests # Import from apps.ui package from apps.ui.gradio_app import create_interface @@ -70,12 +74,13 @@ def run_gradio(): server_port=7860, share=False, show_error=True, - quiet=False + quiet=False, ) except Exception as e: print(f"❌ Gradio failed to start: {e}") sys.exit(1) + if __name__ == "__main__": print("=" * 60) print("πŸ” AnomaVision - Visual Anomaly Detection") @@ -86,12 +91,11 @@ def run_gradio(): print("=" * 60) # Set multiprocessing start method (important for some platforms) - multiprocessing.set_start_method('spawn', force=True) + multiprocessing.set_start_method("spawn", force=True) # Start FastAPI in a separate process fastapi_process = multiprocessing.Process( - target=run_fastapi, - daemon=True # Dies when main process dies + target=run_fastapi, daemon=True # Dies when main process dies ) fastapi_process.start() diff --git a/apps/ui/streamlit_app.py b/apps/ui/streamlit_app.py index a625500..4524124 100644 --- a/apps/ui/streamlit_app.py +++ b/apps/ui/streamlit_app.py @@ -451,7 +451,7 @@ def get_image_files_from_folder(uploaded_files) -> List: with st.expander("πŸ“Œ Quick Start (API)"): st.code( - f"""# Health + """# Health curl -s {HEALTH_ENDPOINT} # Predict @@ -657,9 +657,10 @@ def run_inference_on_current(): st.session_state.results_cache[cache_key] = { "_error": error_msg, - "_filename": current_file.name + "_filename": current_file.name, } + # Run inference on current image if st.session_state.image_files and api_online: run_inference_on_current() @@ -714,7 +715,7 @@ def run_inference_on_current(): if is_anomaly: sev = percent_diff(anomaly_score, float(new_threshold)) st.markdown( - f""" + """
🚨 Anomaly Detected
@@ -728,7 +729,7 @@ def run_inference_on_current(): else: margin = abs(percent_diff(anomaly_score, float(new_threshold))) st.markdown( - f""" + """
βœ… Normal Image
diff --git a/compare_with_anomalib.py b/compare_with_anomalib.py index d2c4f3e..b1ddfd8 100644 --- a/compare_with_anomalib.py +++ b/compare_with_anomalib.py @@ -278,7 +278,8 @@ def benchmark_your_implementation(self) -> ModelMetrics: model.save_statistics(str(stats_path)) stats_size = stats_path.stat().st_size / (1024 * 1024) print(f" Statistics file size: {stats_size:.2f} MB") - except: + except Exception: + pass # === 4. Export Sizes === @@ -1071,7 +1072,7 @@ def main(): print("\n" + "=" * 60) print("COMPARISON COMPLETE!") print("=" * 60) - print(f"Results saved in: benchmark_results/") + print("Results saved in: benchmark_results/") def generate_summary_report(all_results: Dict): diff --git a/config.yml b/config.yml index 75b672f..daa9bbf 100644 --- a/config.yml +++ b/config.yml @@ -105,4 +105,3 @@ stream_max_frames: null # Max frames to process (null = infinit stream_display_fps: true # Show FPS during streaming stream_save_detections: true # Save detected anomalies to disk stream_detection_dir: "./stream_detections/" # Directory for saved detections - diff --git a/docs/contributing.md b/docs/contributing.md index 837608f..e714b48 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,131 +1,293 @@ -# 🀝 Contributing to AnomaVision +# Contributing to AnomaVision -First off, thanks for taking the time to contribute! πŸŽ‰ -We welcome all kinds of contributions β€” whether it’s bug reports, feature requests, documentation improvements, or new code. +Thank you for your interest in contributing to AnomaVision! This document provides guidelines and instructions for contributing. ---- +## πŸš€ Quick Start -## πŸ“Œ How to Contribute +### 1. Development Setup -1. **Fork the repo** and create your branch from `main`: +```bash +# Fork and clone the repository +git clone https://github.com/YOUR_USERNAME/AnomaVision.git +cd AnomaVision - ```bash - git checkout -b feature/your-feature - ``` -2. **Install in development mode**: +# Install uv (if not already installed) +curl -LsSf https://astral.sh/uv/install.sh | sh - ```bash - uv venv --python 3.11 .venv - source .venv/bin/activate # Windows: .venv\Scripts\Activate.ps1 - uv sync --extra cpu - uv pip install -e ".[dev]" - ``` -3. **Write tests** for new functionality. -4. **Run tests & linters** before pushing: +# Install dependencies with dev tools +uv sync --extra dev --extra cpu - ```bash - pytest tests/ - black . - isort . - flake8 . - ``` -5. **Commit with clear messages** (see below). -6. **Open a Pull Request (PR)** and describe your changes. +# Install pre-commit hooks (REQUIRED) +uv run pre-commit install +``` ---- +### 2. Create a Branch -## πŸ§ͺ Code Style & Quality +```bash +# Create a new branch for your feature or fix +git checkout -b fix/your-fix-name +# or +git checkout -b feat/your-feature-name +``` -We follow these standards: +## βœ… Before Submitting a PR -* **Python β‰₯ 3.9** -* **Black** for formatting -* **isort** for import ordering -* **flake8** for linting -* **pytest** for testing +### Run Code Quality Checks Locally -Run all checks with: +**ALL of these must pass before submitting:** ```bash -make lint test +# 1. Format your code +uv run black . +uv run isort . + +# 2. Run linters +uv run flake8 anomavision/ anodet/ apps/ tests/ + +# 3. Run tests +uv run pytest -v + +# 4. (Optional) Run pre-commit on all files +pre-commit run --all-files ``` ---- +### CI Will Enforce These Checks + +⚠️ **Your PR will be blocked if:** +- Code is not formatted with Black +- Imports are not sorted with isort +- Flake8 finds critical issues (F811 redefinitions) +- Tests fail -## πŸ“ Commit Messages +## πŸ“ Code Quality Standards -Use clear and descriptive commit messages. -Format: +### Formatting +- **Black**: Line length = 88 characters +- **isort**: Import sorting with Black profile + +### Linting Rules + +We use **lenient flake8** configuration to allow legacy code while blocking critical issues: + +**Allowed (won't block PR):** +- E501: Line too long (Black handles this) +- E402: Module imports not at top +- F401: Unused imports (common in `__init__.py`) +- B006-B009: Bugbear warnings (require refactoring) + +**Blocked (will fail CI):** +- F811: Redefinition of variables (bug risk) +- E722: Bare `except:` (security risk) + +### Type Hints (Optional) + +Type hints are encouraged but not required. If you add them: +```python +def process_image(img: np.ndarray, thresh: float = 0.5) -> Tuple[np.ndarray, float]: + ... ``` -(): + +## πŸ§ͺ Testing Guidelines + +### Running Tests + +```bash +# Run all tests +uv run pytest + +# Run specific test file +uv run pytest tests/test_padim.py + +# Run with coverage +uv run pytest --cov=anomavision --cov-report=html ``` -Examples: +### Writing Tests -* `fix(detect): handle empty directory gracefully` -* `feat(export): add OpenVINO INT8 support` -* `docs(api): improve Padim usage examples` +- Place tests in `tests/` directory +- Name test files `test_*.py` +- Use descriptive test names: `test_padim_training_on_cpu()` -Types: `feat`, `fix`, `docs`, `refactor`, `test`, `chore`. +Example: +```python +import pytest +import torch +from anomavision import Padim ---- +def test_padim_initialization(): + model = Padim(backbone="resnet18", device="cpu") + assert model.backbone == "resnet18" -## πŸ”€ Pull Request Guidelines +def test_padim_training_basic(): + # Your test here + pass +``` -* Keep PRs focused and small. -* Update docs if behavior changes. -* Ensure all tests pass (`pytest`). -* Add benchmarks if performance-critical. +## πŸ“¦ Dependency Changes ---- +**If you modify `pyproject.toml`:** -## πŸ› Reporting Bugs +```bash +# Regenerate the lockfile +uv lock --python 3.10 -* Use the [GitHub Issues](https://github.com/DeepKnowledge1/AnomaVision/issues). -* Include: +# Commit both files +git add pyproject.toml uv.lock +git commit -m "deps: add new dependency X" +``` - * OS & Python version - * Installation method (pip/poetry) - * Steps to reproduce - * Expected vs actual behavior +⚠️ **CI will fail if `uv.lock` is not updated** (the `--locked` flag enforces this) ---- +## 🎯 Pull Request Process -## πŸ’‘ Suggesting Features +### 1. Fill Out the PR Template -* Open a [Discussion](https://github.com/DeepKnowledge1/AnomaVision/discussions). -* Clearly describe **why** the feature is useful and how it fits into the project. +When you create a PR, fill out all sections: ---- +- **Related Issue**: Link to the issue this fixes +- **Description**: Explain what changed and why +- **Type of Change**: Bug fix, feature, docs, etc. +- **Hardware Testing**: Which extras did you test? (`cpu`, `cu121`, etc.) +- **Developer Checklist**: Confirm all items -## πŸ§‘β€πŸ’» Development Workflow +### 2. PR Review Process -Typical workflow for contributors: +1. **Automated Checks** run first: + - Code Quality (Black, isort, flake8) + - Tests (Python 3.10, 3.11, 3.12) + - CUDA matrix verification -```bash -# Clone your fork -git clone https://github.com//AnomaVision.git -cd AnomaVision +2. **Maintainer Review**: + - Code quality + - Test coverage + - Documentation updates + +3. **Approval & Merge**: + - At least 1 maintainer approval required + - All CI checks must pass + - No merge conflicts + +## πŸ› Bug Reports + +Use the [Bug Report Template](.github/ISSUE_TEMPLATE/bug-report.yml): + +**Must include:** +- Operating System +- Python version +- Package manager used (uv, pip, poetry) +- Hardware bracket (`anomavision[cpu]`, `anomavision[cu121]`, etc.) +- Minimal reproducible example + +## πŸš€ Feature Requests + +Use the [Feature Request Template](.github/ISSUE_TEMPLATE/feature-request.yml): -# Create a branch -git checkout -b fix/bug-name +**Must include:** +- Feature description +- Use case & benefits +- Priority level -# Make changes and commit -git commit -m "fix(detect): corrected threshold application" +## πŸ“– Documentation -# Push and open PR -git push origin fix/bug-name +### Updating Documentation + +If your PR changes functionality: + +1. Update docstrings: +```python +def new_function(param: str) -> int: + """ + Brief description. + + Args: + param: Description of parameter + + Returns: + Description of return value + + Example: + >>> new_function("test") + 42 + """ + return 42 ``` ---- +2. Update README.md if needed +3. Add examples to `examples/` if it's a new feature + +## πŸ”§ Development Tips + +### Recommended IDE Setup + +**VSCode:** +```json +{ + "python.formatting.provider": "black", + "python.linting.flake8Enabled": true, + "editor.formatOnSave": true, + "[python]": { + "editor.codeActionsOnSave": { + "source.organizeImports": true + } + } +} +``` -## πŸ“œ License +### Pre-commit Hooks -By contributing, you agree that your contributions will be licensed under the [MIT License](LICENSE). +Pre-commit hooks run automatically on `git commit`: ---- +```bash +# Install hooks +pre-commit install -## πŸ™ Acknowledgments +# Run manually on all files +pre-commit run --all-files + +# Skip hooks (NOT RECOMMENDED) +git commit --no-verify +``` + +### Debugging CI Failures + +If CI fails: + +1. **Check the CI logs** on GitHub +2. **Reproduce locally**: + ```bash + # Run the same commands CI runs + uv run black --check . + uv run isort --check-only . + uv run flake8 anomavision/ anodet/ apps/ tests/ + uv run pytest -v + ``` +3. **Fix issues**: + ```bash + uv run black . + uv run isort . + git add . + git commit -m "fix: resolve code quality issues" + git push + ``` + +## πŸ“ž Getting Help + +- **Questions**: [GitHub Discussions](https://github.com/DeepKnowledge1/AnomaVision/discussions) +- **Bugs**: [Bug Report](https://github.com/DeepKnowledge1/AnomaVision/issues/new?template=bug-report.yml) +- **Features**: [Feature Request](https://github.com/DeepKnowledge1/AnomaVision/issues/new?template=feature-request.yml) + +## πŸŽ‰ Recognition + +Contributors are recognized in: +- GitHub Contributors page +- Release notes (for significant contributions) +- README acknowledgments (for major features) + +## πŸ“„ License + +By contributing, you agree that your contributions will be licensed under the MIT License. + +--- -Big thanks to all contributors who help make AnomaVision better for the community! πŸš€ +**Thank you for contributing to AnomaVision! πŸš€** diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index a3aa8f8..a3d289e 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -163,4 +163,3 @@ A: Yes, as long as the dataset follows **MVTec-style folder structure**. βœ… With this guide, you should be able to quickly solve most common problems. If an issue persists, please [open a GitHub Issue](https://github.com/DeepKnowledge1/AnomaVision/issues) with details. - diff --git a/tests/conftest.py b/tests/conftest.py index a6f7724..3e8a19d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ """ Pytest configuration and shared fixtures for AnomaVision tests. """ + import os import sys diff --git a/tests/test_all_formats.py b/tests/test_all_formats.py index b610345..6d80be2 100644 --- a/tests/test_all_formats.py +++ b/tests/test_all_formats.py @@ -2,12 +2,14 @@ """ Simple test to compare predictions across all model formats. """ + from pathlib import Path import numpy as np import pytest import torch +from anomavision.export import ModelExporter from anomavision.inference.model.wrapper import ModelWrapper from anomavision.utils import ( adaptive_gaussian_blur, @@ -15,7 +17,6 @@ merge_config, setup_logging, ) -from anomavision.export import ModelExporter class TestAllFormats: @@ -46,7 +47,7 @@ def test_all_formats_match(self, temp_model_dir, trained_padim_model, sample_bat models["stats_fp16"] = pth_fp16_path # Export other formats - exporter = ModelExporter(pt_path, temp_model_dir, logger=logger,device="cpu") + exporter = ModelExporter(pt_path, temp_model_dir, logger=logger, device="cpu") # ONNX try: @@ -55,7 +56,8 @@ def test_all_formats_match(self, temp_model_dir, trained_padim_model, sample_bat ) if onnx_path: models["onnx"] = onnx_path - except: + except Exception: + raise RuntimeError("ONNX export failed") # TorchScript @@ -63,7 +65,8 @@ def test_all_formats_match(self, temp_model_dir, trained_padim_model, sample_bat ts_path = exporter.export_torchscript(input_shape, "model.torchscript") if ts_path: models["torchscript"] = ts_path - except: + except Exception: + raise RuntimeError("TorchScript export failed") # OpenVINO @@ -75,7 +78,8 @@ def test_all_formats_match(self, temp_model_dir, trained_padim_model, sample_bat xml_files = list(ov_dir.glob("*.xml")) if xml_files: models["openvino"] = xml_files[0] - except: + except Exception: + raise RuntimeError("OpenVINO export failed") print(f"\nTesting {len(models)} formats: {list(models.keys())}") @@ -112,7 +116,7 @@ def test_all_formats_match(self, temp_model_dir, trained_padim_model, sample_bat "openvino": 1e-3, } - print(f"\nComparing against PyTorch reference:") + print("\nComparing against PyTorch reference:") for name, pred in predictions.items(): if name == "pytorch": @@ -129,9 +133,9 @@ def test_all_formats_match(self, temp_model_dir, trained_padim_model, sample_bat # Check tolerance if score_diff <= tolerance and map_diff <= tolerance: - print(f" βœ“ PASS") + print(" βœ“ PASS") else: - print(f" ⚠ TOLERANCE EXCEEDED") + print(" ⚠ TOLERANCE EXCEEDED") # Don't fail test, just warn # 4. Cleanup diff --git a/tests/test_backends.py b/tests/test_backends.py index 0d8f405..3c46da8 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -2,6 +2,7 @@ """ Tests for multi-format backend functionality. """ + from pathlib import Path import numpy as np @@ -268,7 +269,7 @@ def test_pytorch_vs_stats_consistency( score_diff_fp32 = np.abs(full_scores - fp32_scores) map_diff_fp32 = np.abs(full_maps - fp32_maps) - print(f"\nFP32 Stats vs Full Model:") + print("\nFP32 Stats vs Full Model:") print( f"Score differences - max: {score_diff_fp32.max():.8f}, mean: {score_diff_fp32.mean():.8f}" ) @@ -280,7 +281,7 @@ def test_pytorch_vs_stats_consistency( score_diff_fp16 = np.abs(full_scores - fp16_scores) map_diff_fp16 = np.abs(full_maps - fp16_maps) - print(f"\nFP16 Stats vs Full Model:") + print("\nFP16 Stats vs Full Model:") print( f"Score differences - max: {score_diff_fp16.max():.8f}, mean: {score_diff_fp16.mean():.8f}" ) @@ -292,7 +293,7 @@ def test_pytorch_vs_stats_consistency( fp32_vs_fp16_scores = np.abs(fp32_scores - fp16_scores) fp32_vs_fp16_maps = np.abs(fp32_maps - fp16_maps) - print(f"\nFP32 Stats vs FP16 Stats:") + print("\nFP32 Stats vs FP16 Stats:") print( f"Score differences - max: {fp32_vs_fp16_scores.max():.8f}, mean: {fp32_vs_fp16_scores.mean():.8f}" ) @@ -337,7 +338,7 @@ def test_pytorch_vs_stats_consistency( fp32_size = os.path.getsize(str(stats_path_fp32)) / (1024 * 1024) fp16_size = os.path.getsize(str(stats_path_fp16)) / (1024 * 1024) - print(f"\nFile Sizes:") + print("\nFile Sizes:") print(f"Full model: {full_size:.2f} MB") print(f"FP32 stats: {fp32_size:.2f} MB") print(f"FP16 stats: {fp16_size:.2f} MB") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 913ca99..eb60f8e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,6 +2,7 @@ """ Tests for dataset loading functionality. """ + import numpy as np import pytest import torch diff --git a/tests/test_example.py b/tests/test_example.py index 61455af..c212527 100644 --- a/tests/test_example.py +++ b/tests/test_example.py @@ -1,4 +1,3 @@ - def test_function(): example_var = 3 assert example_var == 3 diff --git a/tests/test_export_load_model.py b/tests/test_export_load_model.py index 8c3612a..850bf2e 100644 --- a/tests/test_export_load_model.py +++ b/tests/test_export_load_model.py @@ -21,7 +21,9 @@ def test_exporter_loads_stats_pth_and_wraps(tmp_path, make_stats, patch_extracto from anomavision.export import ModelExporter, _ExportWrapper - exp = ModelExporter(model_path=pth_path, output_dir=tmp_path, logger=logger,device="cpu") + exp = ModelExporter( + model_path=pth_path, output_dir=tmp_path, logger=logger, device="cpu" + ) m = exp._load_model() # should be _ExportWrapper(PadimLite) assert isinstance(m, _ExportWrapper) assert hasattr(m, "forward") diff --git a/tests/test_exporter.py b/tests/test_exporter.py index 2e951ea..74c4fc2 100644 --- a/tests/test_exporter.py +++ b/tests/test_exporter.py @@ -4,11 +4,10 @@ import pytest import torch -from anomavision.utils import get_logger, setup_logging - # Adjust this import to your actual exporter module filename if needed # e.g., from export import ModelExporter, _ExportWrapper from anomavision.export import ModelExporter, _ExportWrapper # noqa: F401 +from anomavision.utils import get_logger, setup_logging setup_logging("INFO") logger = get_logger(__name__) @@ -55,7 +54,7 @@ def _save_tiny_model(tmp_path: Path) -> Path: def test_export_onnx_creates_file(tmp_path): model_path = _save_tiny_model(tmp_path) - exporter = ModelExporter(str(model_path), str(tmp_path), logger,device="cpu") + exporter = ModelExporter(str(model_path), str(tmp_path), logger, device="cpu") out = exporter.export_onnx( input_shape=(1, 3, 16, 16), @@ -71,7 +70,7 @@ def test_export_onnx_creates_file(tmp_path): def test_export_torchscript_creates_file_and_loads(tmp_path): model_path = _save_tiny_model(tmp_path) - exporter = ModelExporter(str(model_path), str(tmp_path), logger,device="cpu") + exporter = ModelExporter(str(model_path), str(tmp_path), logger, device="cpu") out = exporter.export_torchscript( input_shape=(1, 3, 16, 16), @@ -92,7 +91,7 @@ def test_export_openvino_returns_none_if_not_installed(tmp_path, monkeypatch): and return None. If it IS installed, we still accept a valid export. """ model_path = _save_tiny_model(tmp_path) - exporter = ModelExporter(str(model_path), str(tmp_path), logger,device="cpu") + exporter = ModelExporter(str(model_path), str(tmp_path), logger, device="cpu") try: import openvino # noqa: F401 diff --git a/tests/test_inference_utils.py b/tests/test_inference_utils.py index 13dfefe..57a071b 100644 --- a/tests/test_inference_utils.py +++ b/tests/test_inference_utils.py @@ -11,6 +11,7 @@ # import anomavision.detect from anomavision import detect + def test_determine_device_basic_roundtrip(): assert detect.determine_device("cpu") == "cpu" assert detect.determine_device("cuda") == "cuda" @@ -39,7 +40,6 @@ def test_save_visualization_single_and_batch(tmp_path): assert len(files) == 3 - def test_parse_args_defaults(monkeypatch): """Test create_parser with no CLI args to get the defaults.""" monkeypatch.setenv("PYTHONHASHSEED", "0") @@ -54,6 +54,7 @@ def test_parse_args_defaults(monkeypatch): assert hasattr(args, "device") assert hasattr(args, "enable_visualization") + def test_main_with_missing_model_file_raises(tmp_path, monkeypatch): """ Ensures the 'model file not found' path raises FileNotFoundError. diff --git a/tests/test_padim.py b/tests/test_padim.py index 066b954..e8bbc83 100644 --- a/tests/test_padim.py +++ b/tests/test_padim.py @@ -2,6 +2,7 @@ """ Tests for core PaDiM functionality. """ + import tempfile from pathlib import Path