Thrust reduce not working with non equal input/output types

前端 未结 1 2052
遥遥无期
遥遥无期 2021-01-24 23:50

I\'m attempting to reduce the min and max of an array of values using Thrust and I seem to be stuck. Given an array of floats what I would like is to reduce their min and max va

相关标签:
1条回答
  • 2021-01-25 00:39

    As talonmies notes, your reduction does not compile because thrust::reduce expects the binary operator's argument types to match its result type, but ReduceMinMax's argument type is float, while its result type is float2.

    thrust::minmax_element implements this operation directly, but if necessary you could instead implement your reduction with thrust::inner_product, which generalizes thrust::reduce:

    #include <thrust/inner_product.h>
    #include <thrust/device_vector.h>
    #include <thrust/extrema.h>
    #include <cassert>
    
    struct minmax_float
    {
      __host__ __device__
      float2 operator()(float lhs, float rhs)
      {
        return make_float2(thrust::min(lhs, rhs), thrust::max(lhs, rhs));
      }
    };
    
    struct minmax_float2
    {
      __host__ __device__
      float2 operator()(float2 lhs, float2 rhs)
      {
        return make_float2(thrust::min(lhs.x, rhs.x), thrust::max(lhs.y, rhs.y));
      }
    };
    
    float2 minmax1(const thrust::device_vector<float> &x)
    {
      return thrust::inner_product(x.begin(), x.end(), x.begin(), make_float2(4.0, 4.0f), minmax_float2(), minmax_float());
    }
    
    float2 minmax2(const thrust::device_vector<float> &x)
    {
      using namespace thrust;
      pair<device_vector<float>::const_iterator, device_vector<float>::const_iterator> ptr_to_result;
    
      ptr_to_result = minmax_element(x.begin(), x.end());
    
      return make_float2(*ptr_to_result.first, *ptr_to_result.second);
    }
    
    int main()
    {
      thrust::device_vector<float> hat(4);
      hat[0] = 3;
      hat[1] = 5;
      hat[2] = 6;
      hat[3] = 1;
    
      float2 result1 = minmax1(hat);
      float2 result2 = minmax2(hat);
    
      assert(result1.x == result2.x);
      assert(result1.y == result2.y);
    }
    
    0 讨论(0)
提交回复
热议问题