Skip to content

Commit

Permalink
template matching and TrackViz
Browse files Browse the repository at this point in the history
  • Loading branch information
kalwalt committed Jun 4, 2024
1 parent 8c74bcb commit 393533e
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 15 deletions.
209 changes: 194 additions & 15 deletions WebARKit/WebARKitTrackers/WebARKitOpticalTracking/WebARKitTracker.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
#include <WebARKitTrackers/WebARKitOpticalTracking/WebARKitConfig.h>
#include <WebARKitTrackers/WebARKitOpticalTracking/WebARKitHomographyInfo.h>
#include <WebARKitTrackers/WebARKitOpticalTracking/TrackerVisualization.h>
#include <WebARKitTrackers/WebARKitOpticalTracking/TrackingPointSelector.h>
#include <WebARKitTrackers/WebARKitOpticalTracking/WebARKitTracker.h>

namespace webarkit {

class WebARKitTracker::WebARKitTrackerImpl {
public:
bool _trackVizActive;
TrackerVisualization _trackViz;

WebARKitTrackerImpl()
: corners(4), initialized(false), output(17, 0.0), _valid(false), _isDetected(false), _isTracking(false),
numMatches(0), minNumMatches(MIN_NUM_MATCHES), _nn_match_ratio(0.7f) {
numMatches(0), minNumMatches(MIN_NUM_MATCHES), _nn_match_ratio(0.7f),_trackVizActive(false), _trackViz(TrackerVisualization()) {
m_camMatrix = cv::Matx33d::zeros();
m_distortionCoeff = cv::Mat::zeros(4, 1, cv::DataType<double>::type);
};
Expand Down Expand Up @@ -58,6 +62,7 @@ class WebARKitTracker::WebARKitTrackerImpl {
WEBARKIT_LOGi("Init Tracker!\n");

cv::Mat refGray = convert2Grayscale(refData, refCols, refRows, colorSpace);
refGray.copyTo(_image);

cv::Mat trackerFeatureMask = createTrackerFeatureMask(refGray);

Expand Down Expand Up @@ -160,6 +165,8 @@ class WebARKitTracker::WebARKitTrackerImpl {

_isDetected = false;

cv::Mat _image;

cv::Mat frameDescr;
std::vector<cv::KeyPoint> frameKeyPts;
bool valid;
Expand Down Expand Up @@ -312,6 +319,113 @@ class WebARKitTracker::WebARKitTrackerImpl {
}
}

bool RunTemplateMatching(cv::Mat frame, int trackableId)
{
//std::cout << "Starting template match" << std::endl;
std::vector<cv::Point2f> finalTemplatePoints, finalTemplateMatchPoints;
//Get a handle on the corresponding points from current image and the marker
//std::vector<cv::Point2f> trackablePoints = _trackables[trackableId]._trackSelection.GetTrackedFeatures();
//std::vector<cv::Point2f> trackablePointsWarped = _trackables[trackableId]._trackSelection.GetTrackedFeaturesWarped();
std::vector<cv::Point2f> trackablePoints = _trackSelection.GetTrackedFeatures();
std::vector<cv::Point2f> trackablePointsWarped = _trackSelection.GetTrackedFeaturesWarped();
//Create an empty result image - May be able to pre-initialize this container

int n = (int)trackablePointsWarped.size();
if (_trackVizActive) {
_trackViz.templateMatching = {};
_trackViz.templateMatching.templateMatchingCandidateCount = n;
}

for (int j = 0; j < n; j++) {
auto pt = trackablePointsWarped[j];
//if (cv::pointPolygonTest(_trackables[trackableId]._bBoxTransformed, trackablePointsWarped[j], true) > 0) {
if (cv::pointPolygonTest(_bBoxTransformed, trackablePointsWarped[j], true) > 0) {
auto ptOrig = trackablePoints[j];

cv::Rect templateRoi = GetTemplateRoi(pt);
cv::Rect frameROI(0, 0, frame.cols, frame.rows);
if (IsRoiValidForFrame(frameROI, templateRoi)) {
// cv::Rect markerRoi(0, 0, _trackables[trackableId]._image.cols, _trackables[trackableId]._image.rows);
cv::Rect markerRoi(0, 0, _image.cols, _image.rows);

std::vector<cv::Point2f> vertexPoints = GetVerticesFromPoint(ptOrig);
std::vector<cv::Point2f> vertexPointsResults;
// perspectiveTransform(vertexPoints, vertexPointsResults, _trackables[trackableId]._trackSelection.GetHomography());
perspectiveTransform(vertexPoints, vertexPointsResults, _trackSelection.GetHomography());

cv::Rect srcBoundingBox = cv::boundingRect(cv::Mat(vertexPointsResults));

vertexPoints.clear();
vertexPoints = GetVerticesFromTopCorner(srcBoundingBox.x, srcBoundingBox.y, srcBoundingBox.width, srcBoundingBox.height);
// perspectiveTransform(vertexPoints, vertexPointsResults, _trackables[trackableId]._trackSelection.GetHomography().inv());
perspectiveTransform(vertexPoints, vertexPointsResults, _trackSelection.GetHomography().inv());

std::vector<cv::Point2f> testVertexPoints = FloorVertexPoints(vertexPointsResults);
std::vector<cv::Point2f> finalWarpPoints = GetVerticesFromTopCorner(0, 0, srcBoundingBox.width, srcBoundingBox.height);
cv::Mat templateHomography = findHomography(testVertexPoints, finalWarpPoints, cv::RANSAC, ransac_thresh);

if (!templateHomography.empty()) {
cv::Rect templateBoundingBox = cv::boundingRect(cv::Mat(vertexPointsResults));
cv::Rect searchROI = InflateRoi(templateRoi, searchRadius);
if (IsRoiValidForFrame(frameROI, searchROI)) {
searchROI = searchROI & frameROI;
templateBoundingBox = templateBoundingBox & markerRoi;

if (templateBoundingBox.area() > 0 && searchROI.area() > templateBoundingBox.area()) {
cv::Mat searchImage = frame(searchROI);
// cv::Mat templateImage = _trackables[trackableId]._image(templateBoundingBox);
cv::Mat templateImage = _image(templateBoundingBox);
cv::Mat warpedTemplate;

warpPerspective(templateImage, warpedTemplate, templateHomography, srcBoundingBox.size());
cv::Mat matchResult = MatchTemplateToImage(searchImage, warpedTemplate);

if (!matchResult.empty()) {
double minVal; double maxVal;
cv::Point minLoc, maxLoc, matchLoc;
minMaxLoc( matchResult, &minVal, &maxVal, &minLoc, &maxLoc, cv::Mat() );
if (minVal < 0.5) {
matchLoc = minLoc;
matchLoc.x+=searchROI.x + (warpedTemplate.cols/2);
matchLoc.y+=searchROI.y + (warpedTemplate.rows/2);
finalTemplatePoints.push_back(ptOrig);
finalTemplateMatchPoints.push_back(matchLoc);
} else {
if (_trackVizActive) _trackViz.templateMatching.failedTemplateMinimumCorrelationCount++;
}
} else {
if (_trackVizActive) _trackViz.templateMatching.failedTemplateMatchCount++;
}
} else {
if (_trackVizActive) _trackViz.templateMatching.failedTemplateBigEnoughTestCount++;
}
} else {
if (_trackVizActive) _trackViz.templateMatching.failedSearchROIInFrameTestCount++;
}
} else {
if (_trackVizActive) _trackViz.templateMatching.failedGotHomogTestCount++;
}
} else {
if (_trackVizActive) _trackViz.templateMatching.failedROIInFrameTestCount++;
}
} else {
if (_trackVizActive) _trackViz.templateMatching.failedBoundsTestCount++;
}
}
bool gotHomography = UpdateTrackableHomography(trackableId, finalTemplatePoints, finalTemplateMatchPoints);
if (!gotHomography) {
//_trackables[trackableId]._isTracking = false;
// _trackables[trackableId]._isDetected = false;
_isTracking = false;
_isDetected = false;
_currentlyTrackedMarkers--;
}
if (_trackVizActive) {
_trackViz.templateMatching.templateMatchingOK = gotHomography;
}
return gotHomography;
}

void processFrame(cv::Mat& frame) {
if (!this->_valid) {
this->_valid = resetTracking(frame);
Expand Down Expand Up @@ -386,11 +500,11 @@ class WebARKitTracker::WebARKitTrackerImpl {
std::cout << "Optical flow failed." << std::endl;
// return true;
} else {
// if (_trackVizActive) _trackViz.opticalFlowOK = true;
if (_trackVizActive) _trackViz.opticalFlowOK = true;
// Refine optical flow with template match.
/*if (!RunTemplateMatching(frame, i)) {
if (!RunTemplateMatching(frame, i)) {
//std::cout << "Template matching failed." << std::endl;
}*/
}
// std::cout << "Optical flow ok." << std::endl;
// return false;
}
Expand Down Expand Up @@ -531,12 +645,14 @@ class WebARKitTracker::WebARKitTrackerImpl {
// This will be refined by the optical flow pass.
//perspectiveTransform(_trackables[bestMatchIndex]._bBox, _trackables[bestMatchIndex]._bBoxTransformed, homoInfo.homography);
perspectiveTransform(_bBox, _bBoxTransformed, homoInfo.homography);
/*if (_trackVizActive) {
if (_trackVizActive) {
for (int i = 0; i < 4; i++) {
_trackViz.bounds[i][0] = _trackables[bestMatchIndex]._bBoxTransformed[i].x;
_trackViz.bounds[i][1] = _trackables[bestMatchIndex]._bBoxTransformed[i].y;
// _trackViz.bounds[i][0] = _trackables[bestMatchIndex]._bBoxTransformed[i].x;
// _trackViz.bounds[i][1] = _trackables[bestMatchIndex]._bBoxTransformed[i].y;
_trackViz.bounds[i][0] = _bBoxTransformed[i].x;
_trackViz.bounds[i][1] = _bBoxTransformed[i].y;
}
}*/
}
//_currentlyTrackedMarkers++;
}
}
Expand Down Expand Up @@ -566,10 +682,10 @@ class WebARKitTracker::WebARKitTrackerImpl {
filteredTrackedPoints.push_back(flowResultPoints[j]);
}
// std::cout << "Optical Flow ok!!!!" << std::endl;
/*if (_trackVizActive) {
if (_trackVizActive) {
_trackViz.opticalFlowTrackablePoints = filteredTrackablePoints;
_trackViz.opticalFlowTrackedPoints = filteredTrackedPoints;
}*/
}
// std::cout << "Optical flow discarded " << killed1 << " of " << flowResultPoints.size() << " points" <<
// std::endl;

Expand All @@ -596,21 +712,49 @@ class WebARKitTracker::WebARKitTrackerImpl {
// Update the bounding box.
perspectiveTransform(_bBox, _bBoxTransformed, homoInfo.homography);
fill_output2(m_H);
/*if (_trackVizActive) {
if (_trackVizActive) {
for (int i = 0; i < 4; i++) {
_trackViz.bounds[i][0] = _trackables[trackableId]._bBoxTransformed[i].x;
_trackViz.bounds[i][1] = _trackables[trackableId]._bBoxTransformed[i].y;
// _trackViz.bounds[i][0] = _trackables[trackableId]._bBoxTransformed[i].x;
// _trackViz.bounds[i][1] = _trackables[trackableId]._bBoxTransformed[i].y;
_trackViz.bounds[i][0] = _bBoxTransformed[i].x;
_trackViz.bounds[i][1] = _bBoxTransformed[i].y;
}
}
if (_frameCount > 1) {
_trackables[trackableId]._trackSelection.ResetSelection();
}*/
// _trackables[trackableId]._trackSelection.ResetSelection();
_trackSelection.ResetSelection();
}
return true;
}
}
return false;
}

std::vector<cv::Point2f> GetVerticesFromPoint(cv::Point ptOrig, int width = markerTemplateWidth, int height = markerTemplateWidth)
{
std::vector<cv::Point2f> vertexPoints;
vertexPoints.push_back(cv::Point2f(ptOrig.x - width/2, ptOrig.y - height/2));
vertexPoints.push_back(cv::Point2f(ptOrig.x + width/2, ptOrig.y - height/2));
vertexPoints.push_back(cv::Point2f(ptOrig.x + width/2, ptOrig.y + height/2));
vertexPoints.push_back(cv::Point2f(ptOrig.x - width/2, ptOrig.y + height/2));
return vertexPoints;
}

std::vector<cv::Point2f> GetVerticesFromTopCorner(int x, int y, int width, int height)
{
std::vector<cv::Point2f> vertexPoints;
vertexPoints.push_back(cv::Point2f(x, y));
vertexPoints.push_back(cv::Point2f(x + width, y));
vertexPoints.push_back(cv::Point2f(x + width, y + height));
vertexPoints.push_back(cv::Point2f(x, y + height));
return vertexPoints;
}

cv::Rect GetTemplateRoi(cv::Point2f pt)
{
return cv::Rect(pt.x - (markerTemplateWidth/2), pt.y - (markerTemplateWidth/2), markerTemplateWidth, markerTemplateWidth);
}

cv::Mat MatchTemplateToImage(cv::Mat searchImage, cv::Mat warpedTemplate)
{
int result_cols = searchImage.cols - warpedTemplate.cols + 1;
Expand All @@ -633,6 +777,41 @@ class WebARKitTracker::WebARKitTrackerImpl {
return cv::Mat();
}
}

bool IsRoiValidForFrame(cv::Rect frameRoi, cv::Rect roi)
{
return (roi & frameRoi) == roi;
}

cv::Rect InflateRoi(cv::Rect roi, int inflationFactor)
{
cv::Rect newRoi = roi;
newRoi.x -= inflationFactor;
newRoi.y -= inflationFactor;
newRoi.width += 2 * inflationFactor;
newRoi.height += 2 * inflationFactor;
return newRoi;
}

std::vector<cv::Point2f> FloorVertexPoints(const std::vector<cv::Point2f>& vertexPoints)
{
std::vector<cv::Point2f> testVertexPoints = vertexPoints;
float minX = std::numeric_limits<float>::max();
float minY = std::numeric_limits<float>::max();
for (int k = 0; k < testVertexPoints.size(); k++) {
if (testVertexPoints[k].x < minX) {
minX=testVertexPoints[k].x;
}
if (testVertexPoints[k].y < minY) {
minY=testVertexPoints[k].y;
}
}
for(int k = 0; k < testVertexPoints.size(); k++) {
testVertexPoints[k].x -= minX;
testVertexPoints[k].y -= minY;
}
return testVertexPoints;
}

void getMatches(const cv::Mat& frameDescr, std::vector<cv::KeyPoint>& frameKeyPts,
std::vector<cv::Point2f>& refPoints, std::vector<cv::Point2f>& framePoints) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* TrackerVisualization.h
* artoolkitX
*
* This file is part of artoolkitX.
*
* artoolkitX is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* artoolkitX is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with artoolkitX. If not, see <http://www.gnu.org/licenses/>.
*
* As a special exception, the copyright holders of this library give you
* permission to link this library with independent modules to produce an
* executable, regardless of the license terms of these independent modules, and to
* copy and distribute the resulting executable under terms of your choice,
* provided that you also meet, for each linked independent module, the terms and
* conditions of the license of that module. An independent module is a module
* which is neither derived from nor based on this library. If you modify this
* library, you may extend this exception to your version of the library, but you
* are not obligated to do so. If you do not wish to do so, delete this exception
* statement from your version.
*
* Copyright 2024 Eden Networks Ltd.
*
* Author(s): Philip Lamb.
*
*/

#ifndef TRACKER_VISUALIZATION_H
#define TRACKER_VISUALIZATION_H

#include <opencv2/core.hpp>

class TrackerVisualization
{
public:
int id;
float bounds[4][2];
std::vector<cv::Point2f> opticalFlowTrackablePoints;
std::vector<cv::Point2f> opticalFlowTrackedPoints;
bool opticalFlowOK;
struct templateMatching {
int templateMatchingCandidateCount;
int failedBoundsTestCount;
int failedROIInFrameTestCount;
int failedGotHomogTestCount;
int failedSearchROIInFrameTestCount;
int failedTemplateBigEnoughTestCount;
int failedTemplateMatchCount;
int failedTemplateMinimumCorrelationCount;
bool templateMatchingOK;
};
templateMatching templateMatching;
};

#endif // TRACKER_VISUALIZATION_H

0 comments on commit 393533e

Please sign in to comment.