【python】Python闭包(closure),装饰器(Decorator)和注册机制(Registry)的学习笔记。


前言

最近,接触到了MMCV框架,发现MMCV框架为了方便更换backbone,优化器,学习策略等功能模块,引入了一种注册机制(Registry)的方法,可以有效的管理深度学习框架中的内容。同时,也方便用户通过外部接口灵活的搭配来组建自己的网络。为了深入了解注册机制的原理,我对此进行了学习和整理,并分享给读者朋友。

本文主要分为四个部分:
第一部分是简单介绍下python的闭包原理,闭包是装饰器的基础。
第二部分是简单介绍下python的装饰器原理,注册机制是装饰器的应用场景。
第三部分是简单介绍注册机制,并附上python代码示例。
第四部分是简单介绍MMCV框架注册机制代码,并附上相关代码注释。

由于我学的浅显,如果理解和您有偏差,则望各位大佬及时指出。最后如果您觉得对您有帮助的话,可以给小弟一个赞。 ⌣ ¨ \ddot\smile ¨

参考资料如下所示:
1.Python 函数装饰器|菜鸟教程
2.理解闭包的概念,作者:alpha_panda。
3.Python3 命名空间和作用域
4.mmsegment定制模型(八),作者:alex1801。
5. Registry注册机制,作者:~HardBoy~。
6.理解 Python 装饰器看这一篇就够了,作者:刘志军。

1. 前提概念

在开始学习之前,需要提前了解2个概念,如下所示。
(1)对于python而言,一切皆是对象,包括函数自己也是对象。因此,函数是可以赋值给变量,并通过变量进行函数调用的。(函数也是可以赋值给函数的)

def base():
    print("这是基础函数")
def derived():
    print("这是派生函数")
    
# 以变量的形式调用
Var = base
Var()
# 另外,函数也是可以赋值给函数的,以函数的形式调用
derived = base
derived()

在这里插入图片描述
基于上述概念,可以得出函数其实也是可以作为参数传给新函数的

def base():
    print("这是基础函数")
def derived(func):
    func()
    print("这是派生函数")
    
# 以函数的形式调用
derived(base)

在这里插入图片描述
(2)外部函数之中可以定义内部函数,并返回内部函数。

def outsideFunc():
    print("这是外部函数")
    # 在派生函数内定义基本函数
    def insideFunc():
        print("这是内部函数")
    return insideFunc

Var = outsideFunc()
print(Var)
Var()

在这里插入图片描述

2. python的闭包(closure)

本文关于闭包的介绍相对简单,具体深入细节可以参考博客:理解闭包的概念,解释的非常详细,还列出了闭包函数容易错误的点,本文重点是注册函数,因此仅是简单介绍。

概念: 一个函数可以引用作用域外的变量,同时也可以在函数内部进行使用的函数,称之为闭包函数。而这个作用域外的变量,称之为自由变量。这个自由变量不同于C++的static修饰的静态变量,也不同于全局变量,它是和闭包函数进行绑定的。

注意两点:
1.闭包函数与闭包函数之间,自由变量不会相互干扰。
2.闭包函数的自由变量会传到下一次闭包函数中。

def outside():
    # 此处obj为自由变量
    obj = []
    # inside就被称为闭包函数
    def inside(name):
        obj.append(name)
        print(obj)
    return inside
    
Lis = outside()
Lis("1")
# 对闭包函数自由变量的修改会传到下一次闭包函数中。
Lis("2")
Lis("3")
print("-"*20)
# 闭包函数和闭包函数之间的自由变量不会相互影响。
Lis2 = outside()
Lis2("4")
Lis2("5")
Lis2("6")

在这里插入图片描述

从上例中可以看出,obj是针对inside闭包函数的自由变量。每次调用inside闭包函数时,都会给obj添加一个新值,并且会传到下一次inside闭包函数的调用中。
此处,obj自由变量相对于inside闭包函数的作用相当于一个“全局变量”,只不过作用域仅是针对于包含自由变量和闭包函数的外部环境而言,上例中指的是outside的作用域。python将这种作用域,称为闭包函数外的作用域(Enclosing)。

补充知识:
python有四种作用域分别是:
局部作用域(Local):最内层,包含局部变量,比如一个函数/方法内部。
闭包函数外的作用域(Enclosing):包含了非局部(non-local)也非全局(non-global)的变量。比如两个嵌套函数,一个函数(或类) A 里面又包含了一个函数 B ,那么对于 B 中的名称来说 A 中的作用域就为 nonlocal。
全局作用域(Global):当前脚本的最外层,比如当前模块的全局变量。
内建作用域(Built-in):包含了内建的变量/关键字等,最后被搜索。
规则顺序如下图所示:
在这里插入图片描述

如上图所示,python在局部作用域找不到对应的函数变量,便会去局部外的局部找(例如闭包),再找不到就会去全局作用域找,再者去内建作用域中找。此处更详尽的内容可参考博客Python3 命名空间和作用域

3. python装饰器(Decorator)

了解闭包原理后,再来看装饰器的本质。装饰器其实就是通过闭包原理来封装函数,并返回函数的高级函数。封装的目的是为了在保持原有函数功能的基础上,扩充额外的功能。具体可以参考如下示例。

def add(arr1, arr2):
    return arr1 + arr2
def decorator(func):
    def wrapper(arr1, arr2):
        print("实现新的功能,例如数据加1的功能")
        arr1 += 1
        arr2 += 1
        return func(arr1, arr2)
    return wrapper

print(add(2, 3))

add = decorator(add)
print(add(2, 3))

在这里插入图片描述

从代码中可以看到,我们希望在保持数据相加功能的基础上,再添加一个参数自加一的新功能。如果按照往常,我们需要在之前的函数内添加新的内容,这样的话就会破坏原有的结构。因此,在既需要保留原有函数结构的基础上,额外的添加新功能,就可以通过装饰器来进行实现。上述decorator就是一个装饰器函数,下面add = decorator(add)就是他的调用方式,不过在python中,可以使用@语法来简化调用语句。如下所示:

def decorator(func):
    def wrapper(arr1, arr2):
        print("实现新的功能,例如数据加1的功能")
        arr1 += 1
        arr2 += 1
        return func(arr1, arr2)
    return wrapper

@decorator
# @decorator就等价于add = decorator(add)
def add(arr1, arr2):
    return arr1 + arr2

print(add(2, 3))

可以看到@decorator就等价于add = decorator(add),”@“会把它修饰的函数作为参数传给装饰器函数

补充一个小点:
在经过装饰器函数后,其实add函数的内容已经变成wrapper函数了,所以它的其他相关描述内容也变成了wrapper的描述内容,例如__name__。因此,如果想要不修改原有描述内容的话,可以借助functools函数包中的装饰函数@wraps(func)来实现。这个函数可以复制原有函数的描述内容。

# 不加@wraps的情况下
def decorator(func):
    def wrapper(arr1, arr2):
        print("实现新的功能,例如数据加1的功能")
        arr1 += 1
        arr2 += 1
        return func(arr1, arr2)
    return wrapper
def add(arr1, arr2):
    return arr1 + arr2
#原本add函数
print(add.__name__)
#通过装饰器修饰后
add = decorator(add)
print(add.__name__)

在这里插入图片描述

from functools import wraps
# 添加@wraps的情况下
def decorator(func):
    @wraps(func)
    #@wraps(func) 等同于 wrapper = wraps(func)(wrapper)
    def wrapper(arr1, arr2):
        print("实现新的功能,例如数据加1的功能")
        arr1 += 1
        arr2 += 1
        return func(arr1, arr2)
    return wrapper

def add(arr1, arr2):
    return arr1 + arr2
    
#原本add函数
print(add.__name__)
#通过装饰器修饰后
add = decorator(add)
print(add.__name__)

在这里插入图片描述

4. 注册机制(Registry)

概念: 注册机制主要是实现用户输入的字符串到所需函数或者类的映射,方便项目管理和用户使用。注册机制可以通过python装饰器来构建映射关系。比如:MMCV也是通过装饰器的方法来完成的。

完成注册机制主要有三个步骤:
(1)编写注册机制类。
(2)实例化一个注册机制的对象,即构建注册表。
(3)通过装饰器原理来往注册表添加内容,即实现内容注册

4.1 编写注册机制的类。

class Registry:
    def __init__(self, name=None):
        # 生成注册列表的名字, 如果没有给出,则默认是Registry。
        if name == None:
            self._name = "Registry"
        self._name = name
        #创建注册表,以字典的形式。
        self._obj_list = {
    
    }

    def __registry(self, obj):
        """
        内部注册函数
        :param obj:函数或者类的地址。
        :return:
        """
        #判断是否目标函数或者类已经注册,如果已经注册过则标错,如果没有则进行注册。
        assert(obj.__name__ not in self._obj_list.keys()), "{} already exists in {}".format(obj.__name__, self._name)
        self._obj_list[obj.__name__] = obj

    def registry(self, obj=None):
        """
        # 外部注册函数。注册方法分为两种。
        # 1.通过装饰器调用
        # 2.通过函数的方式进行调用

        :param obj: 函数或者类的本身
        :return:
        """
        # 1.通过装饰器调用
        if obj == None:
            def _no_obj_registry(func__or__class, *args, **kwargs):
                self.__registry(func__or__class)
                # 此时被装饰的函数会被修改为该函数的返回值。
                return func__or__class
                                                
            return _no_obj_registry
        #2.通过函数的方式进行调用
        self.__registry(obj)

    def get(self, name):
        """
        通过字符串name获取对应的函数或者类。
        :param name: 函数或者类的名称
        :return: 对应的函数或者类
        """
        assert (name in self._obj_list.keys()), "{}  没有注册".format(name)
        return self._obj_list[name]

这个注册机制类主要包含三个成员函数,分别是__registry,registry,get和2个成员变量self._name和self._obj_list。

成员变量:
1.self._name变量:表示这个注册表的名称,如果没有给予,则默认为Registry。
2.self._obj_list变量:以字典的形式表示注册表,即字符串与对应函数名的映射关系。
成员函数:
1.registry函数:以两种方式对传入的函数进行注册,一种是通过闭包函数 _no_obj_registry,来实现对自由变量self._obj_list的修改。另一种则是直接通过传入registry函数的obj的参数完成注册。这两种方式具体实现功能是通过__registry函数来实现注册的
2.__registry函数:对传入进来的函数参数进行注册,如果存在则报错,如果不存在则完成注册。
3.get函数:通过查找注册表,实现从字符串name到对应函数或类名称的映射,返回对应名称的函数或类。

4.2 创建注册表

基于注册机制类,来实例化一个对象,这个对象就是我们需要的注册表。

# 生成注册表
REGISTRY_LIST = Registry("REGISTRY_LIST")

4.3 内容注册

通过装饰器原理来往注册表添加内容,即实现内容注册。在下例中,通过语句@REGISTRY_LIST.registry()来实现对create_by_decorator函数的注册。

@REGISTRY_LIST.registry()等价于
test_by_decorator = REGISTRY_LIST.registry()(test_by_decorator),
即_no_obj_registry(test_by_decorator)

# 通过装饰器调用
@REGISTRY_LIST.registry()
# @REGISTRY_LIST.registry()等价于test_by_decorator = REGISTRY_LIST.registry()(test_by_decorator),即_no_obj_registry(test_by_decorator)
def create_by_decorator():
    print("通过装饰器完成注册的函数")


def create_by_function():
    print("直接通过registry函数进行注册")
#当然也可以直接通过传入registry函数进行注册。
REGISTRY_LIST.registry(create_by_function)

#通过字符串来获取对应函数名称的函数
test1 = REGISTRY_LIST.get("create_by_decorator")
test1()
test2 = REGISTRY_LIST.get("create_by_function")
test2()

在这里插入图片描述

5. MMCV的注册机制

由于本人是在windows下运行mmcv框架的。因此,我的registry.py文件的文件路径是F:\SegFormer-master\mmcv-1.2.7\mmcv\utils\registry.py,读者可以根据自身情况查找自己项目中mmcv的registty.py文件位置,Linux应该是安装的mmcv包下。整体代码如下。后面我们会拆开来细看。

import inspect
import warnings
from functools import partial

from .misc import is_seq_of

class Registry:
    """A registry to map strings to classes.

    Args:
        name (str): Registry name.
    """

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def __len__(self):
        return len(self._module_dict)

    def __contains__(self, key):
        return self.get(key) is not None

    def __repr__(self):
        format_str = self.__class__.__name__ + \
                     f'(name={
      
      self._name}, ' \
                     f'items={
      
      self._module_dict})'
        return format_str

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        """Get the registry record.

        Args:
            key (str): The class name in string format.

        Returns:
            class: The corresponding class.
        """
        return self._module_dict.get(key, None)

    def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {
      
      type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        else:
            assert is_seq_of(
                module_name,
                str), ('module_name should be either of None, an '
                       f'instance of str or list, but got {
      
      type(module_name)}')
        for name in module_name:
            if not force and name in self._module_dict:
                raise KeyError(f'{
      
      name} is already registered '
                               f'in {
      
      self.name}')
            self._module_dict[name] = module_class

    def deprecated_register_module(self, cls=None, force=False):
        warnings.warn(
            'The old API of register_module(module, force=False) '
            'is deprecated and will be removed, please use the new API '
            'register_module(name=None, force=False, module=None) instead.')
        if cls is None:
            return partial(self.deprecated_register_module, force=force)
        self._register_module(cls, force=force)
        return cls

    def register_module(self, name=None, force=False, module=None):
        """Register a module.

        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.

        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)

        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {
      
      type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # raise the error ahead of time
        if not (name is None or isinstance(name, str)):
            raise TypeError(f'name must be a str, but got {
      
      type(name)}')

        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

具体可以看到,mmcv的Registry类包含两个成员变量self._name和self._module_dict,以及主要的6个成员函数,name函数、module_dict函数、get函数、_register_module函数、deprecated_register_module函数和register_module函数。我们分别进行简单介绍。

成员变量:
1.self._name变量:表示这个注册表的名称。
2.self._module_dict变量:以字典的形式表示注册表,即字符串与对应函数名的映射关系。
成员函数:
1.name函数、module_dict函数都是装饰器@property修饰的。Python内置的@property装饰器就是负责把一个方法变成属性调用。
2.register_module函数:完成对目标类的注册,具体代码如下,函数含义已经进行注释。

def register_module(self, name=None, force=False, module=None):
        """注册一个模型

        类名称将被添加到变量self._module_dict中, 该变量的键值是类别名或者专属名字。
        它可以通过装饰器或者函数直接调用。

        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)

        Args:
            name (str | None): 要注册的模块名称。如果未指定,则将使用类名。
            force (bool, optional): 是否用相同的名称重写现有的类。默认值:False。
            module (type): 要注册的模块类。
        """
        #---------------------------------------------------------------------
        #判断输入force参数是否正确。
        #---------------------------------------------------------------------
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {
      
      type(force)}')
        #---------------------------------------------------------------------
        #注意:这是一个与旧api兼容的演练,而它可能会引入意想不到的错误。
        #---------------------------------------------------------------------
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)
        #---------------------------------------------------------------------
        #判断module是否存在,如果存在则直接进行注册。并返回module
        #---------------------------------------------------------------------
        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module
        #---------------------------------------------------------------------
        #判断输入的name参数是否正确
        #---------------------------------------------------------------------
        # raise the error ahead of time
        if not (name is None or isinstance(name, str)):
            raise TypeError(f'name must be a str, but got {
      
      type(name)}')
        #---------------------------------------------------------------------
        #如果module不存在,则通过装饰器的方式进行注册。
        #---------------------------------------------------------------------
        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register
  1. _register_module函数:具体实现函数或者类的注册方法。具体代码如下,函数已经做了注释。
  def _register_module(self, module_class, module_name=None, force=False):
    """
    具体实现注册方法。
    :param module_class:需要注册的函数本身
    :param module_name:需要注册的函数名称。默认为None
    :param force:是否重写已经存在的函数,默认为False
    """
        #---------------------------------------------------------------------
        #判断module是否是class类。不是类则报错
        #---------------------------------------------------------------------
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {
      
      type(module_class)}')
        #---------------------------------------------------------------------
        #判断module_name是否存在,不存在则默认函数本身名称
        #---------------------------------------------------------------------
        if module_name is None:
            module_name = module_class.__name__
        #---------------------------------------------------------------------
        #判断module_name是否是个字符串,或者列表
        #---------------------------------------------------------------------
        if isinstance(module_name, str):
            module_name = [module_name]
        else:
            assert is_seq_of(
                module_name,
                str), ('module_name should be either of None, an '
                       f'instance of str or list, but got {
      
      type(module_name)}')
        #---------------------------------------------------------------------
        #针对列表中的字符串进行注册。
        #---------------------------------------------------------------------
        for name in module_name:
            if not force and name in self._module_dict:
                raise KeyError(f'{
      
      name} is already registered '
                               f'in {
      
      self.name}')
            #完成注册
            self._module_dict[name] = module_class

4.get函数:通过查找注册表,实现从字符串name到对应函数名称的映射,返回对应名称的函数。

    def get(self, key):
        """或者注册表的键值

        Args:
            key (str): 键值必须是字符串

        Returns:
            class: 键值对应的类.
        """
		return self._module_dict.get(key, None)

5.deprecated_register_module函数:弃用注册模块。这个没看明白。

    def deprecated_register_module(self, cls=None, force=False):
        
        warnings.warn(
            'The old API of register_module(module, force=False) '
            'is deprecated and will be removed, please use the new API '
            'register_module(name=None, force=False, module=None) instead.')
        if cls is None:
            return partial(self.deprecated_register_module, force=force)
        self._register_module(cls, force=force)
        return cls

通过mmcv的Registry简单注册一下。具体实现如下。当然MMCV添加注册模块可不是这样做的。具体如何添加模块,可以参考csdn博客mmsegment定制模型(八

if __name__ == "__main__":

    from torch.optim.adam import Adam
    registry_list = Registry("OPTIM")
    registry_list.register_module(name="registry_adam", module=Adam)
    optim = registry_list.get("registry_adam")
    print(optim)
    print(registry_list.module_dict)

在这里插入图片描述

总结

本文分别总结python闭包,装饰器和注册机制的原理,并列出了代码和输出结果。不过,本文还是仅仅只涉及了最粗浅的部分。在学习各个参考博客时,发现关于闭包和装饰器的内容远远不止这些。如果,您对这些方面有兴趣可以进入参考博客中继续学习。最后,感谢您的阅读。

猜你喜欢

转载自blog.csdn.net/weixin_43610114/article/details/126182474