找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

3464

积分

0

好友

474

主题
发表于 昨天 04:02 | 查看: 4| 回复: 0

在机器学习和深度学习中,余弦相似度是衡量两个向量方向一致性(即夹角大小)的常用指标。其数学计算公式为两个向量的点积除以它们各自模(范数)的乘积:

[
\text{similarity} = \cos(\theta) = \frac{\mathbf{A} \cdot \mathbf{B}}{|\mathbf{A}| |\mathbf{B}|}
]

其中,|\mathbf{A}| 和 |\mathbf{B}| 分别代表向量 \mathbf{A} 和 \mathbf{B} 的 L2 范数。这个公式简洁明了,但在实际的代码实现中,如果处理不当,可能会导致意想不到的错误。

基础实现与潜在问题

一个直接的 PyTorch 实现可能如下所示:

class CosineSimilarity(nn.Module):
  def forward(self, tensor_0,tensor_1):
    normalized_tensor_0 = tensor_0 / tensor_0.norm(dim = -1,keepdim = True)
    normalized_tensor_1 = tensor_1 / tensor_1.norm(dim = -1,keepdim = True)
    return (normalized_tensor_0 * normalized_tensor_1).sum(dim = -1)

这段代码看起来没有问题,但它隐藏着两个可能导致计算结果无效甚至程序崩溃的安全隐患:

  1. 除零错误 (Division by Zero)tensor.norm() 计算的 L2 范数可能为 0(例如,当向量是一个全零向量时)。在这种情况下,normalized_tensor = tensor / norm_tensor 的计算会出现 0 做分母,导致计算结果为 NaN(Not a Number)。

  2. 数值溢出 (Numerical Overflow)tensor.norm() 计算的 L2 范数可能非常大,以至于超出输入张量数据类型的表示范围,从而得到 inf(无穷大)。那么,在后续的归一化计算中,用 tensor / inf 会得到一个不正确的 0 值,这并非是数学上应有的结果,而是由数值溢出导致的错误。

这两个问题在生产环境的模型推理或训练中非常危险,因为它们会悄无声息地污染损失函数或模型输出,导致训练失败或预测结果不可靠。

更安全的实现方案

为了解决上述问题,我们需要一个更健壮的余弦相似度计算方案。核心思路是:

  • 规避零值:在计算归一化时,确保分母不为零。
  • 防止溢出:对于可能溢出的范数计算,考虑进行数据类型提升。

一种改进的实现如下所示:

class CosineSimilarity(nn.Module):
  def forward(self, tensor_0, tensor_1):
    norm_tensor_0 = tensor_0.norm(dim = -1, keepdim = True)
    norm_tensor_1 = tensor_1.norm(dim = -1, keepdim = True)
    norm_tensor_0 = norm_tensor_0.numpy()
    norm_tensor_1 = norm_tensor_1.numpy()
    for  i, vec2 in enumerate(norm_tensor_0[0]):
      if vec2 == 0:
        norm_tensor_0[0][i] = 1
    for i, vec2 in enumerate(norm_tensor_1[0]):
      if vec2 == 0:
        norm_tensor_1[0][i] = 1
    norm_tensor_0 = torch.tensor(norm_tensor_0)
    norm_tensor_1 = torch.tensor(norm_tensor_1)
    normalized_tensor_0 = tensor_0 / norm_tensor_0
    normalized_tensor_1 = tensor_1 / norm_tensor_1
    return (normalized_tensor_0*normalized_tensor_1).sum(dim = -1)

这个实现通过将张量转换为 NumPy 数组,手动检查并将范数中的零值替换为 1,从而避免了除零错误。将分母置为1意味着对于零向量,其“归一化”结果仍是零向量,最终其与任何向量的余弦相似度将被正确地计算为 0。处理数值溢出可能需要更复杂的策略,例如在计算范数前进行张量缩放或使用更高精度的数据类型,这里展示了处理零值的基本思路。

在进行复杂的深度学习项目开发时,关注这些底层计算的健壮性至关重要。看似微小的数值问题,往往会导致模型在训练后期难以排查的失败。希望这个关于 torch.cosine_similarity 安全计算的讨论,能帮助你在构建更稳定的模型时多一份考量。更多关于实用编程技巧和问题排查的经验,欢迎在 云栈社区 与大家交流探讨。




上一篇:Java分库分表实战:业务拆分与范围分片策略对比,附小红书面试经验
下一篇:CPU、GPU、APU核心区别解析:从原理到实际应用场景选择指南
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2026-2-23 10:27 , Processed in 0.557692 second(s), 41 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2026 云栈社区.

快速回复 返回顶部 返回列表