How to enumerate x^2 + y^2 = z^2 - 1 (with additional constraints)

前端 未结 6 2087
野性不改
野性不改 2021-02-04 09:33

Lets N be a number (10<=N<=10^5).

I have to break it into 3 numbers (x,y,z) such that it validates the following conditions.

相关标签:
6条回答
  • 2021-02-04 10:09

    The bounds of x and y are an important part of the problem. I personally went with this Wolfram Alpha query and checked the exact forms of the variables.

    Thanks to @Bleep-Bloop and comments, a very elegant bound optimization was found, which is x < n and x <= y < n - x. The results are the same and the times are nearly identical.

    Also, since the only possible values for x and y are positive even integers, we can reduce the amount of loop iterations by half.

    To optimize even further, since we compute the upper bound of x, we build a list of all possible values for x and make the computation parallel. That saves a massive amount of time on higher values of N but it's a bit slower for smaller values because of the overhead of the parallelization.

    Here's the final code:

    Non-parallel version, with int values:

    List<string> res = new List<string>();
    int n2 = n * n;
    
    double maxX = 0.5 * (2.0 * n - Math.Sqrt(2) * Math.Sqrt(n2 + 1));
    
    for (int x = 2; x < maxX; x += 2)
    {
        int maxY = (int)Math.Floor((n2 - 2.0 * n * x - 1.0) / (2.0 * n - 2.0 * x));
    
        for (int y = x; y <= maxY; y += 2)
        {
            int z2 = x * x + y * y + 1;
            int z = (int)Math.Sqrt(z2);
    
            if (z * z == z2 && x + y + z <= n)
                res.Add(x + "," + y + "," + z);
        }
    }
    

    Parallel version, with long values:

    using System.Linq;
    
    ...
    
    // Use ConcurrentBag for thread safety
    ConcurrentBag<string> res = new ConcurrentBag<string>();
    long n2 = n * n;
    
    double maxX = 0.5 * (2.0 * n - Math.Sqrt(2) * Math.Sqrt(n2 + 1L));
    
    // Build list to parallelize
    int nbX = Convert.ToInt32(maxX);
    List<int> xList = new List<int>();
    for (int x = 2; x < maxX; x += 2)
        xList.Add(x);
    
    Parallel.ForEach(xList, x =>
    {
        int maxY = (int)Math.Floor((n2 - 2.0 * n * x - 1.0) / (2.0 * n - 2.0 * x));
    
        for (long y = x; y <= maxY; y += 2)
        {
            long z2 = x * x + y * y + 1L;
            long z = (long)Math.Sqrt(z2);
    
            if (z * z == z2 && x + y + z <= n)
                res.Add(x + "," + y + "," + z);
        }
    });
    

    When ran individually on a i5-8400 CPU, I get these results:

    N: 10; Solutions: 1; Time elapsed: 0.03 ms (Not parallel, int)

    N: 100; Solutions: 6; Time elapsed: 0.05 ms (Not parallel, int)

    N: 1000; Solutions: 55; Time elapsed: 0.3 ms (Not parallel, int)

    N: 10000; Solutions: 543; Time elapsed: 13.1 ms (Not parallel, int)

    N: 100000; Solutions: 5512; Time elapsed: 849.4 ms (Parallel, long)


    You must use long when N is greater than 36340, because when it's squared, it overflows an int's max value. Finally, the parallel version starts to get better than the simple one when N is around 23000, with ints.

    0 讨论(0)
  • 2021-02-04 10:13
    #include<iostream>
    #include<math.h>
    int main()
    {
        int N = 10000;
        int c = 0;
        for (int x = 2; x < N; x+=2)
        {
            for (int y = x; y < (N - x); y+=2)
            {
                auto z = sqrt(x * x + y * y + 1);
                if(x+y+z>N){
                    break;
                }
                if (z - (int) z == 0)
                {
                    c++;
                }
            }
        }
        std::cout<<c;
    }
    

    This is my solution. On testing the previous solutions for this problem I found that x,y are always even and z is odd. I dont know the mathematical nature behind this, I am currently trying to figure that out.

    0 讨论(0)
  • 2021-02-04 10:15

    No time to properly test it, but seemed to yield the same results as your code (at 100 -> 6 results and at 1000 -> 55 results).

    With N=1000 a time of 2ms vs your 144ms also without List

    and N=10000 a time of 28ms

    var N = 1000;
    var c = 0;
    
    for (int x = 2; x < N; x+=2)
    {
        for (int y = x; y < (N - x); y+=2)
        {
            long z2 = x * x + y * y + 1;
            int z = (int) Math.Sqrt(z2);
            if (x + y + z > N)
                break;
            if (z * z == z2)
                c++;
        }
    }
    
    Console.WriteLine(c);
    
    0 讨论(0)
  • 2021-02-04 10:16

    Here is a simple improvement in Python (converting to the faster equivalent in C-based code is left as an exercise for the reader). To get accurate timing for the computation, I removed printing the solutions themselves (after validating them in a previous run).

    • Use an outer loop for one free variable (I chose z), constrained only by its relation to N.
    • Use an inner loop (I chose y) constrained by the outer loop index.
    • The third variable is directly computed per requirement 2.

    Timing results:

    -------------------- 10 
     1 solutions found in 2.3365020751953125e-05  sec.
    -------------------- 100 
     6 solutions found in 0.00040078163146972656  sec.
    -------------------- 1000 
     55 solutions found in 0.030081748962402344  sec.
    -------------------- 10000 
     543 solutions found in 2.2078349590301514  sec.
    -------------------- 100000 
     5512 solutions found in 214.93411707878113  sec.
    

    That's 3:35 for the large case, plus your time to collate and/or print the results.

    If you need faster code (this is still pretty brute-force), look into Diophantine equations and parameterizations to generate (y, x) pairs, given the target value of z^2 - 1.

    import math
    import time
    
    def break3(N):
        """
        10 <= N <= 10^5
        return x, y, z triples such that:
            x <= y <= z
            x^2 + y^2 = z^2 - 1        
            x + y + z <= N
        """
    
        """
        Observations:
        z <= x + y
        z < N/2
        """
    
        count = 0
        z_limit = N // 2
        for z in range(3, z_limit):
    
            # Since y >= x, there's a lower bound on y
            target = z*z - 1
            ymin = int(math.sqrt(target/2))
            for y in range(ymin, z):
                # Given y and z, compute x.
                # That's a solution iff x is integer.
                x_target = target - y*y
                x = int(math.sqrt(x_target))
                if x*x == x_target and x+y+z <= N:
                    # print("solution", x, y, z)
                    count += 1
    
        return count
    
    
    test = [10, 100, 1000, 10**4, 10**5]
    border = "-"*20
    
    for case in test: 
        print(border, case)
        start = time.time()
        print(break3(case), "solutions found in", time.time() - start, "sec.")
    
    0 讨论(0)
  • 2021-02-04 10:22

    Here's a method that enumerates the triples, rather than exhaustively testing for them, using number theory as described here: https://mathoverflow.net/questions/29644/enumerating-ways-to-decompose-an-integer-into-the-sum-of-two-squares

    Since the math took me a while to comprehend and a while to implement (gathering some code that's credited above it), and since I don't feel much of an authority on the subject, I'll leave it for the reader to research. This is based on expressing numbers as Gaussian integer conjugates. (a + bi)*(a - bi) = a^2 + b^2. We first factor the number, z^2 - 1, into primes, decompose the primes into Gaussian conjugates and find different expressions that we expand and simplify to get a + bi, which can be then raised, a^2 + b^2.

    A perk of reading about the Sum of Squares Function is discovering that we can rule out any candidate z^2 - 1 that contains a prime of form 4k + 3 with an odd power. Using that check alone, I was able to reduce Prune's loop on 10^5 from 214 seconds to 19 seconds (on repl.it) using the Rosetta prime factoring code below.

    The implementation here is just a demonstration. It does not have handling or optimisation for limiting x and y. Rather, it just enumerates as it goes. Play with it here.

    Python code:

    # https://math.stackexchange.com/questions/5877/efficiently-finding-two-squares-which-sum-to-a-prime
    def mods(a, n):
        if n <= 0:
            return "negative modulus"
        a = a % n
        if (2 * a > n):
            a -= n
        return a
    
    def powmods(a, r, n):
        out = 1
        while r > 0:
            if (r % 2) == 1:
                r -= 1
                out = mods(out * a, n)
            r /= 2
            a = mods(a * a, n)
        return out
    
    def quos(a, n):
        if n <= 0:
            return "negative modulus"
        return (a - mods(a, n))/n
    
    def grem(w, z):
        # remainder in Gaussian integers when dividing w by z
        (w0, w1) = w
        (z0, z1) = z
        n = z0 * z0 + z1 * z1
        if n == 0:
            return "division by zero"
        u0 = quos(w0 * z0 + w1 * z1, n)
        u1 = quos(w1 * z0 - w0 * z1, n)
        return(w0 - z0 * u0 + z1 * u1,
               w1 - z0 * u1 - z1 * u0)
    
    def ggcd(w, z):
        while z != (0,0):
            w, z = z, grem(w, z)
        return w
    
    def root4(p):
        # 4th root of 1 modulo p
        if p <= 1:
            return "too small"
        if (p % 4) != 1:
            return "not congruent to 1"
        k = p/4
        j = 2
        while True:
            a = powmods(j, k, p)
            b = mods(a * a, p)
            if b == -1:
                return a
            if b != 1:
                return "not prime"
            j += 1
    
    def sq2(p):
        if p % 4 != 1:
          return "not congruent to 1 modulo 4"
        a = root4(p)
        return ggcd((p,0),(a,1))
    
    # https://rosettacode.org/wiki/Prime_decomposition#Python:_Using_floating_point
    from math import floor, sqrt
    
    def fac(n):
        step = lambda x: 1 + (x<<2) - ((x>>1)<<1)
        maxq = long(floor(sqrt(n)))
        d = 1
        q = n % 2 == 0 and 2 or 3 
        while q <= maxq and n % q != 0:
            q = step(d)
            d += 1
        return q <= maxq and [q] + fac(n//q) or [n]
    
    # My code...
    # An answer for  https://stackoverflow.com/questions/54110614/
    
    from collections import Counter
    from itertools import product
    from sympy import I, expand, Add
    
    def valid(ps):
      for (p, e) in ps.items():
        if (p % 4 == 3) and (e & 1):
          return False
      return True
    
    def get_sq2(p, e):
      if p == 2:
        if e & 1:
          return [2**(e / 2), 2**(e / 2)]
        else:
          return [2**(e / 2), 0]
      elif p % 4 == 3:
        return [p, 0]
      else:
        a,b = sq2(p)
        return [abs(a), abs(b)]
    
    def get_terms(cs, e):
      if e == 1:
        return [Add(cs[0], cs[1] * I)]
      res = [Add(cs[0], cs[1] * I)**e]
      for t in xrange(1, e / 2 + 1):
        res.append(
          Add(cs[0] + cs[1]*I)**(e-t) * Add(cs[0] - cs[1]*I)**t)
      return res
    
    def get_lists(ps):
      items = ps.items()
      lists = []
      for (p, e) in items:
        if p == 2:
          a,b = get_sq2(2, e)
          lists.append([Add(a, b*I)])
        elif p % 4 == 3:
          a,b = get_sq2(p, e)
          lists.append([Add(a, b*I)**(e / 2)])
        else:
          lists.append(get_terms(get_sq2(p, e), e))
      return lists
    
    
    def f(n):
      for z in xrange(2, n / 2):
        zz = (z + 1) * (z - 1)
        ps = Counter(fac(zz))
        is_valid = valid(ps)
        if is_valid:
          print "valid (does not contain a prime of form\n4k + 3 with an odd power)"
          print "z: %s, primes: %s" % (z, dict(ps))
          lists = get_lists(ps)
          cartesian = product(*lists)
          for element in cartesian:
            print "prime square decomposition: %s" % list(element)
            p = 1
            for item in element:
              p *= item
            print "complex conjugates: %s" % p
            vals = p.expand(complex=True, evaluate=True).as_coefficients_dict().values()
            x, y = vals[0], vals[1] if len(vals) > 1 else 0
            print "x, y, z: %s, %s, %s" % (x, y, z)
            print "x^2 + y^2, z^2-1: %s, %s" % (x**2 + y**2, z**2 - 1)
          print ''
    
    if __name__ == "__main__":
      print f(100)
    

    Output:

    valid (does not contain a prime of form
    4k + 3 with an odd power)
    z: 3, primes: {2: 3}
    prime square decomposition: [2 + 2*I]
    complex conjugates: 2 + 2*I
    x, y, z: 2, 2, 3
    x^2 + y^2, z^2-1: 8, 8
    
    valid (does not contain a prime of form
    4k + 3 with an odd power)
    z: 9, primes: {2: 4, 5: 1}
    prime square decomposition: [4, 2 + I]
    complex conjugates: 8 + 4*I
    x, y, z: 8, 4, 9
    x^2 + y^2, z^2-1: 80, 80
    
    valid (does not contain a prime of form
    4k + 3 with an odd power)
    z: 17, primes: {2: 5, 3: 2}
    prime square decomposition: [4 + 4*I, 3]
    complex conjugates: 12 + 12*I
    x, y, z: 12, 12, 17
    x^2 + y^2, z^2-1: 288, 288
    
    valid (does not contain a prime of form
    4k + 3 with an odd power)
    z: 19, primes: {2: 3, 3: 2, 5: 1}
    prime square decomposition: [2 + 2*I, 3, 2 + I]
    complex conjugates: (2 + I)*(6 + 6*I)
    x, y, z: 6, 18, 19
    x^2 + y^2, z^2-1: 360, 360
    
    valid (does not contain a prime of form
    4k + 3 with an odd power)
    z: 33, primes: {17: 1, 2: 6}
    prime square decomposition: [4 + I, 8]
    complex conjugates: 32 + 8*I
    x, y, z: 32, 8, 33
    x^2 + y^2, z^2-1: 1088, 1088
    
    valid (does not contain a prime of form
    4k + 3 with an odd power)
    z: 35, primes: {17: 1, 2: 3, 3: 2}
    prime square decomposition: [4 + I, 2 + 2*I, 3]
    complex conjugates: 3*(2 + 2*I)*(4 + I)
    x, y, z: 18, 30, 35
    x^2 + y^2, z^2-1: 1224, 1224
    
    0 讨论(0)
  • 2021-02-04 10:25

    I want to get it done in C# and it should be covering all the test cases based on condition provided in the question.

    The basic code, converted to long to process the N <= 100000 upper limit, with every optimizaion thrown in I could. I used alternate forms from @Mat's (+1) Wolfram Alpha query to precompute as much as possible. I also did a minimal perfect square test to avoid millions of sqrt() calls at the upper limit:

    public static void Main()
    {
        int c = 0;
    
        long N = long.Parse(Console.ReadLine());
        long N_squared = N * N;
    
        double half_N_squared = N_squared / 2.0 - 0.5;
        double x_limit = N - Math.Sqrt(2) / 2.0 * Math.Sqrt(N_squared + 1);
    
        for (long x = 2; x < x_limit; x += 2)
        {
            long x_squared = x * x + 1;
    
            double y_limit = (half_N_squared - N * x) / (N - x);
    
            for (long y = x; y < y_limit; y += 2)
            {
                long z_squared = x_squared + y * y;
                int digit = (int) z_squared % 10;
    
                if (digit == 3 || digit == 7)
                {
                    continue;  // minimalist non-perfect square elimination
                }
    
                long z = (long) Math.Sqrt(z_squared);
    
                if (z * z == z_squared)
                {
                    c++;
                }
            }
        }
    
        Console.WriteLine(c);
    }
    

    I followed the trend and left out "the degenerate solution" as implied by the OP's code but not explicitly stated.

    0 讨论(0)
提交回复
热议问题