본문 바로가기
AI HW study/Transformer

pytorch-image-models (Timm) model 코드분석 3

by jyun13 2024. 1. 23.

timm/models/deit.py

""" DeiT - Data-efficient Image Transformers

DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below

paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877

paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118

Modifications copyright 2021, Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
from typing import Sequence, Union

import torch
from torch import nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import resample_abs_pos_embed
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations

__all__ = ['VisionTransformerDistilled']  # model_registry will add each entrypoint fn to this

 

더보기

이 코드는 "DeiT" (Data-efficient Image Transformers)라 불리는 Vision Transformer(ViT) 모델의 수정된 버전을 정의합니다. 이는 "DeiT: Data-efficient Image Transformers"

논문 (https://arxiv.org/abs/2012.12877) 및 "DeiT III: Revenge of the ViT" 논문 (https://arxiv.org/abs/2204.07118)에서 제안된 모델의 구현인 것으로 보입니다.

코드의 각 줄을 살펴보겠습니다.

1. `""" DeiT - Data-efficient Image Transformers`: 이것은 코드에 대한 간단한 설명을 제공하는 독스트링입니다. DeiT 모델과 관련이 있으며 Vision Transformers의 데이터 효율적인 버전입니다.

2. `DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below`: 이 주석은 DeiT 모델 정의 및 가중치가 주어진 GitHub 리포지토리에서 가져왔음을 언급하고 Facebook, Inc.의 원래 저작권을 인정합니다.

3. `from functools import partial`: `functools` 모듈에서 `partial` 함수를 가져옵니다. 이것은 특정 인수가 고정된 부분 함수를 생성하는 데 사용될 것입니다.

4. `from typing import Sequence, Union`: `typing` 모듈에서 `Sequence`와 `Union`을 가져옵니다. 아마도 함수 시그니처에서 타입 힌트로 사용될 것입니다.

5. `import torch`: PyTorch 라이브러리를 가져옵니다.

6. `from torch import nn as nn`: PyTorch에서 신경망 모듈 (`nn`)을 가져와 `nn`으로 별칭을 붙입니다.

7. `from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD`: `timm.data` 모듈에서 `IMAGENET_DEFAULT_MEAN` 및 `IMAGENET_DEFAULT_STD` 상수를 가져옵니다. 아마도 ImageNet 데이터셋 정규화를 위한 기본 평균 및 표준 편차 값과 관련이 있습니다.

8. `from timm.layers import resample_abs_pos_embed`: `resample_abs_pos_embed` 함수를 가져옵니다.

 

class VisionTransformerDistilled(VisionTransformer):
    """ Vision Transformer w/ Distillation Token and Head

    Distillation token & head support for `DeiT: Data-efficient Image Transformers`
        - https://arxiv.org/abs/2012.12877
    """
    
     def __init__(self, *args, **kwargs):
        weight_init = kwargs.pop('weight_init', '')
        super().__init__(*args, **kwargs, weight_init='skip')
        assert self.global_pool in ('token',)

        self.num_prefix_tokens = 2
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
        self.distilled_training = False  # must set this True to train w/ distillation token

        self.init_weights(weight_init)

    def init_weights(self, mode=''):
        trunc_normal_(self.dist_token, std=.02)
        super().init_weights(mode=mode)

    @torch.jit.ignore
더보기

이 코드는 `VisionTransformerDistilled` 클래스를 정의하고 있습니다. 이 클래스는 `VisionTransformer` 클래스를 상속받아서, DeiT (Data-efficient Image Transformers) 모델의 특별한 변형을 나타내고 있습니다. 

클래스의 주석과 각 메소드를 살펴보겠습니다.

1. `class VisionTransformerDistilled(VisionTransformer):`: `VisionTransformer` 클래스를 상속받아 `VisionTransformerDistilled` 클래스를 정의합니다.

2. `""" Vision Transformer w/ Distillation Token and Head`: 이 독스트링은 `VisionTransformerDistilled` 클래스가 Distillation Token 및 Head를 지원하는 Vision Transformer임을 설명하고 있습니다. 이는 DeiT 모델의 기능을 추가한 것입니다.

3. `Distillation token & head support for DeiT: 

Data-efficient Image Transformers - https://arxiv.org/abs/2012.12877`: 이 주석은 DeiT 모델의 Distillation Token 및 Head 지원에 대한 논문 링크를 제공하고 있습니다.

4. 이 코드는 `VisionTransformerDistilled` 클래스의 초기화 메소드를 정의함


4-1. `weight_init = kwargs.pop('weight_init', '')`: 키워드 인수로 전달된 `weight_init` 값을 가져오고, 만약 값이 없으면 빈 문자열('')을 기본값으로 설정합니다. 이 부분은 초기화 모드를 설정하기 위한 용도로 사용됩니다.

4-2. `super().__init__(*args, **kwargs, weight_init='skip')`: 부모 클래스인 `VisionTransformer`의 초기화 메소드를 호출합니다. 앞서 가져온 `weight_init` 값을 'skip'으로 설정하여 부모 클래스의 초기화 메소드를 호출합니다. 이렇게 함으로써 `VisionTransformer` 클래스의 초기화 메소드에서 특별한 초기화 작업을 수행하도록 합니다.

4-3. `assert self.global_pool in ('token',)`: 현재 클래스의 `global_pool` 속성이 'token'으로 설정되어 있는지 확인합니다. 이 값이 'token'이 아니면 AssertionError가 발생합니다.

4-4. `self.num_prefix_tokens = 2`: 현재 클래스의 속성인 `num_prefix_tokens`을 2로 설정합니다.

4-5. `self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))`: Distillation에 사용될 토큰을 나타내는 `dist_token`을 정의합니다. 이는 PyTorch의 `nn.Parameter`로 감싸져 있으므로 모델의 학습 중에 업데이트될 수 있는 학습 가능한 매개변수입니다.

4-6. `self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))`: 위치 임베딩을 나타내는 `pos_embed`을 정의합니다. 이 또한 학습 가능한 매개변수로 설정됩니다.

4-7. `self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()`: Distillation 헤드를 정의합니다. `num_classes`가 양수이면 선형 레이어를, 그렇지 않으면 항등 함수를 사용합니다.

4-8. `self.distilled_training = False`: `distilled_training` 속성을 False로 초기화합니다. 이 속성은 distillation token을 사용하여 훈련할지 여부를 결정하는데, 초기에는 사용하지 않도록 False로 설정됩니다.

4-9. `self.init_weights(weight_init)`: `init_weights` 메소드를 호출하여 가중치 초기화를 수행합니다. 이 메소드에서는 Distillation 토큰에 대한 가중치를 특정한 방식으로 초기화한 후에 부모 클래스의 가중치 초기화 메소드를 호출합니다.


5. 이 코드는 `VisionTransformerDistilled` 클래스에 속한 `init_weights` 메소드를 정의하고 있습니다. 

5-1. `trunc_normal_(self.dist_token, std=.02)`: `self.dist_token`에 대한 가중치를 평균이 0이고 표준 편차가 0.02인 절단 정규 분포를 사용하여 초기화합니다. `trunc_normal_` 함수는 텐서의 값을 평균이 0인 정규 분포에서 생성한 후, 일정한 표준 편차 이상의 값은 다시 샘플링하여 절단 정규 분포로 만듭니다. 이렇게 하면 특정 범위를 벗어나는 값이 가중치로 설정되지 않도록 합니다.

5-2. `super().init_weights(mode=mode)`: 부모 클래스인 `VisionTransformer`의 `init_weights` 메소드를 호출합니다. 부모 클래스의 가중치 초기화 메소드를 호출함으로써, 해당 메소드에서 정의한 초기화 작업을 수행합니다. `mode` 인수를 전달하여 부모 클래스의 초기화 메소드가 특정 모드로 동작하도록 지정할 수 있습니다.

5-3. `@torch.jit.ignore`: `torch.jit.ignore` 데코레이터는 TorchScript로 컴파일할 때 해당 메소드를 무시하도록 지정합니다. 이 메소드는 TorchScript에서 무시되어야 하는 경우에 사용됩니다.

 

def group_matcher(self, coarse=False):
        return dict(
            stem=r'^cls_token|pos_embed|patch_embed|dist_token',
            blocks=[
                (r'^blocks\.(\d+)', None),
                (r'^norm', (99999,))]  # final norm w/ last block
        )

    @torch.jit.ignore
    def get_classifier(self):
        return self.head, self.head_dist

    def reset_classifier(self, num_classes, global_pool=None):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

    @torch.jit.ignore
    def set_distilled_training(self, enable=True):
        self.distilled_training = enable