forked from DLLXW/baby-llama2-chinese
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
60 lines (55 loc) · 1.9 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import random
import pandas as pd
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch
from sklearn.model_selection import train_test_split
class PretrainDataset(Dataset):
def __init__(self,data_path_lst,max_length=256,memmap=False):
super().__init__()
#
if memmap:
with open(data_path_lst[0],'r') as f:
nbytes = f.seek(0,2)
flen = f.tell() // np.dtype('uint16').itemsize
self.data = np.memmap(data_path_lst[0],dtype=np.dtype('uint16'),shape=(flen//max_length,max_length))
else:
data_lst=[]
for data_path in data_path_lst:
with open(data_path,'rb') as f:
data=np.fromfile(f,dtype=np.uint16)
data_lst.append(data)
data = np.concatenate(data_lst)
data = data[:max_length*int(len(data)/max_length)]
#np.random.shuffle(data)
self.data = data.reshape(-1,max_length)
#
print("memmap:{} train data.shape:{}".format(memmap,self.data.shape))
print("downloading finished.....")
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index: int):
#
sample = self.data[index]
X=np.array(sample[:-1]).astype(np.int64)
Y=np.array(sample[1:]).astype(np.int64)
return torch.from_numpy(X),torch.from_numpy(Y)
#
if __name__=="__main__":
data_path_lst=[
'./data/diagnosis/train.csv'
]
train_ds = PretrainDataset(data_path_lst, max_length=256)
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=2,
pin_memory=False,
drop_last=False,
shuffle=False,
num_workers=0,
)
for i, (X, Y) in enumerate(train_loader):
print(X.shape,Y.shape)
print(X[0])
print(Y[0])
break