Skip to article frontmatterSkip to article content

Skorch

VAE Loss


class VAELoss(nn.Module):
    def __init__(self) -> None:
        super(VAELoss, self).__init__()

    def forward(self, model_output, X) -> torch.Tensor:
        """
        Comments.
        Args:
            None.
        Returns:
            None.
        """
        Xhat, mu, log_var = model_output
        KL_Divergence = - 0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).mean()
        Reconstruction_Loss = F.mse_loss(Xhat, X)

        loss = Reconstruction_Loss + KL_Divergence

        return loss