pandas和csv使用最为频繁,保存数据集时尽量使用csv存储,而不是txt
对于训练集中的数据,content,labels,将原始的list封装成dict,直接转换为dataFrame
data = pd.DataFrame({"samples":content, "labels":labels})
def generate_data(random_state = 24, is_pse_label=True): skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state) i = 0 for train_index, dev_index in skf.split(X, y): print(i, "TRAIN:", train_index, "TEST:", dev_index) DATA_DIR = "./data_StratifiedKFold_{}/data_origin_{}/".format(random_state,i) if not os.path.exists(DATA_DIR): os.makedirs(DATA_DIR) tmp_train_df = train_df.iloc[train_index] tmp_dev_df = train_df.iloc[dev_index] test_df.to_csv(DATA_DIR+"test.csv") if is_pse_label: pse_dir = "data_pse_{}/".format(i) pse_df = pd.read_csv(pse_dir+'train.csv') tmp_train_df = pd.concat([tmp_train_df, pse_df],ignore_index=True,sort=False) tmp_train_df.to_csv(DATA_DIR + "train.csv") tmp_dev_df.to_csv(DATA_DIR+"dev.csv") print(tmp_train_df.shape, tmp_dev_df.shape) i+=1
来源:https://www.cnblogs.com/demo-deng/p/12517771.html