본문 바로가기
AI HW study/Transformer

pytorch-image-models (Timm) model 코드분석 1

by jyun13 2024. 1. 22.

DeiT모델이 어떻게 코드로 구현되어 있는지 숙지해야하는데, DeiT의 모델은 timm에서 구현한 vit 및 deit 모델을 가져와 사용한다.

timm에 있는 vit & deit 코드를 보며 모델이 어떻게 구성되어 있는지 확인해 보고 싶고 앞으로 timm에 있는 models의 코드를 하나씩 설명하고자 한다.

 

1. _pruned

- 하나의 딥러닝 모델의 가중치(weight)를 설명

- 각각의 "conv"는 합성곱 레이어 & "bn"은 배치 정규화(batch normalization) 레이어 & "se.conv"는 Squeeze-and-Excitation(SE) 모듈의 합성곱 가중치 & "downsample"은 다운샘플링을 위한 레이어

ex) layer1.0.conv1.weight는 ResNet 등의 레지듀얼 블록의 첫 번째 합성곱 레이어의 가중치

= 이 가중치의 크기는 [45, 64, 1, 1]으로, 입력 채널이 64이고 출력 채널이 45인 합성곱 필터

- fc.weight는 모델의 마지막 Fully Connected (FC) 레이어의 가중치 -> 크기는 [1000, 2042]로, 2042개의 입력 피처를 1000개의 클래스로 매핑하는 데 사용

- 딥러닝 모델이 이미 학습한 특징을 나타내며, 훈련된 모델을 로드하여 예측을 수행하거나 추가적인 학습을 진행할 때 사용

- 이러한 가중치들은 주로 훈련된 모델의 학습된 특징을 나타내며, 직접적으로 이해하기는 어렵습니다. 이러한 모델은 대개 대량의 이미지 데이터셋에서 사전 훈련된 후 특정 작업에 맞게 fine-tuning되어 사용

 

2. layers ( __init__.py )
- timm 라이브러리에서 layers 모듈을 import하는 부분

- 그러나 이 코드는 deprecated(사용 중지)되었으며, timm.layers를 사용하도록 권장

- 새로운 코드에서는 다음과 같이 수정하여 사용

# 기존 코드
from timm.models.layers import *

# 수정된 코드
from timm.layers.activations import *
from timm.layers.adaptive_avgmax_pool import \
    adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
# ... (다른 import 문들도 동일하게 수정)

 

3. __init__.py

- timm 라이브러리에서 다양한 모델들을 import하는 부분

- import 구문은 각 모델에 해당하는 모듈에서 해당 모델을 가져오는 역할

- 모델을 구성하거나 미리 훈련된 가중치를 로드하는 데 사용되는 기능들도 포함

from .beit import *
from .byoanet import *
from .byobnet import *
from .cait import *
from .coat import *
from .convit import *
from .convmixer import *
from .convnext import *
from .crossvit import *
from .cspnet import *
from .davit import *
from .deit import *
from .densenet import *
from .dla import *
from .dpn import *
from .edgenext import *
from .efficientformer import *
from .efficientformer_v2 import *
from .efficientnet import *
from .efficientvit_mit import *
from .efficientvit_msra import *
from .eva import *
from .fastvit import *
from .focalnet import *
from .gcvit import *
from .ghostnet import *
from .hardcorenas import *
from .hrnet import *
from .inception_next import *
from .inception_resnet_v2 import *
from .inception_v3 import *
from .inception_v4 import *
from .levit import *
from .maxxvit import *
from .metaformer import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .mobilevit import *
from .mvitv2 import *
from .nasnet import *
from .nest import *
from .nfnet import *
from .pit import *
from .pnasnet import *
from .pvt_v2 import *
from .regnet import *
from .repghost import *
from .repvit import *
from .res2net import *
from .resnest import *
from .resnet import *
from .resnetv2 import *
from .rexnet import *
from .selecsls import *
from .senet import *
from .sequencer import *
from .sknet import *
from .swin_transformer import *
from .swin_transformer_v2 import *
from .swin_transformer_v2_cr import *
from .tiny_vit import *
from .tnt import *
from .tresnet import *
from .twins import *
from .vgg import *
from .visformer import *
from .vision_transformer import *
from .vision_transformer_hybrid import *
from .vision_transformer_relpos import *
from .vision_transformer_sam import *
from .volo import *
from .vovnet import *
from .xception import *
from .xception_aligned import *
from .xcit import *

from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
    set_pretrained_download_progress, set_pretrained_check_hash
from ._factory import create_model, parse_model_name, safe_model_name
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
    register_notrace_module, is_notrace_module, get_notrace_modules, \
    register_notrace_function, is_notrace_function, get_notrace_functions
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
    group_modules, group_parameters, checkpoint_seq, adapt_input_conv
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
from ._prune import adapt_model_from_string
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
    register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
    is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

 

더보기

이 코드는 timm 라이브러리에서 모델을 구축하고 관리하는 데 사용되는 여러 유틸리티 함수와 클래스를 import하고 있습니다. 각 import된 모듈과 함수의 역할을 간단히 설명하겠습니다:

1. `_builder`: 모델을 구축하고 미리 훈련된 가중치를 로드하는 데 사용되는 함수들이 정의된 모듈입니다.
   - `build_model_with_cfg`: 설정(configuration)을 기반으로 모델을 구축합니다.
   - `load_pretrained`: 미리 훈련된 모델 가중치를 로드합니다.
   - `load_custom_pretrained`: 사용자 정의 된 미리 훈련된 모델 가중치를 로드합니다.
   - `resolve_pretrained_cfg`: 미리 훈련된 모델의 설정을 해결합니다.
   - `set_pretrained_download_progress`: 미리 훈련된 모델 다운로드 중 진행 상황을 설정합니다.
   - `set_pretrained_check_hash`: 미리 훈련된 모델의 해시 체크를 설정합니다.

2. `_factory`: 모델을 생성하고 모델 이름을 파싱하는 데 사용되는 함수들이 정의된 모듈입니다.
   - `create_model`: 모델을 생성합니다.
   - `parse_model_name`: 모델 이름을 파싱합니다.
   - `safe_model_name`: 모델 이름을 안전하게 반환합니다.

(파싱 = 모델의 이름을 분석하고 그 정보를 추출하는 과정, 모델 이름은 종종 특정한 패턴이나 규칙을 따르며, 이러한 이름에는 모델의 아키텍처, 크기, 변형 등과 관련된 다양한 정보가 포함

-> 모델을 동적으로 선택하거나 미리 훈련된 모델 가중치를 로드할 때 사용 )

3. `_features`: 특징 정보 및 특징 추출을 위한 클래스와 함수들이 정의된 모듈입니다.

4. `_features_fx`: 특징 추출을 위한 클래스와 함수들이 정의된 모듈입니다.

5. `_helpers`: 모델과 상호 작용하는 데 도움이 되는 여러 함수들이 정의된 모듈입니다.

6. `_hub`: Hugging Face Hub와 관련된 함수들이 정의된 모듈입니다.

7. `_manipulate`: 모델을 조작하고 변경하는 데 사용되는 함수들이 정의된 모듈입니다.

8. `_pretrained`: 미리 훈련된 모델 설정과 관련된 함수와 클래스들이 정의된 모듈입니다.

9. `_prune`: 모델 가중치를 가지치기(pruning)하는 데 사용되는 함수들이 정의된 모듈입니다.

10. `_registry`: 모델 등록 및 관리에 사용되는 함수와 클래스들이 정의된 모듈입니다.

이러한 모듈들은 모델 구축, 미리 훈련된 가중치 로드, 특징 추출 등 다양한 작업에 사용

`timm` 라이브러리를 통해 다양한 비전 모델을 쉽게 사용

 

 

4. timm/models/_builder.py

import dataclasses
import logging
import os
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple

from torch import nn as nn
from torch.hub import load_state_dict_from_url

from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
from timm.models._registry import get_pretrained_cfg

_logger = logging.getLogger(__name__)

# Global variables for rarely used pretrained checkpoint download progress and hash check.
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = False
_CHECK_HASH = False
_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0

__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
           'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']

 

더보기

- timm(Torch Image Models) 라이브러리에서 모델의 사전 훈련된 가중치(pretrained weights)를 로드하는 데 사용되는 여러 함수 및 변수들을 정의하는 모듈


1. `import` 구문: 필요한 라이브러리 및 모듈을 가져옵니다.

2. `_logger`: logging을 위한 logger 객체를 생성합니다.

3. 전역 변수들:
- `_DOWNLOAD_PROGRESS`: 
   - 설명: 사전 훈련된 가중치의 다운로드 진행 상황을 나타내는 변수입니다. 
   - 사용: `set_pretrained_download_progress` 함수를 사용하여 이 변수를 조절할 수 있습니다.
   - 활용: 사전 훈련된 모델 가중치를 다운로드할 때 진행 상황을 표시하거나, 진행 상황을 표시하지 않고 다운로드를 수행할 때 사용됩니다.

 

- `_CHECK_HASH`:
   - 설명: 사전 훈련된 가중치의 해시를 확인할지 여부를 나타내는 변수입니다.
   - 사용: `set_pretrained_check_hash` 함수를 사용하여 이 변수를 조절할 수 있습니다.
   - 활용: 다운로드한 가중치 파일의 해시를 확인하여 파일의 무결성을 검증하는 데 사용됩니다. 무결성 검사를 수행할지 여부를 설정합니다.

- `_USE_OLD_CACHE`:
   - 설명: 이전 버전의 캐시를 사용할지 여부를 나타내는 변수입니다. (이미 이전에 다운로드한 파일들을 계속 사용하겠는가?)
   - 사용: 환경 변수 `TIMM_USE_OLD_CACHE`를 통해 설정할 수 있습니다.
   - 활용: 이전에 다운로드한 캐시를 계속 사용할지 여부를 설정합니다. 이전 버전의 캐시를 사용하면 다운로드 속도를 높일 수 있지만, 새로운 업데이트가 누락될 수 있습니다. 새로운 모델 가중치를 사용하고자 할 때에는 이 옵션을 끄고(0으로 설정) 다시 다운로드 하는 것이 일반적인 절차입니다. 

4. `__all__`: 모듈 외부에서 import 가능한 이름들의 리스트를 정의합니다.

5. 기타 함수 및 클래스 import: 여러 모듈에서 사용되는 함수들과 클래스들을 import합니다. 이들은 모델의 특징(feature), 가중치 조작(weight manipulation), 사전 훈련(pretrained) 관련 기능 등을 구현하는 데 사용됩니다.

이 모듈은 주로 모델의 사전 훈련된 가중치를 관리하고 로드하는 데 사용되는 기능들을 제공합니다.

logging은 파이썬에서 로깅(logging)을 구현하기 위한 내장 모듈입니다.
로깅은 프로그램의 실행 중에 발생하는 이벤트나 정보를 기록하고 관리하는 것을 말합니다.
로깅을 사용하면 프로그램이 어떻게 동작하고 있는지에 대한 정보를 기록하여 나중에 분석하거나 디버깅하는 데 도움

 

def _resolve_pretrained_source(pretrained_cfg):
    cfg_source = pretrained_cfg.get('source', '')
    pretrained_url = pretrained_cfg.get('url', None)
    pretrained_file = pretrained_cfg.get('file', None)
    pretrained_sd = pretrained_cfg.get('state_dict', None)
    hf_hub_id = pretrained_cfg.get('hf_hub_id', None)

    # resolve where to load pretrained weights from
    load_from = ''
    pretrained_loc = ''
    if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
        # hf-hub specified as source via model identifier
        load_from = 'hf-hub'
        assert hf_hub_id
        pretrained_loc = hf_hub_id
    else:
        # default source == timm or unspecified
        if pretrained_sd:
            # direct state_dict pass through is the highest priority
            load_from = 'state_dict'
            pretrained_loc = pretrained_sd
            assert isinstance(pretrained_loc, dict)
        elif pretrained_file:
            # file load override is the second-highest priority if set
            load_from = 'file'
            pretrained_loc = pretrained_file
        else:
            old_cache_valid = False
            if _USE_OLD_CACHE:
                # prioritized old cached weights if exists and env var enabled
                old_cache_valid = check_cached_file(pretrained_url) if pretrained_url else False
            if not old_cache_valid and hf_hub_id and has_hf_hub(necessary=True):
                # hf-hub available as alternate weight source in default_cfg
                load_from = 'hf-hub'
                pretrained_loc = hf_hub_id
            elif pretrained_url:
                load_from = 'url'
                pretrained_loc = pretrained_url

    if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
        # if a filename override is set, return tuple for location w/ (hub_id, filename)
        pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
    return load_from, pretrained_loc


def set_pretrained_download_progress(enable=True):
    """ Set download progress for pretrained weights on/off (globally). """
    global _DOWNLOAD_PROGRESS
    _DOWNLOAD_PROGRESS = enable


def set_pretrained_check_hash(enable=True):
    """ Set hash checking for pretrained weights on/off (globally). """
    global _CHECK_HASH
    _CHECK_HASH = enable

 

더보기

`_resolve_pretrained_source` 함수: 이 함수는 사전 훈련된 모델의 가중치를 어디서 가져올지를 결정

1. `cfg_source`: 설정 파일에서 모델의 가중치 소스로 지정된 값입니다. ('hf-hub', 'timm', 또는 비어 있을 수 있음)
2. `pretrained_url`: 설정 파일에서 가중치 파일의 URL입니다.
3. `pretrained_file`: 설정 파일에서 가중치 파일의 경로입니다.
4. `pretrained_sd`: 설정 파일에서 직접 전달되는 state_dict입니다.
5. `hf_hub_id`: 설정 파일에서 Hugging Face hub에서 모델을 식별하는 ID입니다.

로딩 소스의 우선순위:
- 'hf-hub'이 설정되어 있고 Hugging Face hub이 사용 가능하면, 해당 hub에서 가중치를 로드합니다.
- 직접적인 state_dict가 전달되었다면, 해당 state_dict를 사용합니다.
- 파일 경로가 지정되었다면, 해당 파일을 사용합니다.
- 위의 조건들이 모두 해당되지 않으면, 이전 버전의 캐시를 사용할 수 있을 경우 사용하고, Hugging Face hub이 사용 가능하면 해당 hub에서 가중치를 로드합니다.
- 최종적으로 URL이 지정되어 있다면 해당 URL에서 가중치를 로드합니다.

 

// assert isinstance(pretrained_loc, dict)

이 코드 라인은 pretrained_loc이 dict 자료형인지 확인하는데 사용됩니다. assert 문은 주어진 조건이 True가 아니면 AssertionError를 발생시키며, 이는 디버깅 및 코드 검증 목적으로 사용됩니다.

여기서 pretrained_loc은 이전에 설정된 가중치의 경로나 딕셔너리일 수 있습니다. 코드는 pretrained_sd 변수에 직접적인 state_dict가 전달된 경우 해당하는지 확인하고, 그렇지 않은 경우 다른 경로나 URL을 찾기 위해 사용됩니다. 이때 pretrained_loc이 딕셔너리인지 확인하는 것은 예외 상황을 처리하기 위함입니다.

//


`set_pretrained_download_progress` 및 `set_pretrained_check_hash` 함수:
- `set_pretrained_download_progress`: 사전 훈련된 가중치 다운로드 중에 진행 상황을 표시할지 여부를 설정합니다.
- `set_pretrained_check_hash`: 사전 훈련된 가중치 다운로드 후 해시를 확인할지 여부를 설정합니다.

이 함수들은 전역 변수 `_DOWNLOAD_PROGRESS` 및 `_CHECK_HASH`를 변경하여 사전 훈련된 가중치의 다운로드 및 검증 옵션을 전역적으로 제어합니다.

`_USE_OLD_CACHE`:
이 변수는 이전 버전의 캐시를 사용할지 여부를 결정합니다. 환경 변수 `TIMM_USE_OLD_CACHE`를 통해 설정할 수 있으며, 이전 버전의 캐시가 사용 가능한 경우 (`_USE_OLD_CACHE`가 True인 경우), 이전 버전의 캐시를 우선시하게 됩니다.

이 모든 함수와 변수는 Timm 라이브러리에서 모델을 효과적으로 관리하고 사전 훈련된 가중치를 로드할 때 사용

 

def load_custom_pretrained(
        model: nn.Module,
        pretrained_cfg: Optional[Dict] = None,
        load_fn: Optional[Callable] = None,
):
    r"""Loads a custom (read non .pth) weight file

    Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
    a passed in custom load fun, or the `load_pretrained` model member fn.

    If the object is already present in `model_dir`, it's deserialized and returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        model: The instantiated model to load weights into
        pretrained_cfg (dict): Default pretrained model cfg
        load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
            'laod_pretrained' on the model will be called if it exists
    """
    pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
    if not pretrained_cfg:
        _logger.warning("Invalid pretrained config, cannot load weights.")
        return

    load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
    if not load_from:
        _logger.warning("No pretrained weights exist for this model. Using random initialization.")
        return
    if load_from == 'hf-hub':
        _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
    elif load_from == 'url':
        pretrained_loc = download_cached_file(
            pretrained_loc,
            check_hash=_CHECK_HASH,
            progress=_DOWNLOAD_PROGRESS,
        )

    if load_fn is not None:
        load_fn(model, pretrained_loc)
    elif hasattr(model, 'load_pretrained'):
        model.load_pretrained(pretrained_loc)
    else:
        _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")

 

더보기

`load_custom_pretrained` 함수는 사용자 정의(pretrained된 .pth 파일이 아닌 다른 형식의) 가중치 파일을 모델에 로드하는 역할을 합니다. 이 함수는 주어진 모델에 대한 가중치를 로드하는 과정을 담당하며, 이를 위해 사용자가 제공한 로드 함수(`load_fn`)나 모델 자체에 정의된 `load_pretrained` 메서드를 호출합니다.

여러 인자와 동작이 있습니다:
- `model`: 가중치를 로드할 대상이 되는 인스턴스화된 모델입니다.
- `pretrained_cfg`: 모델의 기본 사전 훈련된 가중치 설정입니다. 이는 모델이나 해당 가중치의 기본 설정을 나타내는 딕셔너리로 생각할 수 있습니다.
- `load_fn`: 사용자 정의 로드 함수입니다. 만약 이 함수가 주어지면, 사용자가 정의한 함수를 사용하여 가중치를 로드합니다. 그렇지 않으면 모델이나 모델의 `load_pretrained` 메서드를 사용합니다.

코드는 먼저 주어진 모델과 관련된 사전 훈련된 가중치 설정(`pretrained_cfg`)을 확인합니다. 

설정이 없는 경우 경고를 출력하고 함수를 종료합니다. 

그런 다음 `_resolve_pretrained_source` 함수를 사용하여 실제로 가중치를 어디서 로드할지를 결정합니다.

로드할 위치(`load_from`)에 따라 가중치를 다운로드하거나(Hugging Face hub이나 URL일 경우), 이미 캐시된 파일이 있다면 캐시된 파일을 사용합니다. 

로드 함수(`load_fn`)가 제공되었다면 해당 함수를 사용하여 가중치를 로드하고, 

그렇지 않으면 모델의 `load_pretrained` 메서드를 사용하여 가중치를 로드합니다. 

 

만약 이 두 가지 방법 중 어느 것도 사용할 수 없다면 경고를 출력하고 무작위 초기화로 가중치를 설정합니다.

 

def load_pretrained(

더보기

`load_pretrained` 함수는 모델에 대한 사전 훈련된 가중치를 로드하는 역할을 합니다. 여러 인자와 동작이 있습니다:

- `model`: 가중치를 로드할 대상이 되는 PyTorch 모델입니다.
- `pretrained_cfg`: 모델의 기본 사전 훈련된 가중치 설정입니다. 이는 모델이나 해당 가중치의 기본 설정을 나타내는 딕셔너리로 생각할 수 있습니다.
- `num_classes`: 대상 모델의 클래스 수입니다.
- `in_chans`: 대상 모델의 입력 채널 수입니다.
- `filter_fn`: state_dict를 필터링하는 사용자 정의 함수입니다.
- `strict`: 가중치 로드 시 strict 모드 여부를 결정하는 플래그입니다.

코드는 먼저 주어진 모델과 관련된 사전 훈련된 가중치 설정(`pretrained_cfg`)을 확인합니다. 설정이 없는 경우 런타임 오류를 발생시킵니다.

그런 다음 `_resolve_pretrained_source` 함수를 사용하여 실제로 가중치를 어디서 로드할지를 결정합니다. 로드할 위치(`load_from`)에 따라 가중치를 다양한 방법으로 로드합니다:

- `state_dict`: 이미 로드된 state_dict를 사용합니다.
- `file`: 파일에서 가중치를 로드합니다.
- `url`: URL에서 가중치를 다운로드하고 로드합니다.
- `hf-hub`: Hugging Face hub에서 가중치를 다운로드하고 로드합니다.

로드된 state_dict를 `model.load_state_dict`를 사용하여 모델에 적용합니다. 로드가 완료된 후에는 몇 가지 추가적인 작업이 수행됩니다:

- `filter_fn`이 제공된 경우 state_dict를 필터링합니다.
- `input_convs`가 설정되어 있고 입력 채널이 3이 아닌 경우, 입력 컨볼루션의 가중치를 조정합니다.
- `classifiers`가 설정되어 있고 클래스 수가 사전 훈련된 가중치와 일치하지 않는 경우, 분류기의 가중치를 제거합니다.
- `label_offset`이 설정된 경우, 가중치에서 불필요한 클래스에 대한 부분을 제거합니다.

마지막으로 로드된 가중치에 대한 정보 및 경고 메시지가 출력됩니다.

 

def pretrained_cfg_for_features(pretrained_cfg):
    pretrained_cfg = deepcopy(pretrained_cfg)
    # remove default pretrained cfg fields that don't have much relevance for feature backbone
    to_remove = ('num_classes', 'classifier', 'global_pool')  # add default final pool size?
    for tr in to_remove:
        pretrained_cfg.pop(tr, None)
    return pretrained_cfg


def _filter_kwargs(kwargs, names):
    if not kwargs or not names:
        return
    for n in names:
        kwargs.pop(n, None)

 

더보기

1. `pretrained_cfg_for_features` 함수는 특징 추출 부분에 사용할 사전 훈련된 가중치 설정을 생성한다. 주어진 설정에서 특징 추출에 필요하지 않은 몇 가지 기본 필드를 제거한 후, 변경된 설정을 반환한다.

2. `pretrained_cfg`를 `deepcopy`하여 새로운 객체를 생성한다.

3. `to_remove`에는 특징 추출과 관련이 적은 몇 가지 기본 사전 훈련된 가중치 설정 필드가 포함되어 있다. 현재는 'num_classes' (클래스 수), 'classifier' (분류기), 'global_pool' (전역 풀링)이 제거 대상으로 설정되어 있다. 'global_pool' 필드를 제거할 때 'add default final pool size?' 주석이 남아있다.

4. `to_remove`에 있는 각 필드에 대해 반복문을 통해 `pretrained_cfg`에서 해당 필드를 제거한다.

5. 최종적으로 변경된 `pretrained_cfg`를 반환한다.

6. `_filter_kwargs` 함수는 주어진 인자 딕셔너리에서 특정한 키들을 필터링하는 함수이다.

7. 만약 `kwargs`가 비어있거나 `names`가 비어있다면 아무 작업도 수행하지 않고 함수를 종료한다.

8. 그렇지 않은 경우, `names` 리스트에 있는 각 키를 `kwargs`에서 제거한다.

 

def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
    """ Update the default_cfg and kwargs before passing to model

    Args:
        pretrained_cfg: input pretrained cfg (updated in-place)
        kwargs: keyword args passed to model build fn (updated in-place)
        kwargs_filter: keyword arg keys that must be removed before model __init__
    """
    # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
    default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
    if pretrained_cfg.get('fixed_input_size', False):
        # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
        default_kwarg_names += ('img_size',)

    for n in default_kwarg_names:
        # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
        # pretrained_cfg has one input_size=(C, H ,W) entry
        if n == 'img_size':
            input_size = pretrained_cfg.get('input_size', None)
            if input_size is not None:
                assert len(input_size) == 3
                kwargs.setdefault(n, input_size[-2:])
        elif n == 'in_chans':
            input_size = pretrained_cfg.get('input_size', None)
            if input_size is not None:
                assert len(input_size) == 3
                kwargs.setdefault(n, input_size[0])
        elif n == 'num_classes':
            default_val = pretrained_cfg.get(n, None)
            # if default is < 0, don't pass through to model
            if default_val is not None and default_val >= 0:
                kwargs.setdefault(n, pretrained_cfg[n])
        else:
            default_val = pretrained_cfg.get(n, None)
            if default_val is not None:
                kwargs.setdefault(n, pretrained_cfg[n])

    # Filter keyword args for task specific model variants (some 'features only' models, etc.)
    _filter_kwargs(kwargs, names=kwargs_filter)
더보기

1. `_update_default_model_kwargs` 함수는 모델을 빌드하기 전에 `pretrained_cfg` 및 `kwargs`를 업데이트하는 함수이다.

2. `default_kwarg_names`에는 모델의 `__init__` 메서드에 전달되는 몇 가지 기본 인수 이름이 포함되어 있다. ('num_classes', 'global_pool', 'in_chans')

3. 만약 `pretrained_cfg`에서 'fixed_input_size'가 True로 설정되어 있다면, 모델은 입력 크기를 고정시키기 위해 'img_size' 인수를 사용한다. 따라서 'img_size'도 `default_kwarg_names`에 추가된다.

4. 반복문을 통해 `default_kwarg_names`에 있는 각 인수에 대해 다음 작업을 수행한다.
   - 'img_size': 만약 'img_size'가 `pretrained_cfg`에 존재하고, 이미 `kwargs`에 설정되어 있지 않으면, `pretrained_cfg`에서 'input_size'를 가져와서 두 번째 및 세 번째 요소를 'img_size'로 설정한다.
   - 'in_chans': 만약 'in_chans'가 `pretrained_cfg`에 존재하고, 이미 `kwargs`에 설정되어 있지 않으면, `pretrained_cfg`에서 'input_size'를 가져와서 첫 번째 요소를 'in_chans'로 설정한다.
   - 'num_classes': 만약 'num_classes'가 `pretrained_cfg`에 존재하고, 이미 `kwargs`에 설정되어 있지 않으면, `pretrained_cfg`에서 'num_classes'를 가져와서 'num_classes'로 설정한다.
   - 나머지 경우에는 `pretrained_cfg`에서 해당 인수를 가져와서 `kwargs`에 설정한다.

5. `kwargs_filter`에 있는 키들을 사용하여 `kwargs`를 필터링한다.

 

//

`_filter_kwargs` 함수는 주어진 `kwargs`에서 특정한 키워드 인수를 제거하는 역할을 합니다. 여러 키워드 인수를 담은 딕셔너리인 `kwargs`와 제거하고자 하는 키워드의 리스트 `names`를 받아서 해당하는 키워드를 딕셔너리에서 제거합니다.

여기에 `_filter_kwargs` 함수의 역할을 좀 더 자세히 설명합니다.

```python
def _filter_kwargs(kwargs, names):
    if not kwargs or not names:
        return
    for n in names:
        kwargs.pop(n, None)
```

1. `kwargs` 또는 `names`가 비어있다면 함수 실행을 종료합니다.
2. 주어진 `names` 리스트에 있는 각 키워드 `n`에 대해 `kwargs`에서 해당 키워드를 제거합니다.

예를 들어, 만약 `kwargs`가 `{'a': 1, 'b': 2, 'c': 3}`이고 `names`가 `['b', 'c']`이면, 함수 실행 후 `kwargs`는 `{'a': 1}`이 됩니다. 함수는 주어진 `names`에 있는 모든 키워드를 제거합니다.

//


6. 최종적으로 업데이트된 `pretrained_cfg`와 `kwargs`를 사용하여 모델을 빌드할 준비를 마친다.

```python
# 사용 예시:

# pretrained_cfg와 kwargs가 어딘가에 정의되어 있다고 가정합니다.

# 모델에 전달하기 전에 default_cfg 및 kwargs를 업데이트합니다.
_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter=['custom_key'])

# 이제 업데이트된 pretrained_cfg 및 kwargs를 사용하여 모델을 빌드할 수 있습니다.
model = build_model_with_cfg(**kwargs)
```

code snippet은 `_update_default_model_kwargs` 함수를 사용하는 방법을 보여줍니다.

default 구성 (`pretrained_cfg`) 및 키워드 인수 (`kwargs`)를 업데이트한 후 수정된 구성으로 모델을 만들려면 `build_model_with_cfg` 함수에 전달할 수 있습니다.

`kwargs_filter` 매개변수를 사용하면 필요한 경우 키워드 인수에서 특정 키를 걸러낼 수 있습니다.

 

def resolve_pretrained_cfg(
        variant: str,
        pretrained_cfg=None,
        pretrained_cfg_overlay=None,
) -> PretrainedCfg:
    model_with_tag = variant
    pretrained_tag = None
    if pretrained_cfg:
        if isinstance(pretrained_cfg, dict):
            # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
            pretrained_cfg = PretrainedCfg(**pretrained_cfg)
        elif isinstance(pretrained_cfg, str):
            pretrained_tag = pretrained_cfg
            pretrained_cfg = None

    # fallback to looking up pretrained cfg in model registry by variant identifier
    if not pretrained_cfg:
        if pretrained_tag:
            model_with_tag = '.'.join([variant, pretrained_tag])
        pretrained_cfg = get_pretrained_cfg(model_with_tag)

    if not pretrained_cfg:
        _logger.warning(
            f"No pretrained configuration specified for {model_with_tag} model. Using a default."
            f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
        pretrained_cfg = PretrainedCfg()  # instance with defaults

    pretrained_cfg_overlay = pretrained_cfg_overlay or {}
    if not pretrained_cfg.architecture:
        pretrained_cfg_overlay.setdefault('architecture', variant)
    pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)

    return pretrained_cfg
더보기

이 코드는 주어진 모델 `variant`에 대한 사전 훈련된 모델 설정을 해결하는 함수인 `resolve_pretrained_cfg`입니다. 이 함수는 세 가지 입력 매개변수를 받습니다.

1. `variant` (str): 모델의 변형 또는 식별자입니다.
2. `pretrained_cfg` (optional): 미리 정의된 사전 훈련된 설정이 포함된 객체 또는 설정의 식별자입니다.
3. `pretrained_cfg_overlay` (optional): 추가적인 설정을 덮어쓰기 위한 딕셔너리입니다.

함수의 반환값은 `PretrainedCfg` 형식의 객체입니다. 이 객체는 사전 훈련된 모델을 구성하는 데 사용되는 설정을 포함합니다.

함수의 동작을 간단히 설명하면 다음과 같습니다.

1. `pretrained_cfg`가 주어진 경우:
   - 만약 `pretrained_cfg`가 딕셔너리인 경우, 이를 `PretrainedCfg` 클래스의 인스턴스로 변환합니다.
   - `pretrained_cfg`가 문자열인 경우, `pretrained_tag` 변수에 저장하고 `pretrained_cfg`를 `None`으로 설정합니다.

2. `pretrained_cfg`가 주어지지 않은 경우:
   - `pretrained_tag`가 있으면 `model_with_tag`에 `variant`와 `pretrained_tag`를 결합하여 새로운 식별자를 생성합니다.
   - `get_pretrained_cfg` 함수를 사용하여 모델 레지스트리에서 `model_with_tag`에 해당하는 사전 훈련된 설정을 가져옵니다.
   - 가져온 설정이 없으면 경고를 출력하고 기본 설정을 사용합니다.

3. `pretrained_cfg_overlay`가 주어진 경우:
   - `pretrained_cfg`의 일부 설정을 `pretrained_cfg_overlay`로 덮어씁니다.

4. 최종적으로 수정된 `pretrained_cfg`를 반환합니다.

다음은 예시 사용법입니다:

```python
# Example usage:

# Assuming variant, pretrained_cfg, and pretrained_cfg_overlay are defined somewhere

# Resolve the pretrained configuration
resolved_pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg, pretrained_cfg_overlay)

# Now you can use the resolved_pretrained_cfg to build or load the model
model = build_model_with_cfg(pretrained_cfg=resolved_pretrained_cfg, **other_kwargs)
```

이 코드 스니펫은 `resolve_pretrained_cfg` 함수를 사용하여 모델의 사전 훈련된 설정을 해결한 후, 해당 설정을 사용하여 모델을 구축하거나 로드하는 방법을 보여줍니다.

 

def build_model_with_cfg(
        model_cls: Callable,
        variant: str,
        pretrained: bool,
        pretrained_cfg: Optional[Dict] = None,
        pretrained_cfg_overlay: Optional[Dict] = None,
        model_cfg: Optional[Any] = None,
        feature_cfg: Optional[Dict] = None,
        pretrained_strict: bool = True,
        pretrained_filter_fn: Optional[Callable] = None,
        kwargs_filter: Optional[Tuple[str]] = None,
        **kwargs,
):
    """ Build model with specified default_cfg and optional model_cfg

    This helper fn aids in the construction of a model including:
      * handling default_cfg and associated pretrained weight loading
      * passing through optional model_cfg for models with config based arch spec
      * features_only model adaptation
      * pruning config / model adaptation

    Args:
        model_cls (nn.Module): model class
        variant (str): model variant name
        pretrained (bool): load pretrained weights
        pretrained_cfg (dict): model's pretrained weight/task config
        model_cfg (Optional[Dict]): model's architecture config
        feature_cfg (Optional[Dict]: feature extraction adapter config
        pretrained_strict (bool): load pretrained weights strictly
        pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
        kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
        **kwargs: model args passed through to model __init__
    """
    pruned = kwargs.pop('pruned', False)
    features = False
    feature_cfg = feature_cfg or {}

    # resolve and update model pretrained config and model kwargs
    pretrained_cfg = resolve_pretrained_cfg(
        variant,
        pretrained_cfg=pretrained_cfg,
        pretrained_cfg_overlay=pretrained_cfg_overlay
    )

    # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
    pretrained_cfg = pretrained_cfg.to_dict()

    _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)

    # Setup for feature extraction wrapper done at end of this fn
    if kwargs.pop('features_only', False):
        features = True
        feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
        if 'out_indices' in kwargs:
            feature_cfg['out_indices'] = kwargs.pop('out_indices')

    # Instantiate the model
    if model_cfg is None:
        model = model_cls(**kwargs)
    else:
        model = model_cls(cfg=model_cfg, **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg  # alias for backwards compat

    if pruned:
        model = adapt_model_from_file(model, variant)

    # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
    num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
    if pretrained:
        load_pretrained(
            model,
            pretrained_cfg=pretrained_cfg,
            num_classes=num_classes_pretrained,
            in_chans=kwargs.get('in_chans', 3),
            filter_fn=pretrained_filter_fn,
            strict=pretrained_strict,
        )

    # Wrap the model in a feature extraction module if enabled
    if features:
        feature_cls = FeatureListNet
        output_fmt = getattr(model, 'output_fmt', None)
        if output_fmt is not None:
            feature_cfg.setdefault('output_fmt', output_fmt)
        if 'feature_cls' in feature_cfg:
            feature_cls = feature_cfg.pop('feature_cls')
            if isinstance(feature_cls, str):
                feature_cls = feature_cls.lower()
                if 'hook' in feature_cls:
                    feature_cls = FeatureHookNet
                elif feature_cls == 'fx':
                    feature_cls = FeatureGraphNet
                else:
                    assert False, f'Unknown feature class {feature_cls}'
        model = feature_cls(model, **feature_cfg)
        model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg)  # add back pretrained cfg
        model.default_cfg = model.pretrained_cfg  # alias for rename backwards compat (default_cfg -> pretrained_cfg)

    return model
더보기

이 코드는 주어진 모델 클래스(`model_cls`) 및 관련 설정으로 모델을 구축하는 함수인 `build_model_with_cfg`입니다. 이 함수는 다양한 매개변수를 받아 모델을 구성하고, 필요한 경우 사전 훈련된 가중치를 로드하며, 특징 추출 래퍼를 적용하는 등 다양한 작업을 수행합니다. 주요 매개변수와 작동 방식을 설명하겠습니다.

1. `model_cls` (nn.Module): 모델 클래스입니다.
2. `variant` (str): 모델의 변형 또는 식별자입니다.
3. `pretrained` (bool): 사전 훈련된 가중치를 로드할지 여부를 나타내는 플래그입니다.
4. `pretrained_cfg` (Optional[Dict]): 모델의 사전 훈련된 가중치 및 작업 구성에 대한 설정입니다.
5. `pretrained_cfg_overlay` (Optional[Dict]): `pretrained_cfg`를 덮어쓰기 위한 추가 설정입니다.
6. `model_cfg` (Optional[Dict]): 모델의 아키텍처 설정입니다.
7. `feature_cfg` (Optional[Dict]): 특징 추출 어댑터의 설정입니다.
8. `pretrained_strict` (bool): 사전 훈련된 가중치를 엄격하게 로드할지 여부를 나타내는 플래그입니다.
9. `pretrained_filter_fn` (Optional[Callable]): 사전 훈련된 가중치에 대한 필터링 함수입니다.
10. `kwargs_filter` (Optional[Tuple[str]]): 모델 생성에 사용되는 키워드 인수를 필터링하는데 사용됩니다.
11. `**kwargs`: 모델의 `__init__` 메서드에 전달되는 기타 키워드 인수입니다.

이 함수의 주요 작동 방식:

- `pruned` 변수가 있는 경우 모델이 가지는 특성을 나타냅니다.
- `pretrained_cfg`를 해결하고 모델의 기본 설정 및 모델 키워드를 업데이트합니다.
- `model_cfg`가 주어진 경우 해당 구성을 사용하여 모델을 인스턴스화합니다.
- `pruned`가 True인 경우 모델을 파일에서 적응시킵니다.
- `pretrained`가 True인 경우 `load_pretrained` 함수를 사용하여 모델에 사전 훈련된 가중치를 로드합니다.
- `features_only`가 True인 경우 특징만 추출할 수 있도록 모델을 래핑합니다.
- 최종적으로 구성된 모델을 반환합니다.

이 함수는 다양한 모델 구성 및 로딩 작업을 통합하여 사용자가 편리하게 모델을 생성하고 로드할 수 있도록 도와주는 유틸리티 함수입니다.