python多进程实践

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/qq_36810544/article/details/102514411

由于训练集样本数量实在太少了,同事在网上爬了550W+的图片URL,仅仅存url的txt文件就500M了,辣么问题来了,如果下载一个url保存一次,不知道到何年月能把这550W的图片全部down下来了。解决思想就是,多进程同时去down,让cpu、网卡、硬盘并行工作,但是考虑到磁盘I/O的效率超级慢,所以就攒一批后集中写入。硬件配置I7-7700 4核8线程,32G内存,总共用时大概4小时30分钟的样子,url超时设置1.5s(最开始没设置,总有进程卡死,导致程序没法继续运行!隐藏的大坑啊~),共下510W张图片。
接下来,对510W张图片进行去重和过滤掉尺寸小于100的图片(down的时候,把这事忘了,晕),思想同上,但是考虑到去重的时候需要查找,python的set()和dict()的key后台应该是采用了红黑树查找效率高,list()是线性结构,对于这种上百万数据量的需求就不考虑了。过滤510W+的图片,总共用时3个小时不到,过滤出7.1W张图片,效率在可接受范围内。代码没仔细优化,而且就是用一次的小工具,能跑就行,有点乱!

========================== 工作使我快乐==========================
代码中的download()用于下载,filter()用于过滤,使用requests.get()下载图片,imagehash计算图片的hash来去重

import datetime
import requests
from tqdm import tqdm
from PIL import Image
from multiprocessing import Process, Manager, Lock
import copy
import json
import time
import numpy as np
import imagehash
import shutil

sum = 0


def _save_img(url_str, fail_list):
    data_dir = r'D:\bdtb'
    nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    url_list = json.loads(url_str)
    guid = id(url_list)
    print(nowTime + ': process id {0} start save images...'.format(guid))
    img_list = []
    for index, url in enumerate(url_list):
        try:
            img = Image.open(requests.get(url, stream=True, timeout=1.5).raw)
            img = img.convert('RGB')
            img_list.append(img)
            if index % 2000 == 0:
                print('{0}: img_list length: {1}'.format(guid, len(img_list)))
                head = 'bdtb_' + datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
                img_copy = copy.copy(img_list)
                for i, img in enumerate(img_copy):
                    name = head + '_{0:05}.jpg'.format(i)
                    img.save(os.path.join(data_dir, name), quality=90)
                img_list.clear()
                nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                print('{0}: {1} have download {2} images.'.format(nowTime, guid, index))
        except Exception as e:
            fail_list.append(url)
            print('{0} have failed'.format(len(fail_list)))
            continue
    nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(nowTime + ': {0} save images success!'.format(guid))


def download():
    # 通过url下载图像,共550W条数据
    block_num = 20000
    task = []
    process_list = []
    manger = Manager()
    fail_list = manger.list()
    nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print('start: ' + nowTime)
    with open("url.txt") as f:
        for url in f:
            task.append(url.strip())
            sum += 1
            if sum % block_num == 0:
                task_copy = copy.copy(task)
                p = Process(target=_save_img, args=(json.dumps(task_copy), fail_list))
                process_list.append(p)
                task.clear()

    task_copy = copy.copy(task)
    p = Process(target=_save_img, args=(json.dumps(task_copy),))
    process_list.append(p)

    nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(nowTime + ': read file end!')
    print('start download images...')
    _start_process(process_list)
    print('{0} images failed!'.format(len(fail_list)))
    with open('faile.txt', 'w') as sf:
        for url in fail_list:
            sf.write(url + '\n')
    print('success!')

def _start_process(process_list):
    process_sum = len(process_list)
    start = 0
    end = 0

    for i in range(8, process_sum, 8):
        end = i
        start_time = time.clock()
        for p in process_list[start:end]:
            p.start()

        for p in process_list[start:end]:
            p.join()

        end_time = time.clock()
        print('{0}:{1} cost {2}s'.format(start, end, end_time - start_time))
        start = end

    if end < process_sum:
        for p in process_list[end:process_sum]:
            p.start()

        for p in process_list[end:process_sum]:
            p.join()

def _filter_process(path_list):
    start = time.clock()
    local_dic = {}
    guid = id(local_dic)
    for path in path_list:
        img = Image.open(path)
        min_size = np.min(img.size)
        if min_size > 99:
            hash_value = str(imagehash.phash(img))
            local_dic[hash_value] = path
    print('{0}: receive {1} images, output {1} images'.format(guid, len(path_list), len(local_dic)))

    for key in local_dic.keys():
        src_path = local_dic[key]
        des_path = src_path.replace('bdtb', 'bdtb_filter')
        shutil.copyfile(src_path, des_path)

    end = time.clock()
    print('{0}: process {1} images cost {2}s'.format(guid, len(path_list), end - start))

def filter():
    # 图片过滤,去重 去分辨率太小的照片
    manger = Manager()
    src_dir = r'E:\code\data\bdtb'
    branch_size = 20000
    process_list = []
    start = time.clock()
    path_list = manger.list([os.path.join(src_dir, name) for name in os.listdir(src_dir)])
    end = time.clock()
    print('read file list cost: {0}s'.format(end - start))

    start_index = 0
    end_index = 0
    for i in range(branch_size, len(path_list), branch_size):
        end_index = i
        p = Process(target=_filter_process, args=(path_list[start_index:end_index],))
        process_list.append(p)
        start_index = end_index

    if end_index < len(path_list):
        p = Process(target=_filter_process, args=(path_list[end_index:],))
        process_list.append(p)

    nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print('{0}:start images process...!'.format(nowTime))
    _start_process(process_list)
    nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print('{0}:success!'.format(nowTime))

if __name__ == '__main__':
    filter()

猜你喜欢

转载自blog.csdn.net/qq_36810544/article/details/102514411