摘要
论文地址:RecRecNet: Rectangling Rectified Wide-Angle Images by Thin-Plate Spline Model and DoF-based Curriculum Learning
源码地址:https://github.com/KangLiao929/RecRecNet
广角镜头在VR技术等领域有着诱人的应用,但它会使拍摄的图像产生严重的径向畸变。为了还原真实场景,以往的工作致力于校正广角图像的内容。然而,这种校正方法不可避免地会扭曲图像边界,改变相关的几何分布,并误导当前的视觉感知模型。在这项工作中,我们通过提出一种新的学习模型,即矩形校正网络(RecRecNet),探索在内容和边界上构建一种双赢的表示。特别是,我们提出了一个薄板样条(TPS)模块来构建用于图像矩形化的非线性和非刚性变换。通过学习校正后图像上的控制点,我们的模型可以灵活地将源结构扭曲到目标域,并实现端到端的无监督变形。为了缓解结构逼近的复杂性,我们接着启发RecRecNet通过基于自由度(DoF)的课程学习来掌握渐进变形规则。通过在每个课程阶段增加自由度,即从相似变换(4自由度)到单应变换(8自由度),网络能够探究更详细的变形,在最终的矩形化任务上实现快速收敛。实验表明,我们的方法在定量和定性评估上都优于对比方法。
算法实现
1.模型训练
环境安装:
conda create -n recrecnet python=3.6
git clone https://github.com/KangLiao929/RecRecNet.git
conda activate recrecnet
pip install -r requirements.txt
训练数据集下载:train.zip,test.zip 。
预训练模型下载地址:https://drive.google.com/file/d/1y9iTfWCycS3BAFViMsClbur11IY-HgXf/view,
下载之后将其放入.\checkpoint文件夹中。
生成数据:
sh scripts/curriculum_gen.sh
模型训练:
sh scripts/train.sh
模型测试:
sh scripts/test.sh
2.模型推理
2.1 C++推理
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <fstream>
#include <string>
#include <math.h>
#include <opencv2/dnn.hpp>
#include <opencv2/opencv.hpp>
using namespace cv;
using namespace std;
using namespace dnn;
Mat linspace(float begin, float finish, int number)
{
float interval = (finish - begin) / (number - 1);//
Mat f(1, number, CV_32FC1);
for (int i = 0; i < f.rows; i++)
{
for (int j = 0; j < f.cols; j++)
{
f.at<float>(i, j) = begin + j * interval;
}
}
return f;
}
void get_norm_rigid_mesh_inv_grid(Mat& grid, Mat& W_inv, const int input_height, const int input_width, const int grid_h, const int grid_w)
{
float interval_x = input_width / grid_w;
float interval_y = input_height / grid_h;
const int h = grid_h + 1;
const int w = grid_w + 1;
const int length = h * w;
Mat norm_rigid_mesh(length, 2, CV_32FC1);
///norm_rigid_mesh.create(length, 2, CV_32FC1);
Mat W(length + 3, length + 3, CV_32FC1);
for (int i = 0; i < h; i++)
{
for (int j = 0; j < w; j++)
{
const int row_ind = i * w + j;
const float x = (j * interval_x) * 2.0 / float(input_width) - 1.0;
const float y = (i * interval_y) * 2.0 / float(input_height) - 1.0;
W.at<float>(row_ind, 0) = 1;
W.at<float>(row_ind, 1) = x;
W.at<float>(row_ind, 2) = y;
W.at<float>(length, 3 + row_ind) = 1;
W.at<float>(length + 1, 3 + row_ind) = x;
W.at<float>(length + 2, 3 + row_ind) = y;
norm_rigid_mesh.at<float>(row_ind, 0) = x;
norm_rigid_mesh.at<float>(row_ind, 1) = y;
}
}
for (int i = 0; i < length; i++)
{
for (int j = 0; j < length; j++)
{
const float d2_ij = powf(W.at<float>(i, 0) - W.at<float>(j, 0), 2.0) + powf(W.at<float>(i, 1) - W.at<float>(j, 1), 2.0) + powf(W.at<float>(i, 2) - W.at<float>(j, 2), 2.0);
W.at<float>(i, 3 + j) = d2_ij * logf(d2_ij + 1e-6);
}
}
for (int i = 0; i < 3; i++)
{
for (int j = 0; j < 3; j++)
{
W.at<float>(length + i, j) = 0;
}
}
W_inv = W.inv();
interval_x = 2.0 / (input_width - 1);
interval_y = 2.0 / (input_height - 1);
const int grid_width = input_height * input_width;
///Mat grid(length + 3, grid_width, CV_32FC1);
grid.create(length + 3, grid_width, CV_32FC1);
for (int i = 0; i < input_height; i++)
{
for (int j = 0; j < input_width; j++)
{
const float x = -1.0 + j * interval_x;
const float y = -1.0 + i * interval_y;
const int col_ind = i * input_width + j;
grid.at<float>(0, col_ind) = 1;
grid.at<float>(1, col_ind) = x;
grid.at<float>(2, col_ind) = y;
}
}
for (int i = 0; i < length; i++)
{
for (int j = 0; j < grid_width; j++)
{
const float d2_ij = powf(norm_rigid_mesh.at<float>(i, 0) - grid.at<float>(1, j), 2.0) + powf(norm_rigid_mesh.at<float>(i, 1) - grid.at<float>(2, j), 2.0);
grid.at<float>(3 + i, j) = d2_ij * logf(d2_ij + 1e-6);
}
}
norm_rigid_mesh.release();
}
void get_ori_rigid_mesh_tp(Mat& tp, Mat& ori_mesh_np_x, Mat& ori_mesh_np_y, const float* offset, const int input_height, const int input_width, const int grid_h, const int grid_w)
{
const float interval_x = input_width / grid_w;
const float interval_y = input_height / grid_h;
const int h = grid_h + 1;
const int w = grid_w + 1;
const int length = h * w;
tp.create(length + 3, 2, CV_32FC1);
ori_mesh_np_x.create(h, w, CV_32FC1);
ori_mesh_np_y.create(h, w, CV_32FC1);
for (int i = 0; i < h; i++)
{
for (int j = 0; j < w; j++)
{
const int row_ind = i * w + j;
const float x = j * interval_x + offset[row_ind * 2];
const float y = i * interval_y + offset[row_ind * 2 + 1];
tp.at<float>(row_ind, 0) = (j * interval_x + offset[row_ind * 2]) * 2.0 / float(input_width) - 1.0;
tp.at<float>(row_ind, 1) = (i * interval_y + offset[row_ind * 2 + 1]) * 2.0 / float(input_height) - 1.0;
ori_mesh_np_x.at<float>(i, j) = x;
ori_mesh_np_y.at<float>(i, j) = y;
}
}
for (int i = 0; i < 3; i++)
{
tp.at<float>(length + i, 0) = 0;
tp.at<float>(length + i, 1) = 0;
}
}
Mat _interpolate(Mat im, Mat xy_flat, Size out_size) xy_flat的形状是(2, 65536)
{
const int height = im.size[2];
const int width = im.size[3];
const int max_x = width - 1;
const int max_y = height - 1;
const float height_f = float(height);
const float width_f = float(width);
const int area = height * width;
const float* pdata = (float*)im.data; 形状是(1,3,256,256)
Mat output(out_size.height, out_size.width, CV_32FC3);
for (int i = 0; i < height; i++)
{
for (int j = 0; j < width; j++)
{
const int col_ind = i * width + j;
float x = (xy_flat.at<float>(0, col_ind) + 1.0) * width_f * 0.5;
float y = (xy_flat.at<float>(1, col_ind) + 1.0) * height_f * 0.5;
int x0 = int(x);
int x1 = x0 + 1;
int y0 = int(y);
int y1 = y0 + 1;
x0 = std::min(std::max(x0, 0), max_x);
x1 = std::min(std::max(x1, 0), max_x);
y0 = std::min(std::max(y0, 0), max_y);
y1 = std::min(std::max(y1, 0), max_y);
int base_y0 = y0 * width;
int base_y1 = y1 * width;
int idx_a = base_y0 + x0;
int idx_b = base_y1 + x0;
int idx_c = base_y0 + x1;
int idx_d = base_y1 + x1;
float x0_f = float(x0);
float x1_f = float(x1);
float y0_f = float(y0);
float y1_f = float(y1);
float wa = (x1_f - x) * (y1_f - y);
float wb = (x1_f - x) * (y - y0_f);
float wc = (x - x0_f) * (y1_f - y);
float wd = (x - x0_f) * (y - y0_f);
float pix_r = wa * pdata[idx_a] + wb * pdata[idx_b] + wc * pdata[idx_c] + wd * pdata[idx_d];
float pix_g = wa * pdata[area + idx_a] + wb * pdata[area + idx_b] + wc * pdata[area + idx_c] + wd * pdata[area + idx_d];
float pix_b = wa * pdata[2 * area + idx_a] + wb * pdata[2 * area + idx_b] + wc * pdata[2 * area + idx_c] + wd * pdata[2 * area + idx_d];
output.at<Vec3f>(i, j) = Vec3f(pix_r, pix_g, pix_b);
}
}
return output;
}
Mat draw_mesh_on_warp(const Mat warp, const Mat f_local_x, const Mat f_local_y)