defforward(self, input_data): flattened_data = self.flatten(input_data) logits = self.dense_layers(flattened_data) # logits here means the input of the final softmax predictions = self.softmax(logits) return predictions
deftrain_one_epoch(model, data_loader, loss_fn, optimizer, device): model.train() # change to train mode loss_sum = 0. correct = 0 for inputs, targets in data_loader: inputs, targets = inputs.to(device), targets.to(device)
# calculate loss predictions = model(inputs) # this will call forward function automatically loss = loss_fn(predictions, targets)
# backpropagate loss and update weights optimizer.zero_grad() # reset grads loss.backward() # calculate grads optimizer.step() # update weights
loss_sum += loss.item() # item() returns the value of this tensor as a standard Python number with torch.no_grad(): _, predictions_indexes = torch.max(predictions, 1) # get predicted indexes correct += torch.sum(predictions_indexes == targets) # or correct += (predictions.argmax(1) == targets).type(torch.float).sum().item()
# create a data loader for the train set train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE) val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
# train model train(simple_net, train_data_loader, val_data_loader, loss_fn, optimizer, device, EPOCHS)
# save model torch.save(simple_net.state_dict(), "simple_net.pth") print("Model saved") # torch.save(model.state_dict(), "my_model.pth") # only save parameters # torch.save(model, "my_model.pth") # save the whole model # checkpoint = {"net": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch}