常用操作
# -*- coding: utf-8 -*-
import lmdb
# map_size定义最大储存容量,单位是kb,以下定义1TB容量
env = lmdb.open("./train", map_size=1099511627776)
txn = env.begin(write=True)
# 添加数据和键值
txn.put(key = '1', value = 'aaa')
txn.put(key = '2', value = 'bbb')
txn.put(key = '3', value = 'ccc')
# 通过键值删除数据
txn.delete(key = '1')
# 修改数据
txn.put(key = '3', value = 'ddd')
# 通过commit()函数提交更改
txn.commit()
env.close()
合并操作
训练时可能涉及到很多个数据集,如果不打算写成一个lmdb,在下面合并的时候不要用lmdb.commit()函数,否则会直接在原文件写入合并后的结果
lmdb_paths = [
f'/data/lmdb1',
f'/data/lmdb2',
f'/data/lmdb3',
]
lmdb_path = lmdb_paths[0]
# Assert LMDB directories exist
lmdb_pairs = os.path.join(lmdb_path, "pairs")
lmdb_imgs = os.path.join(lmdb_path, "imgs")
# Open LMDB files
env_pairs = lmdb.open(lmdb_pairs, readonly=False, create=False, lock=False, readahead=False, meminit=False)
env_pairs.set_mapsize(1024**4) #需要重新设置大小
txn_pairs = env_pairs.begin(write=True) #写入模式
env_imgs = lmdb.open(lmdb_imgs, readonly=False, create=False, lock=False, readahead=False, meminit=False)
env_imgs.set_mapsize(1024**4)
txn_imgs = env_imgs.begin(write=True)
number_samples = int(txn_pairs.get(key=b"num_samples"))
number_images = int(txn_imgs.get(key=b"num_images"))
print(f"init {mode}, number_samples:{number_samples}, number_images:{number_images}")
for path in lmdb_paths[1:]:
_lmdb_pairs = os.path.join(path, "pairs")
_lmdb_imgs = os.path.join(path, "imgs")
# Open LMDB files
_env_pairs = lmdb.open(_lmdb_pairs, readonly=False, create=False, lock=False, readahead=False, meminit=False)
_txn_pairs = _env_pairs.begin(buffers=True)
_pairs_database = _txn_pairs.cursor()
_env_imgs = lmdb.open(_lmdb_imgs, readonly=False, create=False, lock=False, readahead=False, meminit=False)
_txn_imgs = _env_imgs.begin(buffers=True)
_imgs_database = _txn_imgs.cursor()
_number_images = int(_txn_imgs.get(key=b"num_images"))
_number_samples = int(_txn_pairs.get(key=b"num_samples"))
for idx, (key, value) in tqdm(enumerate(_pairs_database), desc=f"[{path}]-text"):
if 'num_' not in str(key.tobytes()):
new_key = str(number_samples + int(key.tobytes())).encode('utf-8')
txn_pairs.put(new_key, value)
for idx, (key, value) in tqdm(enumerate(_imgs_database), desc=f"[{path}]-image"):
if 'num_' not in str(key.tobytes()):
txn_imgs.put(key, value)
number_samples += _number_samples
number_images += _number_images
_env_pairs.close()
_env_imgs.close()
txn_pairs.put(b'num_samples', str(number_samples).encode('utf-8'))
txn_imgs.put(b'num_images', str(number_images).encode('utf-8'))
print(f'merge {mode}, num_samples:{int(txn_pairs.get(key=b"num_samples"))}, num_images:{int(txn_imgs.get(key=b"num_images"))}')