aboutsummaryrefslogtreecommitdiff
path: root/ext/misc/decimal.c
diff options
context:
space:
mode:
authordrh <>2023-06-29 20:28:03 +0000
committerdrh <>2023-06-29 20:28:03 +0000
commitec3e57fa92df277f6a92c5232f1b5af581fe9232 (patch)
tree1f2e1ec5708222a50b0510a05e9987ecb75c7b7c /ext/misc/decimal.c
parent500ca334bd1d4f25aaa3e4ac58bdc8076dff345a (diff)
downloadsqlite-ec3e57fa92df277f6a92c5232f1b5af581fe9232.tar.gz
sqlite-ec3e57fa92df277f6a92c5232f1b5af581fe9232.zip
Enhancements to the DECIMAL extension:
(1) If the argument to decimal(X) is a floating point value (or an 8-byte blob), the floating point value is expanded into its exact decimal representation. (2) Function decimal_sci(X) works the same except it returns the result in scientific notation. (3) New function decimal_pow2(N) returns the full decimal expansion of the N-th integer power of 2. FossilOrigin-Name: 8baf8c10aecb261751f2b154356ab224b79d07230929ec9f123791278e601bba
Diffstat (limited to 'ext/misc/decimal.c')
-rw-r--r--ext/misc/decimal.c281
1 files changed, 214 insertions, 67 deletions
diff --git a/ext/misc/decimal.c b/ext/misc/decimal.c
index 865f3ce24..6c080c200 100644
--- a/ext/misc/decimal.c
+++ b/ext/misc/decimal.c
@@ -292,22 +292,6 @@ static void decimal_result_sci(sqlite3_context *pCtx, Decimal *p){
}
/*
-** SQL Function: decimal(X)
-**
-** Convert input X into decimal and then back into text
-*/
-static void decimalFunc(
- sqlite3_context *context,
- int argc,
- sqlite3_value **argv
-){
- Decimal *p = decimal_new(context, argv[0], 0, 0);
- UNUSED_PARAMETER(argc);
- decimal_result(context, p);
- decimal_free(p);
-}
-
-/*
** Compare to Decimal objects. Return negative, 0, or positive if the
** first object is less than, equal to, or greater than the second.
**
@@ -399,7 +383,7 @@ static void decimal_expand(Decimal *p, int nDigit, int nFrac){
}
/*
-** Add the value pB into pA.
+** Add the value pB into pA. A := A + B.
**
** Both pA and pB might become denormalized by this routine.
*/
@@ -469,6 +453,200 @@ static void decimal_add(Decimal *pA, Decimal *pB){
}
/*
+** Multiply A by B. A := A * B
+**
+** All significant digits after the decimal point are retained.
+** Trailing zeros after the decimal point are omitted as long as
+** the number of digits after the decimal point is no less than
+** either the number of digits in either input.
+*/
+static void decimalMul(Decimal *pA, Decimal *pB){
+ signed char *acc = 0;
+ int i, j, k;
+ int minFrac;
+
+ if( pA==0 || pA->oom || pA->isNull
+ || pB==0 || pB->oom || pB->isNull
+ ){
+ goto mul_end;
+ }
+ acc = sqlite3_malloc64( pA->nDigit + pB->nDigit + 2 );
+ if( acc==0 ){
+ pA->oom = 1;
+ goto mul_end;
+ }
+ memset(acc, 0, pA->nDigit + pB->nDigit + 2);
+ minFrac = pA->nFrac;
+ if( pB->nFrac<minFrac ) minFrac = pB->nFrac;
+ for(i=pA->nDigit-1; i>=0; i--){
+ signed char f = pA->a[i];
+ int carry = 0, x;
+ for(j=pB->nDigit-1, k=i+j+3; j>=0; j--, k--){
+ x = acc[k] + f*pB->a[j] + carry;
+ acc[k] = x%10;
+ carry = x/10;
+ }
+ x = acc[k] + carry;
+ acc[k] = x%10;
+ acc[k-1] += x/10;
+ }
+ sqlite3_free(pA->a);
+ pA->a = acc;
+ acc = 0;
+ pA->nDigit += pB->nDigit + 2;
+ pA->nFrac += pB->nFrac;
+ pA->sign ^= pB->sign;
+ while( pA->nFrac>minFrac && pA->a[pA->nDigit-1]==0 ){
+ pA->nFrac--;
+ pA->nDigit--;
+ }
+
+mul_end:
+ sqlite3_free(acc);
+}
+
+/*
+** Create a new Decimal object that contains an integer power of 2.
+*/
+static Decimal *decimalPow2(int N){
+ Decimal *pA = 0; /* The result to be returned */
+ Decimal *pX = 0; /* Multiplier */
+ if( N<-20000 || N>20000 ) goto pow2_fault;
+ pA = decimal_new(0, 0, 3, (unsigned char*)"1.0");
+ if( pA==0 || pA->oom ) goto pow2_fault;
+ if( N==0 ) return pA;
+ if( N>0 ){
+ pX = decimal_new(0, 0, 3, (unsigned char*)"2.0");
+ }else{
+ N = -N;
+ pX = decimal_new(0, 0, 3, (unsigned char*)"0.5");
+ }
+ if( pX==0 || pX->oom ) goto pow2_fault;
+ while( 1 /* Exit by break */ ){
+ if( N & 1 ){
+ decimalMul(pA, pX);
+ if( pA->oom ) goto pow2_fault;
+ }
+ N >>= 1;
+ if( N==0 ) break;
+ decimalMul(pX, pX);
+ }
+ decimal_free(pX);
+ return pA;
+
+pow2_fault:
+ decimal_free(pA);
+ decimal_free(pX);
+ return 0;
+}
+
+/*
+** Use an IEEE754 binary64 ("double") to generate a new Decimal object.
+*/
+static Decimal *decimalFromDouble(double r){
+ sqlite3_int64 m, a;
+ int e;
+ int isNeg;
+ Decimal *pA;
+ Decimal *pX;
+ char zNum[100];
+ if( r<0.0 ){
+ isNeg = 1;
+ r = -r;
+ }else{
+ isNeg = 0;
+ }
+ memcpy(&a,&r,sizeof(a));
+ if( a==0 ){
+ e = 0;
+ m = 0;
+ }else{
+ e = a>>52;
+ m = a & ((((sqlite3_int64)1)<<52)-1);
+ if( e==0 ){
+ m <<= 1;
+ }else{
+ m |= ((sqlite3_int64)1)<<52;
+ }
+ while( e<1075 && m>0 && (m&1)==0 ){
+ m >>= 1;
+ e++;
+ }
+ if( isNeg ) m = -m;
+ e = e - 1075;
+ if( e>971 ){
+ return 0; /* A NaN or an Infinity */
+ }
+ }
+
+ /* At this point m is the integer significand and e is the exponent */
+ sqlite3_snprintf(sizeof(zNum), zNum, "%lld", m);
+ pA = decimal_new(0, 0, (int)strlen(zNum), (unsigned char*)zNum);
+ pX = decimalPow2(e);
+ decimalMul(pA, pX);
+ decimal_free(pX);
+ return pA;
+}
+
+/*
+** SQL Function: decimal(X)
+**
+** Convert input X into decimal and then back into text.
+**
+** If X is originally a float, then a full decoding of that floating
+** point value is done. Or if X is an 8-byte blob, it is interpreted
+** as a float and similarly expanded.
+*/
+static void decimalFunc(
+ sqlite3_context *context,
+ int argc,
+ sqlite3_value **argv
+){
+ Decimal *p = 0;
+ UNUSED_PARAMETER(argc);
+ switch( sqlite3_value_type(argv[0]) ){
+ case SQLITE_TEXT:
+ case SQLITE_INTEGER: {
+ p = decimal_new(context, argv[0], 0, 0);
+ break;
+ }
+
+ case SQLITE_FLOAT: {
+ p = decimalFromDouble(sqlite3_value_double(argv[0]));
+ break;
+ }
+
+ case SQLITE_BLOB: {
+ const unsigned char *x;
+ unsigned int i;
+ sqlite3_uint64 v = 0;
+ double r;
+
+ if( sqlite3_value_bytes(argv[0])!=sizeof(r) ) break;
+ x = sqlite3_value_blob(argv[0]);
+ for(i=0; i<sizeof(r); i++){
+ v = (v<<8) | x[i];
+ }
+ memcpy(&r, &v, sizeof(r));
+ p = decimalFromDouble(r);
+ break;
+ }
+
+ case SQLITE_NULL: {
+ break;
+ }
+ }
+ if( p ){
+ if( sqlite3_user_data(context)!=0 ){
+ decimal_result_sci(context, p);
+ }else{
+ decimal_result(context, p);
+ }
+ decimal_free(p);
+ }
+}
+
+/*
** Compare text in decimal order.
*/
static int decimalCollFunc(
@@ -592,11 +770,6 @@ static void decimalSumFinalize(sqlite3_context *context){
** SQL Function: decimal_mul(X, Y)
**
** Return the product of X and Y.
-**
-** All significant digits after the decimal point are retained.
-** Trailing zeros after the decimal point are omitted as long as
-** the number of digits after the decimal point is no less than
-** either the number of digits in either input.
*/
static void decimalMulFunc(
sqlite3_context *context,
@@ -605,67 +778,39 @@ static void decimalMulFunc(
){
Decimal *pA = decimal_new(context, argv[0], 0, 0);
Decimal *pB = decimal_new(context, argv[1], 0, 0);
- signed char *acc = 0;
- int i, j, k;
- int minFrac;
UNUSED_PARAMETER(argc);
if( pA==0 || pA->oom || pA->isNull
|| pB==0 || pB->oom || pB->isNull
){
goto mul_end;
}
- acc = sqlite3_malloc64( pA->nDigit + pB->nDigit + 2 );
- if( acc==0 ){
- sqlite3_result_error_nomem(context);
+ decimalMul(pA, pB);
+ if( pA->oom ){
goto mul_end;
}
- memset(acc, 0, pA->nDigit + pB->nDigit + 2);
- minFrac = pA->nFrac;
- if( pB->nFrac<minFrac ) minFrac = pB->nFrac;
- for(i=pA->nDigit-1; i>=0; i--){
- signed char f = pA->a[i];
- int carry = 0, x;
- for(j=pB->nDigit-1, k=i+j+3; j>=0; j--, k--){
- x = acc[k] + f*pB->a[j] + carry;
- acc[k] = x%10;
- carry = x/10;
- }
- x = acc[k] + carry;
- acc[k] = x%10;
- acc[k-1] += x/10;
- }
- sqlite3_free(pA->a);
- pA->a = acc;
- acc = 0;
- pA->nDigit += pB->nDigit + 2;
- pA->nFrac += pB->nFrac;
- pA->sign ^= pB->sign;
- while( pA->nFrac>minFrac && pA->a[pA->nDigit-1]==0 ){
- pA->nFrac--;
- pA->nDigit--;
- }
decimal_result(context, pA);
mul_end:
- sqlite3_free(acc);
decimal_free(pA);
decimal_free(pB);
}
/*
-** SQL Function: decimal_sci(X)
+** SQL Function: decimal_pow2(N)
**
-** Convert decimal number X into scientific notation ("+N.NNNe+NN").
+** Return the N-th power of 2. N must be an integer.
*/
-static void decimalSciFunc(
+static void decimalPow2Func(
sqlite3_context *context,
int argc,
sqlite3_value **argv
){
- Decimal *pA = decimal_new(context, argv[0], 0, 0);
UNUSED_PARAMETER(argc);
- decimal_result_sci(context, pA);
- decimal_free(pA);
+ if( sqlite3_value_type(argv[0])==SQLITE_INTEGER ){
+ Decimal *pA = decimalPow2(sqlite3_value_int(argv[0]));
+ decimal_result_sci(context, pA);
+ decimal_free(pA);
+ }
}
#ifdef _WIN32
@@ -680,14 +825,16 @@ int sqlite3_decimal_init(
static const struct {
const char *zFuncName;
int nArg;
+ int iArg;
void (*xFunc)(sqlite3_context*,int,sqlite3_value**);
} aFunc[] = {
- { "decimal", 1, decimalFunc },
- { "decimal_cmp", 2, decimalCmpFunc },
- { "decimal_add", 2, decimalAddFunc },
- { "decimal_sub", 2, decimalSubFunc },
- { "decimal_mul", 2, decimalMulFunc },
- { "decimal_sci", 1, decimalSciFunc },
+ { "decimal", 1, 0, decimalFunc },
+ { "decimal_cmp", 2, 0, decimalCmpFunc },
+ { "decimal_add", 2, 0, decimalAddFunc },
+ { "decimal_sub", 2, 0, decimalSubFunc },
+ { "decimal_mul", 2, 0, decimalMulFunc },
+ { "decimal_sci", 1, 1, decimalFunc },
+ { "decimal_pow2", 1, 0, decimalPow2Func },
};
unsigned int i;
(void)pzErrMsg; /* Unused parameter */
@@ -697,7 +844,7 @@ int sqlite3_decimal_init(
for(i=0; i<(int)(sizeof(aFunc)/sizeof(aFunc[0])) && rc==SQLITE_OK; i++){
rc = sqlite3_create_function(db, aFunc[i].zFuncName, aFunc[i].nArg,
SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC,
- 0, aFunc[i].xFunc, 0, 0);
+ aFunc[i].iArg ? db : 0, aFunc[i].xFunc, 0, 0);
}
if( rc==SQLITE_OK ){
rc = sqlite3_create_window_function(db, "decimal_sum", 1,