实现一个遍历文件夹,并排列组合该文件夹下图片的py脚本

一、背景需求

因测试AI人脸识别项目,从LFW图像库中,我们要准备两组数据:

1. 3万组预期结果为“匹配”的图像集合

2. 3万组预期结果为“不匹配”的图像集合

组长让我们写个python脚本自动生成测试图像数据。简单介绍一下LFW,它是无约束自然场景人脸识别数据集,主要测试人脸识别的准确率,该数据集由13000多张全世界知名人士互联网自然场景不同朝向、表情和光照环境人脸图片组成,共有5000多人,其中有1680人有2张或2张以上人脸图片。每张人脸图片都有其唯一的姓名ID和序号加以区分。

二、整理思路

LFW数据集已经在我本地电脑上,它的结构是:一个LFW主文件夹,下面有5000多个以人名命名的子文件夹子文件夹中是对应人名的人脸图片,图片命名规则是人名_0001.jpg.

创建3万组预期匹配的图像集合的过程如下:

  1. 从LFW子文件夹中,找出那些有2张或2张以上人脸图片的文件夹
  2. 从同一人名文件夹中取出不重复的2张人脸图片
  3. 循环取3万组不重复的数据(A_0001.jpg A_0002.jpg与A_0002.jpg A_0001.jpg视为重复数据)
  4. 写入一个txt文件

创建3万组预期不匹配的图像集合的过程如下:

  1. 让LFW子文件夹两两组合(5000多个人名文件夹共有一百多万种组合)
  2. 从组合的2个子文件夹中分别随机取出1张人脸图片
  3. 重复步骤2,3万次
  4. 写入txt文件

三、遇到的问题及处理

1.如何遍历文件夹?

咨询了吴老后,知道了Python内置的os模块有个walk()方法,用于通过在目录树中游走输出在目录中的文件名。

os.walk()方法需要传入一个待遍历的文件路径,它返回的是一个三元组(root,dirs,files):

  • root 指的是当前正在遍历的这个文件夹的本身的地址
  • dirs 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
  • files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)
import os
root_path = "F:\Faker\光荣之路"
for root, dirs, files in os.walk(root_path, topdown=True):
    print(root, dirs, files)

运行结果:

2.如何实现找出那些人脸图片大于等于2的子文件夹?

根据os.walk()返回的files,我们可以通过len(files)判断该子文件夹下的图片张数是否满足条件

3.如何实现从同一个子文件中组合出不同的2张人脸图片?

咨询了组长,让我去了解下python内置itertools模块的combinations()方法。

combinations(iterable,r)创建一个迭代器,返回iterable中所有长度为r的子序列,返回的子序列中的项按输入iterable中的顺序不重复组合。

迭代器还有个优势就是延迟计算,按需使用,从而提高开发体验和运行效率。

举个例子:

>>> import itertools
>>> l=[1,2,3,4,5]
>>> type(itertools.combinations(l,2))
<class 'itertools.combinations'>
>>> for i in itertools.combinations(l,2):
...     print(i)
...
(1, 2)
(1, 3)
(1, 4)
(1, 5)
(2, 3)
(2, 4)
(2, 5)
(3, 4)
(3, 5)
(4, 5)
>>>

4.如何保证运行效率?

5749个人名,通过itertools.combinations()方法得到的人名组合很庞大,有16522626万种组合,而我们只需要3万种,所以需要加入一个判断,当我们取到3万条数据时,通过使用return语句结束遍历这个组合。

四、实现代码

import os
import itertools
import random

pic_file_name = "lfw"
root_path = os.path.join("/Users/lipan/Downloads/", pic_file_name)


# 获取文件夹名称与图片编号组成的字典
def get_file_dict(file_path):
    file_name_list = []  # 用于存放所有子文件夹名称
    file_dict = {}
    # 遍历lfw文件夹
    for root, dirs, files in os.walk(file_path, topdown=False):
        # 通过切片路径,获取子文件夹名称
        file_name = root.split(r'/')[-1]
        # 组成子文件夹(人名)列表
        if file_name not in file_name_list and file_name:
            file_name_list.append(file_name)
        file_pic = []  # 用于存放每个子文件夹下的图片编号
        for i in files:
            # 通过切片得到图片的编号,原始图片名称格式:人名_0001.jpg
            pic_num = i[-8:-4]
            file_pic.append(pic_num)
        # 若file_name不为空,避免出现'':[.DS_Store']的数据
        if file_name:
            # 组装{人名:[人名对应的图片编号0001,0002]}字典
            file_dict.update({file_name: file_pic})
    return file_dict


# 生成3万条匹配pairs
def create_match_pairs(dicts):
    match_pairs_list = []  # 用于存放匹配数据
    for key, value in dicts.items():
        # 若人名对应的人脸图片张数>=2
        if len(value) >= 2:
            # 排列组合同一人名下的人脸图片
            pairs = itertools.combinations(value, 2)
            # 遍历pairs迭代器
            for pair in pairs:
                p = "\t".join(pair)  # 组装图片编号,以\t分隔
                # 组装待写入的每一行的数据,格式为:人名+\t+人名对应的第一张图片编号+\t+人名对应的第二张图片编号
                each_line = key + "\t" + p + "\n"  
                if len(match_pairs_list) < 30000:
                    match_pairs_list.append(each_line)  # 将每一条待写入的数据存放至list中
    return match_pairs_list


# 生存3万条不匹配的pairs
def create_dismatch_pairs(dicts):
    dismatch_pairs_list = []  # 用于存放不匹配数据
    # 排列组合人名,这里需要注意:共有一千多万种组合情况,而我们只需要3万种
    name_comb = itertools.combinations(list(dicts.keys()), 2)
    # 遍历人名迭代器
    for i in name_comb:
        # 分别获取到每组中的2个人名
        name_1 = i[0]
        name_2 = i[1]
        # 组装待写入的数据,格式为:人名1+\t+人名1对应的图片编号+\t+人名2+人名2对应的图片编号
        each_line = name_1 + "\t" + random.choice(dicts[name_1]) + "\t" + name_2 + "\t" + random.choice(dicts[name_2]) + "\n"
        # 当取够3万条数据时,使用return语句结束程序,若遍历一千多万种情况,运行效率低
        if len(dismatch_pairs_list) >= 30000:
            return dismatch_pairs_list
        else:
            dismatch_pairs_list.append(each_line)


# 写入pair.txt文件
def write_pairs(match_data, not_match_data):
    with open("/Users/lipan/Downloads/%s_pairs.txt" % pic_file_name, "a", encoding="utf-8") as fp:
        fp.write("1\t30000\n")
        for i in match_data:
            fp.write(i)
        print("match pairs write successful")
        for j in not_match_data:
            fp.write(j)
        print("dismatch pairs write successful")


# 执行部分
# 生成名称:图片编号字典
files_dict = get_file_dict(root_path)
# 生成匹配pairs list
match_datas = create_match_pairs(files_dict)
# 生成不匹配pairs list
dismatch_datas = create_dismatch_pairs(files_dict)
# 写入txt文件
write_pairs(match_datas, dismatch_datas)

猜你喜欢

转载自blog.csdn.net/qq_22895113/article/details/81814489
今日推荐