AIMET API 文档(2)


1.1.3 模型准备器 API

AIMET PyTorch ModelPreparer API 使用 PyTorch 1.9+ 版本中提供的新图形转换功能,并自动执行用户所需的模型定义更改。 例如,它将前向传递中定义的函数更改为 torch.nn.Module 类型模块,以用于激活和元素函数。 此外,当重用 torch.nn.Module 类型的模块时,它会展开为独立的模块。

强烈建议用户首先使用 AIMET PyTorch ModelPreparer API,然后使用返回的模型作为所有 AIMET 量化功能的输入。

AIMET PyTorch ModelPreparer API 至少需要 PyTorch 1.9 版本。

1.1.3.1 顶层API

aimet_torch.model_preparer.prepare_model(model, modules_to_exclude=None, module_classes_to_exclude=None, concrete_args=None)[source]

使用 torch.FX 符号跟踪 API 准备和修改 AIMET 功能的 pytorch 模型。

  1. 将 torch.nn.function 替换为 torch.nn.Module 类型的模块
  2. 为重用/重复模块创建新的独立 torch.nn.Module 实例

参数:

  • model (Module) – 要修改的 pytorch 模型。
  • modules_to_exclude (Optional[List[Module]]) – 跟踪时要排除的模块列表。
  • module_classes_to_exclude (Optional[List[Callable]]) – 跟踪时要排除的模块类列表。
  • concrete_args (Optional[Dict[str, Any]]) – 允许你部分专业化你的功能,无论是删除控制流还是数据结构。 如果模型具有控制流,torch.fx 将无法跟踪模型。 详细检查 torch.fx.symbolic_trace API。

返回类型GraphModule
返回:修改后的pytorch模型

1.1.3.2 代码示例

所需导入

import torch
import torch.nn.functional as F
from aimet_torch.model_preparer import prepare_model

示例 1:具有函数式 reLU 的模型

我们从下面的模型开始,它包含两个函数relus和relu方法在forward方法中。

class ModelWithFunctionalReLU(torch.nn.Module):
    """ Model that uses functional ReLU instead of nn.Modules. Expects input of shape (1, 3, 32, 32) """
    def __init__(self):
        super(ModelWithFunctionalReLU, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x).relu()
        return x

通过传入模型来在模型上运行模型准备器 API。

def model_preparer_functional_example():

    # Load the model and keep in eval() mode
    model = ModelWithFunctionalReLU().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = torch.randn(*input_shape)

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(input_tensor), prepared_model(input_tensor))

之后,我们得到prepared_model,它在功能上与原始模型相同。 用户可以通过比较两个模型的输出来验证这一点。

prepared_model 现在应该将所有三个功能 relus 转换为 torch.nn.ReLU 模块,这些模块满足模型指南中描述的模型指南。

示例 2:具有重用 torch.nn.ReLU 模块的模型

我们从以下模型开始,其中包含 torch.nn.ReLU 模块,该模块在模型前向函数内的多个实例中使用。

class ModelWithReusedReLU(torch.nn.Module):
    """ Model that uses single ReLU instances multiple times in the forward. Expects input of shape (1, 3, 32, 32) """
    def __init__(self):
        super(ModelWithReusedReLU, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.relu = torch.nn.ReLU()
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return x

通过传入模型来在模型上运行模型准备器 API。

def model_preparer_reused_example():

    # Load the model and keep in eval() mode
    model = ModelWithReusedReLU().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = torch.randn(*input_shape)

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(input_tensor), prepared_model(input_tensor))

之后,我们得到prepared_model,它在功能上与原始模型相同。 用户可以通过比较两个模型的输出来验证这一点。

prepared_model 应该有单独的独立的 torch.nn.Module 实例,满足模型指南中描述的模型指南。

示例 3:带有逐元素加法的模型

我们从以下模型开始,其中包含模型前向函数内的元素添加操作。

class ModelWithElementwiseAddOp(torch.nn.Module):
    def __init__(self):
        super(ModelWithElementwiseAddOp, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5, bias=False)
        self.conv2 = torch.nn.Conv2d(3, 6, 5)

    def forward(self, *inputs):
        x1 = self.conv1(inputs[0])
        x2 = self.conv2(inputs[1])
        x = x1 + x2
        return x

通过传入模型来在模型上运行模型准备器 API。

def model_preparer_elementwise_add_example():

    # Load the model and keep in eval() mode
    model = ModelWithElementwiseAddOp().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = [torch.randn(*input_shape), torch.randn(*input_shape)]

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(*input_tensor), prepared_model(*input_tensor))

之后,我们得到prepared_model,它在功能上与原始模型相同。 用户可以通过比较两个模型的输出来验证这一点。

1.1.3.3 torch.fx 符号跟踪 API 的限制

torch.fx 符号跟踪的限制:https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing

  1. torch.fx 循环或 if-else 语句不支持动态控制流,其中条件可能取决于某些输入值。 它只能跟踪一个执行路径,所有其他未跟踪的分支将被忽略。 例如,跟踪以下简单函数时,将失败并显示 TraceError ,指出“符号跟踪变量不能用作控制流的输入”:
def f(x, flag):
    if flag:
        return x
    else:
        return x*2

torch.fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={
    
    'flag': True})

此问题的解决方法:

  • 许多情况下的动态控制流可以简单地变成静态控制流,这由torch.fx符号跟踪支持。 静态控制流是 where 循环或 if-else 语句,其值在不同模型前向传递中不能改变。 通过将具体值传递给“concrete_args”以专门化您的前向函数,可以消除对输入值的数据依赖性,从而跟踪此类情况。

  • 在真正的动态控制流中,用户应该使用 torch.fx.wrap API 将这段代码包装在模型级范围内,这会将其保留为节点,而不是通过以下方式进行跟踪:

    @torch.fx.wrap
    def custom_function_not_to_be_traced(x, y):
        """ Function which we do not want to be traced, when traced using torch FX API, call to this function will
        be inserted as call_function, and won't be traced through """
        for i in range(2):
            x += x
            y += y
        return x * x + y * y
    
    
  1. 符号跟踪默认不支持不使用 torch_function 机制的非 Torch 函数。

此问题的解决方法:

  • 如果我们不想在符号跟踪中捕获它们,那么用户应该在模块级范围内使用 torch.fx.wrap() API:

    import torch
    import torch.fx
    torch.fx.wrap('len')  # call the API at module-level scope.
    torch.fx.wrap('sqrt') # call the API at module-level scope.
    
    class ModelWithNonTorchFunction(torch.nn.Module):
        def __init__(self):
            super(ModelWithNonTorchFunction, self).__init__()
            self.conv = torch.nn.Conv2d(3, 4, kernel_size=2, stride=2, padding=2, bias=False)
    
        def forward(self, *inputs):
            x = self.conv(inputs[0])
            return x / sqrt(len(x))
    
    model = ModelWithNonTorchFunction().eval()
    model_transformed = prepare_model(model)
    
    
  1. 通过重写 Tracer.is_leaf_module() API 自定义跟踪行为

在符号跟踪中,叶模块显示为节点而不是被跟踪,并且所有标准 torch.nn 模块都是默认的叶模块集。 但可以通过重写 Tracer.is_leaf_module() API 来更改此行为。

AIMET 模型准备器 API 公开了“module_to_exclude”参数,该参数可用于防止跟踪某些模块。 例如,让我们检查以下代码片段,我们不想进一步跟踪 CustomModule:

class CustomModule(torch.nn.Module):
    @staticmethod
    def forward(x):
        return x * torch.nn.functional.softplus(x).sigmoid()

class CustomModel(torch.nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2)
        self.custom = CustomModule()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.custom(x)
        return x

model = CustomModel().eval()
prepared_model = prepare_model(model, modules_to_exclude=[model.custom])
print(prepared_model)

在此示例中,“self.custom”被保留为节点并且不被跟踪。

  1. 张量构造函数不可追踪

例如,让我们检查以下代码片段:

def f(x):
    return torch.arange(x.shape[0], device=x.device)

torch.fx.symbolic_trace(f)

Error traceback:
    return torch.arange(x.shape[0], device=x.device)
    TypeError: arange() received an invalid combination of arguments - got (Proxy, device=Attribute), but expected one of:
    * (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
    * (Number start, Number end, Number step, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

上面的代码片段是有问题的,因为 torch.arange() 的参数是依赖于输入的。 此问题的解决方法:

  • 使用确定性构造函数(硬编码),以便它们产生的值将作为常量嵌入到图中:

    def f(x):
        return torch.arange(10, device=torch.device('cpu'))
    
  • 或者使用 torch.fx.wrap API 来包装 torch.arange() 并调用它:

@torch.fx.wrap
def do_not_trace_me(x):
    return torch.arange(x.shape[0], device=x.device)

def f(x):
    return do_not_trace_me(x)

torch.fx.symbolic_trace(f)

猜你喜欢

转载自blog.csdn.net/weixin_38498942/article/details/133066832