반응형
[파이썬, 파이토치] 모델 커스텀 : pytorch model custom¶
In [1]:
# 샘플 모델을 생성합니다
from torch import nn
class OriginNetwork(nn.Module):
def __init__(self, n_labels:int):
super(OriginNetwork, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 16, 2, 1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
self.flatten = nn.Flatten()
self.fc = nn.Linear(16*7*7, n_labels)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.flatten(x)
x = self.fc(x)
return x
In [2]:
model = OriginNetwork(10)
model
Out[2]:
OriginNetwork( (layer1): Sequential( (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (layer2): Sequential( (0): Conv2d(32, 16, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (flatten): Flatten(start_dim=1, end_dim=-1) (fc): Linear(in_features=784, out_features=10, bias=True) )
In [3]:
# 샘플 모델의 fc 레이어를 하나더 추가 합니다.
class CustomNetwork(nn.Module):
def __init__(self, n_labels):
super(CustomNetwork, self).__init__()
# 기존에 모델을 불러옵니다.
model = OriginNetwork(n_labels)
self.model = vars(model)['_modules']
# 변화가 없는 레이어를 구성합니다
layers = []
self.layer_name_list = [n for n in self.model][:-1]
for name in self.layer_name_list:
layers.append(self.model[name])
self.layer1 = nn.ModuleList(layers) # nn.ModuleList를 통해
# 새로운 레이어를 추가 합니다.
self.fc1 = nn.Linear(784, 100)
self.fc2 = nn.Linear(100, n_labels)
def forward(self, x):
# 레이터의 순서를 정의 합니다.
for name in self.layer1:
x = name(x)
x = self.fc1(x)
x = self.fc2(x)
return x
In [4]:
model = CustomNetwork(10)
model
Out[4]:
CustomNetwork( (layer1): ModuleList( (0): Sequential( (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (1): Sequential( (0): Conv2d(32, 16, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (2): Flatten(start_dim=1, end_dim=-1) ) (fc1): Linear(in_features=784, out_features=100, bias=True) (fc2): Linear(in_features=100, out_features=10, bias=True) )
반응형
'python' 카테고리의 다른 글
ConversionError: Failed to convert value(s) to axis units: (0) | 2023.08.26 |
---|---|
argparse 복붙용 기본 코드 (0) | 2023.08.20 |
딕셔너리 값이 최대값인 키 도출 (파이썬) (0) | 2023.08.19 |
파이토치 데이터 셔플 ranperm (0) | 2023.08.19 |
파이썬 데코레이터 간편 설명 (0) | 2023.08.18 |
댓글