I have a huge dataset that does not fit in memory (150G) and I'm looking for the best way to work with it in pytorch. The dataset is composed of several .npz files of 10k samples each. I tried to build a Dataset class
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.files = os.listdir(self.path)
self.file_length = {}
for f in self.files:
# Load file in as a nmap
d = np.load(os.path.join(self.path, f), mmap_mode='r')
self.file_length[f] = len(d['y'])
def __len__(self):
raise NotImplementedException()
def __getitem__(self, idx):
# Find the file where idx belongs to
count = 0
f_key = ''
local_idx = 0
for k in self.file_length:
if count < idx < count + self.file_length[k]:
f_key = k
local_idx = idx - count
break
else:
count += self.file_length[k]
# Open file as numpy.memmap
d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
# Actually fetch the data
X = np.expand_dims(d['X'][local_idx], axis=1)
y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
return X, y
but when a sample is actually fetched, it takes more than 30s. It looks like the entire .npz is opened, stocked in RAM and it accessed the right index.
How to be more efficient ?
EDIT
It appears to be a misunderstading of .npz files see post, but is there a better approach ?
SOLUTION PROPOSAL
As proposed by @covariantmonkey, lmdb can be a good choice. For now, as the problem comes from .npz files and not memmap, I remodelled my dataset by splitting .npz packages files into several .npy files. I can now use the same logic where memmap makes all sense and is really fast (several ms to load a sample).