From 074aeb223ab496f23c2075eabd35e6e76241d1d8 Mon Sep 17 00:00:00 2001 From: Joel Jakobsson Date: Mon, 1 Jul 2024 07:17:50 +0200 Subject: [PATCH] Add SQL-callable numeric_mul_patched() to bench Simplified fast-path computation --- src/backend/utils/adt/numeric.c | 438 ++++++++++++++++++++++++++++++++ src/include/catalog/pg_proc.dat | 3 + src/include/utils/numeric.h | 2 + test-mul-var.sql | 48 ++++ 4 files changed, 491 insertions(+) create mode 100644 test-mul-var.sql diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c index 5510a203b0..8f5d553f15 100644 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -551,6 +551,9 @@ static void sub_var(const NumericVar *var1, const NumericVar *var2, static void mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale); +static void mul_var_patched(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, + int rscale); static void div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale, bool round); @@ -3115,6 +3118,131 @@ numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error) } +/* + * numeric_mul_patched() - + * + * This function multiplies two numeric values using the patched algorithm, + * designed for efficient handling of large numbers. It's introduced to allow + * direct benchmark comparisons with the standard numeric_mul() function. + */ +Datum +numeric_mul_patched(PG_FUNCTION_ARGS) +{ + Numeric num1 = PG_GETARG_NUMERIC(0); + Numeric num2 = PG_GETARG_NUMERIC(1); + int32 rscale_adjustment = PG_GETARG_INT32(2); + Numeric res; + + res = numeric_mul_patched_opt_error(num1, num2, rscale_adjustment, NULL); + + PG_RETURN_NUMERIC(res); +} + + +/* + * numeric_mul_patched_opt_error() - + * + * Internal version of numeric_mul_patched(). + * If "*have_error" flag is provided, on error it's set to true, NULL returned. + * This is helpful when caller need to handle errors by itself. + */ +Numeric +numeric_mul_patched_opt_error(Numeric num1, Numeric num2, int32 rscale_adjustment, bool *have_error) +{ + NumericVar arg1; + NumericVar arg2; + NumericVar result; + Numeric res; + + /* + * Handle NaN and infinities + */ + if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2)) + { + if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2)) + return make_result(&const_nan); + if (NUMERIC_IS_PINF(num1)) + { + switch (numeric_sign_internal(num2)) + { + case 0: + return make_result(&const_nan); /* Inf * 0 */ + case 1: + return make_result(&const_pinf); + case -1: + return make_result(&const_ninf); + } + Assert(false); + } + if (NUMERIC_IS_NINF(num1)) + { + switch (numeric_sign_internal(num2)) + { + case 0: + return make_result(&const_nan); /* -Inf * 0 */ + case 1: + return make_result(&const_ninf); + case -1: + return make_result(&const_pinf); + } + Assert(false); + } + /* by here, num1 must be finite, so num2 is not */ + if (NUMERIC_IS_PINF(num2)) + { + switch (numeric_sign_internal(num1)) + { + case 0: + return make_result(&const_nan); /* 0 * Inf */ + case 1: + return make_result(&const_pinf); + case -1: + return make_result(&const_ninf); + } + Assert(false); + } + Assert(NUMERIC_IS_NINF(num2)); + switch (numeric_sign_internal(num1)) + { + case 0: + return make_result(&const_nan); /* 0 * -Inf */ + case 1: + return make_result(&const_ninf); + case -1: + return make_result(&const_pinf); + } + Assert(false); + } + + /* + * Unpack the values, let mul_var() compute the result and return it. + * Unlike add_var() and sub_var(), mul_var() will round its result. In the + * case of numeric_mul(), which is invoked for the * operator on numerics, + * we request exact representation for the product (rscale = sum(dscale of + * arg1, dscale of arg2)). If the exact result has more digits after the + * decimal point than can be stored in a numeric, we round it. Rounding + * after computing the exact result ensures that the final result is + * correctly rounded (rounding in mul_var() using a truncated product + * would not guarantee this). + */ + init_var_from_num(num1, &arg1); + init_var_from_num(num2, &arg2); + + init_var(&result); + + mul_var_patched(&arg1, &arg2, &result, arg1.dscale + arg2.dscale + rscale_adjustment); + + if (result.dscale > NUMERIC_DSCALE_MAX) + round_var(&result, NUMERIC_DSCALE_MAX); + + res = make_result_opt_error(&result, have_error); + + free_var(&result); + + return res; +} + + /* * numeric_div() - * @@ -8864,6 +8992,316 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, strip_var(result); } +/* + * mul_var_patched() - + * + * Implements patched multiplication for large numbers, introduced + * alongside the unchanged original mul_var(). This function is part of + * an optimization effort, allowing direct benchmark comparisons with + * mul_var(). It selects full or half patched based on input size. + * This is a temporary measure before considering its replacement of + * mul_var() based on benchmark outcomes. + */ +static void +mul_var_patched(const NumericVar *var1, const NumericVar *var2, + NumericVar *result, int rscale) +{ + int res_ndigits; + int res_sign; + int res_weight; + int maxdigits; + int *dig; + int carry; + int maxdig; + int newdig; + int var1ndigits; + int var2ndigits; + NumericDigit *var1digits; + NumericDigit *var2digits; + NumericDigit *res_digits; + int i, + i1, + i2; + + /* + * Arrange for var1 to be the shorter of the two numbers. This improves + * performance because the inner multiplication loop is much simpler than + * the outer loop, so it's better to have a smaller number of iterations + * of the outer loop. This also reduces the number of times that the + * accumulator array needs to be normalized. + */ + if (var1->ndigits > var2->ndigits) + { + const NumericVar *tmp = var1; + + var1 = var2; + var2 = tmp; + } + + /* copy these values into local vars for speed in inner loop */ + var1ndigits = var1->ndigits; + var2ndigits = var2->ndigits; + var1digits = var1->digits; + var2digits = var2->digits; + + if (var1ndigits == 0 || var2ndigits == 0) + { + /* one or both inputs is zero; so is result */ + zero_var(result); + result->dscale = rscale; + return; + } + + /* Determine result sign and (maximum possible) weight */ + if (var1->sign == var2->sign) + res_sign = NUMERIC_POS; + else + res_sign = NUMERIC_NEG; + res_weight = var1->weight + var2->weight + 2; + + /* + * Determine the number of result digits to compute. If the exact result + * would have more than rscale fractional digits, truncate the computation + * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that + * would only contribute to the right of that. (This will give the exact + * rounded-to-rscale answer unless carries out of the ignored positions + * would have propagated through more than MUL_GUARD_DIGITS digits.) + * + * Note: an exact computation could not produce more than var1ndigits + + * var2ndigits digits, but we allocate one extra output digit in case + * rscale-driven rounding produces a carry out of the highest exact digit. + */ + res_ndigits = var1ndigits + var2ndigits + 1; + maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS + + MUL_GUARD_DIGITS; + res_ndigits = Min(res_ndigits, maxdigits); + + if (res_ndigits < 3) + { + /* All input digits will be ignored; so result is zero */ + zero_var(result); + result->dscale = rscale; + return; + } + + /* + * Simplified fast-path computation, if var1 has just one or two digits. + * This is significantly faster, since it avoids allocating a separate + * digit array, making multiple passes over var2, and having separate + * carry-propagation passes. + */ + if (var1ndigits <= 3) + { + NumericDigit *res_buf; + + /* Allocate result digit array */ + res_buf = digitbuf_alloc(res_ndigits); + res_buf[0] = 0; /* spare digit for later rounding */ + res_digits = res_buf + 1; + + /* + * Compute the result digits directly, in one pass, propagating the + * carry up as we go. + */ + switch (var1ndigits) + { + case 1: + carry = 0; + for (i = res_ndigits - 3; i >= 0; i--) + { + newdig = (int) var1digits[0] * var2digits[i] + carry; + res_digits[i + 1] = (NumericDigit) (newdig % NBASE); + carry = newdig / NBASE; + } + res_digits[0] = (NumericDigit) carry; + break; + + case 2: + newdig = (int) var1digits[1] * var2digits[res_ndigits - 4]; + if (res_ndigits - 3 < var2ndigits) + newdig += (int) var1digits[0] * var2digits[res_ndigits - 3]; + res_digits[res_ndigits - 2] = (NumericDigit) (newdig % NBASE); + carry = newdig / NBASE; + for (i = res_ndigits - 4; i >= 1; i--) + { + newdig = (int) var1digits[0] * var2digits[i] + + (int) var1digits[1] * var2digits[i - 1] + carry; + res_digits[i + 1] = (NumericDigit) (newdig % NBASE); + carry = newdig / NBASE; + } + newdig = (int) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (newdig % NBASE); + res_digits[0] = (NumericDigit) (newdig / NBASE); + break; + + case 3: + newdig = (int) var1digits[2] * var2digits[res_ndigits - 5]; + if (res_ndigits - 4 < var2ndigits) + newdig += (int) var1digits[1] * var2digits[res_ndigits - 4]; + if (res_ndigits - 3 < var2ndigits) + newdig += (int) var1digits[0] * var2digits[res_ndigits - 3]; + res_digits[res_ndigits - 2] = (NumericDigit) (newdig % NBASE); + carry = newdig / NBASE; + for (i = res_ndigits - 4; i >= 2; i--) + { + newdig = carry; + if (i < var2ndigits) + newdig += (int) var1digits[0] * var2digits[i]; + if (i - 1 >= 0 && i - 1 < var2ndigits) + newdig += (int) var1digits[1] * var2digits[i - 1]; + if (i - 2 >= 0 && i - 2 < var2ndigits) + newdig += (int) var1digits[2] * var2digits[i - 2]; + res_digits[i + 1] = (NumericDigit) (newdig % NBASE); + carry = newdig / NBASE; + } + newdig = carry; + if (var2ndigits > 1) + newdig += (int) var1digits[0] * var2digits[1]; + if (var2ndigits > 0) + newdig += (int) var1digits[1] * var2digits[0]; + res_digits[2] = (NumericDigit) (newdig % NBASE); + carry = newdig / NBASE; + newdig = (int) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (newdig % NBASE); + res_digits[0] = (NumericDigit) (newdig / NBASE); + break; + } + + /* Store the product in result (minus extra rounding digit) */ + digitbuf_free(result->buf); + result->ndigits = res_ndigits - 1; + result->buf = res_buf; + result->digits = res_digits; + result->weight = res_weight - 1; + result->sign = res_sign; + + /* Round to target rscale (and set result->dscale) */ + round_var(result, rscale); + + /* Strip leading and trailing zeroes */ + strip_var(result); + + return; + } + + /* + * We do the arithmetic in an array "dig[]" of signed int's. Since + * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom + * to avoid normalizing carries immediately. + * + * maxdig tracks the maximum possible value of any dig[] entry; when this + * threatens to exceed INT_MAX, we take the time to propagate carries. + * Furthermore, we need to ensure that overflow doesn't occur during the + * carry propagation passes either. The carry values could be as much as + * INT_MAX/NBASE, so really we must normalize when digits threaten to + * exceed INT_MAX - INT_MAX/NBASE. + * + * To avoid overflow in maxdig itself, it actually represents the max + * possible value divided by NBASE-1, ie, at the top of the loop it is + * known that no dig[] entry exceeds maxdig * (NBASE-1). + */ + dig = (int *) palloc0(res_ndigits * sizeof(int)); + maxdig = 0; + + /* + * The least significant digits of var1 should be ignored if they don't + * contribute directly to the first res_ndigits digits of the result that + * we are computing. + * + * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit + * i1+i2+2 of the accumulator array, so we need only consider digits of + * var1 for which i1 <= res_ndigits - 3. + */ + for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--) + { + NumericDigit var1digit = var1digits[i1]; + + if (var1digit == 0) + continue; + + /* Time to normalize? */ + maxdig += var1digit; + if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1)) + { + /* Yes, do it */ + carry = 0; + for (i = res_ndigits - 1; i >= 0; i--) + { + newdig = dig[i] + carry; + if (newdig >= NBASE) + { + carry = newdig / NBASE; + newdig -= carry * NBASE; + } + else + carry = 0; + dig[i] = newdig; + } + Assert(carry == 0); + /* Reset maxdig to indicate new worst-case */ + maxdig = 1 + var1digit; + } + + /* + * Add the appropriate multiple of var2 into the accumulator. + * + * As above, digits of var2 can be ignored if they don't contribute, + * so we only include digits for which i1+i2+2 < res_ndigits. + * + * This inner loop is the performance bottleneck for multiplication, + * so we want to keep it simple enough so that it can be + * auto-vectorized. Accordingly, process the digits left-to-right + * even though schoolbook multiplication would suggest right-to-left. + * Since we aren't propagating carries in this loop, the order does + * not matter. + */ + { + int i2limit = Min(var2ndigits, res_ndigits - i1 - 2); + int *dig_i1_2 = &dig[i1 + 2]; + + for (i2 = 0; i2 < i2limit; i2++) + dig_i1_2[i2] += var1digit * var2digits[i2]; + } + } + + /* + * Now we do a final carry propagation pass to normalize the result, which + * we combine with storing the result digits into the output. Note that + * this is still done at full precision w/guard digits. + */ + alloc_var(result, res_ndigits); + res_digits = result->digits; + carry = 0; + for (i = res_ndigits - 1; i >= 0; i--) + { + newdig = dig[i] + carry; + if (newdig >= NBASE) + { + carry = newdig / NBASE; + newdig -= carry * NBASE; + } + else + carry = 0; + res_digits[i] = newdig; + } + Assert(carry == 0); + + pfree(dig); + + /* + * Finally, round the result to the requested precision. + */ + result->weight = res_weight; + result->sign = res_sign; + + /* Round to target rscale (and set result->dscale) */ + round_var(result, rscale); + + /* Strip leading and trailing zeroes */ + strip_var(result); + +} + /* * div_var() - diff --git a/src/include/catalog/pg_proc.dat b/src/include/catalog/pg_proc.dat index d4ac578ae6..5b3024cb6d 100644 --- a/src/include/catalog/pg_proc.dat +++ b/src/include/catalog/pg_proc.dat @@ -4465,6 +4465,9 @@ { oid => '1726', proname => 'numeric_mul', prorettype => 'numeric', proargtypes => 'numeric numeric', prosrc => 'numeric_mul' }, +{ oid => '6347', + proname => 'numeric_mul_patched', prorettype => 'numeric', + proargtypes => 'numeric numeric int4', prosrc => 'numeric_mul_patched' }, { oid => '1727', proname => 'numeric_div', prorettype => 'numeric', proargtypes => 'numeric numeric', prosrc => 'numeric_div' }, diff --git a/src/include/utils/numeric.h b/src/include/utils/numeric.h index 43c75c436f..454a56da9a 100644 --- a/src/include/utils/numeric.h +++ b/src/include/utils/numeric.h @@ -97,6 +97,8 @@ extern Numeric numeric_sub_opt_error(Numeric num1, Numeric num2, bool *have_error); extern Numeric numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error); +extern Numeric numeric_mul_patched_opt_error(Numeric num1, Numeric num2, + int32 rscale_adjustment, bool *have_error); extern Numeric numeric_div_opt_error(Numeric num1, Numeric num2, bool *have_error); extern Numeric numeric_mod_opt_error(Numeric num1, Numeric num2, diff --git a/test-mul-var.sql b/test-mul-var.sql new file mode 100644 index 0000000000..ee7e3855bc --- /dev/null +++ b/test-mul-var.sql @@ -0,0 +1,48 @@ +CREATE TABLE test_numeric_mul_patched ( + var1 numeric, + var2 numeric, + rscale_adjustment int, + result numeric +); + +DO $$ +DECLARE +var1 numeric; +var2 numeric; +BEGIN + FOR i IN 1..100 LOOP + RAISE NOTICE '%', i; + FOR var1ndigits IN 1..4 LOOP + FOR var2ndigits IN 1..4 LOOP + FOR var1dscale IN 0..(var1ndigits*4) LOOP + FOR var2dscale IN 0..(var2ndigits*4) LOOP + FOR rscale_adjustment IN 0..(var1dscale+var2dscale) LOOP + var1 := round(random( + format('1%s',repeat('0',(var1ndigits-1)*4-1))::numeric, + format('%s',repeat('9',var1ndigits*4))::numeric + ) / 10::numeric^var1dscale, var1dscale); + var2 := round(random( + format('1%s',repeat('0',(var2ndigits-1)*4-1))::numeric, + format('%s',repeat('9',var2ndigits*4))::numeric + ) / 10::numeric^var2dscale, var2dscale); + INSERT INTO test_numeric_mul_patched + (var1, var2, rscale_adjustment) + VALUES + (var1, var2, -rscale_adjustment); + END LOOP; + END LOOP; + END LOOP; + END LOOP; + END LOOP; + END LOOP; +END $$; + +-- First, set result with a numeric_mul_patched() version where +-- the Simplified fast-path computation code has been commented out. +UPDATE test_numeric_mul_patched SET result = numeric_mul_patched(var1, var2, rscale_adjustment); + +-- Then, recompile with the Simplified fast-path computation code, +-- and check if any differences can be found: +SELECT *, numeric_mul_patched(var1,var2,rscale_adjustment) +FROM test_numeric_mul_patched +WHERE result IS DISTINCT FROM numeric_mul_patched(var1,var2,rscale_adjustment); -- 2.45.1