Soft Thresholding CUDA implementation

前端 未结 1 1143
日久生厌
日久生厌 2021-01-16 04:42

I am wondering how should I implement a nice soft thresholding function kernel in CUDA? The soft thresholding function is like following:

1条回答
  •  抹茶落季
    2021-01-16 05:38

    The two solutions proposed in the comments above, set up for an elementwise processing, are the following:

    __global__ void myKernel1(float* __restrict__ x, float lambda, const int N)
    {
        int tid = threadIdx.x + blockIdx.x * blockDim.x;
    
        if (tid < N) {
            float xa = fabs(x[tid]); 
            x[tid] = (xa > lambda) ? x[tid] * ((xa - lambda) / xa) : 0;
        }
    
    }
    

    and

    __global__ void myKernel2(float* __restrict__ x, float lambda, const int N)
    {
        int tid = threadIdx.x + blockIdx.x * blockDim.x;
    
        if (tid < N) {
            float xa = fabs(x[tid]); 
            x[tid] = signbit(lambda-xa)*copysign(xa-lambda,x[tid]);
        }
    
    }
    

    The disassembled codes for the two solutions are reported below. As noticed also by @njuffa, the second one seems to be in principle less burdened than the first one due to the lacking x/|x| division. However, as also noticed by @njuffa, this scenario will be likely memory bound rather than compute bound. However, perhaps this analysis gives indication that the second solution is preferrable when implemented as __device__ functions for non-elementwise computationS.

    DISASSEMBLED CODE FOR THE FIRST SOLUTION

    code for sm_21
        Function : _Z9myKernel1Pffi
    .headerflags    @"EF_CUDA_SM20 EF_CUDA_PTX_SM(EF_CUDA_SM20)"
        /*0000*/        MOV R1, c[0x1][0x100];                       /* 0x2800440400005de4 */
        /*0008*/        S2R R0, SR_CTAID.X;                          /* 0x2c00000094001c04 */
        /*0010*/        S2R R3, SR_TID.X;                            /* 0x2c0000008400dc04 */
        /*0018*/        IMAD R0, R0, c[0x0][0x8], R3;                /* 0x2006400020001ca3 */
        /*0020*/        ISETP.GE.AND P0, PT, R0, c[0x0][0x2c], PT;   /* 0x1b0e4000b001dc23 */
        /*0028*/    @P0 EXIT ;                                       /* 0x80000000000001e7 */
        /*0030*/        MOV32I R3, 0x4;                              /* 0x180000001000dde2 */
        /*0038*/        SSY 0x90;                                    /* 0x6000000140000007 */
        /*0040*/        IMAD R16.CC, R0, R3, c[0x0][0x20];           /* 0x2007800080041ca3 */
        /*0048*/        IMAD.HI.X R17, R0, R3, c[0x0][0x24];         /* 0x2086800090045ce3 */
        /*0050*/        LD.E R2, [R16];                              /* 0x8400000001009c85 */
        /*0058*/        FSETP.GT.AND P0, PT, |R2|, c[0x0][0x28], PT; /* 0x220e4000a021dc80 */
        /*0060*/        F2F.F32.F32 R5, |R2|;                        /* 0x1000000009215c44 */
        /*0068*/    @P0 BRA 0x78;                                    /* 0x40000000200001e7 */
        /*0070*/        MOV.S R0, RZ;                                /* 0x28000000fc001df4 */
        /*0078*/        FADD R4, |R2|, -c[0x0][0x28];                /* 0x50004000a0211d80 */
        /*0080*/        JCAL 0x0;                                    /* 0x1000000000010007 */
        /*0088*/        FMUL.S R0, R2, R4;                           /* 0x5800000010201c10 */
        /*0090*/        ST.E [R16], R0;                              /* 0x9400000001001c85 */
        /*0098*/        EXIT ;                                       /* 0x8000000000001de7 */
        .................................
    
    
        Function : __cuda_sm20_div_rn_noftz_f32_slowpath
    .headerflags    @"EF_CUDA_SM20 EF_CUDA_PTX_SM(EF_CUDA_SM20)"
        /*0000*/        SHL R0, R4, 0x1;                                   /* 0x6000c00004401c03 */
        /*0008*/        MOV32I R6, 0x1;                                    /* 0x1800000004019de2 */
        /*0010*/        SHL R3, R5, 0x1;                                   /* 0x6000c0000450dc03 */
        /*0018*/        IMAD.U32.U32.HI R0, R0, 0x100, -R6;                /* 0x200cc00400001d43 */
        /*0020*/        ISETP.GT.U32.AND P0, PT, R0, 0xfd, PT;             /* 0x1a0ec003f401dc03 */
        /*0028*/        IMAD.U32.U32.HI R3, R3, 0x100, -R6;                /* 0x200cc0040030dd43 */
        /*0030*/        ISETP.GT.U32.OR P0, PT, R3, 0xfd, P0;              /* 0x1a20c003f431dc03 */
        /*0038*/   @!P0 BRA 0x178;                                         /* 0x40000004e00021e7 */
        /*0040*/        FSETP.LE.AND P0, PT, |R4|, +INF , PT;              /* 0x218edfe00041dc80 */
        /*0048*/   @!P0 BRA 0x60;                                          /* 0x40000000400021e7 */
        /*0050*/        FSETP.LE.AND P0, PT, |R5|, +INF , PT;              /* 0x218edfe00051dc80 */
        /*0058*/    @P0 BRA 0x70;                                          /* 0x40000000400001e7 */
        /*0060*/        FADD R4, R4, R5;                                   /* 0x5000000014411c00 */
        /*0068*/        BRA 0x370;                                         /* 0x4000000c00001de7 */
        /*0070*/        SHL R7, R5, 0x1;                                   /* 0x6000c0000451dc03 */
        /*0078*/        SHL R6, R4, 0x1;                                   /* 0x6000c00004419c03 */
        /*0080*/        ISETP.EQ.U32.AND P2, PT, R7, RZ, PT;               /* 0x190e0000fc75dc03 */
        /*0088*/        ISETP.EQ.U32.AND P1, PT, R6, RZ, PT;               /* 0x190e0000fc63dc03 */
        /*0090*/        PSETP.AND.AND P0, PT, P1, P2, PT;                  /* 0x0c0e00000811dc04 */
        /*0098*/    @P0 BRA 0xc0;                                          /* 0x40000000800001e7 */
        /*00a0*/        FSETP.EQ.AND P3, PT, |R4|, +INF , PT;              /* 0x210edfe00047dc80 */
        /*00a8*/        FSETP.EQ.AND P0, PT, |R5|, +INF , PT;              /* 0x210edfe00051dc80 */
        /*00b0*/   @!P3 BRA 0xd8;                                          /* 0x4000000080002de7 */
        /*00b8*/   @!P0 BRA 0xd8;                                          /* 0x40000000600021e7 */
        /*00c0*/        MOV32I R0, 0xffc00000;                             /* 0x1bff000000001de2 */
        /*00c8*/        MUFU.RSQ R4, R0;                                   /* 0xc800000014011c00 */
        /*00d0*/        BRA 0x370;                                         /* 0x4000000a60001de7 */
        /*00d8*/        PSETP.OR.AND P0, PT, P0, P1, PT;                   /* 0x0c0e00004401dc04 */
        /*00e0*/   @!P0 BRA 0x100;                                         /* 0x40000000600021e7 */
        /*00e8*/        LOP.XOR R0, R5, R4;                                /* 0x6800000010501c83 */
        /*00f0*/        LOP32I.AND R4, R0, 0x80000000;                     /* 0x3a00000000011c02 */
        /*00f8*/        BRA 0x370;                                         /* 0x40000009c0001de7 */
        /*0100*/        PSETP.OR.AND P0, PT, P3, P2, PT;                   /* 0x0c0e00004831dc04 */
        /*0108*/   @!P0 BRA 0x130;                                         /* 0x40000000800021e7 */
        /*0110*/        LOP.XOR R0, R5, R4;                                /* 0x6800000010501c83 */
        /*0118*/        LOP32I.AND R0, R0, 0x80000000;                     /* 0x3a00000000001c02 */
        /*0120*/        LOP32I.OR R4, R0, 0x7f800000;                      /* 0x39fe000000011c42 */
        /*0128*/        BRA 0x370;                                         /* 0x4000000900001de7 */
        /*0130*/        ISETP.GE.AND P1, PT, R0, RZ, PT;                   /* 0x1b0e0000fc03dc23 */
        /*0138*/        ISETP.GE.AND P0, PT, R3, RZ, PT;                   /* 0x1b0e0000fc31dc23 */
        /*0140*/   @!P1 MOV32I R6, 0xffffffc0;                             /* 0x1bffffff0001a5e2 */
        /*0148*/   @!P1 FFMA R4, R4, 1.84467440737095520000e+019, RZ;      /* 0x307ed7e000412400 */
        /*0150*/    @P1 MOV R6, RZ;                                        /* 0x28000000fc0185e4 */
        /*0158*/    @P0 BRA 0x180;                                         /* 0x40000000800001e7 */
        /*0160*/        FFMA R5, R5, 1.84467440737095520000e+019, RZ;      /* 0x307ed7e000515c00 */
        /*0168*/        IADD R6, R6, 0x40;                                 /* 0x4800c00100619c03 */
        /*0170*/        BRA 0x180;                                         /* 0x4000000020001de7 */
        /*0178*/        MOV R6, RZ;                                        /* 0x28000000fc019de4 */
        /*0180*/        IADD R7, R3, -0x7e;                                /* 0x4800fffe0831dc03 */
        /*0188*/        MOV32I R9, 0x3f800000;                             /* 0x18fe000000025de2 */
        /*0190*/        ISCADD R7, -R7, R5, 0x17;                          /* 0x410000001471dee3 */
        /*0198*/        ISUB R3, R0, R3;                                   /* 0x480000000c00dd03 */
        /*01a0*/        MUFU.RCP R8, R7;                                   /* 0xc800000010721c00 */
        /*01a8*/        IADD R5, R0, -0x7e;                                /* 0x4800fffe08015c03 */
        /*01b0*/        FFMA R9, -R7, R8, R9;                              /* 0x3012000020725e00 */
        /*01b8*/        ISCADD R4, -R5, R4, 0x17;                          /* 0x4100000010511ee3 */
        /*01c0*/        FFMA R5, R8, R9, R8;                               /* 0x3010000024815c00 */
        /*01c8*/        FFMA R8, R4, R5, RZ;                               /* 0x307e000014421c00 */
        /*01d0*/        FFMA R9, -R7, R8, R4;                              /* 0x3008000020725e00 */
        /*01d8*/        FFMA R8, R9, R5, R8;                               /* 0x3010000014921c00 */
        /*01e0*/        FFMA R7, -R7, R8, R4;                              /* 0x300800002071de00 */
        /*01e8*/        FFMA R4, R7, R5, R8;                               /* 0x3010000014711c00 */
        /*01f0*/        SHL R9, R4, 0x1;                                   /* 0x6000c00004425c03 */
        /*01f8*/        SHR.U32 R9, R9, 0x18;                              /* 0x5800c00060925c03 */
        /*0200*/        IADD R0, R3, R9;                                   /* 0x4800000024301c03 */
        /*0208*/        IADD R6, R6, R0;                                   /* 0x4800000000619c03 */
        /*0210*/        IADD R0, R6, -0x1;                                 /* 0x4800fffffc601c03 */
        /*0218*/        ISETP.GT.U32.AND P0, PT, R0, 0xfd, PT;             /* 0x1a0ec003f401dc03 */
        /*0220*/    @P0 BRA 0x240;                                         /* 0x40000000600001e7 */
        /*0228*/        ISUB R0, R6, R9;                                   /* 0x4800000024601d03 */
        /*0230*/        ISCADD R4, R0, R4, 0x17;                           /* 0x4000000010011ee3 */
        /*0238*/        BRA 0x370;                                         /* 0x40000004c0001de7 */
        /*0240*/        ISETP.LE.AND P0, PT, R6, 0xfe, PT;                 /* 0x198ec003f861dc23 */
        /*0248*/    @P0 BRA 0x268;                                         /* 0x40000000600001e7 */
        /*0250*/        LOP32I.AND R0, R4, 0x80000000;                     /* 0x3a00000000401c02 */
        /*0258*/        LOP32I.OR R4, R0, 0x7f800000;                      /* 0x39fe000000011c42 */
        /*0260*/        BRA 0x370;                                         /* 0x4000000420001de7 */
        /*0268*/        ISETP.GT.AND P0, PT, R6, RZ, PT;                   /* 0x1a0e0000fc61dc23 */
        /*0270*/    @P0 BRA 0x370;                                         /* 0x40000003e00001e7 */
        /*0278*/        ISETP.GE.AND P0, PT, R6, -0x18, PT;                /* 0x1b0effffa061dc23 */
        /*0280*/    @P0 BRA 0x298;                                         /* 0x40000000400001e7 */
        /*0288*/        LOP32I.AND R4, R4, 0x80000000;                     /* 0x3a00000000411c02 */
        /*0290*/        BRA 0x370;                                         /* 0x4000000360001de7 */
        /*0298*/        FFMA.RP R3, R7, R5, R8;                            /* 0x311000001470dc00 */
        /*02a0*/        FFMA.RM R0, R7, R5, R8;                            /* 0x3090000014701c00 */
        /*02a8*/        FFMA.RZ R5, R7, R5, R8;                            /* 0x3190000014715c00 */
        /*02b0*/        FSET.NEU.AND R3, R0, R3, PT;                       /* 0x168e00000c00dc00 */
        /*02b8*/        I2I.S32.S32 R7, -R6;                               /* 0x1c0000001921df84 */
        /*02c0*/        LOP32I.AND R5, R5, 0x7fffff;                       /* 0x3801fffffc515c02 */
        /*02c8*/        ISETP.EQ.AND P0, PT, R7, RZ, PT;                   /* 0x190e0000fc71dc23 */
        /*02d0*/        LOP32I.AND R0, R4, 0x80000000;                     /* 0x3a00000000401c02 */
        /*02d8*/        I2I.S32.S32 R3, -R3;                               /* 0x1c0000000d20df84 */
        /*02e0*/        I2I.S32.S32 R4, -R6;                               /* 0x1c00000019211f84 */
        /*02e8*/        LOP32I.OR R7, R5, 0x800000;                        /* 0x380200000051dc42 */
        /*02f0*/    @P0 BRA.U 0x328;                                       /* 0x40000000c00081e7 */
        /*02f8*/   @!P0 IADD R5, R6, 0x20;                                 /* 0x4800c00080616003 */
        /*0300*/   @!P0 SHL R5, R7, R5;                                    /* 0x6000000014716003 */
        /*0308*/   @!P0 ICMP.EQ.U32 R5, RZ, 0x1, R5;                       /* 0x310ac00007f16003 */
        /*0310*/   @!P0 SHR.U32 R7, R7, R4;                                /* 0x580000001071e003 */
        /*0318*/   @!P0 LOP.OR R3, R3, R5;                                 /* 0x680000001430e043 */
        /*0320*/        NOP;                                               /* 0x4000000000001de4 */
        /*0328*/        SHL R4, R7, 0x1e;                                  /* 0x6000c00078711c03 */
        /*0330*/        SHR.U32 R5, R4, 0x1f;                              /* 0x5800c0007c415c03 */
        /*0338*/        LOP.AND R4, R7, 0x1;                               /* 0x6800c00004711c03 */
        /*0340*/        LOP.OR R3, R3, R5;                                 /* 0x680000001430dc43 */
        /*0348*/        LOP.AND R3, R4, R3;                                /* 0x680000000c40dc03 */
        /*0350*/        SHR.U32 R4, R7, 0x1;                               /* 0x5800c00004711c03 */
        /*0358*/        ISETP.NE.U32.AND P0, PT, R3, RZ, PT;               /* 0x1a8e0000fc31dc03 */
        /*0360*/    @P0 IADD R4, R4, 0x1;                                  /* 0x4800c00004410003 */
        /*0368*/        LOP.OR R4, R0, R4;                                 /* 0x6800000010011c43 */
        /*0370*/        RET ;                                              /* 0x9000000000001de7 */
        ......................................................
    
    
        Function : __cuda_sm20_div_rn_f32
    .headerflags    @"EF_CUDA_SM20 EF_CUDA_PTX_SM(EF_CUDA_SM20)"
        /*0000*/        MUFU.RCP R3, R5;                     /* 0xc80000001050dc00 */
        /*0008*/        MOV32I R6, 0x3f800000;               /* 0x18fe000000019de2 */
        /*0010*/        LOP32I.AND R0, R4, 0x7fffff;         /* 0x3801fffffc401c02 */
        /*0018*/        FFMA.FTZ R6, -R5, R3, R6;            /* 0x300c00000c519e40 */
        /*0020*/        LOP32I.OR R0, R0, 0x3f800000;        /* 0x38fe000000001c42 */
        /*0028*/        FFMA.FTZ R3, R3, R6, R3;             /* 0x300600001830dc40 */
        /*0030*/        FFMA.FTZ R6, R0, R3, RZ;             /* 0x307e00000c019c40 */
        /*0038*/        FFMA.FTZ R7, -R5, R6, R0;            /* 0x300000001851de40 */
        /*0040*/        FFMA.FTZ R6, R7, R3, R6;             /* 0x300c00000c719c40 */
        /*0048*/        FFMA.FTZ R0, -R5, R6, R0;            /* 0x3000000018501e40 */
        /*0050*/        LOP32I.AND R7, R4, 0xff800000;       /* 0x3bfe00000041dc02 */
        /*0058*/        FFMA.FTZ R6, R0, R3, R6;             /* 0x300c00000c019c40 */
        /*0060*/        FFMA.FTZ R0, R6, R7, RZ;             /* 0x307e00001c601c40 */
        /*0068*/        LOP32I.AND R3, R0, 0x7fffffff;       /* 0x39fffffffc00dc02 */
        /*0070*/        MOV32I R6, 0x7effffef;               /* 0x19fbffffbc019de2 */
        /*0078*/        IADD32I R3, R3, -0x800010;           /* 0x0bfdffffc030dc02 */
        /*0080*/        ISETP.GT.U32.AND P0, PT, R3, R6, PT; /* 0x1a0e00001831dc03 */
        /*0088*/   @!P0 BRA 0xa8;                            /* 0x40000000600021e7 */
        /*0090*/        JCAL 0x0;                            /* 0x1000000000010007 */
        /*0098*/        MOV R0, R4;                          /* 0x2800000010001de4 */
        /*00a0*/        NOP;                                 /* 0x4000000000001de4 */
        /*00a8*/        MOV R4, R0;                          /* 0x2800000000011de4 */
        /*00b0*/        RET ;                                /* 0x9000000000001de7 */
        .......................................
    

    DISASSEMBLED CODE FOR THE SECOND SOLUTION

    code for sm_21
        Function : _Z9myKernel2Pffi
    .headerflags    @"EF_CUDA_SM20 EF_CUDA_PTX_SM(EF_CUDA_SM20)"
        /*0000*/        MOV R1, c[0x1][0x100];                     /* 0x2800440400005de4 */
        /*0008*/        S2R R0, SR_CTAID.X;                        /* 0x2c00000094001c04 */
        /*0010*/        S2R R2, SR_TID.X;                          /* 0x2c00000084009c04 */
        /*0018*/        IMAD R0, R0, c[0x0][0x8], R2;              /* 0x2004400020001ca3 */
        /*0020*/        ISETP.GE.AND P0, PT, R0, c[0x0][0x2c], PT; /* 0x1b0e4000b001dc23 */
        /*0028*/    @P0 BRA.U 0x98;                                /* 0x40000001a00081e7 */
        /*0030*/   @!P0 MOV32I R3, 0x4;                            /* 0x180000001000e1e2 */
        /*0038*/   @!P0 IMAD R2.CC, R0, R3, c[0x0][0x20];          /* 0x200780008000a0a3 */
        /*0040*/   @!P0 IMAD.HI.X R3, R0, R3, c[0x0][0x24];        /* 0x208680009000e0e3 */
        /*0048*/   @!P0 LD.E R0, [R2];                             /* 0x8400000000202085 */
        /*0050*/   @!P0 FADD R5, |R0|, -c[0x0][0x28];              /* 0x50004000a0016180 */
        /*0058*/   @!P0 FADD R4, -|R0|, c[0x0][0x28];              /* 0x50004000a0012280 */
        /*0060*/   @!P0 LOP32I.AND R0, R0, 0x80000000;             /* 0x3a00000000002002 */
        /*0068*/   @!P0 LOP32I.AND R5, R5, 0x7fffffff;             /* 0x39fffffffc516002 */
        /*0070*/   @!P0 SHR.U32 R4, R4, 0x1f;                      /* 0x5800c0007c412003 */
        /*0078*/   @!P0 LOP.OR R5, R0, R5;                         /* 0x6800000014016043 */
        /*0080*/   @!P0 I2F.F32.S32 R0, R4;                        /* 0x1800000011202204 */
        /*0088*/   @!P0 FMUL R0, R0, R5;                           /* 0x5800000014002000 */
        /*0090*/   @!P0 ST.E [R2], R0;                             /* 0x9400000000202085 */
        /*0098*/        EXIT ;                                     /* 0x8000000000001de7 */
        .................................
    

    EDIT

    A follow-up of this post has appeared in Soft thresholding in CUDA.

    0 讨论(0)
提交回复
热议问题