]> git.kaiwu.me - klib.git/commitdiff
added Brent's root finding; not tested
authorHeng Li <lh3@me.com>
Wed, 6 Dec 2017 20:28:11 +0000 (15:28 -0500)
committerHeng Li <lh3@me.com>
Wed, 6 Dec 2017 20:28:11 +0000 (15:28 -0500)
kmath.c

diff --git a/kmath.c b/kmath.c
index 1a90044e942c49c00a6133957efc4b570237d056..e51b3a78bc07c8b20d3e32879d8397e0aa40bea9 100644 (file)
--- a/kmath.c
+++ b/kmath.c
@@ -286,6 +286,76 @@ double kmin_brent(kmin1_f func, double a, double b, void *data, double tol, doub
        return fb;
 }
 
+static inline float SIGN(float a, float b)
+{
+       return b >= 0 ? (a >= 0 ? a : -a) : (a >= 0 ? -a : a);
+}
+
+double krf_brent(double x1, double x2, double tol, double (*func)(double, void*), void *data, int *err)
+{
+       const int max_iter = 100;
+       const double eps = 3e-8f;
+       int i;
+       double a = x1, b = x2, c = x2, d, e, min1, min2;
+       double fa, fb, fc, p, q, r, s, tol1, xm;
+
+       *err = 0;
+       fa = func(a, data), fb = func(b, data);
+       if ((fa > 0.0f && fb > 0.0f) || (fa < 0.0f && fb < 0.0f)) {
+               *err = -1;
+               return 0.0f;
+       }
+       fc = fb;
+       for (i = 0; i < max_iter; ++i) {
+               if ((fb > 0.0f && fc > 0.0f) || (fb < 0.0f && fc < 0.0f)) {
+                       c = a;
+                       fc = fa;
+                       e = d = b - a;
+               }
+               if (fabs(fc) < fabs(fb)) {
+                       a = b, b = c, c = a;
+                       fa = fb, fb = fc, fc = fa;
+               }
+               tol1 = 2.0f * eps * fabs(b) + 0.5f * tol;
+               xm = 0.5f * (c - b);
+               if (fabs(xm) <= tol1 || fb == 0.0f)
+                       return b;
+               if (fabs(e) >= tol1 && fabs(fa) > fabs(fb)) {
+                       s = fb / fa;
+                       if (a == c) {
+                               p = 2.0f * xm * s;
+                               q = 1.0f - s;
+                       } else {
+                               q = fa / fc;
+                               r = fb / fc;
+                               p = s * (2.0f * xm * q * (q - r) - (b - a) * (r - 1.0f));
+                               q = (q - 1.0f) * (r - 1.0f) * (s - 1.0f);
+                       }
+                       if (p > 0.0f) q = -q;
+                       p = fabs(p);
+                       min1 = 3.0f * xm * q - fabs(tol1 * q);
+                       min2 = fabs(e * q);
+                       if (2.0f * p < (min1 < min2 ? min1 : min2)) {
+                               e = d;
+                               d = p / q;
+                       } else {
+                               d = xm;
+                               e = d;
+                       }
+               } else {
+                       d = xm;
+                       e = d;
+               }
+               a = b;
+               fa = fb;
+               if (fabs(d) > tol1) b += d;
+               else b += SIGN(tol1, xm);
+               fb = func(b, data);
+       }
+       *err = -2;
+       return 0.0;
+}
+
 /*************************
  *** Special functions ***
  *************************/