绑定完请刷新页面
取消
刷新

分享好友

×
取消 复制
pytorch 使用lmdb加载数据
2022-04-15 14:40:25
1. 在pytorch中使用lmdb

lmdb数据由键值对组成,可以实现将所有键取出来,根据键去读取对应的值。lmdb的具体操作和原理不在这里说明。这里的示例使用由caffe生成的lmdb,也可以自己定义键值的形式并写入


示例:

class LmdbDataset_train(Dataset):
    def __init__(self,lmdb_path,optimizer,keys_path):
        # super().__init__()
        self.optimizer = optimizer

        self.datum=caffe_pb2.Datum()
        self.lmdb_path = lmdb_path
        keys = np.load(keys_path)
        self.keys = keys.tolist()
        self.length = len(self.keys)
    def open_lmdb(self):
        self.env = lmdb.open(self.lmdb_path, max_readers=1, readonly=True, lock=False,
                             readahead=False, meminit=False)
        self.txn = self.env.begin(buffers=True,write=False)
   
    def __getitem__(self, index):
        if not hasattr(self, 'txn'):
            self.open_lmdb()

        serialized_str = self.txn.get(self.keys[index])
        self.datum.ParseFromString(serialized_str)
        size=self.datum.width*self.datum.height
        
        pixles1=self.datum.data[0:size]
        pixles2=self.datum.data[size:2*size]
        pixles3=self.datum.data[2*size:3*size]

        image1=Image.frombytes('L', (self.datum.width, self.datum.height), pixles1)
        image2=Image.frombytes('L', (self.datum.width, self.datum.height), pixles2)
        image3=Image.frombytes('L', (self.datum.width, self.datum.height), pixles3)

        img=Image.merge("RGB",(image3,image2,image1))

        img =self.optimizer(img)

        label=self.datum.label
        return img, label

    def __len__(self):
        return self.length


注意lmdb.open()部分不能放在初始化中,否则会报错。

2. 使用两个数据集

场景:有两个lmdb数据集,而且很大,进行合并比较耗时,此时可以分别加载然后训练时叠加使用。

'''
train_data1,train_data2为 lmdb路径
keys_path_train1,keys_path_train2 为两个数据的键的npy文件路径
'''
training_data1 = LmdbDataset_train(train_data1,transform,keys_path_train1)
training_data2 = LmdbDataset_train(train_data2,transform,keys_path_train2)


training_data = training_data1 + training_data2
train_loader = torch.utils.data.DataLoader(training_data,
                                               batch_size=opt.batch_size,
                                               shuffle=False,
                                               num_workers=opt.n_threads,
                                               pin_memory=True,
                                               sampler=train_sampler)


假设数据1的长度为1000 数据2的长度为10000。则training_data 的长度为11000


3. 使用两个dataloader

场景:有一大一小两个数据集,想要在训练过程中每个iteration里在两个数据集中各取一部分数据(这里各取一半)。

区别上面的方法,上面是按数据集顺序读取,这里是每个batch中在两个数据集中各取一部分。


假设设置batchsize为128,则在两个数据集中各取64个样本(batchsize为64)

假设training_data1为


training_data1 = LmdbDataset_train(train_data1,transform,keys_path_train1)
training_data2 = LmdbDataset_train(train_data2,transform,keys_path_train2)

train_loader1 = torch.utils.data.DataLoader(training_data1,
                                           batch_size=64,
                                           shuffle=False,
                                           num_workers=opt.n_threads,
                                           pin_memory=True,
                                           worker_init_fn=worker_init_fn,
                                           sampler=train_sampler1)
train_loader2 = torch.utils.data.DataLoader(training_data2,
                                           batch_size=64,
                                           shuffle=False,
                                           num_workers=opt.n_threads,
                                           pin_memory=True,
                                           worker_init_fn=worker_init_fn,
                                           sampler=train_sampler2)

由于数据集大小不同,在训练时,可以按照不同的策略进行取样。

  1. 以小数据集为基准,即当把小数据集的数据取完时,结束当前epoch。
'''
每个epoch中训练时加载数据及标签的过程
'''

for i,(data1,data2) in enumerate(zip(data_loader1,data_loader2)):

    inputs1,targets1 = data1
    inputs2,targets2 = data2
    
    inputs = torch.cat((inputs1,inputs2),0)
    targets = torch.cat((targets1,targets2),0)
 
  1. 以大数据集为基准,当小数据取完时从头开始重新取,直到大数据集取完。
dataloader_iterator = iter(data_loader2)
for i, data1 in enumerate(data_loader1):

  try:  
    data2 = next(dataloader_iterator)
  except StopIteration:
    dataloader_iterator = iter(data_loader2)
    data2 = next(dataloader_iterator)

  inputs1,targets1 = data1
  inputs2,targets2 = data2

  inputs = torch.cat((inputs1,inputs2),0)
  targets = torch.cat((targets1,targets2),0)


注意 网上有提到在小数据上循环时可以使用cycle代替zip。 实测 发现这种方法会导致内存泄露

4. 按照权重取样

数据集中样本数量不均衡时可以采取按照权重取样的方法进行数据均衡。使用pytorch中的WeightedRandomSampler方法实现,原理不在此处说明。


此方法的关键是给出数据中每个类别所占的权重(数量少的给一个大权重,数量多的权重小),然后将各类的权重分配给每个样本,后按照每个样本的权重进行取样。

其中,根据每个样本的权重进行取样的过程使用sampler实现。


有两种方法可以实现每个样本所占权重的计算,先给出代码再比较异同。

方法一:

def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1        # 统计每类的数量
    print('数量统计完成')                                          
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))          # 计算所有样本总数                                      
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight     

方法2 :

def make_weights_for_balanced_classes(txt, nclasses):          
    with open(txt,'r') as f:
        lines = f.readlines()

    count = [0] * nclasses                                                      
    for item in lines: 
        
        label = int(item.strip().split('\t')[1])
        count[label] += 1        # 统计每类的数量
        
    print('数量统计完成')                                       
    weight_per_class = [0.] * nclasses                                      
    # N = float(sum(count))          # 计算所有样本总数                                      
    for i in range(nclasses): 
        try:                                         
            weight_per_class[i] = 1/float(count[i])  # 每一类的权重
        except:
            continue
                            
    weight = [0] * len(lines)    
                       
    for idx, val in enumerate(lines): 
        label = int(val.strip().split('\t')[1])                          
        weight[idx] = weight_per_class[label]                           
    return weight     

方法1中的image 是实例化的torch.utils.data.Dataset . 方法2中的txt是保存了所有样本标签的txt文件。两种方法的原理都是按顺序读取数据集中每个样本的标签,然后计算权重。很明显方法2会更快,但前提是事先按顺序保存了数据集中样本的标签。两种方法等价的条件是image和txt每个位置(index)是同一个样本。


得到权重后就可以按照权重取样了

training_data = LmdbDataset_train(opt.train_data1,transform,opt.keys_path_train1)

weights = make_weights_for_balanced_classes(opt.train_txt, 27249)

weights = torch.DoubleTensor(weights)

assert len(training_data) == len(weights)

train_sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights))

train_loader = torch.utils.data.DataLoader(training_data,
                                               batch_size=opt.batch_size,
                                               shuffle=False,
                                               num_workers=opt.n_threads,
                                               pin_memory=True,
                                               worker_init_fn=worker_init_fn,
                                               sampler=train_sampler)

来源 https://zhuanlan.zhihu.com/p/374875094
分享好友

分享这个小栈给你的朋友们,一起进步吧。

LMDB
创建时间:2022-04-15 14:36:38
LMDB
展开
订阅须知

• 所有用户可根据关注领域订阅专区或所有专区

• 付费订阅:虚拟交易,一经交易不退款;若特殊情况,可3日内客服咨询

• 专区发布评论属默认订阅所评论专区(除付费小栈外)

技术专家

查看更多
  • itt0918
    专家
戳我,来吐槽~