반응형
pytorch nn.Parameter 파이토치 밑바닥 부터 레이어 쌓기¶
- 문득 AI 모델에서 업데이트 되는 가중치 즉 파라미터를 코드에 어떻게 반영할 수 있을까를 생각해 보았습니다.
nn.Parameter
를 사용하면 파라미터 텐서를 만들 수 있습니다- 처음에는 랜덤한 텐서를 만들지만, 옵티마이저(
optimizer
)를 거치면 사용자가 따로 계산하지 않아도 자동 업데이트 됩니다.
In [1]:
import torch
import torch.nn as nn
import numpy as np
# 랜덤한 파라미터 텐서 생성 예시
x = torch.randn(10, 3)
w = nn.Parameter(torch.randn(5, 3))
In [2]:
# linear 연산 함수(funtional.linear)를 사용할때 weight는 out_features 다음 in_features 순으로 작성해야 합니다.
y = nn.functional.linear(x,w,bias=None)
y
Out[2]:
tensor([[-0.6175, -0.1348, 1.2370, -0.0166, -1.5779], [-0.1767, -0.4421, 0.3978, 0.4268, 3.2577], [ 1.3622, 0.4835, -1.9993, 0.7709, -0.5035], [ 0.7897, 0.4182, -1.0723, 0.4259, -1.8684], [ 2.0184, 0.3768, -2.7528, 1.7204, 1.8516], [ 0.6454, 0.2382, -0.3400, 1.1125, -2.1665], [ 0.0357, 0.7990, -0.2945, -1.0159, -6.7663], [ 0.5522, -0.0957, -0.7828, 0.6194, 2.4895], [ 1.2228, 0.6646, -2.2806, -0.1283, -1.1746], [ 0.5738, 0.3911, -1.4218, -0.5718, -0.2392]], grad_fn=<MmBackward0>)
활용 예시¶
y = x·w^T
방정식을 torch 모듈로 만들기 위해Parameter
함수를 사용한 예시 코드는 아래와 같습니다.
In [3]:
# 파이토치 모델 객체 생성
class CustomLinear(nn.Module):
def __init__(self, in_features, out_features):
super(CustomLinear, self).__init__()
self.W = nn.Parameter(torch.tensor(np.random.randn(in_features, out_features), dtype=torch.float32))
def forward(self, x):
y = nn.functional.linear(x,self.W,bias=None)
return y
In [4]:
# 모델 활용
x = torch.randn(10, 3)
model = CustomLinear(5, 3)
model(x)
Out[4]:
tensor([[ 0.2684, 0.2009, -0.3103, 0.9142, -1.0912], [ 3.1719, 1.8421, -4.1588, 0.1639, -6.3709], [-0.7688, -1.0604, 0.2915, -2.8571, 3.2368], [-0.5173, 0.9207, 2.1568, 2.2048, -0.2471], [-0.6416, -0.3402, 0.8586, 1.4128, 0.3990], [ 1.3391, 2.5827, 0.4691, 0.8698, -3.0496], [-0.2431, 1.1407, 1.8947, 0.8161, 0.0721], [-0.4500, -0.6526, 0.0861, 1.1783, 0.1336], [ 1.7873, 0.5172, -2.9461, -2.6261, -1.9513], [-1.3689, -0.6517, 1.9704, 0.0671, 2.6751]], grad_fn=<MmBackward0>)
반응형
'python' 카테고리의 다른 글
python argparse True False(action="store_true") (0) | 2023.04.24 |
---|---|
브이월드 api, 주소를 활용하여 위경도 정보 가져오기 (1) | 2023.04.23 |
python tkinter, GUI 입력창 버튼 만들기 : tk.Entry, tk.Button (0) | 2023.04.19 |
파이썬 mulit image show, matplotlib, 여러장, 두장 이미지 시각화 (0) | 2023.04.18 |
파이썬 GUI 프로그램 개발 tkinter를 활용한 gui 프로그램 개발 : 윈도우 창 이름 변경 (0) | 2023.04.18 |
댓글