How can an interaction design matrix be created from categorical variables?

前端 未结 2 794
[愿得一人]
[愿得一人] 2021-01-20 15:31

I\'m coming from mainly working in R for statistical modeling / machine learning and looking to improve my skills in Python. I am wondering the best way to create a design m

2条回答
  •  鱼传尺愫
    2021-01-20 15:51

    If you use the OneHotEncoder on your design matrix to obtain a one-hot design matrix, then interactions are nothing other than multiplications between columns. If X_1hot is your one-hot design matrix, where samples are lines, then for 2nd order interactions you can write

    X_2nd_order = (X_1hot[:, np.newaxis, :] * X_1hot[:, :, np.newaxis]).reshape(len(X_1hot), -1)
    

    There will be duplicates of interactions and it will contain the original features as well.

    Going to arbitrary order is going to make your design matrix explode. If you really want to do that, then you should look into kernelizing with a polynomial kernel, which will let you go to arbitrary degrees easily.

    Using the data frame you present, we can proceed as follows. First, a manual way to construct a one-hot design out of the data frame:

    import numpy as np
    indicators = []
    state_names = []
    for column_name in df.columns:
        column = df[column_name].values
        one_hot = (column[:, np.newaxis] == np.unique(column)).astype(float)
        indicators.append(one_hot)
        state_names = state_names + ["%s__%s" % (column_name, state) for state in np.unique(column)]
    
    X_1hot = np.hstack(indicators)
    

    The column names are then stored in state_names and the indicator matrix is X_1hot. Then we calculate the second order features

    X_2nd_order = (X_1hot[:, np.newaxis, :] * X_1hot[:, :, np.newaxis]).reshape(len(X_1hot), -1)
    

    In order to know the names of the columns of the second order matrix, we construct them like this

    from itertools import product
    one_hot_interaction_names = ["%s___%s" % (column1, column2) 
                                 for column1, column2 in product(state_names, state_names)]
    

提交回复
热议问题