Solving a cubic equation

后端 未结 5 1379
后悔当初
后悔当初 2020-12-16 03:24

As part of a program I\'m writing, I need to solve a cubic equation exactly (rather than using a numerical root finder):

a*x**3 + b*x**2 + c*x + d = 0.


        
相关标签:
5条回答
  • 2020-12-16 03:55

    Here, I put a cubic equation (with complex coefficients) solver.

    #include <string>
    #include <fstream>
    #include <iostream>
    #include <cstdlib>
    
    using namespace std;
    
    #define PI 3.141592
    
    long double complex_multiply_r(long double xr, long double xi, long double yr, long double yi) {
        return (xr * yr - xi * yi);
    }
    
    long double complex_multiply_i(long double xr, long double xi, long double yr, long double yi) {
        return (xr * yi + xi * yr);
    }
    
    long double complex_triple_multiply_r(long double xr, long double xi, long double yr, long double yi, long double zr, long double zi) {
        return (xr * yr * zr - xi * yi * zr - xr * yi * zi - xi * yr * zi);
    }
    
    long double complex_triple_multiply_i(long double xr, long double xi, long double yr, long double yi, long double zr, long double zi) {
        return (xr * yr * zi - xi * yi * zi + xr * yi * zr + xi * yr * zr);
    }
    
    long double complex_quadraple_multiply_r(long double xr, long double xi, long double yr, long double yi, long double zr, long double zi, long double wr, long double wi) {
        long double z1r, z1i, z2r, z2i;    
        z1r = complex_multiply_r(xr, xi, yr, yi);
        z1i = complex_multiply_i(xr, xi, yr, yi);
        z2r = complex_multiply_r(zr, zi, wr, wi);
        z2i = complex_multiply_i(zr, zi, wr, wi);
        return (complex_multiply_r(z1r, z1i, z2r, z2i));
    }
    
    long double complex_quadraple_multiply_i(long double xr, long double xi, long double yr, long double yi, long double zr, long double zi, long double wr, long double wi) {
        long double z1r, z1i, z2r, z2i;
        z1r = complex_multiply_r(xr, xi, yr, yi);
        z1i = complex_multiply_i(xr, xi, yr, yi);
        z2r = complex_multiply_r(zr, zi, wr, wi);
        z2i = complex_multiply_i(zr, zi, wr, wi);
        return (complex_multiply_i(z1r, z1i, z2r, z2i));
    }
    
    long double complex_divide_r(long double xr, long double xi, long double yr, long double yi) {
        return ((xr * yr + xi * yi) / (yr * yr + yi * yi));
    }
    
    long double complex_divide_i(long double xr, long double xi, long double yr, long double yi) {
        return ((-xr * yi + xi * yr) / (yr * yr + yi * yi));
    }
    
    long double complex_root_r(long double xr, long double xi) {
        long double r, theta;
        r = sqrt(xr*xr + xi*xi);
        if (r != 0.0) {
            if (xr >= 0 && xi >= 0) {
                theta = atan(xi / xr);
            }
            else if (xr < 0 && xi >= 0) {
                theta = PI - abs(atan(xi / xr));
            }
            else if (xr < 0 && xi < 0) {
                theta = PI + abs(atan(xi / xr));
            }
            else {
                theta = 2.0 * PI + atan(xi / xr);
            }
            return (sqrt(r) * cos(theta / 2.0));
        }
        else {
            return 0.0;
        }
    
    }    
    
    long double complex_root_i(long double xr, long double xi) {
        long double r, theta;
        r = sqrt(xr*xr + xi*xi);
        if (r != 0.0) {
            if (xr >= 0 && xi >= 0) {
                theta = atan(xi / xr);
            }
            else if (xr < 0 && xi >= 0) {
                theta = PI - abs(atan(xi / xr));
            }
            else if (xr < 0 && xi < 0) {
                theta = PI + abs(atan(xi / xr));
            }
            else {
                theta = 2.0 * PI + atan(xi / xr);
            }
            return (sqrt(r) * sin(theta / 2.0));
        }
        else {
            return 0.0;
        }
    }    
    
    long double complex_cuberoot_r(long double xr, long double xi) {
        long double r, theta;
        r = sqrt(xr*xr + xi*xi);
        if (r != 0.0) {
            if (xr >= 0 && xi >= 0) {
                theta = atan(xi / xr);
            }
            else if (xr < 0 && xi >= 0) {
                theta = PI - abs(atan(xi / xr));
            }
            else if (xr < 0 && xi < 0) {
                theta = PI + abs(atan(xi / xr));
            }
            else {
                theta = 2.0 * PI + atan(xi / xr);
            }
            return (pow(r, 1.0 / 3.0) * cos(theta / 3.0));
        }
        else {
            return 0.0;
        }
    }    
    
    long double complex_cuberoot_i(long double xr, long double xi) {
        long double r, theta;
        r = sqrt(xr*xr + xi*xi);
        if (r != 0.0) {
            if (xr >= 0 && xi >= 0) {
                theta = atan(xi / xr);
            }
            else if (xr < 0 && xi >= 0) {
                theta = PI - abs(atan(xi / xr));
            }
            else if (xr < 0 && xi < 0) {
                theta = PI + abs(atan(xi / xr));
            }
            else {
                theta = 2.0 * PI + atan(xi / xr);
            }
            return (pow(r, 1.0 / 3.0) * sin(theta / 3.0));
        }
        else {
            return 0.0;
        }
    }    
    
    void main() {
        long double a[2], b[2], c[2], d[2], minusd[2];
        long double r, theta;
        cout << "ar?";
        cin >> a[0];
        cout << "ai?";
        cin >> a[1];
        cout << "br?";
        cin >> b[0];
        cout << "bi?";
        cin >> b[1];
        cout << "cr?";
        cin >> c[0];
        cout << "ci?";
        cin >> c[1];
        cout << "dr?";
        cin >> d[0];
        cout << "di?";
        cin >> d[1];
    
        if (b[0] == 0.0 && b[1] == 0.0 && c[0] == 0.0 && c[1] == 0.0) {
            if (d[0] == 0.0 && d[1] == 0.0) {
                cout << "x1r: 0.0 \n";
                cout << "x1i: 0.0 \n";
                cout << "x2r: 0.0 \n";
                cout << "x2i: 0.0 \n";
                cout << "x3r: 0.0 \n";
                cout << "x3i: 0.0 \n";
            }
            else {
                    minusd[0] = -d[0];
                    minusd[1] = -d[1];
                    r = sqrt(minusd[0]*minusd[0] + minusd[1]*minusd[1]);
                    if (minusd[0] >= 0 && minusd[1] >= 0) {
                        theta = atan(minusd[1] / minusd[0]);
                    }
                    else if (minusd[0] < 0 && minusd[1] >= 0) {
                        theta = PI - abs(atan(minusd[1] / minusd[0]));
                    }
                    else if (minusd[0] < 0 && minusd[1] < 0) {
                        theta = PI + abs(atan(minusd[1] / minusd[0]));
                    }
                    else {
                        theta = 2.0 * PI + atan(minusd[1] / minusd[0]);
                    }
                    cout << "x1r: " << pow(r, 1.0 / 3.0) * cos(theta / 3.0) << "\n";
                    cout << "x1i: " << pow(r, 1.0 / 3.0) * sin(theta / 3.0) << "\n";
                    cout << "x2r: " << pow(r, 1.0 / 3.0) * cos((theta + 2.0 * PI) / 3.0) << "\n";
                    cout << "x2i: " << pow(r, 1.0 / 3.0) * sin((theta + 2.0 * PI) / 3.0) << "\n";
                    cout << "x3r: " << pow(r, 1.0 / 3.0) * cos((theta + 4.0 * PI) / 3.0) << "\n";
                    cout << "x3i: " << pow(r, 1.0 / 3.0) * sin((theta + 4.0 * PI) / 3.0) << "\n";
                }
            }
            else {
            // find eigenvalues
            long double term0[2], term1[2], term2[2], term3[2], term3buf[2];
            long double first[2], second[2], second2[2], third[2];
            term0[0] = -4.0 * complex_quadraple_multiply_r(a[0], a[1], c[0], c[1], c[0], c[1], c[0], c[1]);
            term0[1] = -4.0 * complex_quadraple_multiply_i(a[0], a[1], c[0], c[1], c[0], c[1], c[0], c[1]);
            term0[0] += complex_quadraple_multiply_r(b[0], b[1], b[0], b[1], c[0], c[1], c[0], c[1]);
            term0[1] += complex_quadraple_multiply_i(b[0], b[1], b[0], b[1], c[0], c[1], c[0], c[1]);
            term0[0] += -4.0 * complex_quadraple_multiply_r(b[0], b[1], b[0], b[1], b[0], b[1], d[0], d[1]);
            term0[1] += -4.0 * complex_quadraple_multiply_i(b[0], b[1], b[0], b[1], b[0], b[1], d[0], d[1]);
            term0[0] += 18.0 * complex_quadraple_multiply_r(a[0], a[1], b[0], b[1], c[0], c[1], d[0], d[1]);
            term0[1] += 18.0 * complex_quadraple_multiply_i(a[0], a[1], b[0], b[1], c[0], c[1], d[0], d[1]);
            term0[0] += -27.0 * complex_quadraple_multiply_r(a[0], a[1], a[0], a[1], d[0], d[1], d[0], d[1]);
            term0[1] += -27.0 * complex_quadraple_multiply_i(a[0], a[1], a[0], a[1], d[0], d[1], d[0], d[1]);
            term1[0] = -27.0 * complex_triple_multiply_r(a[0], a[1], a[0], a[1], d[0], d[1]);
            term1[1] = -27.0 * complex_triple_multiply_i(a[0], a[1], a[0], a[1], d[0], d[1]);
            term1[0] += 9.0 * complex_triple_multiply_r(a[0], a[1], b[0], b[1], c[0], c[1]);
            term1[1] += 9.0 * complex_triple_multiply_i(a[0], a[1], b[0], b[1], c[0], c[1]);
            term1[0] -= 2.0 * complex_triple_multiply_r(b[0], b[1], b[0], b[1], b[0], b[1]);
            term1[1] -= 2.0 * complex_triple_multiply_i(b[0], b[1], b[0], b[1], b[0], b[1]);
            term2[0] = 3.0 * complex_multiply_r(a[0], a[1], c[0], c[1]);
            term2[1] = 3.0 * complex_multiply_i(a[0], a[1], c[0], c[1]);
            term2[0] -= complex_multiply_r(b[0], b[1], b[0], b[1]);
            term2[1] -= complex_multiply_i(b[0], b[1], b[0], b[1]);
            term3[0] = complex_multiply_r(term1[0], term1[1], term1[0], term1[1]);
            term3[1] = complex_multiply_i(term1[0], term1[1], term1[0], term1[1]);
            term3[0] += 4.0 * complex_triple_multiply_r(term2[0], term2[1], term2[0], term2[1], term2[0], term2[1]);
            term3[1] += 4.0 * complex_triple_multiply_i(term2[0], term2[1], term2[0], term2[1], term2[0], term2[1]);
            term3buf[0] = term3[0];
            term3buf[1] = term3[1];
            term3[0] = complex_root_r(term3buf[0], term3buf[1]);
            term3[1] = complex_root_i(term3buf[0], term3buf[1]);
    
            if (term0[0] == 0.0 && term0[1] == 0.0 && term1[0] == 0.0 && term1[1] == 0.0) {
                cout << "x1r: " << -pow(d[0], 1.0 / 3.0) << "\n";
                cout << "x1i: " << 0.0 << "\n";
                cout << "x2r: " << -pow(d[0], 1.0 / 3.0) << "\n";
                cout << "x2i: " << 0.0 << "\n";
                cout << "x3r: " << -pow(d[0], 1.0 / 3.0) << "\n";
                cout << "x3i: " << 0.0 << "\n";
            }
            else {
                // eigenvalue1
                first[0] = complex_divide_r(complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1]), 3.0 * pow(2.0, 1.0 / 3.0) * a[0], 3.0 * pow(2.0, 1.0 / 3.0) * a[1]);
                first[1] = complex_divide_i(complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1]), 3.0 * pow(2.0, 1.0 / 3.0) * a[0], 3.0 * pow(2.0, 1.0 / 3.0) * a[1]);
                second[0] = complex_divide_r(pow(2.0, 1.0 / 3.0) * term2[0], pow(2.0, 1.0 / 3.0) * term2[1], 3.0 * complex_multiply_r(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 3.0 * complex_multiply_i(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])));
                second[1] = complex_divide_i(pow(2.0, 1.0 / 3.0) * term2[0], pow(2.0, 1.0 / 3.0) * term2[1], 3.0 * complex_multiply_r(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 3.0 * complex_multiply_i(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])));
                third[0] = complex_divide_r(b[0], b[1], 3.0 * a[0], 3.0 * a[1]);
                third[1] = complex_divide_i(b[0], b[1], 3.0 * a[0], 3.0 * a[1]);
                cout << "x1r: " << first[0] - second[0] - third[0] << "\n";
                cout << "x1i: " << first[1] - second[1] - third[1] << "\n";
    
                // eigenvalue2
                first[0] = complex_divide_r(complex_multiply_r(1.0, -sqrt(3.0), complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), complex_multiply_i(1.0, -sqrt(3.0), complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 6.0 * pow(2.0, 1.0 / 3.0) * a[0], 6.0 * pow(2.0, 1.0 / 3.0) * a[1]);
                first[1] = complex_divide_i(complex_multiply_r(1.0, -sqrt(3.0), complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), complex_multiply_i(1.0, -sqrt(3.0), complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 6.0 * pow(2.0, 1.0 / 3.0) * a[0], 6.0 * pow(2.0, 1.0 / 3.0) * a[1]);
                second[0] = complex_divide_r(complex_multiply_r(1.0, sqrt(3.0), term2[0], term2[1]), complex_multiply_i(1.0, sqrt(3.0), term2[0], term2[1]), 3.0 * pow(2.0, 2.0 / 3.0) * complex_multiply_r(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 3.0 * pow(2.0, 2.0 / 3.0) * complex_multiply_i(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])));
                second[1] = complex_divide_i(complex_multiply_r(1.0, sqrt(3.0), term2[0], term2[1]), complex_multiply_i(1.0, sqrt(3.0), term2[0], term2[1]), 3.0 * pow(2.0, 2.0 / 3.0) * complex_multiply_r(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 3.0 * pow(2.0, 2.0 / 3.0) * complex_multiply_i(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])));
                third[0] = complex_divide_r(b[0], b[1], 3.0 * a[0], 3.0 * a[1]);
                third[1] = complex_divide_i(b[0], b[1], 3.0 * a[0], 3.0 * a[1]);
                cout << "x2r: " << -first[0] + second[0] - third[0] << "\n";
                cout << "x2i: " << -first[1] + second[1] - third[1] << "\n";
    
                // eigenvalue3
                first[0] = complex_divide_r(complex_multiply_r(1.0, sqrt(3.0), complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), complex_multiply_i(1.0, sqrt(3.0), complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 6.0 * pow(2.0, 1.0 / 3.0) * a[0], 6.0 * pow(2.0, 1.0 / 3.0) * a[1]);
                first[1] = complex_divide_i(complex_multiply_r(1.0, sqrt(3.0), complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), complex_multiply_i(1.0, sqrt(3.0), complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 6.0 * pow(2.0, 1.0 / 3.0) * a[0], 6.0 * pow(2.0, 1.0 / 3.0) * a[1]);
                second[0] = complex_divide_r(complex_multiply_r(1.0, -sqrt(3.0), term2[0], term2[1]), complex_multiply_i(1.0, -sqrt(3.0), term2[0], term2[1]), 3.0 * pow(2.0, 2.0 / 3.0) * complex_multiply_r(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 3.0 * pow(2.0, 2.0 / 3.0) * complex_multiply_i(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])));
                second[1] = complex_divide_i(complex_multiply_r(1.0, -sqrt(3.0), term2[0], term2[1]), complex_multiply_i(1.0, -sqrt(3.0), term2[0], term2[1]), 3.0 * pow(2.0, 2.0 / 3.0) * complex_multiply_r(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])), 3.0 * pow(2.0, 2.0 / 3.0) * complex_multiply_i(a[0], a[1], complex_cuberoot_r(term3[0] + term1[0], term3[1] + term1[1]), complex_cuberoot_i(term3[0] + term1[0], term3[1] + term1[1])));
                third[0] = complex_divide_r(b[0], b[1], 3.0 * a[0], 3.0 * a[1]);
                third[1] = complex_divide_i(b[0], b[1], 3.0 * a[0], 3.0 * a[1]);
                cout << "x3r: " << -first[0] + second[0] - third[0] << "\n";
                cout << "x3i: " << -first[1] + second[1] - third[1] << "\n";
            }
        }
    
        int end;
        cin >> end;
    }
    
    0 讨论(0)
  • 2020-12-16 03:59

    Wikipedia's notation (rho^(1/3), theta/3) does not mean that rho^(1/3) is the real part and theta/3 is the imaginary part. Rather, this is in polar coordinates. Thus, if you want the real part, you would take rho^(1/3) * cos(theta/3).

    I made these changes to your code and it worked for me:

    theta = arccos(r/rho)
    s_real = rho**(1./3.) * cos( theta/3)
    t_real = rho**(1./3.) * cos(-theta/3)
    

    (Of course, s_real = t_real here because cos is even.)

    0 讨论(0)
  • 2020-12-16 04:15

    I've looked at the Wikipedia article and your program.

    I also solved the equation using Wolfram Alpha and the results there don't match what you get.

    I'd just go through your program at each step, use a lot of print statements, and get each intermediate result. Then go through with a calculator and do it yourself.

    I can't find what's happening, but where your hand calculations and the program diverge is a good place to look.

    0 讨论(0)
  • 2020-12-16 04:19

    In case someone needs C++ code, you can use this piece of OpenCV:

    https://github.com/opencv/opencv/blob/master/modules/calib3d/src/polynom_solver.cpp

    0 讨论(0)
  • 2020-12-16 04:22

    Here's A. Rex's solution in JavaScript:

    a =  1.0;
    b =  0.0;
    c =  0.2 - 1.0;
    d = -0.7 * 0.2;
    
    q = (3*a*c - Math.pow(b, 2)) / (9 * Math.pow(a, 2));
    r = (9*a*b*c - 27*Math.pow(a, 2)*d - 2*Math.pow(b, 3)) / (54*Math.pow(a, 3));
    console.log("q = "+q);
    console.log("r = "+r);
    
    delta = Math.pow(q, 3) + Math.pow(r, 2);
    console.log("delta = "+delta);
    
    // here delta is less than zero so we use the second set of equations from the article:
    rho = Math.pow((-Math.pow(q, 3)), 0.5);
    theta = Math.acos(r/rho);
    
    // For x1 the imaginary part is unimportant since it cancels out
    s_real = Math.pow(rho, (1./3.)) * Math.cos( theta/3);
    t_real = Math.pow(rho, (1./3.)) * Math.cos(-theta/3);
    
    console.log("s [real] = "+s_real);
    console.log("t [real] = "+t_real);
    
    x1 = s_real + t_real - b / (3. * a);
    
    console.log("x1 = "+x1);
    console.log("should be zero: "+(a*Math.pow(x1, 3)+b*Math.pow(x1, 2)+c*x1+d));
    
    0 讨论(0)
提交回复
热议问题