[시리즈별 느린 업데이트 보내기] CREStereo 소스 코드 읽기 2——특정 모듈 읽기

배경

CREStereo는 이전에 일반적인 프레임워크를 기록했습니다. 이제 강력한 존재감으로 일부 모듈의 특정 논리를 살펴보겠습니다.

참고

논문 읽기
소스 코드
논문
이전 블로그

특정 모듈

프레임워크를 전반적으로 읽은 후 심층 분석이 필요한 모듈은
LocalFeatureTransformer, AGCL, BasicUpdateBlock, convex_upsample입니다. AGCL과 함께 주목해야 합니다.

LocalFeatureTransformer

참조 추가

LoFTR은
LoFTR 참조 소스 코드를 읽습니다.

구체적 내용

먼저 코드 관점에서 살펴보고 잘 이해가 되지 않는다면 다른 사람들의 논문을 읽어보자. 앞으로 직접 시작하십시오. 이 레이어는 단순히 self-attention과 cross-attention의 임의의 조합임을 알 수 있으며, 구체적인 조합은 self.layer가 정의되는 방식에 따라 다릅니다.

for layer, name in zip(self.layers, self.layer_names):
    if name == "self":
        feat0 = layer(feat0, feat0, mask0, mask0)
        feat1 = layer(feat1, feat1, mask1, mask1)
    elif name == "cross":
        feat0 = layer(feat0, feat1, mask0, mask1)
        feat1 = layer(feat1, feat0, mask1, mask0)
    else:
        raise KeyError

따라서 초기화 부분을 살펴보십시오. layer_names에 의해 결정되는 하위 모듈 연결 수를 확인할 수 있습니다. 이 하위 모듈은 LoFTREncoderLayer(d_model, nhead, attention)에 의해 정의되므로 코드를 위로 스크롤하여 이 모듈의 특정 콘텐츠를 확인합니다.

def __init__(self, d_model, nhead, layer_names, attention):
    super(LocalFeatureTransformer, self).__init__()

    self.d_model = d_model
    self.nhead = nhead
    self.layer_names = layer_names
    encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
    self.layers = [
        copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))
    ]
    self._reset_parameters()

아직 처음부터 시작해서 종이 와 결합하여 여기에서 대응하기가 약간 어렵습니다 . 가면서 말하고 싶었는데 다 읽고 나니 기본적으로 똑같다는 걸 알았으니 그냥 여기 있는 논문의 설명을 읽어보세요.

def forward(self, x, source, x_mask=None, source_mask=None):
    bs = x.shape[0]
    query, key, value = x, source, source

    # multi-head attention
    query = F.reshape(
        self.q_proj(query), (bs, -1, self.nhead, self.dim)
    )  # [N, L, (H, D)] (H=8, D=256//8)
    key = F.reshape(
        self.k_proj(key), (bs, -1, self.nhead, self.dim)
    )  # [N, S, (H, D)]
    value = F.reshape(self.v_proj(value), (bs, -1, self.nhead, self.dim))
    message = self.attention(
        query, key, value, q_mask=x_mask, kv_mask=source_mask
    )  # [N, L, (H, D)]
    message = self.merge(
        F.reshape(message, (bs, -1, self.nhead * self.dim))
    )  # [N, L, C]
    message = self.norm1(message)

    # feed-forward network
    message = self.mlp(F.concat([x, message], axis=2))
    message = self.norm2(message)

    return x + message

추신: LinearAttention에는 최적화할 수 있는 작은 부분이 있는데, 기본적으로 속도에 도움이 되지 않습니다 하하. 즉, megengine은 고유한 elu 연산자가 없는 것 같습니다. 이것이 그가 말한 것입니다. 좌우를 살펴보니 여기서 1의 덧셈과 뺄셈을 생략할 수 있을 것 같고, elu_feature_map의 주석 처리된 부분이 제 글씨체입니다.

조달청: 사실 이렇게 보면 트랜스포머 아키텍처가 굉장히 특별하다는 느낌이 들지만, 아주 특별할 뿐이고, reshape 같은 중복 연산이 너무 많아 계산 속도에 영향을 미친다는 느낌이 듭니다. 이것을 어떤 각도에서 조금 수정하는 것이 가능하지 않습니까? 그나저나 코드에서 Linear Transformer를 사용하고 있는데 몇몇 변종은 잘 모르겠습니다.그런 기사를 본 후 여기 에 언급된 Performer를 사용해 볼 수 있습니까?

def elu(x, alpha=1.0):
    return F.maximum(0, x) + F.minimum(0, alpha * (F.exp(x) - 1))

def elu_feature_map(x):
    return elu(x) + 1
    # return F.relu(x) + F.minimum(1, F.exp(x))

AGCL

참조 추가
구체적 내용

좋은 사람, 이건 너무 길어서 읽고 싶지 않아.
먼저 __call__로 읽기 시작하세요. 간단히 말해서 이 두 모듈에 해당합니다. 그런 다음 다음 특정 내용에 해당합니다.

def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
    if iter_mode:
        corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
    else:
        corr = self.corr_att_offset(
            self.fmap1, self.fmap2, flow, extra_offset, small_patch
        )
    return corr

첫 번째는 corr_iter입니다. self.coords는 모듈이 초기화될 때 생성되는 점 좌표 세트와 동일하며, 흐름을 추가하는 것은 새로운 좌표 세트를 의미하며 오른쪽에 있는 좌표 세트에서 점을 수집합니다. 여기 small_patch에서 그의 검색 방법은 선택적 1D 또는 2D 검색이지만 계산의 일관성을 보장하기 위해 9포인트로 보장됨을 알 수 있습니다. 그런 다음 특징 채널 수에 따라 왼쪽 및 오른쪽 특징을 나눕니다. .
왜 4개의 부품을 먼저 분리한 다음 통합해야 합니까? 간소화된 계산과 관련된 것인지, 입력 내용과 관련된 것인지는 모르겠고, 추후에 좀 더 자세한 분석이 필요할 것 같습니다 .

def corr_iter(self, left_feature, right_feature, flow, small_patch):

    coords = self.coords + flow
    coords = F.transpose(coords, (0, 2, 3, 1))
    right_feature = bilinear_sampler(right_feature, coords)

    if small_patch:
        psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
        dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
    else:
        psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
        dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]

    N, C, H, W = left_feature.shape
    lefts = F.split(left_feature, 4, axis=1)
    rights = F.split(right_feature, 4, axis=1)

    corrs = []
    for i in range(len(psize_list)):
        corr = self.get_correlation(
            lefts[i], rights[i], psize_list[i], dilate_list[i]
        )
        corrs.append(corr)

    final_corr = F.concat(corrs, axis=1)

    return final_corr

계속하려면…

추천

출처blog.csdn.net/weixin_42492254/article/details/125081621