Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/ByteTrack/BYTETracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ class BYTETracker
const int& track_buffer = 30,
const float& track_thresh = 0.5,
const float& high_thresh = 0.6,
const float& match_thresh = 0.8);
const float& match_thresh = 0.8,
const bool& giou_enabled=true);
~BYTETracker();

std::vector<STrackPtr> update(const std::vector<Object>& objects);
std::vector<STrackPtr> update(const std::vector<Object>& objects, const bool force_activate);

private:
std::vector<STrackPtr> jointStracks(const std::vector<STrackPtr> &a_tlist,
Expand Down Expand Up @@ -63,6 +64,7 @@ class BYTETracker
const float track_thresh_;
const float high_thresh_;
const float match_thresh_;
const bool giou_enabled_;
const size_t max_time_lost_;

size_t frame_id_;
Expand Down
2 changes: 2 additions & 0 deletions include/ByteTrack/Rect.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Rect
Xyah<T> getXyah() const;

float calcIoU(const Rect<T>& other) const;
// Add GIoU base on Bytrack-Telespazio
float calcGIoU(const Rect<T>& other) const;
};

template<typename T>
Expand Down
8 changes: 6 additions & 2 deletions include/ByteTrack/STrack.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@ enum class STrackState {
class STrack
{
public:
STrack(const Rect<float>& rect, const float& score);
STrack(const Rect<float>& rect, const float& score, const int& category);
~STrack();


const Rect<float>& getRect() const;
const STrackState& getSTrackState() const;

const bool& isActivated() const;
const float& getScore() const;
const int& getCategory() const;
const size_t& getTrackId() const;
const size_t& getFrameId() const;
const size_t& getStartFrameId() const;
const size_t& getTrackletLength() const;

void activate(const size_t& frame_id, const size_t& track_id);
void activate(const size_t& frame_id, const size_t& track_id,const bool force_activate);
void reActivate(const STrack &new_track, const size_t &frame_id, const int &new_track_id = -1);

void predict();
Expand All @@ -49,11 +51,13 @@ class STrack

bool is_activated_;
float score_;
int category_;
size_t track_id_;
size_t frame_id_;
size_t start_frame_id_;
size_t tracklet_len_;


void updateRect();
};
}
21 changes: 15 additions & 6 deletions src/BYTETracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,31 @@
#include <utility>
#include <vector>

#include <iostream>
using namespace std;

byte_track::BYTETracker::BYTETracker(const int& frame_rate,
const int& track_buffer,
const float& track_thresh,
const float& high_thresh,
const float& match_thresh) :
const float& match_thresh,
const bool& giou_enabled) :
track_thresh_(track_thresh),
high_thresh_(high_thresh),
match_thresh_(match_thresh),
giou_enabled_(giou_enabled),
max_time_lost_(static_cast<size_t>(frame_rate / 30.0 * track_buffer)),
frame_id_(0),
track_id_count_(0)

{
}

byte_track::BYTETracker::~BYTETracker()
{
}

std::vector<byte_track::BYTETracker::STrackPtr> byte_track::BYTETracker::update(const std::vector<Object>& objects)
std::vector<byte_track::BYTETracker::STrackPtr> byte_track::BYTETracker::update(const std::vector<Object>& objects, const bool force_activate)
{
////////////////// Step 1: Get detections //////////////////
frame_id_++;
Expand All @@ -37,7 +43,8 @@ std::vector<byte_track::BYTETracker::STrackPtr> byte_track::BYTETracker::update(

for (const auto &object : objects)
{
const auto strack = std::make_shared<STrack>(object.rect, object.prob);
const auto strack = std::make_shared<STrack>(object.rect, object.prob,object.label);

if (object.prob >= track_thresh_)
{
det_stracks.push_back(strack);
Expand Down Expand Up @@ -189,8 +196,9 @@ std::vector<byte_track::BYTETracker::STrackPtr> byte_track::BYTETracker::update(
{
continue;
}

track_id_count_++;
track->activate(frame_id_, track_id_count_);
track->activate(frame_id_, track_id_count_,force_activate);
current_tracked_stracks.push_back(track);
}
}
Expand All @@ -216,7 +224,7 @@ std::vector<byte_track::BYTETracker::STrackPtr> byte_track::BYTETracker::update(

std::vector<STrackPtr> output_stracks;
for (const auto &track : tracked_stracks_)
{
{
if (track->isActivated())
{
output_stracks.push_back(track);
Expand All @@ -225,6 +233,7 @@ std::vector<byte_track::BYTETracker::STrackPtr> byte_track::BYTETracker::update(

return output_stracks;
}

std::vector<byte_track::BYTETracker::STrackPtr> byte_track::BYTETracker::jointStracks(const std::vector<STrackPtr> &a_tlist,
const std::vector<STrackPtr> &b_tlist) const
{
Expand Down Expand Up @@ -392,7 +401,7 @@ std::vector<std::vector<float>> byte_track::BYTETracker::calcIous(const std::vec
{
for (size_t ai = 0; ai < a_rect.size(); ai++)
{
ious[ai][bi] = b_rect[bi].calcIoU(a_rect[ai]);
ious[ai][bi] = giou_enabled_ ? b_rect[bi].calcGIoU(a_rect[ai]): b_rect[bi].calcIoU(a_rect[ai]);
}
}
return ious;
Expand Down
34 changes: 33 additions & 1 deletion src/Rect.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "ByteTrack/Rect.h"

#include <iostream>
#include <algorithm>

template <typename T>
Expand Down Expand Up @@ -124,6 +124,38 @@ float byte_track::Rect<T>::calcIoU(const Rect<T>& other) const
}
return iou;
}
// Add GIoU base on Bytrack-Telespazio
template<typename T>
float byte_track::Rect<T>::calcGIoU(const Rect<T>& other) const
{
float epsilon = 1e-10; // Define a small value

const float box_area = (other.tlwh[2] + 1) * (other.tlwh[3] + 1);
const float iw = std::min(tlwh[0] + tlwh[2], other.tlwh[0] + other.tlwh[2]) - std::max(tlwh[0], other.tlwh[0]) + 1;
const float ih = std::min(tlwh[1] + tlwh[3], other.tlwh[1] + other.tlwh[3]) - std::max(tlwh[1], other.tlwh[1]) + 1;
float iou = 0;
float giou = 0;
float box_area_track = (tlwh[0] + tlwh[2] - tlwh[0] + 1) * (tlwh[1] + tlwh[3] - tlwh[1] + 1);

float overlap = 0.0;
if (iw > 0 && ih > 0)
overlap = iw * ih;

const float ua = box_area_track + box_area - overlap;

iou = overlap / (ua + epsilon);

float enclosed_lt_x = std::min(tlwh[0], other.tlwh[0]);
float enclosed_lt_y = std::min(tlwh[1], other.tlwh[1]);
float enclosed_rb_x = std::max(tlwh[0] + tlwh[2], other.tlwh[0] + other.tlwh[2]);
float enclosed_rb_y = std::max(tlwh[1] + tlwh[3], other.tlwh[1] + other.tlwh[3]);
float enclosed_w = std::max(0.0f, enclosed_rb_x - enclosed_lt_x + 1);
float enclosed_h = std::max(0.0f, enclosed_rb_y - enclosed_lt_y + 1);

float enclosed_area = enclosed_w * enclosed_h + epsilon;
giou = iou - (enclosed_area - ua) / enclosed_area;
return giou;
}

template<typename T>
byte_track::Rect<T> byte_track::generate_rect_by_tlbr(const byte_track::Tlbr<T>& tlbr)
Expand Down
15 changes: 12 additions & 3 deletions src/STrack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,29 @@

#include <cstddef>

byte_track::STrack::STrack(const Rect<float>& rect, const float& score) :
byte_track::STrack::STrack(const Rect<float>& rect, const float& score,const int& category) :
kalman_filter_(),
mean_(),
covariance_(),
rect_(rect),
state_(STrackState::New),
is_activated_(false),
score_(score),
category_(category),
track_id_(0),
frame_id_(0),
start_frame_id_(0),
tracklet_len_(0)

{
}

byte_track::STrack::~STrack()
{
}



const byte_track::Rect<float>& byte_track::STrack::getRect() const
{
return rect_;
Expand All @@ -40,6 +44,11 @@ const float& byte_track::STrack::getScore() const
return score_;
}

const int& byte_track::STrack::getCategory() const
{
return category_;
}

const size_t& byte_track::STrack::getTrackId() const
{
return track_id_;
Expand All @@ -60,14 +69,14 @@ const size_t& byte_track::STrack::getTrackletLength() const
return tracklet_len_;
}

void byte_track::STrack::activate(const size_t& frame_id, const size_t& track_id)
void byte_track::STrack::activate(const size_t& frame_id, const size_t& track_id, const bool force_activate)
{
kalman_filter_.initiate(mean_, covariance_, rect_.getXyah());

updateRect();

state_ = STrackState::Tracked;
if (frame_id == 1)
if (frame_id == 1 || force_activate)
{
is_activated_ = true;
}
Expand Down