python 多线程下载数据集

最近做项目需要多线程下载数据,就简单学了一下。

先是参考了别人的代码,修改后亲测可用。

原文链接:https://junyiseo.com/python/211.html

# -*- coding: UTF-8 -*-
 
import threading
from time import sleep,ctime
 
class myThread (threading.Thread):
    def __init__(self, threadID, name, s , e):
        threading.Thread.__init__(self)
        self.threadID = threadID
        self.name = name
        self.s = s
        self.e = e
    def run(self):
        print ("Starting " + self.name+ctime())
       # 获得锁,成功获得锁定后返回True
       # 可选的timeout参数不填时将一直阻塞直到获得锁定
       # 否则超时后将返回False
        threadLock.acquire()
        #线程需要执行的方法
        printImg(self.s,self.e)
        # 释放锁
        threadLock.release()
 
listImg = [] #创建需要读取的列表,可以自行创建自己的列表
for i in range(179):
    listImg.append(i)
 
# 按照分配的区间,读取列表内容,需要其他功能在这个方法里设置
def printImg(s,e):
    for i in range(s,e):
        print (i)
 
 
totalThread = 3 #需要创建的线程数,可以控制线程的数量
 
lenList = len(listImg) #列表的总长度
gap = lenList // totalThread #列表分配到每个线程的执行数
 
threadLock = threading.Lock() #锁
threads = [] #创建线程列表
 
# 创建新线程和添加线程到列表
for i in range(totalThread):
    thread = 'thread%s' % i
    if i == 0:
        thread = myThread(0, "Thread-%s" % i, 0,gap)
    elif totalThread==i+1:
        thread = myThread(i, "Thread-%s" % i, i*gap,lenList)
    else:
        thread = myThread(i, "Thread-%s" % i, i*gap,(i+1)*gap)
    threads.append(thread) # 添加线程到列表
 
 
# 循环开启线程
for i in range(totalThread):
    threads[i].start()
 
 
# 等待所有线程完成
for t in threads:
    t.join()
print ("Exiting Main Thread")


然后参考这个代码,我写出了我多线程下载数据集的代码,如下:

# -*- coding: UTF-8 -*-
 
import threading
from time import sleep,ctime


class myThread (threading.Thread):
    def __init__(self, threadID, name, s , e, encoding, path):
        threading.Thread.__init__(self)
        self.threadID = threadID
        self.name = name
        self.s = s
        self.e = e
        self.encoding = encoding
        self.path = path
        self.example = []
       
       
    def run(self):
    
        print ("Starting " + self.name + "  " + ctime())
        
        for line in open(self.path+"/"+"labels.txt",'r',encoding=self.encoding).readlines()[self.s:self.e]:  
            strlabel = line.split(' ')[1].strip('\n')
            if len(strlabel) > max_char_count:
                continue
        
            try:
        
                arr, initial_len = resize_image(
                    os.path.join(self.path, line.split(' ')[0]),
                    max_image_width
                )
        
            except(OSError, NameError):
                print('OSError, Path:',os.path.join(self.path, line.split(' ')[0]))
                continue
        
            #print(label_to_array(strlabel))
            self.example.append(
                (
                    arr,
                    label_to_array(strlabel)
                )
            )
             
        print ("Stoping " + self.name + "  " + ctime())




def load_data():
    
    print('Loading data')
      
    threadNums = [32, 32, 16, 8, 8]
    picNums = [400000, 400000, 150000, 80000, 70000]
    encodings = ['UTF-8', 'UTF-8', 'gbk', 'gbk', 'gbk']
    paths = [examples_path_1, examples_path_2, examples_path_3, examples_path_4, examples_path_5]
    gaps = []
    threads = [] 
    examples = []
    
    totalThreadCount = 0
    
    for i in range (len(threadNums)):
        gap = picNums[i] // threadNums[i]
        gaps.append(gap)
   
    
    for i in range (len(picNums)):
        for j in range(threadNums[i]):
            thread = myThread(i*100+j, "Thread-%s-%s" % (i, j), gaps[i]*j, gaps[i]*(j+1), encodings[i], paths[i])
            threads.append(thread)
            totalThreadCount += 1
            print(paths[i], gaps[i]*j, gaps[i]*(j+1))

    for i in range(totalThreadCount):
        threads[i].start()

    for t in threads:
        t.join()
        
    for t in range(len(threads)):
        threadExample = threads[t].example
        examples = examples + threadExample
        print (t, threads[t].threadID, len(threads[t].example), len(examples))
    print ("Exiting Main Thread")
    
    print(len(examples))

    random.shuffle(examples)
    return examples, len(examples)

猜你喜欢

转载自blog.csdn.net/qq_30534935/article/details/102806742