"""
Copyright (c) 2021-2022 UCLouvain, ICTEAM
Licensed under GPL-3.0 [see LICENSE for details]
Written by Jonathan Samelson (2021-2022)
"""
from pytb.tracking.bboxes.bboxes_2d_tracker.bboxes_2d_tracker import BBoxes2DTracker
from pytb.tracking.bboxes.bboxes_2d_tracker.mbtracker.sort.sort_abewley import Sort as SortAbewley
from pytb.output.bboxes_2d import BBoxes2D
from pytb.output.bboxes_2d_track import BBoxes2DTrack
import numpy as np
from timeit import default_timer
import logging
log = logging.getLogger("aptitude-toolbox")
[docs]class SORT(BBoxes2DTracker):
[docs] def __init__(self, proc_parameters: dict):
"""Initializes a SORT tracker with the given parameters.
Args:
proc_parameters (dict): A dictionary containing the related SORT's parameters
"""
super().__init__(proc_parameters)
# An object that is not tracked for max_age frame is removed from the memory
self.max_age = proc_parameters["params"].get("max_age", 10)
# Minimum of hits to start tracking the objects
self.min_hits = proc_parameters["params"].get("min_hits", 3)
# The minimum IOU threshold to keep the association of a previously detected object
self.iou_thresh = proc_parameters["params"].get("iou_thresh", 0.3)
# Above a value of 1.0, it enables a fading memory which gives less importance to the older tracks in the memory
self.memory_fade = proc_parameters["params"].get("memory_fade", 1.0)
log.debug("SORT {} implementation selected.".format(self.pref_implem))
if self.pref_implem == "Abewley":
self.tracker = SortAbewley(self.max_age, self.min_hits, self.iou_thresh, self.memory_fade)
else:
assert False, "[ERROR] Unknown implementation of SORT: {}".format(self.pref_implem)
[docs] def track(self, detection: BBoxes2D) -> BBoxes2DTrack:
"""Performs an inference on the given frame.
Args:
detection (BBoxes2D): The detection used to infer IDs.
Returns:
BBoxes2DTrack: A set of 2D bounding boxes identifying detected objects with the tracking information added.
"""
if self.pref_implem == "Abewley":
if detection.number_objects == 0:
dets = np.empty((0, 6))
else:
detection.to_x1_y1_x2_y2()
dets = np.column_stack((detection.bboxes, detection.det_confs, detection.class_IDs))
# Update results based on the detections of current frame
start = default_timer()
res = self.tracker.update(dets)
tracking_time = default_timer() - start
# Format results
res_split = np.hsplit(res, np.array([4, 5, 6, 7]))
bboxes = res_split[0]
class_IDs = res_split[2].flatten().astype(int)
det_confs = res_split[1].flatten()
global_IDs = res_split[3].flatten().astype(int)
output = BBoxes2DTrack(detection.detection_time, bboxes, class_IDs, det_confs,
detection.dim_width, detection.dim_height, tracking_time, global_IDs,
bboxes_format="x1_y1_x2_y2")
output.to_xt_yt_w_h()
return output
else:
assert False, "[ERROR] Unknown implementation of SORT: {}".format(self.pref_implem)
[docs] def reset_state(self, reset_id: bool = False):
"""Reset the current state of the tracker."""
self.tracker.reset_state(reset_id)