1 /** Algorithms for finding roots and extrema of one-argument real functions
2  * using bracketing.
3  *
4  * Copyright: Copyright (C) 2008 Don Clugston.
5  * License:   BSD style: $(LICENSE), Digital Mars.
6  * Authors:   Don Clugston.
7  *
8  */
9 module tango.math.Bracket;
10 import tango.math.Math;
11 import tango.math.IEEE;
13 private:
15 // return true if a and b have opposite sign
16 bool oppositeSigns(T)(T a, T b)
17 {    
18     return (signbit(a) ^ signbit(b))!=0;
19 }
21 // TODO: This should be exposed publically, but needs a better name.
22 struct BracketResult(T, R)
23 {
24     T xlo;
25     T xhi;
26     R fxlo;
27     R fxhi;
28 }
30 public:
32 /**  Find a real root of the real function f(x) via bracketing.
33  *
34  * Given a range [a..b] such that f(a) and f(b) have opposite sign,
35  * returns the value of x in the range which is closest to a root of f(x).
36  * If f(x) has more than one root in the range, one will be chosen arbitrarily.
37  * If f(x) returns $(NAN), $(NAN) will be returned; otherwise, this algorithm
38  * is guaranteed to succeed. 
39  *  
40  * Uses an algorithm based on TOMS748, which uses inverse cubic interpolation 
41  * whenever possible, otherwise reverting to parabolic or secant
42  * interpolation. Compared to TOMS748, this implementation improves worst-case
43  * performance by a factor of more than 100, and typical performance by a factor
44  * of 2. For 80-bit reals, most problems require 8 - 15 calls to f(x) to achieve
45  * full machine precision. The worst-case performance (pathological cases) is 
46  * approximately twice the number of bits. 
47  *
48  * References: 
49  * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra, 
50  *   Yixun Shi, Mathematics of Computation 61, pp733-744 (1993).
51  *   Fortran code available from www.netlib.org as algorithm TOMS478.
52  *
53  */
54 T findRoot(T, R)(scope R delegate(T) f, T ax, T bx)
55 {
56     auto r = findRoot(f, ax, bx, f(ax), f(bx), (BracketResult!(T,R) r){ 
57          return r.xhi==nextUp(r.xlo); });
58     return fabs(r.fxlo)<=fabs(r.fxhi) ? r.xlo : r.xhi;
59 }
61 private:
63 /** Find root by bracketing, allowing termination condition to be specified
64  *
65  * Params:
66  * tolerance   Defines the termination condition. Return true when acceptable
67  *             bounds have been obtained.
68  */
69 BracketResult!(T, R) findRoot(T,R)(scope R delegate(T) f, T ax, T bx, R fax, R fbx,
70     scope bool delegate(BracketResult!(T,R) r) tolerance)
71 in {
72     assert(ax<=bx, "Parameters ax and bx out of order.");
73     assert(!isNaN(ax) && !isNaN(bx), "Limits must not be NaN");
74     assert(oppositeSigns(fax,fbx), "Parameters must bracket the root.");
75 }
76 body {   
77 // This code is (heavily) modified from TOMS748 (www.netlib.org). Some ideas
78 // were borrowed from the Boost Mathematics Library.
80     T a = ax, b = bx, d;  // [a..b] is our current bracket.
81     R fa = fax, fb = fbx, fd; // d is the third best guess.       
83     // Test the function at point c; update brackets accordingly
84     void bracket(T c)
85     {
86         T fc = f(c);        
87         if (fc == 0) { // Exact solution
88             a = c;
89             fa = fc;
90             d = c;
91             fd = fc;
92             return;
93         }
94         // Determine new enclosing interval
95         if (oppositeSigns(fa, fc)) {
96             d = b;
97             fd = fb;
98             b = c;
99             fb = fc;
100         } else {
101             d = a;
102             fd = fa;
103             a = c;
104             fa = fc;
105         }
106     }
108    /* Perform a secant interpolation. If the result would lie on a or b, or if
109      a and b differ so wildly in magnitude that the result would be meaningless,
110      perform a bisection instead.
111     */
112     T secant_interpolate(T a, T b, T fa, T fb)
113     {
114         if (( ((a - b) == a) && b!=0) || (a!=0 && ((b - a) == b))) {
115             // Catastrophic cancellation
116             if (a == 0) a = copysign(0.0L, b);
117             else if (b == 0) b = copysign(0.0L, a);
118             else if (oppositeSigns(a, b)) return 0;
119             T c = ieeeMean(a, b); 
120             return c;
121         }
122        // avoid overflow
123        if (b - a > T.max)    return b / 2.0 + a / 2.0;
124        if (fb - fa > T.max)  return a - (b - a) / 2;
125        T c = a - (fa / (fb - fa)) * (b - a);
126        if (c == a || c == b) return (a + b) / 2;
127        return c;
128     }
130     /* Uses 'numsteps' newton steps to approximate the zero in [a..b] of the
131        quadratic polynomial interpolating f(x) at a, b, and d.
132        Returns:         
133          The approximate zero in [a..b] of the quadratic polynomial.
134     */
135     T newtonQuadratic(int numsteps)
136     {
137         // Find the coefficients of the quadratic polynomial.
138         T a0 = fa;
139         T a1 = (fb - fa)/(b - a);
140         T a2 = ((fd - fb)/(d - b) - a1)/(d - a);
142         // Determine the starting point of newton steps.
143         T c = oppositeSigns(a2, fa) ? a  : b;
145         // start the safeguarded newton steps.
146         for (int i = 0; i<numsteps; ++i) {        
147             T pc = a0 + (a1 + a2 * (c - b))*(c - a);
148             T pdc = a1 + a2*((2.0 * c) - (a + b));
149             if (pdc == 0) return a - a0 / a1;
150             else c = c - pc / pdc;        
151         }
152         return c;    
153     }
155     // On the first iteration we take a secant step:
156     if(fa != 0) {
157         bracket(secant_interpolate(a, b, fa, fb));
158     }
159     // Starting with the second iteration, higher-order interpolation can
160     // be used.
161     int itnum = 1;   // Iteration number    
162     int baditer = 1; // Num bisections to take if an iteration is bad.
163     T c, e;  // e is our fourth best guess
164     R fe;   
165 whileloop:
166     while((fa != 0) && !tolerance(BracketResult!(T,R)(a, b, fa, fb))) {        
167         T a0 = a, b0 = b; // record the brackets
169         // Do two higher-order (cubic or parabolic) interpolation steps.
170         for (int QQ = 0; QQ < 2; ++QQ) {      
171             // Cubic inverse interpolation requires that 
172             // all four function values fa, fb, fd, and fe are distinct; 
173             // otherwise use quadratic interpolation.
174             bool distinct = (fa != fb) && (fa != fd) && (fa != fe) 
175                          && (fb != fd) && (fb != fe) && (fd != fe);
176             // The first time, cubic interpolation is impossible.
177             if (itnum<2) distinct = false;
178             bool ok = distinct;
179             if (distinct) {                
180                 // Cubic inverse interpolation of f(x) at a, b, d, and e
181                 real q11 = (d - e) * fd / (fe - fd);
182                 real q21 = (b - d) * fb / (fd - fb);
183                 real q31 = (a - b) * fa / (fb - fa);
184                 real d21 = (b - d) * fd / (fd - fb);
185                 real d31 = (a - b) * fb / (fb - fa);
187                 real q22 = (d21 - q11) * fb / (fe - fb);
188                 real q32 = (d31 - q21) * fa / (fd - fa);
189                 real d32 = (d31 - q21) * fd / (fd - fa);
190                 real q33 = (d32 - q22) * fa / (fe - fa);
191                 c = a + (q31 + q32 + q33);
192                 if (isNaN(c) || (c <= a) || (c >= b)) {
193                     // DAC: If the interpolation predicts a or b, it's 
194                     // probable that it's the actual root. Only allow this if
195                     // we're already close to the root.                
196                     if (c == a && a - b != a) {
197                         c = nextUp(a);
198                     }
199                     else if (c == b && a - b != -b) {
200                         c = nextDown(b);
201                     } else {
202                         ok = false;
203                     }
204                 }
205             }
206             if (!ok) {
207                c = newtonQuadratic(distinct ? 3 : 2);
208                if(isNaN(c) || (c <= a) || (c >= b)) {
209                   // Failure, try a secant step:
210                   c = secant_interpolate(a, b, fa, fb);
211                }
212             }
213             ++itnum;                
214             e = d;
215             fe = fd;
216             bracket(c);
217             if((fa == 0) || tolerance(BracketResult!(T,R)(a, b, fa, fb)))
218                 break whileloop;
219             if (itnum == 2)
220                 continue whileloop;
221         }
222         // Now we take a double-length secant step:
223         T u;
224         R fu;
225         if(fabs(fa) < fabs(fb)) {
226              u = a;
227              fu = fa;
228         } else {
229              u = b;
230              fu = fb;
231         }
232         c = u - 2 * (fu / (fb - fa)) * (b - a);
233         // DAC: If the secant predicts a value equal to an endpoint, it's
234         // probably false.      
235         if(c==a || c==b || isNaN(c) || fabs(c - u) > (b - a) / 2) {
236             if ((a-b) == a || (b-a) == b) {
237                 if ( (a>0 && b<0) || (a<0 && b>0) ) c = 0;
238                 else {
239                    if (a==0) c = ieeeMean(copysign(0.0L, b), b);
240                    else if (b==0) c = ieeeMean(copysign(0.0L, a), a);
241                    else c = ieeeMean(a, b);
242                 }
243             } else {
244                 c = a + (b - a) / 2;
245             }       
246         }
247         e = d;
248         fe = fd;
249         bracket(c);
250         if((fa == 0) || tolerance(BracketResult!(T,R)(a, b, fa, fb)))
251             break;
253         // We must ensure that the bounds reduce by a factor of 2 
254         // (DAC: in binary space!) every iteration. If we haven't achieved this
255         // yet (DAC: or if we don't yet know what the exponent is),
256         // perform a binary chop.
258         if( (a==0 || b==0 || 
259             (fabs(a) >= 0.5 * fabs(b) && fabs(b) >= 0.5 * fabs(a))) 
260             &&  (b - a) < 0.25 * (b0 - a0))  {
261                 baditer = 1;        
262                 continue;
263             }
264         // DAC: If this happens on consecutive iterations, we probably have a
265         // pathological function. Perform a number of bisections equal to the
266         // total number of consecutive bad iterations.
268         if ((b - a) < 0.25 * (b0 - a0)) baditer=1;
269         for (int QQ = 0; QQ < baditer ;++QQ) {
270             e = d;
271             fe = fd;
273             T w;
274             if ((a>0 && b<0) ||(a<0 && b>0)) w = 0;
275             else {
276                 T usea = a;
277                 T useb = b;
278                 if (a == 0) usea = copysign(0.0L, b);
279                 else if (b == 0) useb = copysign(0.0L, a);
280                 w = ieeeMean(usea, useb);
281             }
282             bracket(w);
283         }
284         ++baditer;
285     }
287     if (fa == 0) return BracketResult!(T, R)(a, a, fa, fa);
288     else if (fb == 0) return BracketResult!(T, R)(b, b, fb, fb);
289     else return BracketResult!(T, R)(a, b, fa, fb);
290 }
292 public:
293 /**
294  * Find the minimum value of the function func().
295  *
296  * Returns the value of x such that func(x) is minimised. Uses Brent's method, 
297  * which uses a parabolic fit to rapidly approach the minimum but reverts to a
298  * Golden Section search where necessary.
299  *
300  * The minimum is located to an accuracy of feqrel(min, truemin) < 
301  * real.mant_dig/2.
302  *
303  * Parameters:
304  *     func         The function to be minimized
305  *     xinitial     Initial guess to be used.
306  *     xlo, xhi     Upper and lower bounds on x.
307  *                  func(xinitial) <= func(x1) and func(xinitial) <= func(x2)
308  *     funcMin      The minimum value of func(x).
309  */
310 T findMinimum(T,R)(scope R delegate(T) func, T xlo, T xhi, T xinitial, 
311      out R funcMin)
312 in {
313     assert(xlo <= xhi);
314     assert(xinitial >= xlo);
315     assert(xinitial <= xhi);
316     assert(func(xinitial) <= func(xlo) && func(xinitial) <= func(xhi));
317 }
318 body{
319     // Based on the original Algol code by R.P. Brent.
320     enum real GOLDENRATIO = 0.3819660112501051; // (3 - sqrt(5))/2 = 1 - 1/phi
322     T stepBeforeLast = 0.0;
323     T lastStep;
324     T bestx = xinitial; // the best value so far (min value for f(x)).
325     R fbest = func(bestx);
326     T second = xinitial;  // the point with the second best value of f(x)
327     R fsecond = fbest;
328     T third = xinitial;  // the previous value of second.
329     R fthird = fbest;
330     int numiter = 0;
331     for (;;) {
332         ++numiter;
333         T xmid = 0.5 * (xlo + xhi);
334         enum real SQRTEPSILON = 3e-10L; // sqrt(real.epsilon)
335         T tol1 = SQRTEPSILON * fabs(bestx);
336         T tol2 = 2.0 * tol1;
337         if (fabs(bestx - xmid) <= (tol2 - 0.5*(xhi - xlo)) ) {
338             funcMin = fbest;
339             return bestx;
340         }
341         if (fabs(stepBeforeLast) > tol1) {
342             // trial parabolic fit
343             real r = (bestx - second) * (fbest - fthird);
344             // DAC: This can be infinite, in which case lastStep will be NaN.
345             real denom = (bestx - third) * (fbest - fsecond);
346             real numerator = (bestx - third) * denom - (bestx - second) * r;
347             denom = 2.0 * (denom-r);
348             if ( denom > 0) numerator = -numerator;
349             denom = fabs(denom);
350             // is the parabolic fit good enough?
351             // it must be a step that is less than half the movement
352             // of the step before last, AND it must fall
353             // into the bounding interval [xlo,xhi].
354             if (fabs(numerator) >= fabs(0.5 * denom * stepBeforeLast)
355                 || numerator <= denom*(xlo-bestx) 
356                 || numerator >= denom*(xhi-bestx)) {
357                 // No, use a golden section search instead.
358                 // Step into the larger of the two segments.
359                 stepBeforeLast = (bestx >= xmid) ? xlo - bestx : xhi - bestx;
360                 lastStep = GOLDENRATIO * stepBeforeLast;
361             } else {
362                 // parabola is OK
363                 stepBeforeLast = lastStep;
364                 lastStep = numerator/denom;
365                 real xtest = bestx + lastStep;
366                 if (xtest-xlo < tol2 || xhi-xtest < tol2) {
367                     if (xmid-bestx > 0)
368                         lastStep = tol1;
369                     else lastStep = -tol1;
370                 }
371             }
372         } else {
373             // Use a golden section search instead
374             stepBeforeLast = bestx >= xmid ? xlo - bestx : xhi - bestx;
375             lastStep = GOLDENRATIO * stepBeforeLast;
376         }
377         T xtest;
378         if (fabs(lastStep) < tol1 || isNaN(lastStep)) {
379             if (lastStep > 0) lastStep = tol1;
380             else lastStep = - tol1;
381         }
382         xtest = bestx + lastStep;
383         // Evaluate the function at point xtest.
384         R ftest = func(xtest);
386         if (ftest <= fbest) {
387             // We have a new best point!
388             // The previous best point becomes a limit.
389             if (xtest >= bestx) xlo = bestx; else xhi = bestx;
390             third = second;  fthird = fsecond;
391             second = bestx;  fsecond = fbest;
392             bestx = xtest;  fbest = ftest;
393         } else {
394             // This new point is now one of the limits.
395             if (xtest < bestx)  xlo = xtest; else xhi = xtest;
396             // Is it a new second best point?
397             if (ftest < fsecond || second == bestx) {
398                 third = second;  fthird = fsecond;
399                 second = xtest;  fsecond = ftest;
400             } else if (ftest <= fthird || third == bestx || third == second) {
401                 // At least it's our third best point!
402                 third = xtest;  fthird = ftest;
403             }
404         }
405     }
406 }
408 private:
409 debug(UnitTest) {
410 unittest{
412     int numProblems = 0;
413     int numCalls;
415     void testFindRoot(scope real delegate(real) f, real x1, real x2) {
416         numCalls=0;
417         ++numProblems;
418         assert(!isNaN(x1) && !isNaN(x2));
419         auto result = findRoot(f, x1, x2, f(x1), f(x2),
420             (BracketResult!(real, real) r){ return r.xhi==nextUp(r.xlo); });
422         auto flo = f(result.xlo);
423         auto fhi = f(result.xhi);
424         if (flo!=0) {
425             assert(oppositeSigns(flo, fhi));
426         }
427     }
429     // Test functions
430     real cubicfn (real x) {
431        ++numCalls;
432        if (x>float.max) x = float.max;
433        if (x<-double.max) x = -double.max;
434        // This has a single real root at -59.286543284815
435        return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2;
436     }
437     // Test a function with more than one root.
438     real multisine(real x) { ++numCalls; return sin(x); }
439     testFindRoot( &multisine, 6, 90);
440     testFindRoot(&cubicfn, -100, 100);    
441     testFindRoot( &cubicfn, -double.max, real.max);
444 /* Tests from the paper:
445  * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra, 
446  *   Yixun Shi, Mathematics of Computation 61, pp733-744 (1993).
447  */
448     // Parameters common to many alefeld tests.
449     int n;
450     real ale_a, ale_b;
452     int powercalls = 0;
454     real power(real x) {
455         ++powercalls;
456         ++numCalls;
457         return pow(x, n) + double.min_normal;
458     }
459     int [] power_nvals = [3, 5, 7, 9, 19, 25];
460     // Alefeld paper states that pow(x,n) is a very poor case, where bisection
461     // outperforms his method, and gives total numcalls = 
462     // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit), 
463     // 2624 for brent (6.8/bit)
464     // ... but that is for double, not real80.
465     // This poor performance seems mainly due to catastrophic cancellation, 
466     // which is avoided here by the use of ieeeMean().
467     // I get: 231 (0.48/bit).
468     // IE this is 10X faster in Alefeld's worst case
469     numProblems=0;
470     foreach(k; power_nvals) {
471         n = k;
472         testFindRoot(&power, -1, 10);
473     }
475     int powerProblems = numProblems;
477     // Tests from Alefeld paper
479     int [9] alefeldSums;
480     real alefeld0(real x){
481         ++alefeldSums[0];
482         ++numCalls;
483         real q =  sin(x) - x/2;
484         for (int i=1; i<20; ++i)
485             q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i));
486         return q;
487     }
488    real alefeld1(real x) {
489         ++numCalls;
490        ++alefeldSums[1];
491        return ale_a*x + exp(ale_b * x);
492    }
493    real alefeld2(real x) {
494         ++numCalls;
495        ++alefeldSums[2];
496        return pow(x, n) - ale_a;
497    }
498    real alefeld3(real x) {
499         ++numCalls;
500        ++alefeldSums[3];
501        return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2);
502    }
503    real alefeld4(real x) {
504         ++numCalls;
505        ++alefeldSums[4];
506        return x*x - pow(1-x, n);
507    }
509    real alefeld5(real x) {
510         ++numCalls;
511        ++alefeldSums[5];
512        return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4);
513    }
515    real alefeld6(real x) {
516         ++numCalls;
517        ++alefeldSums[6];
518        return exp(-n*x)*(x-1.01L) + pow(x, n);
519    }
521    real alefeld7(real x) {
522         ++numCalls;
523        ++alefeldSums[7];
524        return (n*x-1)/((n-1)*x);
525    }
526    numProblems=0;
527    testFindRoot(&alefeld0, PI_2, PI);
528    for (n=1; n<=10; ++n) {
529     testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L);
530    }
531    ale_a = -40; ale_b = -1;
532    testFindRoot(&alefeld1, -9, 31);
533    ale_a = -100; ale_b = -2;
534    testFindRoot(&alefeld1, -9, 31);
535    ale_a = -200; ale_b = -3;
536    testFindRoot(&alefeld1, -9, 31);
537    int [] nvals_3 = [1, 2, 5, 10, 15, 20];
538    int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20];
539    int [] nvals_6 = [1, 5, 10, 15, 20];
540    int [] nvals_7 = [2, 5, 15, 20];
542     for(int i=4; i<12; i+=2) {
543        n = i;
544        ale_a = 0.2;
545        testFindRoot(&alefeld2, 0, 5);
546        ale_a=1;
547        testFindRoot(&alefeld2, 0.95, 4.05);
548        testFindRoot(&alefeld2, 0, 1.5);       
549     }
550     foreach(i; nvals_3) {
551         n=i;
552         testFindRoot(&alefeld3, 0, 1);
553     }
554     foreach(i; nvals_3) {
555         n=i;
556         testFindRoot(&alefeld4, 0, 1);
557     }
558     foreach(i; nvals_5) {
559         n=i;
560         testFindRoot(&alefeld5, 0, 1);
561     }
562     foreach(i; nvals_6) {
563         n=i;
564         testFindRoot(&alefeld6, 0, 1);
565     }
566     foreach(i; nvals_7) {
567         n=i;
568         testFindRoot(&alefeld7, 0.01L, 1);
569     }   
570     real worstcase(real x) { ++numCalls;
571         return x<0.3*real.max? -0.999e-3 : 1.0;
572     }
573     testFindRoot(&worstcase, -real.max, real.max);
575 /*   
576    int grandtotal=0;
577    foreach(calls; alefeldSums) {
578        grandtotal+=calls;
579    }
580    grandtotal-=2*numProblems;
581    printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n", 
582    grandtotal, (1.0*grandtotal)/numProblems);
583    powercalls -= 2*powerProblems;
584    printf("POWER TOTAL = %d avg = %f ", powercalls, 
585         (1.0*powercalls)/powerProblems);
586 */        
587 }
589 unittest {
590     int numcalls=-4;
591     // Extremely well-behaved function.
592     real parab(real bestx) {
593         ++numcalls;
594         return 3 * (bestx-7.14L) * (bestx-7.14L) + 18;
595     }
596     real minval;
597     real minx;
598     // Note, performs extremely poorly if we have an overflow, so that the
599     // function returns infinity. It might be better to explicitly deal with 
600     // that situation (all parabolic fits will fail whenever an infinity is
601     // present).
602     minx = findMinimum(&parab, -sqrt(real.max), sqrt(real.max), 
603         cast(real)(float.max), minval);
604     assert(minval==18);
605     assert(feqrel(minx,7.14L)>=float.mant_dig);
607      // Problems from Jack Crenshaw's "World's Best Root Finder"
608     // http://www.embedded.com/columns/programmerstoolbox/9900609
609    // This has a minimum of cbrt(0.5).
610    real crenshawcos(real x) { return cos(2*PI*x*x*x); }
611    minx = findMinimum(&crenshawcos, 0.0L, 1.0L, 0.1L, minval);
612    assert(feqrel(minx*minx*minx, 0.5L)<=real.mant_dig-4);
614 }
615 }