Pandas - How to flatten a hierarchical index in columns

后端 未结 17 997
忘掉有多难
忘掉有多难 2020-11-22 02:55

I have a data frame with a hierarchical index in axis 1 (columns) (from a groupby.agg operation):

     USAF   WBAN  year  month  day  s_PC  s_CL         


        
相关标签:
17条回答
  • 2020-11-22 03:16

    To flatten a MultiIndex inside a chain of other DataFrame methods, define a function like this:

    def flatten_index(df):
      df_copy = df.copy()
      df_copy.columns = ['_'.join(col).rstrip('_') for col in df_copy.columns.values]
      return df_copy.reset_index()
    

    Then use the pipe method to apply this function in the chain of DataFrame methods, after groupby and agg but before any other methods in the chain:

    my_df \
      .groupby('group') \
      .agg({'value': ['count']}) \
      .pipe(flatten_index) \
      .sort_values('value_count')
    
    0 讨论(0)
  • 2020-11-22 03:18

    The easiest and most intuitive solution for me was to combine the column names using get_level_values. This prevents duplicate column names when you do more than one aggregation on the same column:

    level_one = df.columns.get_level_values(0).astype(str)
    level_two = df.columns.get_level_values(1).astype(str)
    df.columns = level_one + level_two
    

    If you want a separator between columns, you can do this. This will return the same thing as Seiji Armstrong's comment on the accepted answer that only includes underscores for columns with values in both index levels:

    level_one = df.columns.get_level_values(0).astype(str)
    level_two = df.columns.get_level_values(1).astype(str)
    column_separator = ['_' if x != '' else '' for x in level_two]
    df.columns = level_one + column_separator + level_two
    

    I know this does the same thing as Andy Hayden's great answer above, but I think it is a bit more intuitive this way and is easier to remember (so I don't have to keep referring to this thread), especially for novice pandas users.

    This method is also more extensible in the case where you may have 3 column levels.

    level_one = df.columns.get_level_values(0).astype(str)
    level_two = df.columns.get_level_values(1).astype(str)
    level_three = df.columns.get_level_values(2).astype(str)
    df.columns = level_one + level_two + level_three
    
    0 讨论(0)
  • 2020-11-22 03:21

    After reading through all the answers, I came up with this:

    def __my_flatten_cols(self, how="_".join, reset_index=True):
        how = (lambda iter: list(iter)[-1]) if how == "last" else how
        self.columns = [how(filter(None, map(str, levels))) for levels in self.columns.values] \
                        if isinstance(self.columns, pd.MultiIndex) else self.columns
        return self.reset_index() if reset_index else self
    pd.DataFrame.my_flatten_cols = __my_flatten_cols
    

    Usage:

    Given a data frame:

    df = pd.DataFrame({"grouper": ["x","x","y","y"], "val1": [0,2,4,6], 2: [1,3,5,7]}, columns=["grouper", "val1", 2])
    
      grouper  val1  2
    0       x     0  1
    1       x     2  3
    2       y     4  5
    3       y     6  7
    
    • Single aggregation method: resulting variables named the same as source:

      df.groupby(by="grouper").agg("min").my_flatten_cols()
      
      • Same as df.groupby(by="grouper", as_index=False) or .agg(...).reset_index()
      • ----- before -----
                   val1  2
          grouper         
        
        ------ after -----
          grouper  val1  2
        0       x     0  1
        1       y     4  5
        
    • Single source variable, multiple aggregations: resulting variables named after statistics:

      df.groupby(by="grouper").agg({"val1": [min,max]}).my_flatten_cols("last")
      
      • Same as a = df.groupby(..).agg(..); a.columns = a.columns.droplevel(0); a.reset_index().
      • ----- before -----
                    val1    
                   min max
          grouper         
        
        ------ after -----
          grouper  min  max
        0       x    0    2
        1       y    4    6
        
    • Multiple variables, multiple aggregations: resulting variables named (varname)_(statname):

      df.groupby(by="grouper").agg({"val1": min, 2:[sum, "size"]}).my_flatten_cols()
      # you can combine the names in other ways too, e.g. use a different delimiter:
      #df.groupby(by="grouper").agg({"val1": min, 2:[sum, "size"]}).my_flatten_cols(" ".join)
      
      • Runs a.columns = ["_".join(filter(None, map(str, levels))) for levels in a.columns.values] under the hood (since this form of agg() results in MultiIndex on columns).
      • If you don't have the my_flatten_cols helper, it might be easier to type in the solution suggested by @Seigi: a.columns = ["_".join(t).rstrip("_") for t in a.columns.values], which works similarly in this case (but fails if you have numeric labels on columns)
      • To handle the numeric labels on columns, you could use the solution suggested by @jxstanford and @Nolan Conaway (a.columns = ["_".join(tuple(map(str, t))).rstrip("_") for t in a.columns.values]), but I don't understand why the tuple() call is needed, and I believe rstrip() is only required if some columns have a descriptor like ("colname", "") (which can happen if you reset_index() before trying to fix up .columns)
      • ----- before -----
                   val1           2     
                   min       sum    size
          grouper              
        
        ------ after -----
          grouper  val1_min  2_sum  2_size
        0       x         0      4       2
        1       y         4     12       2
        
    • You want to name the resulting variables manually: (this is deprecated since pandas 0.20.0 with no adequate alternative as of 0.23)

      df.groupby(by="grouper").agg({"val1": {"sum_of_val1": "sum", "count_of_val1": "count"},
                                         2: {"sum_of_2":    "sum", "count_of_2":    "count"}}).my_flatten_cols("last")
      
      • Other suggestions include: setting the columns manually: res.columns = ['A_sum', 'B_sum', 'count'] or .join()ing multiple groupby statements.
      • ----- before -----
                           val1                      2         
                  count_of_val1 sum_of_val1 count_of_2 sum_of_2
          grouper                                              
        
        ------ after -----
          grouper  count_of_val1  sum_of_val1  count_of_2  sum_of_2
        0       x              2            2           2         4
        1       y              2           10           2        12
        

    Cases handled by the helper function

    • level names can be non-string, e.g. Index pandas DataFrame by column numbers, when column names are integers, so we have to convert with map(str, ..)
    • they can also be empty, so we have to filter(None, ..)
    • for single-level columns (i.e. anything except MultiIndex), columns.values returns the names (str, not tuples)
    • depending on how you used .agg() you may need to keep the bottom-most label for a column or concatenate multiple labels
    • (since I'm new to pandas?) more often than not, I want reset_index() to be able to work with the group-by columns in the regular way, so it does that by default
    0 讨论(0)
  • 2020-11-22 03:22

    I think the easiest way to do this would be to set the columns to the top level:

    df.columns = df.columns.get_level_values(0)
    

    Note: if the to level has a name you can also access it by this, rather than 0.

    .

    If you want to combine/join your MultiIndex into one Index (assuming you have just string entries in your columns) you could:

    df.columns = [' '.join(col).strip() for col in df.columns.values]
    

    Note: we must strip the whitespace for when there is no second index.

    In [11]: [' '.join(col).strip() for col in df.columns.values]
    Out[11]: 
    ['USAF',
     'WBAN',
     'day',
     'month',
     's_CD sum',
     's_CL sum',
     's_CNT sum',
     's_PC sum',
     'tempf amax',
     'tempf amin',
     'year']
    
    0 讨论(0)
  • 2020-11-22 03:23

    The most pythonic way to do this to use map function.

    df.columns = df.columns.map(' '.join).str.strip()
    

    Output print(df.columns):

    Index(['USAF', 'WBAN', 'day', 'month', 's_CD sum', 's_CL sum', 's_CNT sum',
           's_PC sum', 'tempf amax', 'tempf amin', 'year'],
          dtype='object')
    

    Update using Python 3.6+ with f string:

    df.columns = [f'{f} {s}' if s != '' else f'{f}' 
                  for f, s in df.columns]
    
    print(df.columns)
    

    Output:

    Index(['USAF', 'WBAN', 'day', 'month', 's_CD sum', 's_CL sum', 's_CNT sum',
           's_PC sum', 'tempf amax', 'tempf amin', 'year'],
          dtype='object')
    
    0 讨论(0)
  • 2020-11-22 03:23

    Another simple routine.

    def flatten_columns(df, sep='.'):
        def _remove_empty(column_name):
            return tuple(element for element in column_name if element)
        def _join(column_name):
            return sep.join(column_name)
    
        new_columns = [_join(_remove_empty(column)) for column in df.columns.values]
        df.columns = new_columns
    
    0 讨论(0)
提交回复
热议问题