问题
Please try to give parameterize solution (there are more than three alternatives).
I have a dict with beta values:
{'B_X1': 2.0, 'B_X2': -3.0}
And this data frame:
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789
6.75 4.69 9.59 5.52 9.69 7.40
7.46 4.94 3.01 1.78 1.38 4.68
2.05 7.30 4.08 7.02 8.24 8.49
5.60 7.88 8.11 5.98 4.60 1.39
1.80 8.28 9.16 7.34 7.69 6.16
3.73 6.93 8.93 2.58 3.48 6.04
8.06 8.88 7.06 6.76 4.68 7.82
5.00 7.29 5.86 3.92 5.67 4.10
2.49 2.55 4.66 7.15 6.26 7.87
1.50 3.35 5.70 9.86 4.83 1.17
8.19 7.72 9.56 6.61 4.15 3.64
2.43 9.54 9.15 4.41 9.18 7.85
2.71 3.24 4.56 6.22 7.89 9.93
5.96 4.34 5.26 8.63 9.81 9.40
123
, 456
, and 789
are the alternatives.
I want to calculate the prediction probability using this formula:
j
, k
, and s
are the mentioned alternatives.
Expected result:
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789 P_123 P_456 P_789
6.75 4.69 9.59 5.52 9.69 7.40 0.490 0.000 0.510
7.46 4.94 3.01 1.78 1.38 4.68 0.979 0.021 0.000
2.05 7.30 4.08 7.02 8.24 8.49 0.001 0.998 0.001
5.60 7.88 8.11 5.98 4.60 1.39 0.000 0.000 1.000
1.80 8.28 9.16 7.34 7.69 6.16 0.000 0.002 0.998
3.73 6.93 8.93 2.58 3.48 6.04 0.024 0.952 0.024
8.06 8.88 7.06 6.76 4.68 7.82 0.000 1.000 0.000
5.00 7.29 5.86 3.92 5.67 4.10 0.210 0.107 0.683
2.49 2.55 4.66 7.15 6.26 7.87 0.038 0.623 0.339
1.50 3.35 5.70 9.86 4.83 1.17 0.000 0.000 1.000
8.19 7.72 9.56 6.61 4.15 3.64 0.000 0.005 0.995
2.43 9.54 9.15 4.41 9.18 7.85 0.041 0.037 0.922
2.71 3.24 4.56 6.22 7.89 9.93 0.981 0.019 0.001
5.96 4.34 5.26 8.63 9.81 9.40 0.975 0.001 0.024
Probabilities sum should be 1 in every row.
Please try to give parameterize solution (there are more than three alternatives).
Expected result with constant for each alternative:
{'B_X1': 2.0, 'B_X2': -3.0, 'B_123': 0.1, 'B_456': 0.2, 'B_789': 0.3}
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789 P_123 P_456 P_789
6.75 4.69 9.59 5.52 9.69 7.40 0.440 0.000 0.560
7.46 4.94 3.01 1.78 1.38 4.68 0.977 0.023 0.000
2.05 7.30 4.08 7.02 8.24 8.49 0.001 0.998 0.001
5.60 7.88 8.11 5.98 4.60 1.39 0.000 0.000 1.000
1.80 8.28 9.16 7.34 7.69 6.16 0.000 0.002 0.998
3.73 6.93 8.93 2.58 3.48 6.04 0.021 0.952 0.027
8.06 8.88 7.06 6.76 4.68 7.82 0.000 1.000 0.000
5.00 7.29 5.86 3.92 5.67 4.10 0.180 0.102 0.717
2.49 2.55 4.66 7.15 6.26 7.87 0.034 0.604 0.363
1.50 3.35 5.70 9.86 4.83 1.17 0.000 0.000 1.000
8.19 7.72 9.56 6.61 4.15 3.64 0.000 0.005 0.995
2.43 9.54 9.15 4.41 9.18 7.85 0.034 0.034 0.932
2.71 3.24 4.56 6.22 7.89 9.93 0.978 0.021 0.001
5.96 4.34 5.26 8.63 9.81 9.40 0.970 0.001 0.029
回答1:
IIUC:
Turn columns into a MultiIndex
df = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
And define your B
such that the keys match the prefixes in df
B = {'X1': 2.0, 'X2': -3.0}
Then
def f(b, x):
return np.exp((b * x).sum(1))
parts = f(B, df.stack()).unstack()
preds = parts.div(parts.sum(1), axis=0)
df.join(pd.concat({'P': preds}, axis=1).round(3)).pipe(
lambda d: d.set_axis(map('_'.join, d.columns), axis=1, inplace=False)
)
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789 P_123 P_456 P_789
0 6.75 4.69 9.59 5.52 9.69 7.40 0.490 0.000 0.510
1 7.46 4.94 3.01 1.78 1.38 4.68 0.979 0.021 0.000
2 2.05 7.30 4.08 7.02 8.24 8.49 0.001 0.998 0.001
3 5.60 7.88 8.11 5.98 4.60 1.39 0.000 0.000 1.000
4 1.80 8.28 9.16 7.34 7.69 6.16 0.000 0.002 0.998
5 3.73 6.93 8.93 2.58 3.48 6.04 0.024 0.952 0.024
6 8.06 8.88 7.06 6.76 4.68 7.82 0.000 1.000 0.000
7 5.00 7.29 5.86 3.92 5.67 4.10 0.210 0.107 0.683
8 2.49 2.55 4.66 7.15 6.26 7.87 0.038 0.623 0.339
9 1.50 3.35 5.70 9.86 4.83 1.17 0.000 0.000 1.000
10 8.19 7.72 9.56 6.61 4.15 3.64 0.000 0.005 0.995
11 2.43 9.54 9.15 4.41 9.18 7.85 0.041 0.037 0.922
12 2.71 3.24 4.56 6.22 7.89 9.93 0.981 0.019 0.001
13 5.96 4.34 5.26 8.63 9.81 9.40 0.975 0.001 0.024
Wrapped in one pretty function
def f(df, b):
d = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
parts = np.exp(d.stack().mul(b).sum(1).unstack())
preds = pd.concat({'P': parts.div(parts.sum(1), axis=0)}, axis=1).round(3)
d = d.join(preds)
d.columns = list(map('_'.join, d.columns))
return d
f(df, B)
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789 P_123 P_456 P_789
0 6.75 4.69 9.59 5.52 9.69 7.40 0.490 0.000 0.510
1 7.46 4.94 3.01 1.78 1.38 4.68 0.979 0.021 0.000
2 2.05 7.30 4.08 7.02 8.24 8.49 0.001 0.998 0.001
3 5.60 7.88 8.11 5.98 4.60 1.39 0.000 0.000 1.000
4 1.80 8.28 9.16 7.34 7.69 6.16 0.000 0.002 0.998
5 3.73 6.93 8.93 2.58 3.48 6.04 0.024 0.952 0.024
6 8.06 8.88 7.06 6.76 4.68 7.82 0.000 1.000 0.000
7 5.00 7.29 5.86 3.92 5.67 4.10 0.210 0.107 0.683
8 2.49 2.55 4.66 7.15 6.26 7.87 0.038 0.623 0.339
9 1.50 3.35 5.70 9.86 4.83 1.17 0.000 0.000 1.000
10 8.19 7.72 9.56 6.61 4.15 3.64 0.000 0.005 0.995
11 2.43 9.54 9.15 4.41 9.18 7.85 0.041 0.037 0.922
12 2.71 3.24 4.56 6.22 7.89 9.93 0.981 0.019 0.001
13 5.96 4.34 5.26 8.63 9.81 9.40 0.975 0.001 0.024
来源:https://stackoverflow.com/questions/59938586/calculating-multinomial-logit-model-prediction-probabilities