Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions scripts/export_to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import os
import sys
import argparse
from pathlib import Path

import torch

from src.inference.utils.inference_factory import InferenceFactory, InferenceConfig, Backend, ModelArch
from src.inference.utils.inference_factory import Backend, InferenceConfig, InferenceFactory, ModelArch
from src.path_utils import ensure_clean_directory

NUM_CLASSES = 83
INPUT_SIZE = (256, 256)
DEVICE = "cpu"
MODELS_DIR_PATH = Path("models")


def main(model_name: str):
arch = ModelArch.RESNET if "resnet" in model_name.lower() else ModelArch.MOBILENET
backend = Backend.PYTORCH
Expand Down
3 changes: 1 addition & 2 deletions scripts/onnx_validation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
import argparse
from pathlib import Path

import numpy as np
import torch
from PIL import Image

from dataset.optimal_class_mapping import MODEL_NAMES as ID_TO_NAME
from src.inference.utils.inference_factory import InferenceFactory, InferenceConfig, Backend, ModelArch
from src.inference.base.classifier_inference_base import ClassifierInferenceBase
from src.inference.base.classifier_inference_base_onnx import OnnxClassifierInferenceBase
from src.inference.utils.inference_factory import Backend, InferenceConfig, InferenceFactory, ModelArch

NUM_CLASSES = 83
INPUT_SIZE = (256, 256)
Expand Down
16 changes: 10 additions & 6 deletions services/consumer/consumer.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
import paho.mqtt.client as mqtt
import pymongo
import json
from datetime import datetime

import paho.mqtt.client as mqtt
import pymongo


def on_connect(client, userdata, flags, rc):
print(f"🔗 Connected: {rc}")
client.subscribe("deepstream/predictions")


def on_message(client, userdata, msg):
try:
payload = json.loads(msg.payload.decode())
payload["received_at"] = datetime.now().isoformat()

mongo_client = pymongo.MongoClient("mongodb://agstream_mongo:27017/")
db = mongo_client["agstream"]
collection = db["predictions"]

collection.insert_one(payload)
print(f"💾 Saved to MongoDB!")
print("💾 Saved to MongoDB!")
mongo_client.close()

except Exception as e:
print(f"❌ Error: {e}")


client = mqtt.Client()
client.on_connect = on_connect
client.on_message = on_message
Expand Down
33 changes: 17 additions & 16 deletions services/consumer/mqtt_consumer.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,60 @@
#!/usr/bin/env python3
import paho.mqtt.client as mqtt
import pymongo
import json
from datetime import datetime

import paho.mqtt.client as mqtt
import pymongo

MQTT_BROKER = "agstream_mosquitto"
MQTT_PORT = 1883
MQTT_TOPIC = "deepstream/predictions"
MONGO_URI = "mongodb://agstream_mongo:27017/"
MONGO_DB = "agstream"
MONGO_COLLECTION = "predictions"


def on_connect(client, userdata, flags, rc):
print(f"✅ Connected to MQTT broker with result code {rc}")
client.subscribe(MQTT_TOPIC)


def on_message(client, userdata, msg):
try:
payload = json.loads(msg.payload.decode())
payload["received_at"] = datetime.now().isoformat()

# Extract classification data from nvmsgbroker format
obj = payload.get("object", {})
if obj.get("id") != "0" and obj.get("id"):
class_id = int(obj["id"])
confidence = obj.get("confidence", 0)

# Add extracted classification
payload["classification"] = {
"class_id": class_id,
"confidence": confidence
}
payload["classification"] = {"class_id": class_id, "confidence": confidence}
print(f"🌱 FOUND CLASSIFICATION: ID {class_id}, confidence {confidence:.3f}")

mongo_client = pymongo.MongoClient(MONGO_URI)
db = mongo_client[MONGO_DB]
collection = db[MONGO_COLLECTION]

result = collection.insert_one(payload)

if "classification" in payload:
print(f"✅ Saved classification to MongoDB!")
print("✅ Saved classification to MongoDB!")
else:
print(f"✅ Saved: No classification")
print("✅ Saved: No classification")

mongo_client.close()

except Exception as e:
print(f"❌ Error: {e}")


if __name__ == "__main__":
client = mqtt.Client()
client.on_connect = on_connect
client.on_message = on_message

print("🚀 Starting Enhanced MQTT Consumer...")
client.connect(MQTT_BROKER, MQTT_PORT, 60)
client.loop_forever()
1 change: 1 addition & 0 deletions src/deepstream/helpers/load_class_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

PLANT_LABELS = "/workspace/configs/crop_and_weed_83_classes.txt"


# Load class labels
def load_class_labels() -> List[str]:
try:
Expand Down
9 changes: 4 additions & 5 deletions src/deepstream/helpers/meta_tensor_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import ctypes
import numpy.typing as npt

import numpy as np
import numpy.typing as npt
import pyds


class TensorExtractor:
def extract_logits(self, tensor_meta) -> npt.NDArray[np.float32]:
"""
Expand All @@ -13,10 +15,7 @@ def extract_logits(self, tensor_meta) -> npt.NDArray[np.float32]:
dims = [layer.dims.d[i] for i in range(layer.dims.numDims)]
numel = int(np.prod(dims))

ptr = ctypes.cast(
pyds.get_ptr(layer.buffer),
ctypes.POINTER(ctypes.c_float)
)
ptr = ctypes.cast(pyds.get_ptr(layer.buffer), ctypes.POINTER(ctypes.c_float))
logits = np.ctypeslib.as_array(ptr, shape=(numel,))

# Copy so we are not tied to DeepStream's memory lifetime
Expand Down
6 changes: 2 additions & 4 deletions src/deepstream/helpers/plant_msg_meta_builder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import sys
from pydantic import BaseModel

import gi
gi.require_version("Gst", "1.0")

from gi.repository import Gst
import pyds
from pydantic import BaseModel

from src.deepstream.helpers.softmax_topk_classifier import ClassificationPrediction


class PlantEvent(BaseModel):
frame_id: int
plant_id: str
Expand Down
8 changes: 3 additions & 5 deletions src/deepstream/helpers/remove_background.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import cv2
import numpy as np

MORPH_KERNEL = 34


def remove_background(frame_bgr: np.ndarray) -> np.ndarray:
rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
R, G, B = rgb[..., 0], rgb[..., 1], rgb[..., 2]
Expand All @@ -12,10 +13,7 @@ def remove_background(frame_bgr: np.ndarray) -> np.ndarray:

_, mask = cv2.threshold(exg_norm, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE,
(MORPH_KERNEL, MORPH_KERNEL)
)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (MORPH_KERNEL, MORPH_KERNEL))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)

out = np.zeros_like(frame_bgr)
Expand Down
15 changes: 7 additions & 8 deletions src/deepstream/helpers/should_skip_frame.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from typing import Any
from enum import Enum
from typing import Any

import cv2
import gi
import numpy as np
import pyds
import cv2

gi.require_version("Gst", "1.0")
from gi.repository import Gst

from src.frame_comparison.frame_change_detector import FrameChangeDetector


class FrameProcessDecision(str, Enum):
PROCESS = "process"
SKIP = "skip"


def should_skip_frame(info: Any, frame_meta: Any, batch_meta: Any, frame_change_detector: FrameChangeDetector) -> int:
"""Pad probe to drop frames based on frame difference analysis."""
gst_buffer = info.get_buffer()
Expand All @@ -25,12 +28,8 @@ def should_skip_frame(info: Any, frame_meta: Any, batch_meta: Any, frame_change_
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)

should_process, metrics = frame_change_detector.should_process(frame_bgr)

decision = (
FrameProcessDecision.PROCESS
if should_process
else FrameProcessDecision.SKIP
)

decision = FrameProcessDecision.PROCESS if should_process else FrameProcessDecision.SKIP

print(
f"Frame {frame_meta.frame_num:05d}: {decision.value} "
Expand Down
5 changes: 3 additions & 2 deletions src/deepstream/helpers/softmax_topk_classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from typing import List

import numpy as np
from pydantic import BaseModel


Expand All @@ -19,7 +20,7 @@ def predict_from_logits(self, logits: np.ndarray) -> List[ClassificationPredicti
exp = np.exp(logits - np.max(logits))
probs = exp / np.sum(exp)

top_idx = np.argsort(probs)[-self.top_k:][::-1]
top_idx = np.argsort(probs)[-self.top_k :][::-1]

results: List[ClassificationPrediction] = []
for idx in top_idx:
Expand Down
9 changes: 4 additions & 5 deletions src/deepstream/pipelines/access_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os

import pyds
from gi.repository import GLib, Gst
from gi.repository import Gst

from src.deepstream.helpers.pipeline_runner import run_pipeline

Expand All @@ -22,6 +22,7 @@
MESSAGE_TEMPLATE = "Frame processed: {}"
os.makedirs(OUTPUT_DIR, exist_ok=True)


def modify_metadata(frame_meta, batch_meta, message_template=MESSAGE_TEMPLATE):
"""
Modify the metadata by adding user-defined metadata.
Expand Down Expand Up @@ -79,9 +80,7 @@ def read_metadata(frame_meta):
class_id = label_info.result_class_id
confidence = label_info.result_prob

print(
f"[Frame {frame_meta.frame_num}] Class ID: {class_id}, Confidence: {confidence:.4f}"
)
print(f"[Frame {frame_meta.frame_num}] Class ID: {class_id}, Confidence: {confidence:.4f}")


def access_metadata(pad, info, u_data):
Expand Down Expand Up @@ -189,8 +188,8 @@ def on_pad_added_decode(src, pad):

return pipeline


if __name__ == "__main__":
Gst.init(None)
pipeline = build_pipeline()
run_pipeline(pipeline)

8 changes: 4 additions & 4 deletions src/deepstream/pipelines/access_raw_frames_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import gi
import numpy as np
import pyds
from gi.repository import GLib, Gst
from gi.repository import Gst

gi.require_version("Gst", "1.0")

Expand All @@ -36,6 +36,7 @@ def frames_manipulation(pad: Gst.Pad, info: Gst.PadProbeInfo) -> Gst.PadProbeRet
np.copyto(np.array(surface, copy=False, order="C"), frame)
return Gst.PadProbeReturn.OK


def build_pipeline() -> Gst.Pipeline:
pipeline: Gst.Pipeline = Gst.Pipeline.new("simple-pipeline")

Expand Down Expand Up @@ -71,9 +72,7 @@ def build_pipeline() -> Gst.Pipeline:

rtspsrc.set_property("location", RTSP_URL)
rtspsrc.set_property("latency", 200)
capsfilter.set_property(
"caps", Gst.Caps.from_string("video/x-raw(memory:NVMM), format=RGBA")
)
capsfilter.set_property("caps", Gst.Caps.from_string("video/x-raw(memory:NVMM), format=RGBA"))
streammux.set_property("batch-size", 1)
streammux.set_property("width", 640)
streammux.set_property("height", 480)
Expand Down Expand Up @@ -121,6 +120,7 @@ def on_decode_pad_added(src: Gst.Element, new_pad: Gst.Pad) -> None:
srcpad.add_probe(Gst.PadProbeType.BUFFER, frames_manipulation)
return pipeline


def main() -> None:
os.makedirs(OUTPUT_DIR, exist_ok=True)
Gst.init(None)
Expand Down
10 changes: 3 additions & 7 deletions src/deepstream/pipelines/cpu_frames_skipping_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import sys
import os

from typing import Any

import gi
import numpy as np
import pyds

gi.require_version("Gst", "1.0")
import cv2
from gi.repository import GLib, Gst
from gi.repository import Gst

from src.deepstream.helpers.pipeline_runner import run_pipeline
from src.deepstream.probes.frame_comparison.cpu.frame_skipping_probe import frame_skip_probe
from src.frame_comparison.frame_change_detector import FrameChangeDetector
from src.deepstream.helpers.pipeline_runner import run_pipeline

rtsp_port = os.environ.get("RTSP_PORT", "8554")
RTSP_URL = f"rtsp://127.0.0.1:{rtsp_port}/test"
Expand Down Expand Up @@ -108,6 +103,7 @@ def on_pad_added_decode(src: Any, pad: Any) -> None:

return pipeline


if __name__ == "__main__":
Gst.init(None)
pipeline = build_pipeline()
Expand Down
1 change: 1 addition & 0 deletions src/deepstream/pipelines/deepstream_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

from gi.repository import Gst

from src.path_utils import ensure_clean_directory # noqa: E402

Gst.init(None)
Expand Down
Loading
Loading