Pytorch lightning
Basic Basic Model
import torch.nn as nn
import pytorch_lightning as pl
class Learner(pl.LightningModule):
def __init__(self, model:nn.Module, settings:dict={}):
super().__init__()
defaults.update(settings)
self.settings = defaults
self.model = model
self.c = 0
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
logs = {'train_loss': loss}
return {'loss': loss, 'log': logs}
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.005)
def train_dataloader(self):
return trainloader