Understanding the while loop in Tensorflow

前端 未结 1 1792
挽巷
挽巷 2021-02-03 12:42

I am using the Python API for Tensorflow. I am trying to implement the Rosenbrock function given below without the use of a Python loop:

My current implementati

相关标签:
1条回答
  • 2021-02-03 13:02

    This can be achieved using the tf.while_loop() and standard tuples as per the second example in the documentation.

    def rosenbrock(data_tensor):
        columns = tf.unstack(data_tensor)
    
        # Track both the loop index and summation in a tuple in the form (index, summation)
        index_summation = (tf.constant(1), tf.constant(0.0))
    
        # The loop condition, note the loop condition is 'i < n-1'
        def condition(index, summation):
            return tf.less(index, tf.subtract(tf.shape(columns)[0], 1))
    
        # The loop body, this will return a result tuple in the same form (index, summation)
        def body(index, summation):
            x_i = tf.gather(columns, index)
            x_ip1 = tf.gather(columns, tf.add(index, 1))
    
            first_term = tf.square(tf.subtract(x_ip1, tf.square(x_i)))
            second_term = tf.square(tf.subtract(x_i, 1.0))
            summand = tf.add(tf.multiply(100.0, first_term), second_term)
    
            return tf.add(index, 1), tf.add(summation, summand)
    
        # We do not care about the index value here, return only the summation
        return tf.while_loop(condition, body, index_summation)[1]
    

    It is important to note that the index increment should occur in the body of the loop similar to a standard while loop. In the solution given, it is the first item in the tuple returned by the body() function.

    Additionally, the loop condition function must allot a parameter for the summation although it is not used in this particular example.

    0 讨论(0)
提交回复
热议问题