从PSMNet到GwcNet:CVPR 2019立体匹配论文的代码级改进点解析

张开发
2026/4/18 13:07:28 15 分钟阅读

分享文章

从PSMNet到GwcNet:CVPR 2019立体匹配论文的代码级改进点解析
从PSMNet到GwcNet立体匹配网络的核心改进与代码实现深度剖析在计算机视觉领域立体匹配一直是三维重建和深度估计的基础任务之一。2017年提出的PSMNetPyramid Stereo Matching Network通过引入空间金字塔池化和3D沙漏网络在当时的主流数据集上取得了显著优势。而2019年CVPR发表的GwcNetGroup-wise Correlation Stereo Network则在PSMNet的基础上进行了多项关键改进这些改进不仅提升了模型精度更为后续的立体匹配网络设计提供了新的思路。本文将深入解析GwcNet相对于PSMNet的核心改进点特别是从代码实现层面剖析这些改进如何具体影响网络性能。1. 代价体构建的革新从简单级联到分组相关立体匹配网络的核心挑战之一是如何有效地构建代价体cost volume即衡量左右图像对应点匹配程度的四维数据结构。PSMNet采用了一种直观但计算量较大的方法——直接级联左右图像的特征图。1.1 PSMNet的concat_volume实现PSMNet中构建代价体的关键代码如下def build_concat_volume(refimg_fea, targetimg_fea, maxdisp): B, C, H, W refimg_fea.shape volume refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W]) for i in range(maxdisp): if i 0: volume[:, :C, i, :, i:] refimg_fea[:, :, :, i:] volume[:, C:, i, :, i:] targetimg_fea[:, :, :, :-i] else: volume[:, :C, i, :, :] refimg_fea volume[:, C:, i, :, :] targetimg_fea volume volume.contiguous() return volume这种方法虽然简单直接但存在两个明显问题通道维度爆炸输出代价体的通道数是输入特征的两倍2*C增加了后续3D卷积的计算负担信息冗余简单的特征堆叠并不能有效捕捉左右图像特征间的相关性1.2 GwcNet的group-wise correlation创新GwcNet创新性地提出了分组相关的代价体构建方法其核心实现分为两个部分def groupwise_correlation(fea1, fea2, num_groups): B, C, H, W fea1.shape assert C % num_groups 0 channels_per_group C // num_groups cost (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim2) return cost def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups): B, C, H, W refimg_fea.shape volume refimg_fea.new_zeros([B, num_groups, maxdisp, H, W]) for i in range(maxdisp): if i 0: volume[:, :, i, :, i:] groupwise_correlation( refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], num_groups) else: volume[:, :, i, :, :] groupwise_correlation( refimg_fea, targetimg_fea, num_groups) return volume这种设计带来了三个关键优势计算效率提升通过分组计算相关性输出通道数从2C减少到num_groups通常设为40物理意义明确点积操作更接近传统立体匹配中的相关性计算信息互补GwcNet实际同时使用了分组相关和级联两种代价体形成互补实际应用中num_groups的选择需要权衡组数太少会导致特征区分度不足太多则增加计算量。论文实验表明40组在Scene Flow数据集上效果最佳。2. 3D聚合模块的优化沙漏网络结构调整代价体构建后需要通过3D卷积网络进行代价聚合cost aggregation这是立体匹配网络的另一个关键组件。PSMNet和GwcNet都采用了堆叠沙漏stacked hourglass结构但GwcNet做了重要改进。2.1 PSMNet的原始沙漏设计PSMNet使用了两个堆叠的3D沙漏模块每个沙漏网络包含以下特点对称的编码器-解码器结构跳跃连接保留多尺度信息中间监督提升训练效果这种设计虽然有效但存在梯度流动路径过长、信息冗余等问题。2.2 GwcNet的三项关键改进GwcNet对3D聚合模块进行了三项重要调整移除沙漏间的shortcut连接减少了冗余信息的传递增加1×1×1的3D卷积过渡更灵活地调节特征维度扩展为三个沙漏模块增强特征提取能力这些改变带来的实际效果对比如下改进点PSMNetGwcNet效果提升沙漏数量230.3% EPE沙漏间连接方式直接跳连1×1×1卷积0.2% EPE中间监督位置每个沙漏后每个沙漏后保持稳定EPEEnd-Point Error是立体匹配中常用的评估指标表示预测视差与真实视差之间的平均欧氏距离。3. 输出模块与损失函数的精细化设计GwcNet在输出预测和损失计算方面也做了细致优化这些改进往往容易被忽视但对最终精度有重要影响。3.1 多尺度输出与融合GwcNet的3D聚合模块包含四个输出点output0-output3分别对应不同深度的特征# 简化版的输出处理流程 for i in range(4): cost self.__getattr__(foutput{i})(x) # 各输出点的3D卷积 cost F.interpolate(cost, scale_factor2, modetrilinear) prob F.softmax(cost, dim1) # 视差维度softmax disp disparity_regression(prob, maxdisp) # 视差回归 disp_preds.append(disp)这种设计实现了渐进式优化浅层输出捕捉局部细节深层输出包含全局信息中间监督每个输出点都参与损失计算缓解梯度消失3.2 自适应加权损失函数GwcNet的损失函数考虑了不同输出点的贡献差异$$ L \sum_{i0}^{3} \lambda_i \cdot Smooth_{L1}(\widetilde{d_i} - d^*) $$其中权重$\lambda_i$的设置遵循两个原则深层输出权重更大$\lambda_30.8$总权重和为1$\lambda_00.1$, $\lambda_10.1$, $\lambda_20$这种加权方式既保证了深层特征的优化力度又避免了过强的中间监督干扰特征学习。4. 实际部署中的工程考量将理论改进转化为实际性能提升还需要考虑工程实现细节。以下是GwcNet实现中的几个关键点4.1 内存优化策略立体匹配网络面临的主要挑战之一是显存占用。GwcNet通过以下方式优化内存使用梯度检查点Gradient Checkpointingfrom torch.utils.checkpoint import checkpoint def forward(self, x): x checkpoint(self.block1, x) # 不保存中间激活值 x self.block2(x) return x混合精度训练with torch.cuda.amp.autocast(): output model(input) loss criterion(output, target)4.2 推理速度优化实际应用中我们常需要平衡精度和速度。GwcNet的推理过程可以针对不同场景调整配置项高精度模式快速模式差异输入分辨率原图1/2缩放30% FPS最大视差19212825% FPS沙漏模块数3220% FPS4.3 跨数据集泛化技巧在不同数据集上微调GwcNet时有几个实用技巧渐进式视差调整先在小视差范围训练再逐步扩大特征提取器冻结先固定特征提取部分只训练3D聚合模块自适应损失权重根据各数据集的视差分布调整$\lambda_i$这些改进虽然看似微小但在KITTI、ETH3D等真实场景数据上能带来2-3%的性能提升。

更多文章