Experiment - RBIG Sample Consistency#

import sys, os
from pyprojroot import here

# spyder up to find the root
pysim_root = "/home/emmanuel/code/pysim"
rbig_root = "/home/emmanuel/code/rbig"
# append to path
sys.path.append(str(rbig_root))
sys.path.append(str(pysim_root))

import numpy as np

# MATPLOTLIB Settings
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# SEABORN SETTINGS
import seaborn as sns
import corner

sns.set_context(context="poster", font_scale=0.7)

%load_ext autoreload
%autoreload 2
%load_ext lab_black

Mutual Information#

# ==========================
# INITIALIZE LOGGER
# ==========================
import wandb

wandb_logger = wandb.init(project="rbig4it", entity="ipl_uv")
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: ml4floods (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
Tracking run with wandb version 0.10.30
Syncing run charmed-brook-1 to Weights & Biases (Documentation).
Project page: https://wandb.ai/ipl_uv/rbig4it
Run page: https://wandb.ai/ipl_uv/rbig4it/runs/ep72ky8b
Run data is saved locally in /home/emmanuel/documents/research_notebook/research_notebook/content/notes/info_theory/experiments/wandb/run-20210609_001640-ep72ky8b

Dataset#

from pysim.data.information.gaussian import (
    generate_gaussian_data,
    generate_gaussian_rotation_data,
)
from pysim.data.information.linear import generate_linear_entropy_data
from pysim.data.information.studentt import generate_studentt_data
from functools import partial
from pysim.information.entropy import marginal_entropy
from pysim.information.histogram import hist_entropy
from typing import NamedTuple


def get_tc_datasets(n_samples, n_features, seed, dataset="gaussian", **kwargs):

    if dataset == "gaussian":

        res = generate_gaussian_data(
            n_samples=n_samples,
            n_features=n_features,
            seed=seed,
            n_base_samples=500_000,
        )

    elif dataset == "gaussian_rotation":

        res = generate_gaussian_rotation_data(
            n_samples=n_samples,
            n_features=n_features,
            seed=seed,
            n_base_samples=500_000,
        )

    elif dataset == "linear_rotation":

        f = partial(marginal_entropy, estimator=hist_entropy, bins="sqrt")

        res = generate_linear_entropy_data(
            n_samples=n_samples,
            n_features=n_features,
            seed=seed,
            marg_h_estimator=f,
            estimator_name="histogram",
            n_base_samples=500_000,
        )
    elif dataset == "studentt":

        res = generate_studentt_data(
            n_samples=n_samples,
            n_features=n_features,
            seed=seed,
            n_base_samples=500_000,
            df=kwargs.get("df", 3.0),
        )

    elif dataset == "cauchy":

        res = generate_studentt_data(
            n_samples=n_samples,
            n_features=n_features,
            seed=seed,
            n_base_samples=500_000,
            df=1.0,
        )

    else:
        raise ValueError(f"Unrecognized dataset.")

    return res
datasets = [
    "gaussian",
    "gaussian_rotation",
    "linear_rotation",
    "studentt",
    "cauchy",
]
for idataset in datasets:
    res = get_tc_datasets(n_samples=10_000, n_features=10, seed=42, dataset=idataset)
    print(f"{idataset.capitalize()}")
    fig = corner.corner(res.X, hist_factor=2, color="red")
    plt.tight_layout()
    plt.gcf()
    wandb.log({f"data_{idataset}_X": wandb.Image(fig)})
    plt.show()
    plt.close(fig)
Gaussian
../../../../_images/c3b52d6d197db4d6ad95ca127c503996fa66e4c3c1b4bf5e015bc5c44285cb42.png
Gaussian_rotation
../../../../_images/8745803c3afa0668624f6ae9cd19669fbd5b86a9fac1ec25466de2af41c231d8.png
Linear_rotation
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
../../../../_images/5c1f075fe7291a0358f70ce8497a11aaf898f93f6fcccc4ef43c5b7cd26443ce.png
Studentt
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
../../../../_images/e13857e525a64e075339f6452f057f9ad23034cadac6b12a274f61c372591612.png
Cauchy
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
WARNING:root:Too few points to create valid contours
../../../../_images/1842b062eb90b886d987b47251fb5f132f7d0ced3be4c13fef85875c429a4a57.png
wandb: Network error resolved after 0:00:45.774716, resuming normal operation.
wandb: 500 encountered ({"error":"Error 1040: Too many connections"}), retrying request
wandb: Network error resolved after 0:00:10.234407, resuming normal operation.
wandb: 500 encountered ({"errors":[{"message":"Error 1040: Too many connections","path":["project"]}],"data":{"project":null}}), retrying request
wandb: Network error resolved after 0:04:57.411301, resuming normal operation.
wandb: Network error resolved after 0:44:48.190587, resuming normal operation.
wandb: 500 encountered ({"errors":[{"message":"Error 1040: Too many connections","path":["project"]}],"data":{"project":null}}), retrying request
wandb: Network error resolved after 0:05:14.952061, resuming normal operation.

Estimators#

from pysim.information.gaussian import gauss_total_corr
from mutual_info.mutual_info import entropy as mi_entropy
from npeet.entropy_estimators import entropy as npeet_entropy
from rbig.rbig import RBIG
import time
from scipy import stats


class TCResult(NamedTuple):
    time: float
    TC: float
    name: str


def total_correlation(X, f, **kwargs):

    return np.sum([f(ix[:, None], **kwargs) for ix in X.T]) - f(X, **kwargs)


def get_tc_estimators(X, method="gaussian", **kwargs):

    if method == "gaussian":
        t0 = time.time()
        res = gauss_total_corr(X=X.copy())
        t1 = time.time() - t0
        return TCResult(t1, res, "gaussian")

    elif method == "knn_nbs":
        t0 = time.time()
        res = total_correlation(X, mi_entropy, k=10, transform=None)
        t1 = time.time() - t0
        return TCResult(t1, res, "knn_nbs")

    elif method == "knn_eps":
        t0 = time.time()
        res = total_correlation(X, npeet_entropy, k=10, base=np.e)
        t1 = time.time() - t0
        return TCResult(t1, res, "knn_eps")

    elif method == "rbig":
        from rbig.rbig import RBIG

        t0 = time.time()
        rbig_model = RBIG(pdf_extension=10)
        rbig_model.fit(X.copy())
        res = rbig_model.mutual_information * np.log(2)
        t1 = time.time() - t0
        return TCResult(t1, res, "rbig")
    else:
        raise ValueError(f"Unrecognized estimator.")

Toy Data#

from pysim.utils import dict_product


datasets = [
    "gaussian",
    "gaussian_rotation",
    "linear_rotation",
    "studentt",
    "cauchy",
]
n_samples = [100, 1_000, 10_000, 100_000]

params = {
    "n_features": [2, 3, 5, 10, 15],
    "n_trials": list(np.arange(1, 11)),
    "dataset": datasets,
}
params = dict_product(params)

models = [
    "gaussian",
    "knn_nbs",
    "knn_eps",
    "rbig",
]

Gaussian#

Implementation Notes:

  1. We randomly generate a positive semi-definite \(D \times D\) matrix. This acts as our covariance matrix \(\boldsymbol{\Sigma}\).

  2. We use a mean, \(\boldsymbol{\mu}\) of \(\mathbf{0}\).

  3. We generate an upper limit of 5e5 data points and then random subset the requested number.

from functools import partial
from tqdm.autonotebook import tqdm, trange
import tqdm
import pandas as pd

results_df = pd.DataFrame()
results_dict = {}

with tqdm.tqdm(params) as pbar:

    for i, iparam in enumerate(pbar):

        # generate data
        res_tuple = get_tc_datasets(
            n_samples=200_000,
            n_features=iparam["n_features"],
            seed=iparam["n_trials"],
            dataset=iparam["dataset"],
        )

        #         results_dict = {**results_dict, **iparam}
        #         results_dict["model"] = "true"
        #         results_dict["dataset"] = iparam["dataset"]
        #         results_dict["tc"] = res_tuple.TC

        #         results_dict["time"] = 0.0

        #         results_df = pd.concat(
        #             [results_df, pd.DataFrame(results_dict, index=[i])], axis=0
        #         )

        rng = np.random.RandomState(iparam["n_trials"])

        for isamples in n_samples:

            # Subsample data
            sub_idx = rng.choice(np.arange(res_tuple.X.shape[0]), isamples)

            X_sub = res_tuple.X[sub_idx]

            for imodel in models:

                pbar.set_description(
                    f"Features: {iparam['n_features']} | Trial: {iparam['n_trials']} | Samples: {isamples} | Models: {imodel}"
                )

                # KNN (NEIGHBOURS)

                # do calculation
                res = get_tc_estimators(X=X_sub.copy(), method=imodel)

                results_dict = {**results_dict, **iparam}
                results_dict["model"] = res.name
                results_dict["dataset"] = iparam["dataset"]
                results_dict["approx"] = res.TC
                results_dict["true"] = res_tuple.TC
                results_dict["time"] = res.time
                results_dict["n_samples"] = isamples
                results_dict["n_features"] = iparam["n_features"]

                results_df = pd.concat(
                    [results_df, pd.DataFrame(results_dict, index=[i])], axis=0
                )
Features: 15 | Trial: 5 | Samples: 100000 | Models: rbig:  90%|████████▉ | 224/250 [4:46:23<1:44:46, 241.80s/it]     
results_df.tail()
results_df.to_csv("./rbig_consistency.csv")
wandb.log({"results": wandb.Table(dataframe=results_df)})
import json

with open("./temp_res.json") as json_data:
    data = json.load(json_data)

results_df = pd.DataFrame(data["data"], columns=data["columns"])
results_df.head()
n_samples n_features n_trials dataset model mi time approx true
0 1000 1 1 gaussian true 0.452329 0.000000 0.000000 0.000000
1 1000 1 1 gaussian gaussian 0.452329 0.002186 0.451450 0.452329
2 1000 1 1 gaussian rv 0.452329 0.000250 0.593796 0.452329
3 1000 1 1 gaussian knn_nbs 0.452329 0.024578 0.414026 0.452329
4 1000 1 1 gaussian knn_eps 0.452329 0.011041 0.428369 0.452329
results_df_approx = results_df.copy()
def create_results_xr(df):

    df = df[df["n_samples"] != 100]

    # make index the experiment params
    df = df.set_index(["model", "n_samples", "n_features", "n_trials", "dataset"])

    #
    df_xr = df.to_xarray()

    return df_xr


def get_mu_std(ds):

    # mean
    mu = ds.mean(dim="n_trials")

    # standard deviation
    std = ds.std(dim="n_trials")

    return mu, std
results_xr = create_results_xr(results_df_approx)
results_xr
<xarray.Dataset>
Dimensions:     (dataset: 5, model: 4, n_features: 5, n_samples: 3, n_trials: 10)
Coordinates:
  * model       (model) object 'gaussian' 'knn_eps' 'knn_nbs' 'rbig'
  * n_samples   (n_samples) int64 1000 10000 100000
  * n_features  (n_features) int64 2 3 5 10 15
  * n_trials    (n_trials) int64 1 2 3 4 5 6 7 8 9 10
  * dataset     (dataset) object 'cauchy' 'gaussian' ... 'studentt'
Data variables:
    approx      (model, n_samples, n_features, n_trials, dataset) float64 0.7...
    true        (model, n_samples, n_features, n_trials, dataset) float64 0.6...
    time        (model, n_samples, n_features, n_trials, dataset) float64 0.0...
def plot_results(results_xr, dataset="gaussian", n_features=2):

    fig, ax = plt.subplots(figsize=(10, 5))

    # True
    mu, std = get_mu_std(
        results_xr.sel(model="gaussian", n_features=n_features, dataset=dataset).true
    )

    mu.plot.line(ax=ax, x="n_samples", color="Black", linewidth=5, label="Truth")
    ax.plot(mu.n_samples, mu.values + std.values, linestyle="--", color="gray")
    ax.plot(mu.n_samples, mu.values - std.values, linestyle="--", color="gray")
    ax.fill_between(
        mu.n_samples,
        mu.values - std.values,
        mu.values + std.values,
        alpha=0.3,
        color="gray",
    )

    # Gaussian Approximation
    mu, std = get_mu_std(
        results_xr.sel(model="gaussian", n_features=n_features, dataset=dataset).approx
    )

    mu.plot.line(
        ax=ax, x="n_samples", color="Green", linewidth=5, label=r"Gaussian",
    )
    ax.plot(mu.n_samples, mu.values + std.values, linestyle="--", color="green")
    ax.plot(mu.n_samples, mu.values - std.values, linestyle="--", color="green")
    ax.fill_between(
        mu.n_samples,
        mu.values - std.values,
        mu.values + std.values,
        alpha=0.3,
        color="green",
    )

    # KNN
    mu, std = get_mu_std(
        results_xr.sel(model="knn_nbs", n_features=n_features, dataset=dataset).approx
    )

    mu.plot.line(
        ax=ax, x="n_samples", linewidth=5, color="Orange", label=r"$k$-Neighbours"
    )
    ax.plot(mu.n_samples, mu.values + std.values, linestyle="--", color="orange")
    ax.plot(mu.n_samples, mu.values - std.values, linestyle="--", color="orange")
    ax.fill_between(
        mu.n_samples,
        mu.values - std.values,
        mu.values + std.values,
        alpha=0.3,
        color="orange",
    )
    # KNN (Epsilon)
    mu, std = get_mu_std(
        results_xr.sel(model="knn_eps", n_features=n_features, dataset=dataset).approx
    )

    mu.plot.line(
        ax=ax, x="n_samples", linewidth=5, color="Blue", label=r"$\epsilon$-Neighbours"
    )
    ax.plot(mu.n_samples, mu.values + std.values, linestyle="--", color="blue")
    ax.plot(mu.n_samples, mu.values - std.values, linestyle="--", color="blue")
    ax.fill_between(
        mu.n_samples,
        mu.values - std.values,
        mu.values + std.values,
        alpha=0.3,
        color="blue",
    )
    # RBIG
    mu, std = get_mu_std(
        results_xr.sel(model="rbig", n_features=n_features, dataset=dataset).approx
    )

    mu.plot.line(ax=ax, x="n_samples", linewidth=5, color="Red", label=r"RBIG")
    ax.plot(mu.n_samples, mu.values + std.values, linestyle="--", color="red")
    ax.plot(mu.n_samples, mu.values - std.values, linestyle="--", color="red")
    ax.fill_between(
        mu.n_samples,
        mu.values - std.values,
        mu.values + std.values,
        alpha=0.3,
        color="red",
    )
    ax.set_ylabel("Total Correlation", fontsize=20)
    ax.set_xlabel("Number of Samples", fontsize=20)
    ax.grid(which="both")
    ax.set_xscale("log")
    ax.set_title("")
    ax.legend()
    plt.gcf()
    wandb.log({f"consistency_{dataset}_{n_features}": wandb.Image(fig)})
    plt.show()
    plt.close(fig)
for idataset in datasets:

    print(idataset)
    for idims in [2, 3, 5, 10, 15]:
        print(f"Dimensions: {idims}")
        plot_results(results_xr, idataset, idims)
gaussian
Dimensions: 2
../../../../_images/98994e48271aadcbd4dfed0d254ae4cac7eb44c6a20ee1d643b1c5cf4a3d88e9.png
Dimensions: 3
../../../../_images/bdf0a66f678ea0fd77714c8715709cf6690b9662d152f3a99c3b57c03fa4b69a.png
Dimensions: 5
../../../../_images/f269b9132103ada90c72f8bb0c641a51d7520e687b4e210f611ec51351d10556.png
Dimensions: 10
../../../../_images/b0848eef334ce986bc689b91583587a8ac4b17dfebf01ca98412e513dff197e7.png
Dimensions: 15
../../../../_images/b25a9c4813168450e3244297f089587aadcc041459b201ff7e24ffe31abdc8ec.png
gaussian_rotation
Dimensions: 2
../../../../_images/615298e44c108a3156ac2d9d0e3b9fc3db915aaf92aaeac661764ae1e01ddaf9.png
Dimensions: 3
../../../../_images/0264180ad945a94ed799444017d150990e2e24b0e93642a8debfcd561ce3dacf.png
Dimensions: 5
../../../../_images/2e68643e07ac23fcbfa015fd4b3c89723d72de490cab2903bf51062ad4d2738c.png
Dimensions: 10
../../../../_images/5a9c830f6bff4a070fba26591084537a6ee6ee82e2304f729a9476c0a48ba1f8.png
Dimensions: 15
../../../../_images/94913a69e34e7598b36d8ed7db50f66404fc1d492c78ade4edf178c89f92b214.png
linear_rotation
Dimensions: 2
../../../../_images/7386a306f0e8b5f4f24370827c474e62f93032c0e0682b2974623652bb0e32bc.png
Dimensions: 3
../../../../_images/02ab38663ce20f1d9be9670a5838ae69c6a4336e5a35f6020efe74040b0286b2.png
Dimensions: 5
../../../../_images/6fbb8a6e6010939360135131bdcdaabbbcb523c79d36847af6e57139859346e1.png
Dimensions: 10
../../../../_images/15c9102f1f1a04d391201a9c53db55e3e258dbd0a0cabb66e6aeb244262f1aec.png
Dimensions: 15
../../../../_images/6582c826404c6ad54112805a79a52b1e1a391974412159b7056b7bffc20d01e2.png
studentt
Dimensions: 2
../../../../_images/9b57f6a43b3cf3c666f2834170e0e83225768e8e0cd2a554190f390f5d1586fb.png
Dimensions: 3
../../../../_images/0a7ca30820a08139700b12c59baa20038ed0d9f580c39ad0d69d74f0e03b9793.png
Dimensions: 5
../../../../_images/c7788b6ca8c45e85ab7b79876e9abc74dbfebe26e8cd4e50964d5c6849076bd3.png
Dimensions: 10
../../../../_images/460c16fb014d91dc50c8bdada05b571a4a588475e2df4b88a62afc283f6c6299.png
Dimensions: 15
../../../../_images/43055a9f13c3cf58618fc9bb089be729278e08206488efc4f74fabdec5ba4b4e.png
cauchy
Dimensions: 2
../../../../_images/f9285f37417aa1f74562561135d87009e897a5711e1fd9d0e7961c7240e83d6e.png
Dimensions: 3
../../../../_images/8243ffec55b9a9fe727bffd695617dd1b17487992d36bc53583f070cd8519cd8.png
Dimensions: 5
../../../../_images/2dd996dd6b90f37863455ce95f6d17721cbbf5de416a9e07ab9f2c942cb11615.png
Dimensions: 10
../../../../_images/20a8c0bea550145179b24158326e60ff4d956b05e61b34d27b8763cfe70ba4e1.png
Dimensions: 15
../../../../_images/b6d2d737733ec5fdca8f4f3055885a98d57550e2ebb0e50d709416e13497dfde.png

Results#

Plot II - Sample Consistency#

def get_sample_tc(df, model, dataset="gaussian"):

    df = df[df["model"] == model]
    df = df[df["dataset"] == dataset]

    # get true values
    df_true = df["samples"]
    df_approx = df["approx"]

    return df_approx, df_true

Plot I - Indiscriminate#

def get_mu_std(df, model, dataset="gaussian"):

    df = df[df["model"] == model]
    df = df[df["dataset"] == dataset]

    # get true values
    df_true = df["true"]
    df_approx = df["approx"]

    return df_approx, df_true
df_knn, df_true = get_mu_std(results_df, "gaussian")
fig, ax = plt.subplots()

ax.scatter(df_knn.values, df_true.values, s=10)
ax.set(xlabel="Approx. Mutual Info", ylabel="True Mutual Info")
plt.tight_layout()
plt.show()
# wandb.log({f"scatter_gaussian_pearson_dim": wandb.Image(fig)})
../../../../_images/d60c5fee02ccab94cb982507f6717e68159039fc9c2c27c4ea64965817bca489.png

Results#

def plot_results(results_df, model="gaussian", dataset="gaussian"):

    results_df = results_df[results_df["n_samples"] != 100]

    df_approx, df_true = get_mu_std(results_df, model, dataset)

    min_val, max_val = df_true.min(), df_true.max()
    fig, ax = plt.subplots()

    ax.scatter(df_approx.values, df_true.values, s=10, zorder=3)
    ax.plot(
        [min_val, max_val],
        [min_val, max_val],
        color="black",
        linestyle="-",
        linewidth=2,
    )
    ax.set(xlabel="Est. Mutual Info", ylabel="True Mutual Info")
    ax.set_title(f"Data: {dataset} | Model: {model}")
    ax.grid()
    plt.tight_layout()
    plt.gcf()
    wandb.log({f"accuracy_{dataset}_{model}": wandb.Image(fig)})
    plt.show()
    plt.close(fig)
models = [
    "gaussian",
    "knn_nbs",
    "knn_eps",
    "rbig",
]

for imodel in models:
    for idataset in datasets:

        plot_results(results_df, model=imodel, dataset=idataset)
../../../../_images/5264382df234a2007f819c45e34c3065d79809c6d9b9a060f42810f1655b4912.png ../../../../_images/d2d4ee533c352a0037deb771c9edd77d65900449b2021c05eb58c4c86b56d747.png ../../../../_images/c517aad2bce38bd287f96572768edc0385e384d376eb560317c160f99d4416a6.png ../../../../_images/f35a8ef0b54bff6381e7aff6d08688c38d0ef231a33ba279020ab53ae0212c62.png ../../../../_images/1063b816cda35a0b6e770bd5db46e45d5d5d7906d4b823dac996eb47cf87c57a.png ../../../../_images/5b3ca052d7f0a38af940ef46e22a81d62449bd6b81d717843184dedb9ac0bde7.png ../../../../_images/4ff0f3d6697d6355dcf2d299bcaf8b636f8d2113dca486eca1b3197868cfc7da.png ../../../../_images/a1e7252d0dbc35b904a9919e282f5901bc17fb3bebc48633b63a99b39d52396d.png ../../../../_images/7a7fb275aa2a1aafd4be90d0d0e1ae41140fd54fa41180964186cc9e610c48b3.png ../../../../_images/965c95eb72c66479453959b0904e40531ed18b8890e80f6d12d82b77314811df.png ../../../../_images/46ccbfc5a9a7a3105a3309cfce888d817fd4e9bbfa070ccd0cba6c32e8ed5aff.png ../../../../_images/ff2be19cbe5281934f501f3492b9bf85220d19e45541f0a3171c511224c7a4cc.png ../../../../_images/11fca1b6996476dcf5e743dd273497de3398581e6aa5ac09114ef74d6de3cc04.png ../../../../_images/271e3e8ce9110fd363dd52434325266a29c0a13fc2f63742a012467f8a3b0971.png ../../../../_images/c2e4fa89f17d47c28b71cc3df988a3239f4cd746a0924c32208f93b3667a21ea.png ../../../../_images/3925b12b5c65be67aacf38573f7ad3e1228e7ef744a3c606df27dbb4e8a5a187.png ../../../../_images/4736be6a4fd688ae9e765d5962d1f09e704683d4113b2bf5df2ca2ca23c003f2.png ../../../../_images/329d3412d9a6961b8ae01c49475b82bfb712eac7d0973cdc4a9b5b9a5a72002e.png ../../../../_images/4e2b18fe51dff3df402568e4ec62bd74015d0372726ccac2a38bd9357e30f285.png ../../../../_images/ee18883b2148a038e8aa688e64f5dc3e48a123615b3fc77cfcc52aecdc53daa3.png