How to get raw pointer from cusp library matrix format

后端 未结 1 1053
说谎
说谎 2021-01-27 13:23

I need to get raw pointer from cusp library matrix format. For example:

cusp::coo_matrix A(3,3,4);

A.values[0] = 1;
A.row_         


        
相关标签:
1条回答
  • 2021-01-27 13:56

    There are multiple ways to accomplish this. For example, if the you wish to start with the raw device data representation instead of the cusp data representation, you could use the methodology in the cusp views functionality.

    If you have cusp data already, and you want to convert to raw data representation, we can use the fact that cusp is built on top of thrust. Here's a fully worked example:

    $ cat t346.cu
    #include <cusp/coo_matrix.h>
    #include <cusp/print.h>
    
    template <typename T>
    __global__ void my_swap_kernel(T *a, T *b, unsigned size){
      int idx = threadIdx.x+blockDim.x*blockIdx.x;
      if (idx < size){
        T temp = b[idx];
        b[idx] = a[idx];
        a[idx] = temp;}
    }
    
    
    
    int main(void)
    {
        // allocate storage for (4,3) matrix with 6 nonzeros
        cusp::coo_matrix<int,float,cusp::device_memory> A(4,3,6);
    
        // initialize matrix entries on host
        A.row_indices[0] = 0; A.column_indices[0] = 0; A.values[0] = 10;
        A.row_indices[1] = 0; A.column_indices[1] = 2; A.values[1] = 20;
        A.row_indices[2] = 2; A.column_indices[2] = 2; A.values[2] = 30;
        A.row_indices[3] = 3; A.column_indices[3] = 0; A.values[3] = 40;
        A.row_indices[4] = 3; A.column_indices[4] = 1; A.values[4] = 50;
        A.row_indices[5] = 3; A.column_indices[5] = 2; A.values[5] = 60;
        float *val0 = thrust::raw_pointer_cast(&A.values[0]);
        float *val3 = thrust::raw_pointer_cast(&A.values[3]);
    
        // A now represents the following matrix
        //    [10  0 20]
        //    [ 0  0  0]
        //    [ 0  0 30]
        //    [40 50 60]
    
        // print matrix entries
        cusp::print(A);
        my_swap_kernel<<<1,3>>>(val0, val3, 3);
        cusp::print(A);
    
        return 0;
    }
    
    $ nvcc -arch=sm_20 -o t346 t346.cu
    $ cuda-memcheck ./t346
    ========= CUDA-MEMCHECK
    sparse matrix <4, 3> with 6 entries
                  0              0             10
                  0              2             20
                  2              2             30
                  3              0             40
                  3              1             50
                  3              2             60
    sparse matrix <4, 3> with 6 entries
                  0              0             40
                  0              2             50
                  2              2             60
                  3              0             10
                  3              1             20
                  3              2             30
    ========= ERROR SUMMARY: 0 errors
    $
    
    0 讨论(0)
提交回复
热议问题