Naive NerFs#

Neural Fields (NerFs) are an emerging class of coordinate-based neural networks. There has been many developments in the last few years for applying NerFs to data like images. In this tutorial, I will introduce NerFs from the geoscience perspective and highlight some potential advantages to using these methods. I will demonstrate some concrete work on sea surface height interpolation and highlight some of the problems (and potential solutions) I faced when applying this class of methods to spatiotemporal data.

import autoroot
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrandom
import optax
import numpy as np
import pandas as pd
import equinox as eqx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from dataclasses import dataclass
import hydra
from omegaconf import OmegaConf

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", False)

%env XLA_PYTHON_CLIENT_PREALLOCATE=false

%matplotlib inline
%load_ext autoreload
%autoreload 2

Coordinate-Based Models#

\[ \mathbf{x}\in\mathbb{R}^D \]

Example I: Time Series

\[\begin{split} \begin{aligned} \text{Input}: && t &\in \mathbb{R}^+ \\ \text{Outpit}: && f &: \mathbb{R}\rightarrow\mathbb{R}^U \end{aligned} \end{split}\]

Example II: Images

\[ x,y \in \mathbb{R} \]

Recap Formulation#

We are interested in learning non-linear functions \(\boldsymbol{f}\).

\[ \begin{aligned} \boldsymbol{f}(\mathbf{x}) &= \mathbf{w}^\top\boldsymbol{\phi}(\mathbf{x})+\mathbf{b} \end{aligned} \]

where the \(\boldsymbol{\phi}(\cdot)\) is a basis function. Neural Fields typically try to learn this basis funciton via a series of composite functions of the form

\[ \boldsymbol{\phi}(\mathbf{x}) = \boldsymbol{\phi}_L\circ\boldsymbol{\phi}_{L-1} \circ\cdots\circ \boldsymbol{\phi}_2\circ\boldsymbol{\phi}_{1}(\mathbf{x}) \]

Problems#

Here, we will demonstrate a problem that a naive network has.

Data#

from dataclasses import dataclass


@dataclass
class FoxDM:
    _target_: str = "jejeqx._src.datamodules.image.ImageFox"
    batch_size: int = 10_000
    resize: int = 2
    shuffle: bool = False
    split_method: str = "even"
    image_url: str = "/gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/images/fox.jpg"


config_dm = OmegaConf.structured(FoxDM())
dm = hydra.utils.instantiate(config_dm)
dm.setup()


init = dm.ds_train[:32]
x_init, y_init = init
img = dm.load_image()
img.shape

fig, ax = plt.subplots()

ax.imshow(img)
ax.set(xticks=[], yticks=[])
plt.tight_layout()
fig.savefig("./figures/demo/fox.png", dpi=100, transparent=True)
plt.show()
../../_images/3429023a44e3149f556c2a8afdeab7dab61f81b94d00ba879a3165c93937d861.png

Coordinates#

\[\begin{split} \vec{\mathbf{x}} = \begin{bmatrix} x \\ y \end{bmatrix} \end{split}\]
\[\begin{split} \mathbf{y} = \begin{bmatrix} R \\ B \\ G \end{bmatrix} \end{split}\]
print(f"Image Shape: {img.shape}")
print(f"Number of Coords: {len(dm.ds_train):,}")

Model#

The input data is a coordinate vector, \(\mathbf{x}_\phi\), of the image coordinates.

\[ \mathbf{x}_\phi \in \mathbb{R}^{D_\phi} \]

where \(D_\phi = [\text{x}, \text{y}]\). So we are interested in learning a function, \(\boldsymbol{f}\), such that we can input a coordinate vector and output a scaler/vector value of the pixel value.

\[ \mathbf{u} = \boldsymbol{f}(\mathbf{x}_\phi; \boldsymbol{\theta}) \]

MLP Layer#

\[ \mathbf{f}_\ell(\mathbf{x}) = \sigma\left(\mathbf{w}^{(\ell)}\mathbf{x} + \mathbf{b}^{(\ell)} \right) \]

where \(\sigma\) is the ReLU activation function.

\[ \sigma(\mathbf{x}) = \text{ReLU}(\mathbf{x}) \]
from typing import Dict
from dataclasses import field


@dataclass
class Activation:
    # _target_: str = "jejeqx._src.nets.activations.Tanh"
    _target_: str = "jejeqx._src.nets.activations.ReLU"


@dataclass
class IdentityAct:
    _target_: str = "equinox.nn.Identity"


@dataclass
class Key:
    _target_: str = "jax.random.PRNGKey"
    seed: int = 123


@dataclass
class MLPModel:
    _target_: str = "equinox.nn.MLP"
    in_size: int = 2
    out_size: int = 3
    width_size: int = 128
    depth: int = 5
    activation: Activation = Activation()
    final_activation: IdentityAct = IdentityAct()
    key: Key = Key(seed=42)
# initialize model
model_config = OmegaConf.structured(MLPModel())

model = hydra.utils.instantiate(model_config)

eqx.tree_pprint(model)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
MLP(
  layers=(
    Linear(
      weight=f32[128,2],
      bias=f32[128],
      in_features=2,
      out_features=128,
      use_bias=True
    ),
    Linear(
      weight=f32[128,128],
      bias=f32[128],
      in_features=128,
      out_features=128,
      use_bias=True
    ),
    Linear(
      weight=f32[128,128],
      bias=f32[128],
      in_features=128,
      out_features=128,
      use_bias=True
    ),
    Linear(
      weight=f32[128,128],
      bias=f32[128],
      in_features=128,
      out_features=128,
      use_bias=True
    ),
    Linear(
      weight=f32[128,128],
      bias=f32[128],
      in_features=128,
      out_features=128,
      use_bias=True
    ),
    Linear(
      weight=f32[3,128],
      bias=f32[3],
      in_features=128,
      out_features=3,
      use_bias=True
    )
  ),
  activation=ReLU(),
  final_activation=Identity(),
  in_size=2,
  out_size=3,
  width_size=128,
  depth=5
)

Note: We have created a function that takes a vector and outputs a vector. In JAX, we don’t have to think about batches until later.

# check output of models
x_vector, y_vector = x_init[0], y_init[0]

# predict
out_vector = model(x_vector)

assert out_vector.shape == y_vector.shape

Now, we can batches by autovectorizing using vmap.

# check output of models
out = jax.vmap(model)(x_init)

assert out.shape == y_init.shape

Optimizer (+ Learning Rate)#

For this, we will use a simple adam optimizer with a learning_rate of 1e-4. From many studies, it appears that a lower learning rate works well with this methods because there is a lot of data. In addition, a bigger batch_size is also desireable. We will set the num_epochs to 2_000 which should be good enough for a single image. Obviously more epochs and a better learning rate scheduler would result in better results but this will be sufficient for this demo.

import optax

num_epochs = 2_000


@dataclass
class Optimizer:
    _target_: str = "optax.adam"
    learning_rate: float = 1e-4


@dataclass
class Scheduler:
    _target_: str = "optax.warmup_cosine_decay_schedule"
    init_value: float = 0.0
    peak_value: float = 1e0
    warmup_steps: int = 500
    end_value: float = 1e-5


use_scheduler = False
optim_config = OmegaConf.structured(Optimizer())
scheduler_config = OmegaConf.structured(Scheduler())

# initialize optimizer
optimizer = hydra.utils.instantiate(optim_config)

if use_scheduler:
    num_steps_per_epoch = len(dm.ds_train)

    scheduler = hydra.utils.instantiate(
        scheduler_config, decay_steps=int(num_epochs * num_steps_per_epoch)
    )

    # initialize optimizer with scheduler
    optimizer = optax.chain(optimizer, optax.scale_by_schedule(scheduler))

Scheduler#

drawing

Fig.1 - An example for learning rate reduction when the validation loss stagnates. Source: Blog

We will use a simple learning rate scheduler - reduce_lr_on_plateau. This will automatically reduce the learning rate as the validation loss stagnates. It will ensure that we really squeeze out as much performance as possible from our models during the training procedure.We start with a (relatively) high learning_rate of 1e-4 so we will set the patience to 5 epochs. So if there is no change in with every epoch, we decrease the learning rate by a factor of 0.1.

This is a rather crude (but effective) method but it tends to work well in some situations. A better method might be the cosine_annealing method or the exponential_decay method. See other examples.

We are interested posterior of the parameters given the data. So we can use Bayes theorem to express this.

\[ p(\boldsymbol{\theta}|\mathcal{D}) \propto p(\mathcal{D}|\boldsymbol{\theta})p(\boldsymbol{\theta}) \]

This can be solved by using the maximimum likelihood method. So in this case, we need to define the likelihood term for the data. We can assume a Gaussian likelihood because we are working with continuous data. To make things simple, we can also assume a constant noise.

\[ p(y|\mathbf{x};\boldsymbol{\theta})=\mathcal{N} \left( \boldsymbol{f}(\mathbf{x};\boldsymbol{\theta}), \sigma^2 \right) \]

This is the maximum likelihood estimation problem. If we assume our samples are i.i.d., we get the following minimization problem

\[ \boldsymbol{\theta} = \underset{\boldsymbol{\theta}}{\text{argmin}} \hspace{1mm} \sum_{y,\mathbf{x}\in\mathcal{D}}-\log p(y|\mathbf{x};\boldsymbol{\theta}) \]

Notice the slight of hand: the minimization of the negative log-likelihood is the same as the maxmimization of the log likelihood.

If we assume a noise level of 1, i.e. \(\sigma=1\), then this loss reduces to the mean squared error (MSE) loss function:

\[ \mathcal{L}(\boldsymbol{\theta}) = \frac{1}{|\mathcal{D}|} \sum_{n\in\mathcal{D}} \left(y - \boldsymbol{f_\theta}(\mathbf{x})\right)^2 \]

We can take minibatches

\[ \mathcal{B}=\left\{\mathbf{x}_b,y_b \right\}^B_{b=1} \]

which is a proper subset of the dataset, \(\mathcal{B} \mathcal{D}=\mathcal{B}\).

So our new loss function will be:

\[ \mathcal{L}(\boldsymbol{\theta}) = \frac{1}{|\mathcal{B}|} \sum_{n\in\mathcal{B}} \left(y - \boldsymbol{f_\theta}(\mathbf{x})\right)^2 \]

Trainer Module#

import glob
import os
from pathlib import Path

from jejeqx._src.trainers.base import TrainerModule
from jejeqx._src.trainers.callbacks import wandb_model_artifact
from jejeqx._src.losses import psnr


class RegressorTrainer(TrainerModule):
    def __init__(self, model, optimizer, **kwargs):
        super().__init__(model=model, optimizer=optimizer, pl_logger=None, **kwargs)

    def create_functions(self):
        @eqx.filter_value_and_grad
        def mse_loss(model, batch):
            x, y = batch
            pred = jax.vmap(model)(x)
            loss = jnp.mean((y - pred) ** 2)
            return loss

        def train_step(state, batch):
            loss, grads = mse_loss(state.params, batch)
            state = state.update_state(state, grads)
            psnr_loss = psnr(loss)
            metrics = {"loss": loss, "psnr": psnr_loss}
            return state, loss, metrics

        def eval_step(model, batch):
            loss, _ = mse_loss(model, batch)
            psnr_loss = psnr(loss)
            return {"loss": loss, "psnr": psnr_loss}

        def test_step(model, batch):
            x, y = batch
            out = jax.vmap(model)(x)
            loss, _ = mse_loss(model, batch)
            psnr_loss = psnr(loss)
            return out, {"loss": loss, "psnr": psnr_loss}

        def predict_step(model, batch):
            x, y = batch
            out = jax.vmap(model)(x)
            return out

        return train_step, eval_step, test_step, predict_step

    def on_training_end(
        self,
    ):
        if self.pl_logger:
            save_dir = Path(self.log_dir).joinpath(self.save_name)
            self.save_model(save_dir)
            wandb_model_artifact(self)
            self.pl_logger.finalize("success")
seed = 123
debug = False
enable_progress_bar = False
log_dir = "./"

trainer = RegressorTrainer(
    model,
    optimizer,
    seed=seed,
    debug=debug,
    enable_progress_bar=enable_progress_bar,
    log_dir=log_dir,
)
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.37384870648384094, 'psnr': 3.4702699184417725}
trainer.load_model("./checkpoints/checkpoint_model_mlp.ckpt")
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.008850343525409698, 'psnr': 41.5156135559082}
%%time

# metrics = trainer.train_model(dm, num_epochs=num_epochs)
CPU times: user 13 µs, sys: 3 µs, total: 16 µs
Wall time: 31.5 µs
# trainer.save_model("./checkpoints/checkpoint_model_mlp.ckpt")
# trainer.save_state("checkpoint_state.ckpt")
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.008850343525409698, 'psnr': 41.5156135559082}
all_metrics = pd.DataFrame(
    data=[["mlp", metrics["loss"], metrics["psnr"]]],
    columns=["model", "MSE", "SNR"],
)
all_metrics
model MSE SNR
0 mlp 0.00885 41.515614
out_mlp = dm.coordinates_2_image(out)
fig, ax = plt.subplots(ncols=2, figsize=(8, 4))
ax[0].imshow(img)
ax[0].set(title="Original")
ax[1].imshow(out_mlp)
ax[1].set(title="Naive MLP")

plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
../../_images/9be9c1c17f065524561fb6dbb1ceb083d7726ce76233bfbb2921d44c9b59a25f.png
fig, ax = plt.subplots()

ax.imshow(out_mlp)
ax.set(xticks=[], yticks=[])
plt.tight_layout()
fig.savefig("./figures/demo/mlp.png", dpi=100, transparent=True)
plt.close()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Random Fourier Features#

\[ \boldsymbol{\phi}(\mathbf{x}) = \sqrt{\frac{\sigma^2}{N_{RF}}} \left[ \cos(\boldsymbol{\Omega}\mathbf{x}), \sin(\boldsymbol{\Omega}\mathbf{x}) \right] \]

where \(\boldsymbol{\Omega}\) is a random matrix sampled from a Gaussian distribution.

\[\begin{split} \begin{aligned} p(\Omega)&\sim\mathcal{N}(0,\boldsymbol{\Lambda}^{-1}_D) \\ \boldsymbol{\Lambda}_D &= \text{diag} \left(\lambda_1, \lambda_2, \ldots, \lambda_D\right) \end{aligned} \end{split}\]

So our final neural network with the additional basis function:

\[ \boldsymbol{f}(\mathbf{x};\boldsymbol{\theta}) = \mathbf{w}^\top\boldsymbol{\phi}(\mathbf{x})+\mathbf{b} \]

where \(\boldsymbol{\phi}(\cdot)\) is the learned basis network.

from typing import Dict
from dataclasses import field


@dataclass
class RFFModel:
    _target_: str = "jejeqx._src.nets.nerfs.ffn.RFFNet"
    in_size: int = 2
    out_size: int = 3
    width_size: int = 256
    depth: int = 5
    ard: bool = True
    method: str = "rbf"
    key: Key = Key(seed=42)
# initialize model
model_config = OmegaConf.structured(RFFModel())

model = hydra.utils.instantiate(model_config)

# check output of models
out = jax.vmap(model)(x_init)

assert out.shape == y_init.shape
# eqx.tree_pprint(model)
seed = 123
debug = False
enable_progress_bar = False
log_dir = "./"
num_epochs = 6_000

trainer = RegressorTrainer(
    model,
    optimizer,
    seed=seed,
    debug=debug,
    enable_progress_bar=enable_progress_bar,
    log_dir=log_dir,
)

train_more = False
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.3707599341869354, 'psnr': 3.5415871143341064}
try:
    trainer.load_model("./checkpoints/checkpoint_model_rff.ckpt")
except:
    RuntimeError()
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.0028392504900693893, 'psnr': 53.401878356933594}
%%time

if train_more:
    metrics = trainer.train_model(dm, num_epochs=num_epochs)
CPU times: user 16 µs, sys: 3 µs, total: 19 µs
Wall time: 36.2 µs
if train_more:
    trainer.save_model("./checkpoints/checkpoint_model_rff.ckpt")
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.0028392504900693893, 'psnr': 53.401878356933594}
all_metrics = pd.concat(
    [
        all_metrics,
        pd.DataFrame(
            data=[["rff", metrics["loss"], metrics["psnr"]]],
            columns=["model", "MSE", "SNR"],
        ),
    ]
)
all_metrics
model MSE SNR
0 mlp 0.008850 41.515614
0 rff 0.002839 53.401878
out_rff = dm.coordinates_2_image(out)
fig, ax = plt.subplots(ncols=3, figsize=(12, 4))
ax[0].imshow(img)
ax[0].set(title="Original")
ax[1].imshow(out_mlp)
ax[1].set(title="Naive MLP")
ax[2].imshow(out_rff)
ax[2].set(title="Random Fourier Features")

plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
../../_images/0c7e75a5a981a72fe8967e59c7bd3bfce80bfe5a23be8d8c9496f49eb0939687.png
fig, ax = plt.subplots()

ax.imshow(out_rff)
ax.set(xticks=[], yticks=[])
plt.tight_layout()
fig.savefig("./figures/demo/ffn.png", dpi=100, transparent=True)
plt.close()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Custom Activation Functions#

SIREN

One of the most famous methods is the SIREN method. This replaces the standard activation function, \(\sigma\), with a sinusoidal function.

\[ \phi(\mathbf{x})_\ell = \sin \left( \omega_\ell\left( \mathbf{w}_\ell\mathbf{x} + \mathbf{b}_\ell \right)\right) \]

So our final neural network with the additional basis function:

\[ \boldsymbol{f}(\mathbf{x};\boldsymbol{\theta}) = \mathbf{w}^\top\boldsymbol{\phi}(\mathbf{x})+\mathbf{b} \]

where \(\boldsymbol{\phi}(\cdot)\) is the learned basis network.

@dataclass
class Key:
    _target_: str = "jax.random.PRNGKey"
    seed: int = 123


@dataclass
class SirenBasis:
    _target_: str = "jejeqx._src.nets.nerfs.siren.SirenNet"
    in_size: int = 2
    out_size: int = 256
    width_size: int = 256
    depth: int = 5
    key: Key = Key()


@dataclass
class LinearModel:
    _target_: str = "equinox.nn.Linear"
    in_features: int = 256
    out_features: int = 3
    use_bias: bool = True
    key: Key = Key()


@dataclass
class NerFModel:
    _target_: str = "jejeqx._src.nets.nerfs.base.NerF"
    network: LinearModel = LinearModel()
    basis_net: SirenBasis = SirenBasis()
from jejeqx._src.nets.nerfs.base import NerF
# initialize model
model_config = OmegaConf.structured(NerFModel())

model = hydra.utils.instantiate(model_config)

# check output of models
out = jax.vmap(model)(x_init)

# assert out.shape == y_init.shape
# eqx.tree_pprint(model)
out.shape
(32, 3)
seed = 123
debug = False
enable_progress_bar = False
log_dir = "./"

trainer = RegressorTrainer(
    model,
    optimizer,
    seed=seed,
    debug=debug,
    enable_progress_bar=enable_progress_bar,
    log_dir=log_dir,
)

train_more = False
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.47536465525627136, 'psnr': 0.8998086452484131}
try:
    trainer.load_model("./checkpoints/checkpoint_model_siren.ckpt")
except:
    pass
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.0019353348761796951, 'psnr': 57.65052795410156}
%%time

if train_more:
    metrics = trainer.train_model(dm, num_epochs=num_epochs)
CPU times: user 15 µs, sys: 3 µs, total: 18 µs
Wall time: 34.1 µs
if train_more:
    trainer.save_model("./checkpoints/checkpoint_model_siren.ckpt")
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
{'loss': 0.0019353348761796951, 'psnr': 57.65052795410156}
all_metrics = pd.concat(
    [
        all_metrics,
        pd.DataFrame(
            data=[["siren", metrics["loss"], metrics["psnr"]]],
            columns=["model", "MSE", "SNR"],
        ),
    ]
)
out_siren = dm.coordinates_2_image(out)
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(8, 8))

ax[0, 0].imshow(img)
ax[0, 0].set(title="Original")
ax[0, 1].imshow(out_mlp)
ax[0, 1].set(title="MLP")
ax[1, 0].imshow(out_rff)
ax[1, 0].set(title="Fourier Features")
ax[1, 1].imshow(out_siren)
ax[1, 1].set(title="Siren")

plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
../../_images/e1ebd8c6cf78c3c22b0d42c9f214b50520c23dc1ccdafadd75631a4d74a0e1ea.png
fig, ax = plt.subplots()

ax.imshow(out_siren)
ax.set(xticks=[], yticks=[])
plt.tight_layout()
fig.savefig("./figures/demo/siren.png", dpi=100, transparent=True)
plt.close()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Multiplicative Filter Networks (MFN)#

Fourier Network#

import optax

num_epochs = 2_000


@dataclass
class Optimizer:
    _target_: str = "optax.adam"
    learning_rate: float = 1e-4


@dataclass
class Scheduler:
    _target_: str = "optax.warmup_cosine_decay_schedule"
    init_value: float = 0.0
    peak_value: float = 1e0
    warmup_steps: int = 500
    end_value: float = 1e-5


use_scheduler = False
optim_config = OmegaConf.structured(Optimizer())
scheduler_config = OmegaConf.structured(Scheduler())

# initialize optimizer
optimizer = hydra.utils.instantiate(optim_config)

if use_scheduler:
    num_steps_per_epoch = len(dm.ds_train)

    scheduler = hydra.utils.instantiate(
        scheduler_config, decay_steps=int(num_epochs * num_steps_per_epoch)
    )

    # initialize optimizer with scheduler
    optimizer = optax.chain(optimizer, optax.scale_by_schedule(scheduler))
from typing import Dict
from dataclasses import field


@dataclass
class IdentityAct:
    _target_: str = "equinox.nn.Identity"


@dataclass
class Key:
    _target_: str = "jax.random.PRNGKey"
    seed: int = 123


@dataclass
class MFNModel:
    _target_: str = "jejeqx._src.nets.nerfs.mfn.FourierNet"
    in_size: int = 2
    out_size: int = 3
    width_size: int = 256
    depth: int = 4
    final_activation: IdentityAct = IdentityAct()
    key: Key = Key(seed=123)
# initialize model
model_config = OmegaConf.structured(MFNModel())

model = hydra.utils.instantiate(model_config)

# eqx.tree_pprint(model)
seed = 123
debug = False
enable_progress_bar = False
log_dir = "./"

trainer = RegressorTrainer(
    model,
    optimizer,
    seed=seed,
    debug=debug,
    enable_progress_bar=enable_progress_bar,
    log_dir=log_dir,
)

train_more = False
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
try:
    trainer.load_model("./checkpoints/checkpoint_model_mfnfourier.ckpt")
except:
    pass
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
%%time

if train_more:
    metrics = trainer.train_model(dm, num_epochs=num_epochs)
if train_more:
    trainer.save_model("./checkpoints/checkpoint_model_mfnfourier.ckpt")
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
all_metrics = pd.concat(
    [
        all_metrics,
        pd.DataFrame(
            data=[["mfnfourier", metrics["loss"], metrics["psnr"]]],
            columns=["model", "MSE", "SNR"],
        ),
    ]
)
out_mfn_fourier = dm.coordinates_2_image(out)
fig, ax = plt.subplots(ncols=2, nrows=3, figsize=(8, 8))

ax[0, 0].imshow(img)
ax[0, 0].set(title="Original")
ax[0, 1].imshow(out_mlp)
ax[0, 1].set(title="MLP")
ax[1, 0].imshow(out_rff)
ax[1, 0].set(title="Fourier Features")
ax[1, 1].imshow(out_siren)
ax[1, 1].set(title="Siren")
ax[2, 1].imshow(out_mfn_fourier)
ax[2, 1].set(title="MFN (FourierNet)")

plt.tight_layout()
plt.show()

GaborNet#

from typing import Dict
from dataclasses import field


@dataclass
class IdentityAct:
    _target_: str = "equinox.nn.Identity"


@dataclass
class Key:
    _target_: str = "jax.random.PRNGKey"
    seed: int = 123


@dataclass
class MFNModel:
    _target_: str = "jejeqx._src.nets.nerfs.mfn.GaborNet"
    in_size: int = 2
    out_size: int = 3
    width_size: int = 256
    depth: int = 4
    final_activation: IdentityAct = IdentityAct()
    key: Key = Key(seed=123)
x_ = jrandom.normal(key=jrandom.PRNGKey(123), shape=(10,))
y_ = jrandom.normal(key=jrandom.PRNGKey(42), shape=(10, 20))
from scipy.spatial.distance import pdist, cdist
import einops
cdist(einops.repeat(x_, "D -> D R", R=20), y_, metric="sqeuclidean").shape
def norm(x, y):
    a = np.sum(x**2, axis=-1)[..., None]
    b = np.sum(y**2, axis=0)[None, :]
    c = -2 * a @ b

    D = a + b + c
    return D.squeeze()


def norm2(x, y):
    a = np.sum(x**2, axis=-1)[..., None]
    b = np.sum(y**2, axis=0)[None, :]
    c = np.einsum("i,ij->j", x, y)

    D = a + b - 2 * c
    return D.squeeze()


def norm3(x, y):
    a_min_b = x[..., None] - y

    D = np.einsum("ij,ij->j", a_min_b, a_min_b)

    return np.sqrt(D).squeeze()


def norm4(x, y):
    return np.linalg.norm(x[..., None] - y, ord=2, axis=1)
norm(x_, y_), norm2(x_, y_), norm3(x_, y_), norm4(x_, y_)
o_ = norm(x_, y_)
o__ = norm2(x_, y_)
o_.shape, o__.shape
o_
o__
a = (x_**2).sum(axis=-1)[..., None]
b = (y_**2).sum(axis=0)[None, ...]
c = jnp.einsum("i,ij -> j", x_, y_)
a.shape, b.shape, (a + b).shape, c.shape
x_.shape, y_.shape
from jejeqx._src.nets.nerfs.mfn import GaborLayer

layer = GaborLayer(2, 3)
layer(x_init[0])
# initialize model
model_config = OmegaConf.structured(MFNModel())

model = hydra.utils.instantiate(model_config)

# eqx.tree_pprint(model)
seed = 123
debug = False
enable_progress_bar = False
log_dir = "./"

trainer = RegressorTrainer(
    model,
    optimizer,
    seed=seed,
    debug=debug,
    enable_progress_bar=enable_progress_bar,
    log_dir=log_dir,
)

train_more = True
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
try:
    trainer.load_model("./checkpoints/checkpoint_model_mfngabor.ckpt")
except:
    pass
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
%%time
#
if train_more:
    metrics = trainer.train_model(dm, num_epochs=num_epochs)
if train_more:
    trainer.save_model("./checkpoints/checkpoint_model_mfngabor.ckpt")
out, metrics = trainer.test_model(dm.test_dataloader())
metrics
all_metrics = pd.concat(
    [
        all_metrics,
        pd.DataFrame(
            data=[["mfngabor", metrics["loss"], metrics["psnr"]]],
            columns=["model", "MSE", "SNR"],
        ),
    ]
)
out_mfn_gabor = dm.coordinates_2_image(out)
fig, ax = plt.subplots(ncols=2, nrows=3, figsize=(8, 8))

ax[0, 0].imshow(img)
ax[0, 0].set(title="Original")
ax[0, 1].imshow(out_mlp)
ax[0, 1].set(title="MLP")
ax[1, 0].imshow(out_rff)
ax[1, 0].set(title="Fourier Features")
ax[1, 1].imshow(out_siren)
ax[1, 1].set(title="Siren")
ax[2, 0].imshow(out_mfn_fourier)
ax[2, 0].set(title="MFN (Fourier)")
ax[2, 1].imshow(out_mfn_gabor)
ax[2, 1].set(title="MFN (Gabor)")

plt.tight_layout()
plt.show()
all_metrics