Flask上下文管理源码分析

引出的问题

Flask如何使用上下文临时把某些对象变为全局可访问

首先我们做如下的几种情况的假设

情况一:单进程单线程

这种情况可以基于全局变量存储临时的对象

情况二:单进程多线程

这种情况会出现多个线程共享全局的变量,为了每个线程中的数据不被其他线程修改,可以借助hreading.local对象,为每个线程做唯一的表示用来做键,请求的对象作为值来实现

多线程共享数据的问题

import threading
class Foo(object):
    def __init__(self):
        self.name = 0

local_values = Foo()

def func(num):
    local_values.name = num
    import time
    time.sleep(1)
    print(local_values.name, threading.current_thread().name)


for i in range(20):
    th = threading.Thread(target=func, args=(i,), name='线程%s' % i)
    th.start()

我们可以看到最后把每个线程中对象中name值都变为了19,不能保证每个线程中对象中的值唯一

使用hreading.local对象可以对每个线程做唯一的表示可以解决上述的问题

import threading

local_values = threading.local()

def func(num):
    local_values.name = num
    import time
    time.sleep(1)
    print(local_values.name, threading.current_thread().name)


for i in range(20):
    th = threading.Thread(target=func, args=(i,), name='线程%s' % i)
    th.start()

可以看到每个线程中的值唯一

- 情况三:单进程单线程(多个协程)Flask 的上下文管理就是基于这种情况做的

 在这种情况下使用上面的方法可以保证线程中的数据唯一,但是使用其内部创建多个协程后,hreading.local只能对线程作唯一的标示,协程是在单线程下切换的,所以多个协程还会出现共享数据的问题

解决的思路:为每个程做唯一的标示,我们可以通过python自带的greenlet模块中的getcurrent来实现

只需对上面的代码做简单的修改即可

import threading
try:
    from greenlet import getcurrent as get_ident # 协程
except ImportError:
    try:
        from thread import get_ident
    except ImportError:
        from _thread import get_ident # 线程


class Local(object):
    def __init__(self):
        self.storage = {}
        self.get_ident = get_ident

    def set(self,k,v):
        ident = self.get_ident()
        origin = self.storage.get(ident)
        if not origin:
            origin = {k:v}
        else:
            origin[k] = v
        self.storage[ident] = origin

    def get(self,k):
        ident = self.get_ident()
        origin = self.storage.get(ident)
        if not origin:
            return None
        return origin.get(k,None)

local_values = Local()


def task(num):
    local_values.set('name',num)
    import time
    time.sleep(1)
    print(local_values.get('name'), threading.current_thread().name)


for i in range(20):
    th = threading.Thread(target=task, args=(i,),name='线程%s' % i)
    th.start()

测试的结果如下

使用面向对象中方法对其进行简单的优化

在初始化的时候设置属性的时候,为了避免循环引用,我们可以这样做  object.__setattr__(self, 'storage', {})

class Foo(object):

    def __init__(self):
        object.__setattr__(self, 'storage', {})
        # self.storage = {}

    def __setattr__(self, key, value):
        self.storage = {'k1':'v1'}
        print(key,value)

    def __getattr__(self, item):
        print(item)
        return 'df'


obj = Foo()

# obj.x = 123
# 对象.xx

修改后的代码如下所示 

import threading
try:
    from greenlet import getcurrent as get_ident # 协程
except ImportError:
    try:
        from thread import get_ident
    except ImportError:
        from _thread import get_ident # 线程


class Local(object):

    def __init__(self):
        object.__setattr__(self, '__storage__', {})
        object.__setattr__(self, '__ident_func__', get_ident)


    def __getattr__(self, name):
        try:
            return self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        ident = self.__ident_func__()
        storage = self.__storage__
        try:
            storage[ident][name] = value
        except KeyError:
            storage[ident] = {name: value}

    def __delattr__(self, name):
        try:
            del self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)


local_values = Local()


def task(num):
    local_values.name = num
    import time
    time.sleep(1)
    print(local_values.name, threading.current_thread().name)


for i in range(20):
    th = threading.Thread(target=task, args=(i,),name='线程%s' % i)
    th.start()

猜你喜欢

转载自www.cnblogs.com/crazymagic/p/9589351.html
今日推荐