1.准备数据
dataset.py
''' 准备数据 ''' from torch.utils.data import DataLoader,Dataset import torch import utils import os import config class ImdbDataset(Dataset): def __init__(self,train = True): data_path = r"H:\073-nlp自然语言处理-v5.bt38[周大伟]\073-nlp自然语言处理-v5.bt38[周大伟]\第四天\代码\data\aclImdb_v1\aclImdb" super(ImdbDataset,self).__init__() data_path += r"\train" if train else r"\test" self.total_path = [] for temp_path in [r"\pos",r"\neg"]: cur_path = data_path + temp_path self.total_path +=[os.path.join(cur_path,i) for i in os.listdir(cur_path) if i.endswith(".txt")] def __getitem__(self, idx): file = self.total_path[idx] review = utils.tokenlize(open(file,encoding='utf-8').read()) label = int(file.split("_")[-1].split(".")[0]) # label = 0 if label <5 else 1 return review,label def __len__(self): return len(self.total_path) # def collate_fn(batch): # #batch是list,其中是一个一个元组,每个元组是dataset中__getitem__的结果 # batch = list(zip(*batch)) # labes = torch.tensor(batch[1],dtype=torch.int32) # texts = batch[0] # del batch # return labes,texts def collate_fn(batch): """ 对batch数据进行处理 :param batch: [一个getitem的结果,getitem的结果,getitem的结果] :return: 元组 """ reviews,labels = zip(*batch) reviews = torch.LongTensor([config.ws.transform(i,max_len=config.max_len) for i in reviews]) labels = torch.LongTensor(labels) return reviews,labels def get_dataloader(train=True): dataset = ImdbDataset(train) batch_size = config.train_batch_size if train else config.test_batch_size return DataLoader(dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn) if __name__ == '__main__': dataset = ImdbDataset() dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True,collate_fn=collate_fn) # 3. 观察数据输出结果 for idx, (label, text) in enumerate(dataloader): print("idx:", idx) print("table:", label) print("text:", text) break
2.conf.py 文件
""" 配置文件 """ import pickle train_batch_size = 512 test_batch_size = 500 ws = pickle.load(open("./model/ws.pkl","rb")) max_len = 80
3.utils.py分词文件
import re def tokenlize(sentence): ''' 进行文本分词 :param sentence: :return: ''' fileters = ['!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>', '\?', '@' , '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '”', '“', ] sentence = sentence.lower() sentence = re.sub("<br />"," ",sentence) sentence = re.sub("|".join(fileters)," ",sentence) # result = sentence.split(" ") #去除空字符串 result = [i for i in sentence.split(" ") if len(i)>0] return result
''' 文本序列化 ''' class Word2Sequence: UNK_TAG = "<UNK>" PAD_TAG = "<PAD>" UNK = 0 PAD = 1 def __init__(self): self.dict = { #保存词语和对应的数字 self.UNK_TAG:self.UNK, self.PAD_TAG:self.PAD } self.count = {} #统计词频的 def fit(self,sentence): ''' 接受句子,统计词频 :param sentence: :return: ''' for word in sentence: self.count[word] = self.count.get(word,0) + 1 def build_vocab(self,min_count = 1,max_count = None,max_feature = None): ''' 根据条件构造 词典 :param min_count: 最小词频 :param max_count: 最大词频 :param max_feature: 最大词语数,这个参数会排序 :return: ''' if min_count is not None: self.count = {word:count for word,count in self.count.items() if count >= min_count} if max_count is not None: self.count = {word:count for word,count in self.count.items() if count <= max_count} if max_feature is not None: self.count = dict(sorted(self.count.items(),lambda x:x[-1],reverse=True)[:max_feature]) for word in self.count.keys(): self.dict[word] = len(self.dict) #获取每个词及生成每个词对应的编号 #字典翻转,键→值,值←键 self.inverse_dict = dict(zip(self.dict.values(),self.dict.keys())) def transform(self,sentence,max_len = None): ''' 把句子转化为数字序列 :param sentense: [str,str,,,,,,,,,,] :return: [num,num,num,,,,,,,] ''' if len(sentence) > max_len: sentence = sentence[:max_len] else: sentence = sentence + [self.PAD_TAG]*(max_len-len(sentence)) return [self.dict.get(i,0) for i in sentence] def inverse_transform(self,incides): ''' 把数字序列转化为字符 :param incides: [num,num,num,,,,,,,,] :return: [str,str,str,,,,,,,] ''' return [self.inverse_dict.get(i,"<UNK>") for i in incides] if __name__ == '__main__': sentences = [['今天','天气','很','好'], ['今天','去','吃','什么']] ws = Word2Sequence() for sentence in sentences: ws.fit(sentence) ws.build_vocab() print(ws.dict) ret = ws.transform(["好","好","好","好","好","好","好","热","呀"],max_len=20) print(ret) ret = ws.inverse_transform(ret) print(ret)
5. main主文件,把文件中的词转换成数字编码并保存
''' 文本序列化及保存模型 ''' from word_sequence import Word2Sequence from dataset import get_dataloader import pickle from tqdm import tqdm if __name__ == '__main__': ws = Word2Sequence() dl_train = get_dataloader(True) dl_test = get_dataloader(False) for label,reviews in tqdm(dl_train,total=len(dl_train)): for review in reviews: ws.fit(review) for label,reviews in tqdm(dl_test,total=len(dl_train)): for review in reviews: ws.fit(review) ws.build_vocab() pickle.dump(ws,open("./model/ws.pkl","wb"))
来源:博客园
作者:高颜值的杀生丸
链接:https://www.cnblogs.com/LiuXinyu12378/p/11425245.html