Добавил:
Katynska
Опубликованный материал нарушает ваши авторские права? Сообщите нам.
Вуз:
Предмет:
Файл:Lab4 / Lab4
.py
import torch
import matplotlib.pyplot as plt
import random
import numpy as np
# Установка seed для воспроизводимости результатов
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Определение исходной функции
def target_function(x):
return 2**x * torch.sin(2**-x)
# Определение функции потерь
def loss(pred, target):
squares = abs((pred - target)) # Расчет функции потерь
return squares.mean()
# Определение класса для нейронной сети
class RegressionNet(torch.nn.Module):
def __init__(self, n_hidden_neurons):
super().__init__()
self.fc1 = torch.nn.Linear(1, n_hidden_neurons)
self.act1 = torch.nn.Tanh()
self.fc2 = torch.nn.Linear(n_hidden_neurons, n_hidden_neurons)
self.act2 = torch.nn.Tanh()
self.fc3 = torch.nn.Linear(n_hidden_neurons, 1)
def forward(self, x):
x = self.fc1(x)
x = self.act1(x)
x = self.fc2(x)
x = self.act2(x)
x = self.fc3(x)
return x
# Функция предсказания с визуализацией исходной и предсказанной функций
def predict(net, x, y):
y_pred = net.forward(x) # Изменение параметров сети
plt.plot(x.numpy(), y.numpy(), 'o', label='Groud truth')
plt.plot(x.numpy(), y_pred.data.numpy(), 'o', c='r', label='Prediction');
plt.legend(loc='upper left') # Показать легенду
plt.xlabel('$x$') # Ось x
plt.ylabel('$y$') # Ось y
plt.show() # Показать график
# Подготовка данных для обучения и визуализации
x_train = torch.linspace(-10, 5, 100)
y_train = target_function(x_train)
noise = torch.randn(y_train.shape) / 20. # Создание шума
y_train = y_train + noise
x_train.unsqueeze_(1)
y_train.unsqueeze_(1)
x_validation = torch.linspace(-10, 5, 100)
y_validation = target_function(x_validation)
x_validation.unsqueeze_(1)
y_validation.unsqueeze_(1)
# Инициализация модели
net = RegressionNet(10)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
epoch_num = 2000 # Количество эпох для обучения
loss_history = [[0, 0] for _ in range(epoch_num)] # Хранение данных функции потерь
# Обучение нейронной сети
for epoch_index in range(epoch_num):
optimizer.zero_grad()
y_pred = net.forward(x_train)
loss_val = loss(y_pred, y_train)
# Сохранение данных о функции потерь
loss_history[epoch_index][0] = epoch_index
loss_history[epoch_index][1] = loss_val.data.numpy().tolist()
loss_val.backward() # Вычисление градиентов
optimizer.step() # Шаг оптимизации
# Визуализация результатов
predict(net, x_validation, y_validation)
plt.plot([row[0] for row in loss_history][100:], [row[1] for row in loss_history][100:], '.')
plt.title(label='Loss function')
plt.xlabel('Epoch_index')
plt.ylabel('Error')
plt.show()