aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDean Rasheed <dean.a.rasheed@gmail.com>2024-08-15 10:36:17 +0100
committerDean Rasheed <dean.a.rasheed@gmail.com>2024-08-15 10:36:17 +0100
commit8dc28d7eb868b6ce5a51614628bf46fc63c7e90c (patch)
tree12e3a1da781232d64eaae46700b76c95a228a432 /src
parentc4e44224cf617c8cd33a734f888c045ac9575226 (diff)
downloadpostgresql-8dc28d7eb868b6ce5a51614628bf46fc63c7e90c.tar.gz
postgresql-8dc28d7eb868b6ce5a51614628bf46fc63c7e90c.zip
Optimise numeric multiplication using base-NBASE^2 arithmetic.
Currently mul_var() uses the schoolbook multiplication algorithm, which is O(n^2) in the number of NBASE digits. To improve performance for large inputs, convert the inputs to base NBASE^2 before multiplying, which effectively halves the number of digits in each input, theoretically speeding up the computation by a factor of 4. In practice, the actual speedup for large inputs varies between around 3 and 6 times, depending on the system and compiler used. In turn, this significantly reduces the runtime of the numeric_big regression test. For this to work, 64-bit integers are required for the products of base-NBASE^2 digits, so this works best on 64-bit machines, on which it is faster whenever the shorter input has more than 4 or 5 NBASE digits. On 32-bit machines, the additional overheads, especially during carry propagation and the final conversion back to base-NBASE, are significantly higher, and it is only faster when the shorter input has more than around 50 NBASE digits. When the shorter input has more than 6 NBASE digits (so that mul_var_short() cannot be used), but fewer than around 50 NBASE digits, there may be a noticeable slowdown on 32-bit machines. That seems to be an acceptable tradeoff, given the performance gains for other inputs, and the effort that would be required to maintain code specifically targeting 32-bit machines. Joel Jacobson and Dean Rasheed. Discussion: https://postgr.es/m/9d8a4a42-c354-41f3-bbf3-199e1957db97%40app.fastmail.com
Diffstat (limited to 'src')
-rw-r--r--src/backend/utils/adt/numeric.c224
1 files changed, 150 insertions, 74 deletions
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index 2a74312d354..77f64331f36 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -101,6 +101,8 @@ typedef signed char NumericDigit;
typedef int16 NumericDigit;
#endif
+#define NBASE_SQR (NBASE * NBASE)
+
/*
* The Numeric type as stored on disk.
*
@@ -8668,21 +8670,30 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
int rscale)
{
int res_ndigits;
+ int res_ndigitpairs;
int res_sign;
int res_weight;
+ int pair_offset;
int maxdigits;
- int *dig;
- int carry;
- int maxdig;
- int newdig;
+ int maxdigitpairs;
+ uint64 *dig,
+ *dig_i1_off;
+ uint64 maxdig;
+ uint64 carry;
+ uint64 newdig;
int var1ndigits;
int var2ndigits;
+ int var1ndigitpairs;
+ int var2ndigitpairs;
NumericDigit *var1digits;
NumericDigit *var2digits;
+ uint32 var1digitpair;
+ uint32 *var2digitpairs;
NumericDigit *res_digits;
int i,
i1,
- i2;
+ i2,
+ i2limit;
/*
* Arrange for var1 to be the shorter of the two numbers. This improves
@@ -8723,86 +8734,164 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
return;
}
- /* Determine result sign and (maximum possible) weight */
+ /* Determine result sign */
if (var1->sign == var2->sign)
res_sign = NUMERIC_POS;
else
res_sign = NUMERIC_NEG;
- res_weight = var1->weight + var2->weight + 2;
/*
- * Determine the number of result digits to compute. If the exact result
- * would have more than rscale fractional digits, truncate the computation
- * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
- * would only contribute to the right of that. (This will give the exact
+ * Determine the number of result digits to compute and the (maximum
+ * possible) result weight. If the exact result would have more than
+ * rscale fractional digits, truncate the computation with
+ * MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that would
+ * only contribute to the right of that. (This will give the exact
* rounded-to-rscale answer unless carries out of the ignored positions
* would have propagated through more than MUL_GUARD_DIGITS digits.)
*
* Note: an exact computation could not produce more than var1ndigits +
- * var2ndigits digits, but we allocate one extra output digit in case
- * rscale-driven rounding produces a carry out of the highest exact digit.
+ * var2ndigits digits, but we allocate at least one extra output digit in
+ * case rscale-driven rounding produces a carry out of the highest exact
+ * digit.
+ *
+ * The computation itself is done using base-NBASE^2 arithmetic, so we
+ * actually process the input digits in pairs, producing a base-NBASE^2
+ * intermediate result. This significantly improves performance, since
+ * schoolbook multiplication is O(N^2) in the number of input digits, and
+ * working in base NBASE^2 effectively halves "N".
+ *
+ * Note: in a truncated computation, we must compute at least one extra
+ * output digit to ensure that all the guard digits are fully computed.
*/
- res_ndigits = var1ndigits + var2ndigits + 1;
+ /* digit pairs in each input */
+ var1ndigitpairs = (var1ndigits + 1) / 2;
+ var2ndigitpairs = (var2ndigits + 1) / 2;
+
+ /* digits in exact result */
+ res_ndigits = var1ndigits + var2ndigits;
+
+ /* digit pairs in exact result with at least one extra output digit */
+ res_ndigitpairs = res_ndigits / 2 + 1;
+
+ /* pair offset to align result to end of dig[] */
+ pair_offset = res_ndigitpairs - var1ndigitpairs - var2ndigitpairs + 1;
+
+ /* maximum possible result weight (odd-length inputs shifted up below) */
+ res_weight = var1->weight + var2->weight + 1 + 2 * res_ndigitpairs -
+ res_ndigits - (var1ndigits & 1) - (var2ndigits & 1);
+
+ /* rscale-based truncation with at least one extra output digit */
maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
MUL_GUARD_DIGITS;
- res_ndigits = Min(res_ndigits, maxdigits);
+ maxdigitpairs = maxdigits / 2 + 1;
+
+ res_ndigitpairs = Min(res_ndigitpairs, maxdigitpairs);
+ res_ndigits = 2 * res_ndigitpairs;
- if (res_ndigits < 3)
+ /*
+ * In the computation below, digit pair i1 of var1 and digit pair i2 of
+ * var2 are multiplied and added to digit i1+i2+pair_offset of dig[]. Thus
+ * input digit pairs with index >= res_ndigitpairs - pair_offset don't
+ * contribute to the result, and can be ignored.
+ */
+ if (res_ndigitpairs <= pair_offset)
{
/* All input digits will be ignored; so result is zero */
zero_var(result);
result->dscale = rscale;
return;
}
+ var1ndigitpairs = Min(var1ndigitpairs, res_ndigitpairs - pair_offset);
+ var2ndigitpairs = Min(var2ndigitpairs, res_ndigitpairs - pair_offset);
/*
- * We do the arithmetic in an array "dig[]" of signed int's. Since
- * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
- * to avoid normalizing carries immediately.
+ * We do the arithmetic in an array "dig[]" of unsigned 64-bit integers.
+ * Since PG_UINT64_MAX is much larger than NBASE^4, this gives us a lot of
+ * headroom to avoid normalizing carries immediately.
*
* maxdig tracks the maximum possible value of any dig[] entry; when this
- * threatens to exceed INT_MAX, we take the time to propagate carries.
- * Furthermore, we need to ensure that overflow doesn't occur during the
- * carry propagation passes either. The carry values could be as much as
- * INT_MAX/NBASE, so really we must normalize when digits threaten to
- * exceed INT_MAX - INT_MAX/NBASE.
+ * threatens to exceed PG_UINT64_MAX, we take the time to propagate
+ * carries. Furthermore, we need to ensure that overflow doesn't occur
+ * during the carry propagation passes either. The carry values could be
+ * as much as PG_UINT64_MAX / NBASE^2, so really we must normalize when
+ * digits threaten to exceed PG_UINT64_MAX - PG_UINT64_MAX / NBASE^2.
*
- * To avoid overflow in maxdig itself, it actually represents the max
- * possible value divided by NBASE-1, ie, at the top of the loop it is
- * known that no dig[] entry exceeds maxdig * (NBASE-1).
+ * To avoid overflow in maxdig itself, it actually represents the maximum
+ * possible value divided by NBASE^2-1, i.e., at the top of the loop it is
+ * known that no dig[] entry exceeds maxdig * (NBASE^2-1).
+ *
+ * The conversion of var1 to base NBASE^2 is done on the fly, as each new
+ * digit is required. The digits of var2 are converted upfront, and
+ * stored at the end of dig[]. To avoid loss of precision, the input
+ * digits are aligned with the start of digit pair array, effectively
+ * shifting them up (multiplying by NBASE) if the inputs have an odd
+ * number of NBASE digits.
*/
- dig = (int *) palloc0(res_ndigits * sizeof(int));
- maxdig = 0;
+ dig = (uint64 *) palloc(res_ndigitpairs * sizeof(uint64) +
+ var2ndigitpairs * sizeof(uint32));
+
+ /* convert var2 to base NBASE^2, shifting up if its length is odd */
+ var2digitpairs = (uint32 *) (dig + res_ndigitpairs);
+
+ for (i2 = 0; i2 < var2ndigitpairs - 1; i2++)
+ var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
+
+ if (2 * i2 + 1 < var2ndigits)
+ var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
+ else
+ var2digitpairs[i2] = var2digits[2 * i2] * NBASE;
/*
- * The least significant digits of var1 should be ignored if they don't
- * contribute directly to the first res_ndigits digits of the result that
- * we are computing.
+ * Start by multiplying var2 by the least significant contributing digit
+ * pair from var1, storing the results at the end of dig[], and filling
+ * the leading digits with zeros.
*
- * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
- * i1+i2+2 of the accumulator array, so we need only consider digits of
- * var1 for which i1 <= res_ndigits - 3.
+ * The loop here is the same as the inner loop below, except that we set
+ * the results in dig[], rather than adding to them. This is the
+ * performance bottleneck for multiplication, so we want to keep it simple
+ * enough so that it can be auto-vectorized. Accordingly, process the
+ * digits left-to-right even though schoolbook multiplication would
+ * suggest right-to-left. Since we aren't propagating carries in this
+ * loop, the order does not matter.
+ */
+ i1 = var1ndigitpairs - 1;
+ if (2 * i1 + 1 < var1ndigits)
+ var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
+ else
+ var1digitpair = var1digits[2 * i1] * NBASE;
+ maxdig = var1digitpair;
+
+ i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
+ dig_i1_off = &dig[i1 + pair_offset];
+
+ memset(dig, 0, (i1 + pair_offset) * sizeof(uint64));
+ for (i2 = 0; i2 < i2limit; i2++)
+ dig_i1_off[i2] = (uint64) var1digitpair * var2digitpairs[i2];
+
+ /*
+ * Next, multiply var2 by the remaining digit pairs from var1, adding the
+ * results to dig[] at the appropriate offsets, and normalizing whenever
+ * there is a risk of any dig[] entry overflowing.
*/
- for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
+ for (i1 = i1 - 1; i1 >= 0; i1--)
{
- NumericDigit var1digit = var1digits[i1];
-
- if (var1digit == 0)
+ var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
+ if (var1digitpair == 0)
continue;
/* Time to normalize? */
- maxdig += var1digit;
- if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
+ maxdig += var1digitpair;
+ if (maxdig > (PG_UINT64_MAX - PG_UINT64_MAX / NBASE_SQR) / (NBASE_SQR - 1))
{
- /* Yes, do it */
+ /* Yes, do it (to base NBASE^2) */
carry = 0;
- for (i = res_ndigits - 1; i >= 0; i--)
+ for (i = res_ndigitpairs - 1; i >= 0; i--)
{
newdig = dig[i] + carry;
- if (newdig >= NBASE)
+ if (newdig >= NBASE_SQR)
{
- carry = newdig / NBASE;
- newdig -= carry * NBASE;
+ carry = newdig / NBASE_SQR;
+ newdig -= carry * NBASE_SQR;
}
else
carry = 0;
@@ -8810,50 +8899,37 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
}
Assert(carry == 0);
/* Reset maxdig to indicate new worst-case */
- maxdig = 1 + var1digit;
+ maxdig = 1 + var1digitpair;
}
- /*
- * Add the appropriate multiple of var2 into the accumulator.
- *
- * As above, digits of var2 can be ignored if they don't contribute,
- * so we only include digits for which i1+i2+2 < res_ndigits.
- *
- * This inner loop is the performance bottleneck for multiplication,
- * so we want to keep it simple enough so that it can be
- * auto-vectorized. Accordingly, process the digits left-to-right
- * even though schoolbook multiplication would suggest right-to-left.
- * Since we aren't propagating carries in this loop, the order does
- * not matter.
- */
- {
- int i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
- int *dig_i1_2 = &dig[i1 + 2];
+ /* Multiply and add */
+ i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
+ dig_i1_off = &dig[i1 + pair_offset];
- for (i2 = 0; i2 < i2limit; i2++)
- dig_i1_2[i2] += var1digit * var2digits[i2];
- }
+ for (i2 = 0; i2 < i2limit; i2++)
+ dig_i1_off[i2] += (uint64) var1digitpair * var2digitpairs[i2];
}
/*
- * Now we do a final carry propagation pass to normalize the result, which
- * we combine with storing the result digits into the output. Note that
- * this is still done at full precision w/guard digits.
+ * Now we do a final carry propagation pass to normalize back to base
+ * NBASE^2, and construct the base-NBASE result digits. Note that this is
+ * still done at full precision w/guard digits.
*/
alloc_var(result, res_ndigits);
res_digits = result->digits;
carry = 0;
- for (i = res_ndigits - 1; i >= 0; i--)
+ for (i = res_ndigitpairs - 1; i >= 0; i--)
{
newdig = dig[i] + carry;
- if (newdig >= NBASE)
+ if (newdig >= NBASE_SQR)
{
- carry = newdig / NBASE;
- newdig -= carry * NBASE;
+ carry = newdig / NBASE_SQR;
+ newdig -= carry * NBASE_SQR;
}
else
carry = 0;
- res_digits[i] = newdig;
+ res_digits[2 * i + 1] = (NumericDigit) ((uint32) newdig % NBASE);
+ res_digits[2 * i] = (NumericDigit) ((uint32) newdig / NBASE);
}
Assert(carry == 0);