python
파이토치 모델 입력값 확인하기 model.parameters
타닥타닥 토다토닥 부부
2023. 1. 5. 23:35
반응형
파이써 파이토치 모델의 input size 확인 하기 pytorch model parameters
In [1]:
import torch
import torch.nn as nn
In [2]:
#예시 모델 만들기
# Module 객체를 상속받는 형식으로 torch 모델을 생성합니다.
class ModelExample(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 20)
def forward(self, x):
out = self.fc1(x)
out = F.relu(out)
out = self.fc2(out)
out = F.relu(out)
out = self.fc3(out)
return out
model = ModelExample()
- 모델의 input shape 확인하기
In [3]:
input_parameter = next(model.parameters())
input_parameter.size()
Out[3]:
torch.Size([256, 64])
- 파이토치에서 제공하는 RNN 모델을 사용하여 input shape를 확인합니다.
In [4]:
# rnn 모델 생성
rnn = nn.RNN(input_size=1, hidden_size=10)
In [5]:
# input shape 확인
input_shape = next(rnn.parameters())
input_shape.size()
Out[5]:
torch.Size([10, 1])
In [6]:
# input size가 1인 샘플데이터 적용해보기
input_tensor = torch.randn(100, 1)
out, hidden = rnn(input_tensor)
In [7]:
out.shape
Out[7]:
torch.Size([100, 10])
In [8]:
hidden.shape
Out[8]:
torch.Size([1, 10])
- 깃허브, 블로그 등 모델의 특성을 이해하고 있지만 input 값의 형태를 모를 때 간단히 확인해 볼 수 있는 방법 입니다.
반응형