Обсуждение: Some improvements to numeric sqrt() and ln()

Поиск
Список
Период
Сортировка

Some improvements to numeric sqrt() and ln()

От
Dean Rasheed
Дата:
Attached is a WIP patch to improve the performance of numeric sqrt()
and ln(), which also makes a couple of related improvements to
div_var_fast(), all of which have knock-on benefits for other numeric
functions. The actual impact varies greatly depending on the inputs,
but the overall effect is to reduce the run time of the numeric_big
regression test by about 20%.

Additionally it improves the accuracy of sqrt() -- currently sqrt()
sometimes rounds the last digit of the result the wrong way, for
example sqrt(100000000000000010000000000000000) returns
10000000000000001, when the correct answer should be 10000000000000000
to zero decimal places. With this patch, sqrt() guarantees to return
the result correctly rounded to the last digit for all inputs.

The main change is to sqrt_var(), which now uses a different algorithm
[1] for better performance than the Newton-Raphson method. Actually
I've re-cast the algorithm from [1] into an iterative form, rather
than doing it recursively, as it's presented in that paper. This
improves performance further, by avoiding overheads from function
calls and copying numeric variables around. Also, IMO, the iterative
form of the algorithm is much more elegant, since it works by making a
single pass over the input digits, consuming them one at a time from
most significant to least, producing a succession of increasingly more
accurate approximations to the square root, until the desired
precision is reached.

For inputs with a handful of digits, this is typically 3-5 times
faster, and for inputs with more digits the performance improvement is
larger (e.g. sqrt(2e131071) is around 10 times faster). If the input
is a perfect square, with a result having a lot of trailing zeros, the
new algorithm is much faster because it basically has nothing to do in
later iterations (e.g., sqrt(64e13070) is about 600 times faster).

Another change to sqrt_var() is that it now explicitly supports a
negative rscale, i.e., rounding before the decimal point. This is
exploited by ln_var() in its argument reduction stage -- ln_var()
reduces all inputs to the range (0.9, 1.1) by repeatedly taking the
square root. For very large inputs this can have an enormous impact,
for example log(1e131071) currently takes about 6.5 seconds on my
machine, whereas with this patch I can run it 1000 times in a plpgsql
loop in about 90ms, so its around 70,000 times faster in that case. Of
course, that's an extreme example, and for most inputs it's a much
more modest difference (e.g., ln(2) is about 1.5 times faster).

In passing, I also made a couple of optimisations to div_var_fast(),
discovered while comparing it's performace with div_var() for various
inputs.

It's possible that there are further gains to be had in the sqrt()
algorithm on platforms that support 128-bit integers, but I haven't
had a chance to investigate that yet.

Regards,
Dean

[1] https://hal.inria.fr/inria-00072854/document

Вложения

Re: Some improvements to numeric sqrt() and ln()

От
Dean Rasheed
Дата:
On Fri, 28 Feb 2020 at 08:15, Dean Rasheed <dean.a.rasheed@gmail.com> wrote:
>
> It's possible that there are further gains to be had in the sqrt()
> algorithm on platforms that support 128-bit integers, but I haven't
> had a chance to investigate that yet.
>

Rebased patch attached, now using 128-bit integers for part of
sqrt_var() on platforms that support them. This turned out to be well
worth it (1.5 to 2 times faster than the previous version if the
result has less than 30 or 40 digits).

Regards,
Dean

Вложения

Re: Some improvements to numeric sqrt() and ln()

От
Tels
Дата:
Dear Dean,

On 2020-03-01 20:47, Dean Rasheed wrote:
> On Fri, 28 Feb 2020 at 08:15, Dean Rasheed <dean.a.rasheed@gmail.com> 
> wrote:
>> 
>> It's possible that there are further gains to be had in the sqrt()
>> algorithm on platforms that support 128-bit integers, but I haven't
>> had a chance to investigate that yet.
>> 
> 
> Rebased patch attached, now using 128-bit integers for part of
> sqrt_var() on platforms that support them. This turned out to be well
> worth it (1.5 to 2 times faster than the previous version if the
> result has less than 30 or 40 digits).

Thank you for these patches, these sound like really nice improvements.
One thing can to my mind while reading the patch:

+     *        If r < 0 Then
+     *            Let r = r + 2*s - 1
+     *            Let s = s - 1

+            /* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+            r_int64 += 2 * s_int64 - 1;
+            s_int64--;

This can be reformulated as:

+     *        If r < 0 Then
+     *            Let r = r + s
+     *            Let s = s - 1
+     *            Let r = r + s

+            /* s is too large by 1; let r = r + 2*s - 1 and s = s - 1 */
+            r_int64 += s_int64;
+            s_int64--;
+            r_int64 += s_int64;

which would remove one mul/shift and the temp. variable. Mind you, I 
have
not benchmarked this, so it might make little difference, but maybe it 
is
worth trying it.

Best regards,

Tels
Вложения

Re: Some improvements to numeric sqrt() and ln()

От
Dean Rasheed
Дата:
On Tue, 3 Mar 2020 at 00:17, Tels <nospam-pg-abuse@bloodgate.com> wrote:
>
> Thank you for these patches, these sound like really nice improvements.

Thanks for looking!

> One thing can to my mind while reading the patch:
>
> +        *              If r < 0 Then
> +        *                      Let r = r + 2*s - 1
> +        *                      Let s = s - 1
>
> This can be reformulated as:
>
> +        *              If r < 0 Then
> +        *                      Let r = r + s
> +        *                      Let s = s - 1
> +        *                      Let r = r + s
>
> which would remove one mul/shift and the temp. variable.

Good point, that's a neat little optimisation.

I wasn't able to detect any difference in performance, because those
corrections are only triggered about 1 time in every 50 or so, but it
looks neater to me, especially in the numeric iterations, where it
saves a sub_var() by const_one as well as not using the temporary
variable.

Regards,
Dean



Re: Some improvements to numeric sqrt() and ln()

От
David Steele
Дата:
Hi Dean,

On 2/28/20 3:15 AM, Dean Rasheed wrote:
> Attached is a WIP patch to improve the performance of numeric sqrt()
> and ln(), which also makes a couple of related improvements to
> div_var_fast(), all of which have knock-on benefits for other numeric
> functions. The actual impact varies greatly depending on the inputs,
> but the overall effect is to reduce the run time of the numeric_big
> regression test by about 20%.

Are these improvements targeted at PG13 or PG14?  This seems a pretty 
big change for the last CF of PG13.

Regards,
-- 
-David
david@pgmasters.net



Re: Some improvements to numeric sqrt() and ln()

От
Dean Rasheed
Дата:
On Wed, 4 Mar 2020 at 14:41, David Steele <david@pgmasters.net> wrote:
>
> Are these improvements targeted at PG13 or PG14?  This seems a pretty
> big change for the last CF of PG13.
>

Well of course that's not entirely up to me, but I was hoping to
commit it for PG13.

It's very well covered by a large number of regression tests in both
numeric.sql and numeric_big.sql, since nearly anything that calls
ln(), log() or pow() ends up going through sqrt_var(). Also, the
changes are local to functions in numeric.c, which makes them easy to
revert if something proves to be wrong.

Regards,
Dean



Re: Some improvements to numeric sqrt() and ln()

От
Tom Lane
Дата:
Dean Rasheed <dean.a.rasheed@gmail.com> writes:
> On Wed, 4 Mar 2020 at 14:41, David Steele <david@pgmasters.net> wrote:
>> Are these improvements targeted at PG13 or PG14?  This seems a pretty
>> big change for the last CF of PG13.

> Well of course that's not entirely up to me, but I was hoping to
> commit it for PG13.

> It's very well covered by a large number of regression tests in both
> numeric.sql and numeric_big.sql, since nearly anything that calls
> ln(), log() or pow() ends up going through sqrt_var(). Also, the
> changes are local to functions in numeric.c, which makes them easy to
> revert if something proves to be wrong.

FWIW, I agree that this is a reasonable thing to consider committing
for v13.  It's not adding any new user-visible behavior, so there's
no definitional issues to quibble over, which is usually what I worry
about regretting after an overly-hasty commit.  And it's only touching
a few functions in one file, so even if the patch is a bit long, the
complexity seems pretty well controlled.

I've not read the patch in detail so this isn't meant as a review,
but from a process standpoint I see no reason not to go forward.

            regards, tom lane



Re: Some improvements to numeric sqrt() and ln()

От
Tom Lane
Дата:
Tels <nospam-pg-abuse@bloodgate.com> writes:
> This can be reformulated as:
> +     *        If r < 0 Then
> +     *            Let r = r + s
> +     *            Let s = s - 1
> +     *            Let r = r + s

Here's a v3 that

* incorporates Tels' idea;

* improves some of the comments (IMO anyway, though some are clear typos);

* adds some XXX comments about things that could be further improved
and/or need better explanations.

I also ran it through pgindent, just cause I'm like that.

With resolutions of the XXX items, I think this'd be committable.

            regards, tom lane

diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index 10229eb..afbc2b0 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -393,16 +393,6 @@ static const NumericVar const_ten =
 #endif

 #if DEC_DIGITS == 4
-static const NumericDigit const_zero_point_five_data[1] = {5000};
-#elif DEC_DIGITS == 2
-static const NumericDigit const_zero_point_five_data[1] = {50};
-#elif DEC_DIGITS == 1
-static const NumericDigit const_zero_point_five_data[1] = {5};
-#endif
-static const NumericVar const_zero_point_five =
-{1, -1, NUMERIC_POS, 1, NULL, (NumericDigit *) const_zero_point_five_data};
-
-#if DEC_DIGITS == 4
 static const NumericDigit const_zero_point_nine_data[1] = {9000};
 #elif DEC_DIGITS == 2
 static const NumericDigit const_zero_point_nine_data[1] = {90};
@@ -518,6 +508,8 @@ static void div_var_fast(const NumericVar *var1, const NumericVar *var2,
 static int    select_div_scale(const NumericVar *var1, const NumericVar *var2);
 static void mod_var(const NumericVar *var1, const NumericVar *var2,
                     NumericVar *result);
+static void div_mod_var(const NumericVar *var1, const NumericVar *var2,
+                        NumericVar *quot, NumericVar *rem);
 static void ceil_var(const NumericVar *var, NumericVar *result);
 static void floor_var(const NumericVar *var, NumericVar *result);

@@ -7712,6 +7704,7 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
              NumericVar *result, int rscale, bool round)
 {
     int            div_ndigits;
+    int            load_ndigits;
     int            res_sign;
     int            res_weight;
     int           *div;
@@ -7766,9 +7759,6 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
     div_ndigits += DIV_GUARD_DIGITS;
     if (div_ndigits < DIV_GUARD_DIGITS)
         div_ndigits = DIV_GUARD_DIGITS;
-    /* Must be at least var1ndigits, too, to simplify data-loading loop */
-    if (div_ndigits < var1ndigits)
-        div_ndigits = var1ndigits;

     /*
      * We do the arithmetic in an array "div[]" of signed int's.  Since
@@ -7781,9 +7771,16 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
      * (approximate) quotient digit and stores it into div[], removing one
      * position of dividend space.  A final pass of carry propagation takes
      * care of any mistaken quotient digits.
+     *
+     * Note that div[] doesn't necessarily contain all of the digits from the
+     * dividend --- the desired precision plus guard digits might be less than
+     * the dividend's precision.  This happens, for example, in the square
+     * root algorithm, where we typically divide a 2N-digit number by an
+     * N-digit number, and only require a result with N digits of precision.
      */
     div = (int *) palloc0((div_ndigits + 1) * sizeof(int));
-    for (i = 0; i < var1ndigits; i++)
+    load_ndigits = Min(div_ndigits, var1ndigits);
+    for (i = 0; i < load_ndigits; i++)
         div[i + 1] = var1digits[i];

     /*
@@ -7844,9 +7841,15 @@ div_var_fast(const NumericVar *var1, const NumericVar *var2,
             maxdiv += Abs(qdigit);
             if (maxdiv > (INT_MAX - INT_MAX / NBASE - 1) / (NBASE - 1))
             {
-                /* Yes, do it */
+                /*
+                 * Yes, do it.  Note that if var2ndigits is much smaller than
+                 * div_ndigits, we can save a significant amount of effort
+                 * here by noting that we only need to normalise those div[]
+                 * entries touched where prior iterations subtracted multiples
+                 * of the divisor.
+                 */
                 carry = 0;
-                for (i = div_ndigits; i > qi; i--)
+                for (i = Min(qi + var2ndigits - 2, div_ndigits); i > qi; i--)
                 {
                     newdig = div[i] + carry;
                     if (newdig < 0)
@@ -8095,6 +8098,76 @@ mod_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)


 /*
+ * div_mod_var() -
+ *
+ *    Calculate the truncated integer quotient and numeric remainder of two
+ *    numeric variables.  The remainder is precise to var2's dscale.
+ */
+static void
+div_mod_var(const NumericVar *var1, const NumericVar *var2,
+            NumericVar *quot, NumericVar *rem)
+{
+    NumericVar    q;
+    NumericVar    r;
+
+    init_var(&q);
+    init_var(&r);
+
+    /*
+     * Use div_var_fast() to get an initial estimate for the integer quotient.
+     * This might be inaccurate (per the warning in div_var_fast's comments),
+     * but we can correct it below.
+     */
+    div_var_fast(var1, var2, &q, 0, false);
+
+    /* Compute initial estimate of remainder using the quotient estimate. */
+    mul_var(var2, &q, &r, var2->dscale);
+    sub_var(var1, &r, &r);
+
+    /*
+     * Adjust the results if necessary --- the remainder should have the same
+     * sign as var1, and its absolute value should be less than the absolute
+     * value of var2.
+     */
+    while (r.ndigits != 0 && r.sign != var1->sign)
+    {
+        /* The absolute value of the quotient is too large */
+        if (var1->sign == var2->sign)
+        {
+            sub_var(&q, &const_one, &q);
+            add_var(&r, var2, &r);
+        }
+        else
+        {
+            add_var(&q, &const_one, &q);
+            sub_var(&r, var2, &r);
+        }
+    }
+
+    while (cmp_abs(&r, var2) >= 0)
+    {
+        /* The absolute value of the quotient is too small */
+        if (var1->sign == var2->sign)
+        {
+            add_var(&q, &const_one, &q);
+            sub_var(&r, var2, &r);
+        }
+        else
+        {
+            sub_var(&q, &const_one, &q);
+            add_var(&r, var2, &r);
+        }
+    }
+
+    set_var_from_var(&q, quot);
+    set_var_from_var(&r, rem);
+
+    free_var(&q);
+    free_var(&r);
+}
+
+
+/*
  * ceil_var() -
  *
  *    Return the smallest integer greater than or equal to the argument
@@ -8213,18 +8286,30 @@ gcd_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result)
 /*
  * sqrt_var() -
  *
- *    Compute the square root of x using Newton's algorithm
+ *    Compute the square root of x using the Karatsuba Square Root algorithm.
+ *    NOTE: we allow rscale < 0 here, implying rounding before the decimal
+ *    point.
  */
 static void
 sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
 {
-    NumericVar    tmp_arg;
-    NumericVar    tmp_val;
-    NumericVar    last_val;
-    int            local_rscale;
     int            stat;
-
-    local_rscale = rscale + 8;
+    int            res_weight;
+    int            res_ndigits;
+    int            src_ndigits;
+    int            step;
+    int            ndigits[32];
+    int            blen;
+    int64        arg_int64;
+    int            src_idx;
+    int64        s_int64;
+    int64        r_int64;
+    NumericVar    s_var;
+    NumericVar    r_var;
+    NumericVar    a0_var;
+    NumericVar    a1_var;
+    NumericVar    q_var;
+    NumericVar    u_var;

     stat = cmp_var(arg, &const_zero);
     if (stat == 0)
@@ -8243,43 +8328,412 @@ sqrt_var(const NumericVar *arg, NumericVar *result, int rscale)
                 (errcode(ERRCODE_INVALID_ARGUMENT_FOR_POWER_FUNCTION),
                  errmsg("cannot take square root of a negative number")));

-    init_var(&tmp_arg);
-    init_var(&tmp_val);
-    init_var(&last_val);
+    init_var(&s_var);
+    init_var(&r_var);
+    init_var(&a0_var);
+    init_var(&a1_var);
+    init_var(&q_var);
+    init_var(&u_var);

-    /* Copy arg in case it is the same var as result */
-    set_var_from_var(arg, &tmp_arg);
+    /*
+     * The result weight is half the input weight, rounded towards minus
+     * infinity.
+     *
+     * XXX do we really need floor(double) for that, rather than plain integer
+     * math?
+     */
+    res_weight = (int) floor((double) arg->weight / 2);

     /*
-     * Initialize the result to the first guess
+     * Number of NBASE digits to compute.  To ensure correct rounding, compute
+     * at least 1 extra decimal digit.  We explicitly allow rscale to be
+     * negative here, but must always compute at least 1 NBASE digit.
+     *
+     * XXX likewise seems like ceil(double) is unnecessary expense.
      */
-    alloc_var(result, 1);
-    result->digits[0] = tmp_arg.digits[0] / 2;
-    if (result->digits[0] == 0)
-        result->digits[0] = 1;
-    result->weight = tmp_arg.weight / 2;
-    result->sign = NUMERIC_POS;
+    res_ndigits = res_weight + 1 + (int) ceil((double) (rscale + 1) / DEC_DIGITS);
+    res_ndigits = Max(res_ndigits, 1);

-    set_var_from_var(result, &last_val);
+    /*
+     * Number of source NBASE digits logically required to produce a result
+     * with this precision --- every digit before the decimal point, plus 2
+     * for each result digit after the decimal point (or minus 2 for each
+     * result digit we round before the decimal point).
+     */
+    src_ndigits = arg->weight + 1 + (res_ndigits - res_weight - 1) * 2;
+    src_ndigits = Max(src_ndigits, 1);

-    for (;;)
+    /* ----------
+     * From this point on, we treat the input and the result as integers and
+     * compute the integer square root and remainder using the Karatsuba
+     * Square Root algorithm, which may be written recursively as follows:
+     *
+     *    SqrtRem(n = a3*b^3 + a2*b^2 + a1*b + a0):
+     *        [ for some base b, and coefficients a0,a1,a2,a3 chosen so that
+     *          0 <= a0,a1,a2 < b and a3 >= b/4 ]
+     *        Let (s,r) = SqrtRem(a3*b + a2)
+     *        Let (q,u) = DivRem(r*b + a1, 2*s)
+     *        Let s = s*b + q
+     *        Let r = u*b + a0 - q^2
+     *        If r < 0 Then
+     *            Let r = r + s
+     *            Let s = s - 1
+     *            Let r = r + s
+     *        Return (s,r)
+     *
+     * See "Karatsuba Square Root", Paul Zimmermann, INRIA Research Report
+     * RR-3805, November 1999.  At the time of writing this was available
+     * on the net at <https://hal.inria.fr/inria-00072854>.
+     *
+     * The way to read the assumption "n = a3*b^3 + a2*b^2 + a1*b + a0" is
+     * "choose a base b such that n requires at least four base-b digits to
+     * express; then those digits are a3,a2,a1,a0, with a3 possibly larger
+     * than b".  For optimal performance, b should have approximately a
+     * quarter the number of digits in the input, so that the outer square
+     * root computes roughly twice as many digits as the inner one.  For
+     * simplicity, we choose b = NBASE^blen, an integer power of NBASE.
+     *
+     * We implement the algorithm iteratively rather than recursively, to
+     * allow the working variables to be reused.  With this approach, each
+     * digit of the input is read precisely once --- src_idx tracks the number
+     * of input digits used so far.
+     *
+     * The array ndigits[] holds the number of NBASE digits of the input that
+     * will have been used at the end of each iteration, which roughly doubles
+     * each time.  Note that the array elements are stored in reverse order,
+     * so if the final iteration requires src_ndigits = 37 input digits, the
+     * array will contain [37,19,11,7,5,3], and we would start by computing
+     * the square root of the 3 most significant NBASE digits.
+     *
+     * XXX I don't understand how this works.  Why is it correct to consider
+     * arg->digits[0] at every step?  Can we prove rigorously that the ndigits
+     * array won't be overrun?  (I can see that src_ndigits is roughly halved
+     * by each iteration, but only roughly, so it's not entirely clear that
+     * the worst-case situation couldn't involve more than 31 steps.)
+     * ----------
+     */
+    step = 0;
+    while ((ndigits[step] = src_ndigits) > 4)
     {
-        div_var_fast(&tmp_arg, result, &tmp_val, local_rscale, true);
+        /* Choose b so that a3 >= b/4 */
+        blen = src_ndigits / 4;
+        if (blen * 4 == src_ndigits && arg->digits[0] < NBASE / 4)
+            blen--;

-        add_var(result, &tmp_val, result);
-        mul_var(result, &const_zero_point_five, result, local_rscale);
+        /* Number of digits in the next step (inner square root) */
+        src_ndigits -= 2 * blen;
+        step++;
+    }

-        if (cmp_var(&last_val, result) == 0)
-            break;
-        set_var_from_var(result, &last_val);
+    /*
+     * First iteration (innermost square root and remainder):
+     *
+     * Here src_ndigits <= 4, and the input fits in an int64.  Its square root
+     * has at most 9 decimal digits, so estimate it using double precision
+     * arithmetic, which will in fact almost certainly return the correct
+     * result with no further correction required.
+     */
+    arg_int64 = arg->digits[0];
+    for (src_idx = 1; src_idx < src_ndigits; src_idx++)
+    {
+        arg_int64 *= NBASE;
+        if (src_idx < arg->ndigits)
+            arg_int64 += arg->digits[src_idx];
     }

-    free_var(&last_val);
-    free_var(&tmp_val);
-    free_var(&tmp_arg);
+    s_int64 = (int64) sqrt((double) arg_int64);
+    r_int64 = arg_int64 - s_int64 * s_int64;
+
+    /* Use Newton's method to correct the result, if necessary */
+    /* XXX is this guaranteed to converge?  integer division truncates... */
+    while (r_int64 < 0 || r_int64 > 2 * s_int64)
+    {
+        s_int64 = (s_int64 + arg_int64 / s_int64) / 2;
+        r_int64 = arg_int64 - s_int64 * s_int64;
+    }
+
+    /*
+     * Iterations with src_ndigits <= 8:
+     *
+     * The next 1 or 2 iterations compute larger (outer) square roots with
+     * src_ndigits <= 8, so the result still fits in an int64 (even though the
+     * input no longer does) and we can continue to compute using int64
+     * variables to avoid more expensive numeric computations.
+     *
+     * It is fairly easy to see that there is no risk of the intermediate
+     * values below overflowing 64-bit integers.  In the worst case, the
+     * previous iteration will have computed a 3-digit square root (of a
+     * 6-digit input less than NBASE^6 / 4), so at the start of this
+     * iteration, s will be less than NBASE^3 / 2 = 10^12 / 2, and r will be
+     * less than 10^12.  In this case, blen will be 1, so numer will be less
+     * than 10^17, and denom will be less than 10^12 (and hence u will also be
+     * less than 10^12).  Finally, since q^2 = u*b + a0 - r, we can also be
+     * sure that q^2 < 10^17.  Therefore all these quantities fit comfortably
+     * in 64-bit integers.
+     */
+    step--;
+    while (step >= 0 && (src_ndigits = ndigits[step]) <= 8)
+    {
+        int            b;
+        int            a0;
+        int            a1;
+        int            i;
+        int64        numer;
+        int64        denom;
+        int64        q;
+        int64        u;
+
+        blen = (src_ndigits - src_idx) / 2;
+
+        /* Extract a1 and a0, and compute b */
+        a0 = 0;
+        a1 = 0;
+        b = 1;
+
+        for (i = 0; i < blen; i++, src_idx++)
+        {
+            b *= NBASE;
+            a1 *= NBASE;
+            if (src_idx < arg->ndigits)
+                a1 += arg->digits[src_idx];
+        }
+
+        for (i = 0; i < blen; i++, src_idx++)
+        {
+            a0 *= NBASE;
+            if (src_idx < arg->ndigits)
+                a0 += arg->digits[src_idx];
+        }

-    /* Round to requested precision */
+        /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+        numer = r_int64 * b + a1;
+        denom = 2 * s_int64;
+        q = numer / denom;
+        u = numer - q * denom;
+
+        /* Compute s = s*b + q and r = u*b + a0 - q^2 */
+        s_int64 = s_int64 * b + q;
+        r_int64 = u * b + a0 - q * q;
+
+        if (r_int64 < 0)
+        {
+            /* s is too large by 1; set r += s, s--, r += s */
+            r_int64 += s_int64;
+            s_int64--;
+            r_int64 += s_int64;
+        }
+
+        Assert(src_idx == src_ndigits); /* All input digits consumed */
+        step--;
+    }
+
+    /*
+     * On platforms with 128-bit integer support, we can further delay the
+     * need to use numeric variables.
+     */
+#ifdef HAVE_INT128
+    if (step >= 0)
+    {
+        int128        s_int128;
+        int128        r_int128;
+
+        s_int128 = s_int64;
+        r_int128 = r_int64;
+
+        /*
+         * Iterations with src_ndigits <= 16:
+         *
+         * The result fits in an int128 (even though the input doesn't) so we
+         * use int128 variables to avoid more expensive numeric computations.
+         */
+        while (step >= 0 && (src_ndigits = ndigits[step]) <= 16)
+        {
+            int64        b;
+            int64        a0;
+            int64        a1;
+            int64        i;
+            int128        numer;
+            int128        denom;
+            int128        q;
+            int128        u;
+
+            blen = (src_ndigits - src_idx) / 2;
+
+            /* Extract a1 and a0, and compute b */
+            a0 = 0;
+            a1 = 0;
+            b = 1;
+
+            for (i = 0; i < blen; i++, src_idx++)
+            {
+                b *= NBASE;
+                a1 *= NBASE;
+                if (src_idx < arg->ndigits)
+                    a1 += arg->digits[src_idx];
+            }
+
+            for (i = 0; i < blen; i++, src_idx++)
+            {
+                a0 *= NBASE;
+                if (src_idx < arg->ndigits)
+                    a0 += arg->digits[src_idx];
+            }
+
+            /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+            numer = r_int128 * b + a1;
+            denom = 2 * s_int128;
+            q = numer / denom;
+            u = numer - q * denom;
+
+            /* Compute s = s*b + q and r = u*b + a0 - q^2 */
+            s_int128 = s_int128 * b + q;
+            r_int128 = u * b + a0 - q * q;
+
+            if (r_int128 < 0)
+            {
+                /* s is too large by 1; set r += s, s--, r += s */
+                r_int128 += s_int128;
+                s_int128--;
+                r_int128 += s_int128;
+            }
+
+            Assert(src_idx == src_ndigits); /* All input digits consumed */
+            step--;
+        }
+
+        /*
+         * All remaining iterations require numeric variables.  Convert the
+         * integer values to NumericVar and continue.  Note that in the final
+         * iteration we don't need the remainder, so we can save a few cycles
+         * there by not fully computing it.
+         */
+        int128_to_numericvar(s_int128, &s_var);
+        if (step >= 0)
+            int128_to_numericvar(r_int128, &r_var);
+    }
+    else
+    {
+        int64_to_numericvar(s_int64, &s_var);
+        /* step < 0, so we certainly don't need r */
+    }
+#else                            /* !HAVE_INT128 */
+    int64_to_numericvar(s_int64, &s_var);
+    if (step >= 0)
+        int64_to_numericvar(r_int64, &r_var);
+#endif                            /* HAVE_INT128 */
+
+    /*
+     * The remaining iterations with src_ndigits > 8 (or 16, if have int128)
+     * use numeric variables.
+     */
+    while (step >= 0)
+    {
+        int            tmp_len;
+
+        src_ndigits = ndigits[step];
+        blen = (src_ndigits - src_idx) / 2;
+
+        /* Extract a1 and a0 */
+        if (src_idx < arg->ndigits)
+        {
+            tmp_len = Min(blen, arg->ndigits - src_idx);
+            alloc_var(&a1_var, tmp_len);
+            memcpy(a1_var.digits, arg->digits + src_idx,
+                   tmp_len * sizeof(NumericDigit));
+            a1_var.weight = blen - 1;
+            a1_var.sign = NUMERIC_POS;
+            a1_var.dscale = 0;
+            strip_var(&a1_var);
+        }
+        else
+        {
+            zero_var(&a1_var);
+            a1_var.dscale = 0;
+        }
+        src_idx += blen;
+
+        if (src_idx < arg->ndigits)
+        {
+            tmp_len = Min(blen, arg->ndigits - src_idx);
+            alloc_var(&a0_var, tmp_len);
+            memcpy(a0_var.digits, arg->digits + src_idx,
+                   tmp_len * sizeof(NumericDigit));
+            a0_var.weight = blen - 1;
+            a0_var.sign = NUMERIC_POS;
+            a0_var.dscale = 0;
+            strip_var(&a0_var);
+        }
+        else
+        {
+            zero_var(&a0_var);
+            a0_var.dscale = 0;
+        }
+        src_idx += blen;
+
+        /* Compute (q,u) = DivRem(r*b + a1, 2*s) */
+        set_var_from_var(&r_var, &q_var);
+        q_var.weight += blen;
+        add_var(&q_var, &a1_var, &q_var);
+        add_var(&s_var, &s_var, &u_var);
+        div_mod_var(&q_var, &u_var, &q_var, &u_var);
+
+        /* Compute s = s*b + q */
+        s_var.weight += blen;
+        add_var(&s_var, &q_var, &s_var);
+
+        /*
+         * Compute r = u*b + a0 - q^2.
+         *
+         * In the final iteration, we don't actually need r; we just need to
+         * know whether it is negative, so that we know whether to adjust s.
+         * So instead of the final subtraction we can just compare.
+         */
+        u_var.weight += blen;
+        add_var(&u_var, &a0_var, &u_var);
+        mul_var(&q_var, &q_var, &q_var, 0);
+
+        if (step > 0)
+        {
+            /* Need r for later iterations */
+            sub_var(&u_var, &q_var, &r_var);
+            if (r_var.sign == NUMERIC_NEG)
+            {
+                /* s is too large by 1; set r += s, s--, r += s */
+                add_var(&r_var, &s_var, &r_var);
+                sub_var(&s_var, &const_one, &s_var);
+                add_var(&r_var, &s_var, &r_var);
+            }
+        }
+        else
+        {
+            /* Don't need r anymore, except to test if s is too large by 1 */
+            if (cmp_var(&u_var, &q_var) < 0)
+                sub_var(&s_var, &const_one, &s_var);
+        }
+
+        Assert(src_idx == src_ndigits); /* All input digits consumed */
+        step--;
+    }
+
+    /*
+     * Construct the final result, rounding it to the requested precision.
+     */
+    set_var_from_var(&s_var, result);
+    result->weight = res_weight;
+    result->sign = NUMERIC_POS;
+
+    /* Round to target rscale (and set result->dscale) */
     round_var(result, rscale);
+
+    /* Strip leading and trailing zeroes */
+    strip_var(result);
+
+    free_var(&s_var);
+    free_var(&r_var);
+    free_var(&a0_var);
+    free_var(&a1_var);
+    free_var(&q_var);
+    free_var(&u_var);
 }


@@ -8530,12 +8984,18 @@ ln_var(const NumericVar *arg, NumericVar *result, int rscale)
      * Each sqrt() will roughly halve the weight of x, so adjust the local
      * rscale as we work so that we keep this many significant digits at each
      * step (plus a few more for good measure).
+     *
+     * Note that we allow local_rscale < 0 during this input reduction
+     * process, which implies rounding before the decimal point.  sqrt_var()
+     * explicitly supports this, and it significantly reduces the work
+     * required to reduce very large inputs to the required range.  Once the
+     * input reduction is complete, x.weight will be 0 and its display scale
+     * will be non-negative again.
      */
     nsqrt = 0;
     while (cmp_var(&x, &const_zero_point_nine) <= 0)
     {
         local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-        local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
         sqrt_var(&x, &x, local_rscale);
         mul_var(&fact, &const_two, &fact, 0);
         nsqrt++;
@@ -8543,7 +9003,6 @@ ln_var(const NumericVar *arg, NumericVar *result, int rscale)
     while (cmp_var(&x, &const_one_point_one) >= 0)
     {
         local_rscale = rscale - x.weight * DEC_DIGITS / 2 + 8;
-        local_rscale = Max(local_rscale, NUMERIC_MIN_DISPLAY_SCALE);
         sqrt_var(&x, &x, local_rscale);
         mul_var(&fact, &const_two, &fact, 0);
         nsqrt++;
diff --git a/src/test/regress/expected/numeric.out b/src/test/regress/expected/numeric.out
index 23a4c6d..c7fe63d 100644
--- a/src/test/regress/expected/numeric.out
+++ b/src/test/regress/expected/numeric.out
@@ -1580,6 +1580,57 @@ select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;
 (1 row)

 --
+-- Test some corner cases for square root
+--
+select sqrt(1.000000000000003::numeric);
+       sqrt
+-------------------
+ 1.000000000000001
+(1 row)
+
+select sqrt(1.000000000000004::numeric);
+       sqrt
+-------------------
+ 1.000000000000002
+(1 row)
+
+select sqrt(96627521408608.56340355805::numeric);
+        sqrt
+---------------------
+ 9829929.87811248648
+(1 row)
+
+select sqrt(96627521408608.56340355806::numeric);
+        sqrt
+---------------------
+ 9829929.87811248649
+(1 row)
+
+select sqrt(515549506212297735.073688290367::numeric);
+          sqrt
+------------------------
+ 718017761.766585921184
+(1 row)
+
+select sqrt(515549506212297735.073688290368::numeric);
+          sqrt
+------------------------
+ 718017761.766585921185
+(1 row)
+
+select sqrt(8015491789940783531003294973900306::numeric);
+       sqrt
+-------------------
+ 89529278953540017
+(1 row)
+
+select sqrt(8015491789940783531003294973900307::numeric);
+       sqrt
+-------------------
+ 89529278953540018
+(1 row)
+
+--
 -- Test code path for raising to integer powers
 --
 select 10.0 ^ -2147483648 as rounds_to_zero;
diff --git a/src/test/regress/sql/numeric.sql b/src/test/regress/sql/numeric.sql
index c5c8d76..41475a9 100644
--- a/src/test/regress/sql/numeric.sql
+++ b/src/test/regress/sql/numeric.sql
@@ -883,6 +883,19 @@ select div(12345678901234567890, 123);
 select div(12345678901234567890, 123) * 123 + 12345678901234567890 % 123;

 --
+-- Test some corner cases for square root
+--
+
+select sqrt(1.000000000000003::numeric);
+select sqrt(1.000000000000004::numeric);
+select sqrt(96627521408608.56340355805::numeric);
+select sqrt(96627521408608.56340355806::numeric);
+select sqrt(515549506212297735.073688290367::numeric);
+select sqrt(515549506212297735.073688290368::numeric);
+select sqrt(8015491789940783531003294973900306::numeric);
+select sqrt(8015491789940783531003294973900307::numeric);
+
+--
 -- Test code path for raising to integer powers
 --


Re: Some improvements to numeric sqrt() and ln()

От
Dean Rasheed
Дата:
On Sun, 22 Mar 2020 at 22:16, Tom Lane <tgl@sss.pgh.pa.us> wrote:
>
> With resolutions of the XXX items, I think this'd be committable.
>

Thanks for looking at this!

Here is an updated patch with the following updates based on your comments:

* Now uses integer arithmetic to compute res_weight and res_ndigits,
instead of floor() and ceil().

* New comment giving a more detailed explanation of how blen is
chosen, and why it must sometimes examine the first digit of the input
and reduce blen by 1 (which can occur at any step, as shown in the
example given).

* New comment giving a proof that the number of steps required is
guaranteed to be less than 32.

* New comment explaining why the initial integer square root using
Newton's method is guaranteed to converge. I couldn't find a formal
reference for this, but there's a Wikipedia article on it -
https://en.wikipedia.org/wiki/Integer_square_root and I think it's a
well-known result in the field.

Regards,
Dean

Вложения

Re: Some improvements to numeric sqrt() and ln()

От
Tom Lane
Дата:
Dean Rasheed <dean.a.rasheed@gmail.com> writes:
> Here is an updated patch with the following updates based on your comments:

This resolves all my concerns.  I've marked it RFC in the CF app.

            regards, tom lane