jakaskerl
Sure thing, I've included the full script below (minus depth visualization). I haven't tested this with multiple NN models, just yolov6-nano. Detections look rock-solid though, consistently coming through with high confidence, at least as determined by the visualization.
#!/usr/bin/env python3
import cv2
import depthai as dai
import numpy as np
modelDescription = dai.NNModelDescription("yolov6-nano")
FPS = 30
class TrackerDisplay(dai.node.HostNode):
def __init__(self, label_map):
dai.node.HostNode.__init__(self)
# This sends all the processing to the pipeline where it's executed by the `pipeline.runTasks()` or implicitly by `pipeline.run()` method.
# It's needed as the GUI window needs to be updated in the main thread, and the `process` method is by default called in a separate thread.
self.sendProcessingToPipeline(True)
self.label_map = label_map
def build(self, tracks: dai.Node.Output, rgb: dai.Node.Output):
self.link_args(tracks, rgb) # Has to match the inputs to the `process` method
return self
def onStart(self) -> None: # Optional method
print("Tracker Display started")
def process(self, tracks: dai.Tracklets, rgbPreview: dai.ImgFrame):
frame = rgbPreview.getCvFrame()
tracklets_data = tracks.tracklets
for t in tracklets_data:
roi = t.roi.denormalize(frame.shape[1], frame.shape[0])
x1 = int(roi.topLeft().x)
y1 = int(roi.topLeft().y)
x2 = int(roi.bottomRight().x)
y2 = int(roi.bottomRight().y)
try:
label = self.label_map[t.label]
except:
label = t.label
color = (255, 255, 255)
cv2.putText(frame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
cv2.putText(frame, f"ID: {[t.id]}", (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
cv2.putText(frame, t.status.name, (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
# Age gives the number of frames since an association was made. 1 means it's stable, starts incrementing if lost
# Does not seem to be a way to specify the number of consecutive frames to observe before registering it as a detecttion
cv2.putText(frame, f"Age: {[t.age]}", (x1 + 10, y1 + 110), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, cv2.FONT_HERSHEY_SIMPLEX)
cv2.putText(frame, f"X: {int(t.spatialCoordinates.x)} mm", (x1 + 10, y1 + 65), cv2.FONT_HERSHEY_TRIPLEX, 0.5,
255)
cv2.putText(frame, f"Y: {int(t.spatialCoordinates.y)} mm", (x1 + 10, y1 + 80), cv2.FONT_HERSHEY_TRIPLEX, 0.5,
255)
cv2.putText(frame, f"Z: {int(t.spatialCoordinates.z)} mm", (x1 + 10, y1 + 95), cv2.FONT_HERSHEY_TRIPLEX, 0.5,
255)
cv2.imshow("Tracker", frame)
key = cv2.waitKey(1)
if key == ord('q'):
print("Detected 'q' - stopping the pipeline...")
self.stopPipeline()
class SpatialVisualizer(dai.node.HostNode):
def __init__(self):
dai.node.HostNode.__init__(self)
self.sendProcessingToPipeline(True)
def build(self, detections: dai.Node.Output, rgb: dai.Node.Output):
self.link_args(detections, rgb) # Must match the inputs to the process method
def process(self, detections, rgbPreview):
rgbPreview = rgbPreview.getCvFrame()
self.displayResults(rgbPreview, detections.detections)
def displayResults(self, rgbFrame, detections):
height, width, _ = rgbFrame.shape
for detection in detections:
self.drawDetections(rgbFrame, detection, width, height)
cv2.imshow("rgb", rgbFrame)
if cv2.waitKey(1) == ord('q'):
self.stopPipeline()
def drawDetections(self, frame, detection, frameWidth, frameHeight):
x1 = int(detection.xmin * frameWidth)
x2 = int(detection.xmax * frameWidth)
y1 = int(detection.ymin * frameHeight)
y2 = int(detection.ymax * frameHeight)
try:
label = self.labelMap[detection.label] # Ensure labelMap is accessible
except IndexError:
label = detection.label
color = (255, 255, 255)
cv2.putText(frame, str(label), (x1 + 10, y1 + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
cv2.putText(frame, "{:.2f}".format(detection.confidence * 100), (x1 + 10, y1 + 35), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
cv2.putText(frame, f"X: {int(detection.spatialCoordinates.x)} mm", (x1 + 10, y1 + 50), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
cv2.putText(frame, f"Y: {int(detection.spatialCoordinates.y)} mm", (x1 + 10, y1 + 65), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
cv2.putText(frame, f"Z: {int(detection.spatialCoordinates.z)} mm", (x1 + 10, y1 + 80), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 1)
# Creates the pipeline and a default device implicitly
with dai.Pipeline() as p:
# Define sources and outputs
camRgb = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_A)
monoLeft = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_B)
monoRight = p.create(dai.node.Camera).build(dai.CameraBoardSocket.CAM_C)
stereo = p.create(dai.node.StereoDepth)
spatialDetectionNetwork = p.create(dai.node.SpatialDetectionNetwork).build(camRgb, stereo, modelDescription, fps=FPS)
visualizer = p.create(SpatialVisualizer)
objectTracker = p.create(dai.node.ObjectTracker)
cam_visualizer = p.create(TrackerDisplay,
label_map=spatialDetectionNetwork.getClasses()).build(objectTracker.out,
camRgb.requestOutput((300,300))
)
# setting node configs
stereo.setExtendedDisparity(True)
platform = p.getDefaultDevice().getPlatform()
if platform == dai.Platform.RVC2:
# For RVC2, width must be divisible by 16
stereo.setOutputSize(640, 400)
# object tracker config
objectTracker.setDetectionLabelsToTrack([0]) # track only person
# possible tracking types: ZERO_TERM_COLOR_HISTOGRAM, ZERO_TERM_IMAGELESS, SHORT_TERM_IMAGELESS, SHORT_TERM_KCF
objectTracker.setTrackerType(dai.TrackerType.SHORT_TERM_IMAGELESS)
# take the smallest ID when new object is tracked, possible options: SMALLEST_ID, UNIQUE_ID
objectTracker.setTrackerIdAssignmentPolicy(dai.TrackerIdAssignmentPolicy.UNIQUE_ID)
# spatial detection network config
spatialDetectionNetwork.input.setBlocking(False)
spatialDetectionNetwork.setBoundingBoxScaleFactor(0.5)
spatialDetectionNetwork.setDepthLowerThreshold(100)
spatialDetectionNetwork.setDepthUpperThreshold(5000)
# Linking
monoLeft.requestOutput((640, 400)).link(stereo.left)
monoRight.requestOutput((640, 400)).link(stereo.right)
visualizer.labelMap = spatialDetectionNetwork.getClasses()
spatialDetectionNetwork.passthrough.link(objectTracker.inputTrackerFrame)
spatialDetectionNetwork.passthrough.link(objectTracker.inputDetectionFrame)
spatialDetectionNetwork.out.link(objectTracker.inputDetections)
visualizer.build(spatialDetectionNetwork.out, spatialDetectionNetwork.passthrough)
p.run()