How to speed up Pandas multilevel dataframe shift by group?

后端 未结 4 1870
离开以前
离开以前 2021-01-22 03:36

I am trying to shift the Pandas dataframe column data by group of first index. Here is the demo code:

 In [8]: df = mul_df(5,4,3)

In [9]: df
Out[9]:
                    


        
相关标签:
4条回答
  • 2021-01-22 04:02

    similar question and added answer with that works for shift in either direction and magnitude: pandas: setting last N rows of multi-index to Nan for speeding up groupby with shift

    Code (including test setup) is:

    #
    # the function to use in apply
    #
    def replace_shift_overlap(grp,col,N,value):
        if (N > 0):
            grp[col][:N] = value
        else:
            grp[col][N:] = value
        return grp
    
    
    length = 5
    groups = 3
    rng1 = pd.date_range('1/1/1990', periods=length, freq='D')
    frames = []
    for x in xrange(0,groups):
        tmpdf = pd.DataFrame({'date':rng1,'category':int(10000000*abs(np.random.randn())),'colA':np.random.randn(length),'colB':np.random.randn(length)})
        frames.append(tmpdf)
    df = pd.concat(frames)
    
    df.sort(columns=['category','date'],inplace=True)
    df.set_index(['category','date'],inplace=True,drop=True)
    shiftBy=-1
    df['tmpShift'] = df['colB'].shift(shiftBy)
    
    # 
    # the apply
    #
    df = df.groupby(level=0).apply(replace_shift_overlap,'tmpShift',shiftBy,np.nan)
    
    # Yay this is so much faster.
    df['newColumn'] = df['tmpShift'] / df['colA']
    df.drop('tmpShift',1,inplace=True)
    

    EDIT: Note that the initial sort really eats into the effectiveness of this. So in some cases the original answer is more effective.

    0 讨论(0)
  • 2021-01-22 04:12

    How about shift the total DataFrame object and then set the first row of every group to NaN?

    dfs = df.shift(1)
    dfs.iloc[df.groupby(level=0).size().cumsum()[:-1]] = np.nan
    
    0 讨论(0)
  • 2021-01-22 04:12

    try this:

    import numpy as np
    import pandas as pd
    df = pd.DataFrame({'A': [10, 20, 15, 30, 45,43,67,22,12,14,54],
                       'B': [13, 23, 18, 33, 48, 1,7, 56,66,45,32],
                       'C': [17, 27, 22, 37, 52,77,34,21,22,90,8],
                       'D':    ['a','a','a','a','b','b','b','c','c','c','c']
                       })
    df
    #>      A   B   C  D
    #> 0   10  13  17  a
    #> 1   20  23  27  a
    #> 2   15  18  22  a
    #> 3   30  33  37  a
    #> 4   45  48  52  b
    #> 5   43   1  77  b
    #> 6   67   7  34  b
    #> 7   22  56  21  c
    #> 8   12  66  22  c
    #> 9   14  45  90  c
    #> 10  54  32   8  c
    def groupby_shift(df, col, groupcol, shift_n, fill_na = np.nan):
        '''df: dataframe
           col: column need to be shifted 
           groupcol: group variable
           shift_n: how much need to shift
           fill_na: how to fill nan value, default is np.nan 
        '''
        rowno = list(df.groupby(groupcol).size().cumsum()) 
        lagged_col = df[col].shift(shift_n)
        na_rows = [i for i in range(shift_n)] 
        for i in rowno:
            if i == rowno[len(rowno)-1]: 
                continue 
            else:
                new = [i + j for j in range(shift_n)]
                na_rows.extend(new) 
        na_rows = list(set(na_rows)) 
        na_rows = [i for i in na_rows if i <= len(lagged_col) - 1] 
        lagged_col.iloc[na_rows] = fill_na
        return lagged_col
        
    df['A_lag_1'] = groupby_shift(df, 'A', 'D', 1)
    df
    #>      A   B   C  D  A_lag_1
    #> 0   10  13  17  a      NaN
    #> 1   20  23  27  a     10.0
    #> 2   15  18  22  a     20.0
    #> 3   30  33  37  a     15.0
    #> 4   45  48  52  b      NaN
    #> 5   43   1  77  b     45.0
    #> 6   67   7  34  b     43.0
    #> 7   22  56  21  c      NaN
    #> 8   12  66  22  c     22.0
    #> 9   14  45  90  c     12.0
    #> 10  54  32   8  c     14.0
    
    0 讨论(0)
  • 2021-01-22 04:21

    the problem is that the shift operation is not cython optimized, so it involves callback to python. Compare this with:

    In [84]: %timeit grp.shift(1)
    1 loops, best of 3: 1.77 s per loop
    
    In [85]: %timeit grp.sum()
    1 loops, best of 3: 202 ms per loop
    

    added an issue for this: https://github.com/pydata/pandas/issues/4095

    0 讨论(0)
提交回复
热议问题