MindSpore从入门到精通:梯度截断、Stop Gradient 与辅助数据梯度处理最佳实践
默认情况下,如果前向函数只返回 loss 一个值,mindspore.grad 只会计算「loss 对指定参数的梯度」,这也是我们训练模型的核心诉求。 但如果前向函数返回多个输出项(如 loss + logits 预测值),MindSpore 的微分函数会默认计算:所有输出项对指定参数的梯度之和,这会导致最终的梯度值失真,与我们需要的「仅 loss 求梯度」的结果不一致! 实战验证:多输出函数的梯度失真问题 运行结果: 结果对比: 只需要对不需要参与梯度计算的输出项(本例中的 z)包裹stop_gradient,即可实现「仅 loss 求梯度」: 运行结果: 结果验证:此时的梯度值与「单输出函数仅返回 loss」的梯度值完全一致,问题完美解决! 运行结果:本文将讲解 MindSpore 中两个高频核心知识点:
问题引入:多输出函数的梯度计算陷阱
# 定义返回 loss + z(预测值) 的多输出函数
def function_with_logits(x, y, w, b):
z = ops.matmul(x, w) + b
loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss, z # 输出项1:loss,输出项2:预测值z
# 生成微分函数,依旧对w(2)、b(3)求导
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print("多输出函数的梯度值:\n", grads)多输出函数的梯度值:
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00],
[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00],
[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00],
[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00],
[ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00]]), Tensor(shape=[3], dtype=Float32, value= [ 1.32618928e+00, 1.01589143e+00, 1.04216456e+00]))解决方案一:Stop Gradient 手动梯度截断【核心 API】
Stop Gradient 核心作用
实战:使用 Stop Gradient 修正梯度计算
def function_stop_gradient(x, y, w, b):
z = ops.matmul(x, w) + b
loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss, ops.stop_gradient(z) # 对z进行梯度截断
# 生成微分函数并求梯度
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print("梯度截断后的梯度值:\n", grads)梯度截断后的梯度值:
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02]]), Tensor(shape=[3], dtype=Float32, value= [ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02]))解决方案二:has_aux=True 自动处理辅助数据【推荐最佳实践】
辅助数据(Auxiliary data)定义
has_aux 参数的核心能力
实战:has_aux=True 优雅实现梯度计算 + 辅助数据返回
# 复用未做任何修改的多输出函数 function_with_logits
def function_with_logits(x, y, w, b):
z = ops.matmul(x, w) + b
loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss, z
# 仅需添加 has_aux=True,无需手动截断梯度
grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b) # 解构:梯度 + 辅助数据
print("梯度值(与单输出一致):\n", grads)
print("辅助数据z(预测值):\n", z)梯度值(与单输出一致):
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02],
[ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02]]), Tensor(shape=[3], dtype=Float32, value= [ 3.26189250e-01, 1.58914644e-02, 4.21645455e-02]))
辅助数据z(预测值):
[ 3.8211915 -2.994512 -1.932323 ]两大方案对比与选型建议
核心总结