Sinkhorn算法实战:用Python快速计算最优传输距离(附完整代码)

张开发
2026/4/18 22:34:11 15 分钟阅读

分享文章

Sinkhorn算法实战:用Python快速计算最优传输距离(附完整代码)
Sinkhorn算法实战用Python快速计算最优传输距离附完整代码当我们需要衡量两个概率分布之间的差异时最优传输距离提供了一个优雅而强大的数学框架。想象一下你有一堆沙子分布在不同的位置源分布需要将它们搬运到另一组目标位置目标分布而每个搬运路径都有不同的成本。最优传输问题就是要找到最经济的搬运方案使得总成本最小。这就是最优传输距离的直观解释。在机器学习领域最优传输距离被广泛应用于图像处理、自然语言处理、生成模型等领域。然而直接计算最优传输距离的计算复杂度很高对于大规模问题来说几乎不可行。这就是Sinkhorn算法大显身手的地方——它通过引入熵正则化将问题转化为可以通过迭代快速求解的形式。本文将带你从零开始实现Sinkhorn算法解决实际应用中可能遇到的各种问题。无论你是想将最优传输距离应用于图像风格迁移、文档匹配还是其他机器学习任务这里的代码和技巧都能为你提供实用参考。1. 环境准备与基础概念在开始编码之前我们需要确保开发环境配置正确并理解几个关键概念。最优传输问题的数学表述看起来可能有些抽象但我们会用直观的例子来解释每个部分的意义。首先安装必要的Python库pip install numpy scipy matplotlib最优传输问题的核心要素包括源分布和目标分布两个需要比较的概率分布代价矩阵表示将质量从源位置移动到目标位置的成本传输计划描述如何将质量从源分布转移到目标分布的方案让我们用一个简单的例子来说明。假设有三个仓库源分布需要向三个商店目标分布配送商品import numpy as np # 仓库的库存分布源分布 r np.array([0.5, 0.3, 0.2]) # 商店的需求分布目标分布 c np.array([0.4, 0.4, 0.2]) # 运输成本矩阵仓库到商店的距离 C np.array([ [4, 8, 6], # 仓库1到各商店的距离 [3, 7, 5], # 仓库2到各商店的距离 [2, 4, 6] # 仓库3到各商店的距离 ])在这个例子中我们的目标是找到一个运输方案π使得总运输成本∑πᵢⱼCᵢⱼ最小同时满足仓库的供应限制和商店的需求限制。2. Sinkhorn算法原理与实现Sinkhorn算法的核心思想是通过引入熵正则化将原始的最优传输问题转化为可以通过迭代行列归一化来求解的形式。这种方法大大降低了计算复杂度使得算法可以处理中等规模的问题。2.1 算法数学原理Sinkhorn算法通过以下步骤求解最优传输问题构造核矩阵KK exp(-C/ε)其中ε是正则化参数初始化缩放因子u 1, v 1迭代更新u r / (Kv)v c / (Kᵀu)计算传输计划π diag(u) K diag(v)计算传输成本π, C ∑πᵢⱼCᵢⱼ正则化参数ε的选择很重要ε越大解越平滑计算越快但偏离原始最优传输解ε越小解越接近原始最优传输解但计算越慢2.2 Python完整实现下面是Sinkhorn算法的完整Python实现包含了详细的注释和错误处理import numpy as np def sinkhorn(cost_matrix, source_dist, target_dist, epsilon0.1, max_iter1000, tol1e-6, verboseFalse): Sinkhorn算法计算最优传输距离 参数: cost_matrix (np.ndarray): 代价矩阵 (n x m) source_dist (np.ndarray): 源分布 (n,) target_dist (np.ndarray): 目标分布 (m,) epsilon (float): 正则化参数 max_iter (int): 最大迭代次数 tol (float): 收敛阈值 verbose (bool): 是否打印迭代信息 返回: transport_plan (np.ndarray): 传输计划矩阵 (n x m) total_cost (float): 最优传输距离 # 输入验证 assert cost_matrix.ndim 2, 代价矩阵必须是二维数组 n, m cost_matrix.shape assert len(source_dist) n, 源分布维度与代价矩阵不匹配 assert len(target_dist) m, 目标分布维度与代价矩阵不匹配 assert np.all(source_dist 0), 源分布必须非负 assert np.all(target_dist 0), 目标分布必须非负 # 归一化分布 source_dist source_dist / source_dist.sum() target_dist target_dist / target_dist.sum() # 初始化核矩阵 K np.exp(-cost_matrix / epsilon) # 初始化缩放因子 u np.ones(n) v np.ones(m) # 迭代更新 for i in range(max_iter): u_prev u.copy() # 更新u和v v target_dist / (K.T u 1e-16) # 防止除以零 u source_dist / (K v 1e-16) # 检查收敛 if np.linalg.norm(u - u_prev) tol: if verbose: print(f在第{i}次迭代后收敛) break else: if verbose: print(f达到最大迭代次数{max_iter}仍未收敛) # 计算传输计划和总成本 transport_plan np.diag(u) K np.diag(v) total_cost np.sum(transport_plan * cost_matrix) return transport_plan, total_cost2.3 算法应用示例让我们用前面的仓库配送问题来测试我们的实现# 定义问题参数 C np.array([[4, 8, 6], [3, 7, 5], [2, 4, 6]]) r np.array([0.5, 0.3, 0.2]) c np.array([0.4, 0.4, 0.2]) # 计算最优传输 pi, cost sinkhorn(C, r, c, epsilon0.1, verboseTrue) print(最优传输计划:) print(pi) print(f最优传输距离: {cost:.4f})运行结果会显示如何最优地从仓库向商店配送商品以及最小的总运输成本。3. 参数选择与性能优化在实际应用中Sinkhorn算法的性能很大程度上取决于参数的选择和实现细节。本节将讨论如何调优算法以获得最佳性能。3.1 正则化参数ε的选择ε的选择需要在精度和计算效率之间权衡ε值计算速度近似精度适用场景较大(1.0)快低需要快速估计的场景中等(0.1)中等中等大多数应用场景较小(0.01)慢高需要高精度的场景建议从ε0.1开始根据需求调整。可以通过观察传输计划的变化来确定合适的ε值。3.2 收敛阈值与迭代次数默认的tol1e-6和max_iter1000适用于大多数情况但对于特别大或特别小的问题可能需要调整# 对于大型问题可以放宽收敛阈值以加快计算 pi, cost sinkhorn(large_C, large_r, large_c, tol1e-4) # 对于高精度需求的小问题可以减小阈值 pi, cost sinkhorn(small_C, small_r, small_c, tol1e-8, max_iter5000)3.3 数值稳定性技巧Sinkhorn算法在数值计算上可能会遇到一些问题特别是当ε很小时。以下是一些提高稳定性的技巧对数域计算避免直接计算指数改用对数操作裁剪处理防止数值溢出添加小常数防止除以零下面是改进后的稳定版本def sinkhorn_stable(cost_matrix, source_dist, target_dist, epsilon0.1, max_iter1000, tol1e-6): 数值稳定的Sinkhorn实现 # 对数域计算 log_K -cost_matrix / epsilon log_u np.zeros_like(source_dist) log_v np.zeros_like(target_dist) for _ in range(max_iter): log_u_prev log_u.copy() # 更新log_v和log_u log_v np.log(target_dist 1e-16) - np.log(np.exp(log_u[:, None] log_K).sum(0) 1e-16) log_u np.log(source_dist 1e-16) - np.log(np.exp(log_K log_v[None, :]).sum(1) 1e-16) if np.max(np.abs(log_u - log_u_prev)) tol: break # 计算传输计划 transport_plan np.exp(log_u[:, None] log_K log_v[None, :]) total_cost np.sum(transport_plan * cost_matrix) return transport_plan, total_cost4. 实际应用案例Sinkhorn算法在机器学习中有广泛的应用。让我们看几个实际案例了解如何将算法应用于真实问题。4.1 图像颜色迁移颜色迁移是将一张图像的色彩风格应用到另一张图像上的任务。我们可以将图像像素看作概率分布使用最优传输来匹配颜色分布。import cv2 import matplotlib.pyplot as plt def color_transfer(source_img, target_img, epsilon0.01): 使用Sinkhorn算法进行颜色迁移 # 将图像转换为Lab颜色空间更好的色彩感知 source_lab cv2.cvtColor(source_img, cv2.COLOR_RGB2LAB) target_lab cv2.cvtColor(target_img, cv2.COLOR_RGB2LAB) # 提取ab通道色彩信息 source_ab source_lab[:, :, 1:].reshape(-1, 2) target_ab target_lab[:, :, 1:].reshape(-1, 2) # 随机采样以减少计算量 n_samples 500 source_ab source_ab[np.random.choice(len(source_ab), n_samples)] target_ab target_ab[np.random.choice(len(target_ab), n_samples)] # 计算颜色之间的距离矩阵 C np.sqrt(((source_ab[:, None] - target_ab[None, :])**2).sum(2)) # 均匀分布假设 r np.ones(len(source_ab)) / len(source_ab) c np.ones(len(target_ab)) / len(target_ab) # 计算最优传输 pi, _ sinkhorn(C, r, c, epsilonepsilon) # 计算传输后的颜色 transferred_ab target_ab[np.argmax(pi, axis1)] # 重建图像 result_lab source_lab.copy() result_lab[:, :, 1:] transferred_ab.reshape(result_lab.shape[0], result_lab.shape[1], 2) result cv2.cvtColor(result_lab, cv2.COLOR_LAB2RGB) return result # 示例使用 source_img plt.imread(source.jpg) target_img plt.imread(target.jpg) result_img color_transfer(source_img, target_img) plt.figure(figsize(15, 5)) plt.subplot(131); plt.imshow(source_img); plt.title(源图像) plt.subplot(132); plt.imshow(target_img); plt.title(目标风格) plt.subplot(133); plt.imshow(result_img); plt.title(迁移结果) plt.show()4.2 文档相似度计算在自然语言处理中我们可以使用最优传输距离来衡量两个文档的相似度。将文档表示为词嵌入的分布然后计算它们之间的最优传输距离。from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity def document_distance(doc1, doc2, epsilon0.1): 使用最优传输计算文档距离 # 创建TF-IDF向量化器 vectorizer TfidfVectorizer().fit([doc1, doc2]) # 获取词向量和词汇表 vocab vectorizer.get_feature_names_out() vec1 vectorizer.transform([doc1]).toarray()[0] vec2 vectorizer.transform([doc2]).toarray()[0] # 归一化为概率分布 vec1 vec1 / vec1.sum() vec2 vec2 / vec2.sum() # 计算词嵌入这里简化使用TF-IDF向量本身 embeddings np.eye(len(vocab)) # 简化示例 # 计算词之间的代价矩阵 C 1 - cosine_similarity(embeddings) # 计算最优传输距离 _, distance sinkhorn(C, vec1, vec2, epsilonepsilon) return distance # 示例使用 doc1 机器学习算法包括监督学习和无监督学习 doc2 深度学习是机器学习的一个分支 doc3 太阳系有八大行星围绕太阳旋转 print(f文档1和2的距离: {document_distance(doc1, doc2):.4f}) print(f文档1和3的距离: {document_distance(doc1, doc3):.4f})4.3 生成模型评估在生成对抗网络(GANs)中最优传输距离可以作为评估生成样本质量的指标。我们可以计算生成样本分布和真实数据分布之间的最优传输距离。def evaluate_gans(real_samples, generated_samples, epsilon0.1, n_subsample500): 使用最优传输距离评估GAN生成质量 # 随机子采样以减少计算量 real_samples real_samples[np.random.choice(len(real_samples), n_subsample)] gen_samples generated_samples[np.random.choice(len(generated_samples), n_subsample)] # 计算样本之间的距离矩阵 C np.sqrt(((real_samples[:, None] - gen_samples[None, :])**2).sum(2)) # 均匀分布假设 r np.ones(len(real_samples)) / len(real_samples) c np.ones(len(gen_samples)) / len(gen_samples) # 计算最优传输距离 _, distance sinkhorn(C, r, c, epsilonepsilon) return distance # 示例使用假设real_data和generated_data是numpy数组 # real_data ... # 真实数据样本 # generated_data ... # GAN生成样本 # ot_distance evaluate_gans(real_data, generated_data) # print(f最优传输距离: {ot_distance:.4f})在实际项目中我发现Sinkhorn算法对于中等规模的数据集几千个点计算效率已经相当不错但对于更大规模的问题可能需要考虑近似方法或分布式计算。另一个实用技巧是对于相同代价矩阵的多次计算可以预先计算并缓存核矩阵K这样可以节省大量重复计算时间。

更多文章