计算[-1,1]中c的sqrt((b²*c²)/(1-c²))的数值稳定方法

对于一些实际价值bcin [-1, 1],我需要计算

sqrt( (b²*c²) / (1-c²) ) = (|b|*|c|) / sqrt((1-c)*(1+c))

c接近 1 或 -1时,分母中出现灾难性抵消。平方根可能也无济于事。

我想知道是否有一个聪明的技巧可以在这里应用来避免 c=1 和 c=-1 周围的困难区域?

回答

这种稳定性方面最有趣的部分是分母sqrt(1 - c*c)。为此,您需要做的就是将其扩展为sqrt(1 - c) * sqrt(1 + c). 我不认为这真的是一个“聪明的把戏”,但这就是所需要的。

对于典型的二进制浮点格式(例如 IEEE 754 binary64,但其他常见格式应该表现得同样好,除了像double-double格式这样令人不快的东西),如果c接近1then1 - c将被精确计算,通过Sterbenz' Lemma,虽然1 + c没有任何稳定性问题。同样,如果c接近-1then1 + c将被精确计算,并且1 - c将被精确计算。平方根和乘法运算不会引入重大的新错误。

这是一个数值演示,在具有 IEEE 754 binary64 浮点和正确舍入sqrt运算的机器上使用 Python 。

让我们c接近(但小于)1

>>> c = float.fromhex('0x1.ffffffff24190p-1')
>>> c
0.9999999999

在这里我们必须小心一点:请注意,显示的十进制值0.999999999是的精确值的近似值c。确切值如十六进制字符串的构造所示,或以分数形式显示,562949953365017/562949953421312我们关心的是获得良好结果的确切值。

表达式 的精确值(sqrt(1 - c*c)四舍五入到小数点后 100 位)是:

0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813

我使用 Python 的decimal模块计算了这个,并使用Pari/GP仔细检查了结果。这是Python计算:

>>> from decimal import Decimal, getcontext
>>> getcontext().prec = 1000
>>> good = (1 - Decimal(c) * Decimal(c)).sqrt().quantize(Decimal("1e-100"))
>>> print(good)
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813

如果我们天真地计算,我们会得到这个结果:

>>> from math import sqrt
>>> naive = sqrt(1 - c*c)
>>> naive
1.4142136208793713e-05

我们可以很容易地计算出 ulps 错误的近似数量(对正在进行的类型转换的数量表示歉意 -float并且Decimal实例不能直接在算术运算中混合):

>>> from math import ulp
>>> float((Decimal(naive) - good) / Decimal(ulp(float(good))))
208701.28298527992

所以朴素的结果是几十万 ulp——粗略地说,我们已经失去了大约 5 个小数位的准确性。

现在让我们尝试使用扩展版本:

>>> better = sqrt(1 - c) * sqrt(1 + c)
>>> better
1.4142136208440158e-05
>>> float((Decimal(better) - good) / Decimal(ulp(float(good))))
-0.7170147200803595

所以在这里我们准确到优于 1 ulp 错误。不完全正确四舍五入,但下一个最好的事情。

通过更多的工作,假设 IEEE 754 二进制浮点,舍入到偶数舍入模式,应该可以声明并证明表达式中 ulps 错误数量的绝对上限sqrt(1 - c) * sqrt(1 + c),在域上-1 < c < 1,和正确四舍五入的操作贯穿始终。我还没有这样做,但如果该上限结果超过 10 ulps,我会感到非常惊讶。

  • Is `sqrt(1-c)*sqrt(1+c)` better than `sqrt((1-c)*(1+c))`?
  • @chtz: Ah, good point. `sqrt((1 - c) * (1 + c))` ought to be even better, since `sqrt` is a contracting operation that tends to reduce relative error. I'll edit in a bit.

回答

Mark Dickinson 为一般情况提供了一个很好的答案,我将使用更专业的方法对其进行补充。

如今,许多计算环境都提供了一种称为融合乘加或简称 FMA 的操作,它是专门针对此类情况而设计的。在fma(a, b, c)全积a * b(未截断和未c舍入)的计算中,与 相加,然后在最后应用一个舍入。

当前出货的 GPU 和 CPU,包括基于 ARM64、x86-64 和 Power 架构的 GPU 和 CPU,通常包括 FMA 的快速硬件实现,它以 C 和 C++ 系列的编程语言以及许多其他语言作为标准公开数学函数fma()。一些(通常是较旧的)软件环境使用 FMA 的软件仿真,并且发现其中一些仿真有问题。此外,这种模拟往往很慢。

在 FMA 可用的情况下,表达式可以在数值上稳定且没有过早上溢和下溢的风险fabs (b * c) / sqrt (fma (c, -c, 1.0)),其中fabs()是浮点操作数的绝对值运算并sqrt()计算平方根。某些环境还提供倒数平方根运算,通常称为rsqrt(),在这种情况下,潜在的替代方法是使用fabs (b * c) * rsqrt (fma (c, -c, 1.0))。使用rsqrt()避免了相对昂贵的除法,因此通常更快。但是,许多 的实现rsqrt()都没有像 那样正确舍入sqrt(),因此准确性可能会差一些。

对下面代码的快速实验似乎表明,只要b正常的浮点数,基于 FMA 的表达式的最大误差约为 3 ulps 。我强调这并不能证明任何错误界限。自动Herbie 工具,它试图找到给定浮点表达式在数值上有利的重写,建议使用fabs (b * c) * sqrt (1.0 / fma (c, -c, 1.0)). 然而,这似乎是一个虚假的结果,因为我既无法想到任何特定的优势,也无法通过实验找到一个。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>

#define USE_ORIGINAL  (0)
#define USE_HERBIE    (1)

/* function under test */
float func (float b, float c)
{
#if USE_HERBIE
     return fabsf (b * c) * sqrtf (1.0f / fmaf (c, -c, 1.0f));
#else USE_HERBIE
     return fabsf (b * c) / sqrtf (fmaf (c, -c, 1.0f));
#endif // USE_HERBIE
}

/* reference */
double funcd (double b, double c)
{
#if USE_ORIGINAL
    double b2 = b * b;
    double c2 = c * c;
    return sqrt ((b2 * c2) / (1.0 - c2));
#else
    return fabs (b * c) / sqrt (fma (c, -c, 1.0));
#endif
}

uint32_t float_as_uint32 (float a)
{
    uint32_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

float uint32_as_float (uint32_t a)
{
    float r;
    memcpy (&r, &a, sizeof r);
    return r;
}

uint64_t double_as_uint64 (double a)
{
    uint64_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

double floatUlpErr (float res, double ref)
{
    uint64_t i, j, err, refi;
    int expoRef;
    
    /* ulp error cannot be computed if either operand is NaN, infinity, zero */
    if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
        (res == 0.0f) || (ref == 0.0f)) {
        return 0.0;
    }
    /* Convert the float result to an "extended float". This is like a float
       with 56 instead of 24 effective mantissa bits.
    */
    i = ((uint64_t)float_as_uint32(res)) << 32;
    /* Convert the double reference to an "extended float". If the reference is
       >= 2^129, we need to clamp to the maximum "extended float". If reference
       is < 2^-126, we need to denormalize because of the float types's limited
       exponent range.
    */
    refi = double_as_uint64(ref);
    expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
    if (expoRef >= 129) {
        j = 0x7fffffffffffffffULL;
    } else if (expoRef < -126) {
        j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
        j = j >> (-(expoRef + 126));
    } else {
        j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
        j = j | ((uint64_t)(expoRef + 127) << 55);
    }
    j = j | (refi & 0x8000000000000000ULL);
    err = (i < j) ? (j - i) : (i - j);
    return err / 4294967296.0;
}

// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC  ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579)                     /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)

#define N  (20)

int main (void)
{
    float b, c, errloc_b, errloc_c, res;
    double ref, err, maxerr = 0;
    
    c = -1.0f;
    while (c <= 1.0f) {
        /* try N random values of `b` per every value of `c` */
        for (int i = 0; i < N; i++) {
            /* allow only normals */
            do {
                b = uint32_as_float (KISS);
            } while (!isnormal (b));
            res = func (b, c);
            ref = funcd ((double)b, (double)c);
            err = floatUlpErr (res, ref);
            if (err > maxerr) {
                maxerr = err;
                errloc_b = b;
                errloc_c = c;
            }
        }
        c = nextafterf (c, INFINITY);
    }
#if USE_HERBIE
    printf ("HERBIE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)n", maxerr, errloc_b, errloc_c);
#else // USE_HERBIE
    printf ("SIMPLE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)n", maxerr, errloc_b, errloc_c);
#endif // USE_HERBIE
    
    return EXIT_SUCCESS;
}

  • I think it's hard to overstate how generally useful `fma` is for accurate computation. The most obvious advantage is that in the typical case it only rounds once rather than twice, but the bigger impact in my opinion is that the general cause of bad floating point accuracy is the combination of 2 operations, one of which affects relative error, and the next which turns that into an absolute error. `fma` deals with a large class of these problems.

以上是计算[-1,1]中c的sqrt((b²*c²)/(1-c²))的数值稳定方法的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>