Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Skorch

VAE Loss


class VAELoss(nn.Module):
    def __init__(self) -> None:
        super(VAELoss, self).__init__()

    def forward(self, model_output, X) -> torch.Tensor:
        """
        Comments.
        Args:
            None.
        Returns:
            None.
        """
        Xhat, mu, log_var = model_output
        KL_Divergence = - 0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).mean()
        Reconstruction_Loss = F.mse_loss(Xhat, X)

        loss = Reconstruction_Loss + KL_Divergence

        return loss