Tools to use for conditional density estimation in Python [closed]

隐身守侯 提交于 2019-12-18 07:31:44

问题


I have a large data set that contains 3 attributes per row: A,B,C

Column A: can take the values 1, 2, and 0. Column B and C: can take any values.

I'd like to perform density estimation using histograms for P(A = 2 | B,C) and plot the results using python.

I do not need the code to do it, I can try and figure that on my own. I just need to know the procedures and the tools that should I use?


回答1:


To answer your over-all question, we should go through different steps and answer different questions:

  • How to read csv file (or text data) ?

  • How to filter data ?

  • How to plot data ?

At each stage, you need to use some techniques and specific tools, you might also have different choices at different stages (You can look on the internet for different alternatives).

1- How to read csv file:

There is a built-in function to go through the csv file where you store your data. But most people recommend Pandas to deal with csv files.

After installing Pandas package, you can read your csv file using Read_CSV command.

import pandas as pd

df= pd.read_csv("file.csv")

As you didn't share the csv file, I will make a random dataset to explain the up-coming steps.

import pandas as pd
import numpy as np

t= [1,1,1,2,0,1,1,0,0,2,1,1,2,0,0,0,0,1,1,1]
df = pd.DataFrame(np.random.randn(20, 2), columns=list('AC'))
df['B']=t  #put a random column with only 0,1,2 values, then insert it to the dataframe

Note: Numpy is a python-Package. It's helpful to work with mathematical operations. You don't primarily need it, but I mentioned it to clear confusion here.

In case you print df in this case, you will get as result:

         A         C    B
0  -0.090162  0.035458  1
1   2.068328 -0.357626  1
2  -0.476045 -1.217848  1
3  -0.405150 -1.111787  2
4   0.502283  1.586743  0
5   1.822558 -0.398833  1
6   0.367663  0.305023  1
7   2.731756  0.563161  0
8   2.096459  1.323511  0
9   1.386778 -1.774599  2
10 -0.512147 -0.677339  1
11 -0.091165  0.587496  1
12 -0.264265  1.216617  2
13  1.731371 -0.906727  0
14  0.969974  1.305460  0
15 -0.795679 -0.707238  0
16  0.274473  1.842542  0
17  0.771794 -1.726273  1
18  0.126508 -0.206365  1
19  0.622025 -0.322115  1

2- - How to filter data: There are different techniques to filter data. The easiest one is by selecting the name of column inside your dataframe + the condition. In our case, the criteria is selecting value "2" in column B.

l= df[df['B']==2]
print l

You can also use other ways such groupby, lambda to go through the data frame and apply different conditions to filter the data.

for key in df.groupby('B'):
    print key 

If you run the above-mentioned scripts you'll get:

For the first one: Only data where B==2

           A         C  B
3  -0.405150 -1.111787  2
9   1.386778 -1.774599  2
12 -0.264265  1.216617  2

For the second one: Printing the results divided in groups.

(0,            A         C  B
4   0.502283  1.586743  0
7   2.731756  0.563161  0
8   2.096459  1.323511  0
13  1.731371 -0.906727  0
14  0.969974  1.305460  0
15 -0.795679 -0.707238  0
16  0.274473  1.842542  0)
(1,            A         C  B
0  -0.090162  0.035458  1
1   2.068328 -0.357626  1
2  -0.476045 -1.217848  1
5   1.822558 -0.398833  1
6   0.367663  0.305023  1
10 -0.512147 -0.677339  1
11 -0.091165  0.587496  1
17  0.771794 -1.726273  1
18  0.126508 -0.206365  1
19  0.622025 -0.322115  1)
(2,            A         C  B
3  -0.405150 -1.111787  2
9   1.386778 -1.774599  2
12 -0.264265  1.216617  2)
  • How to plot your data:

The simplest ways to plot your data is by using matplotlib

The easiest ways to plot data in columns B, is by running :

import random
import matplotlib.pyplot as plt

xbins=range(0,len(l))
plt.hist(df.B, bins=20, color='blue')
plt.show()

You'll get this result:

if you wanna plot the results combined, you should use different colors/techniques to make it useful.

import numpy as np
import matplotlib.pyplot as plt
a = df.A
b = df.B
c = df.C
t= range(20)
plt.plot(t, a, 'r--',  b, 'bs--', c, 'g^--')
plt.legend()
plt.show()

You'll get as a result:

Plotting data is driven by a specific need. You can explore the different ways to plot data by going through the examples of marplotlib.org official website.




回答2:


If you're looking for other tools that do slightly more sophisticated things than nonparametric density estimation with histograms, please check this link to the python repository or directly install the package with

pip install cde

In addition to an extensive documentation, the package implements

  • Nonparametric (conditional & neighborhood kernel density estimation)
  • semiparametric (least squares cde) and
  • parametric neural network-based methods (mixture density networks, kernel density estimation)

Also, the package allows to compute centered moments, statistical divergences (kl-divergence, hellinger, jensen-shannon), percentiles, expected shortfalls and data generating processes (arma-jump, jump-diffusion, GMMs etc.)

Disclaimer: I am one of the package developers.



来源:https://stackoverflow.com/questions/26558576/tools-to-use-for-conditional-density-estimation-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!