Skip to content

Visualizing HSIC Measures

import sys, os
import warnings
import tqdm
import random
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

# Insert path to model directory,.
cwd = os.getcwd()
path = f"{cwd}/../../src"
sys.path.insert(0, path)

# toy datasets
from data.toy import generate_dependence_data

# Kernel Dependency measure
from models.dependence import HSIC
from models.kernel import estimate_sigma, sigma_to_gamma, gamma_to_sigma, get_param_grid

# RBIG IT measures
from models.ite_algorithms import run_rbig_models

import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
%matplotlib inline

warnings.filterwarnings('ignore') # get rid of annoying warnings

%load_ext autoreload
%autoreload 2
save_path   = f'{cwd}/../../results/hsic/'
save_name   = 'trial_v3'

results_df = pd.read_csv(f"{save_path}{save_name}.csv")
res_noise = results_df['noise'].unique().tolist()
res_gammas = results_df['gamma'].unique().tolist()
res_lines = results_df['function'].unique().tolist()

Figure I - Mutual Information & Noise

This first figure is to demonstrate how the mutual information compares with the amount of noise for each of the functions Linear, Sinusoidal, Circular, and Random.

fig, ax = plt.subplots()
ax.scatter(
    results_df[results_df['function'] == 'line']['noise'], 
    results_df[results_df['function'] == 'line']['mi'],
    label='Line'
)
plt.xscale('log')
plt.yscale('log')


ax.scatter(
    results_df[results_df['function'] == 'sine']['noise'], 
    results_df[results_df['function'] == 'sine']['mi'], 
    label='Sine'
)
plt.xscale('log')
plt.yscale('log')


ax.scatter(
    results_df[results_df['function'] == 'circ']['noise'], 
    results_df[results_df['function'] == 'circ']['mi'], 
    label='Circle'
)

plt.xscale('log')
plt.yscale('log')


ax.scatter(
    results_df[results_df['function'] == 'rand']['noise'], 
    results_df[results_df['function'] == 'rand']['mi'], 
    label='Random'
)

plt.xscale('log')
plt.yscale('log')


ax.set_xlabel('Noise, $\sigma_y$')
ax.set_ylabel('Mutual Information, MI$(X,Y)$')
# ax.set_xscale('log')
plt.legend()
ax.set_title('Experimental Parameter Space')
plt.show()

save_path = f'{cwd}/../../results/hsic/figures/'
fig.savefig(f"{save_path}trialv1_parameters.png")
def plot_res_gamma_3D(results_df, function='line', hsic_method='hsic'):

    save_path = f'{cwd}/../../results/hsic/figures/'

    # Set Title stuff
    if function == 'line':
        title = 'Linear Function'
    elif function == 'sine':
        title = 'Sine Function'
    elif function == 'circ':
        title = 'Circle Function'
    elif function == 'rand':
        title = 'Random Function'
    else:
        raise ValueError(f'Unrecognized function: {line}')

    sub_results_df = results_df[results_df['function'] == function]
    free_params = [
    #     'gamma',
        'function'
    ]

    fixed_params = [
        'gamma',
        'value',
        'method',
        'mi'
    ]

    groups = sub_results_df.groupby(free_params)

    hue = 'gamma'


    fig = plt.figure()
    ax = fig.gca(projection='3d')


    for iparams, idata in groups:


        surf = ax.scatter(
            np.log(idata[idata['scorer']== hsic_method]['mi'].values),
            np.log(idata[idata['scorer']== hsic_method]['gamma'].values),
            idata[idata['scorer']== hsic_method]['value'].values,
#             s=20, 
            c=np.log(idata[idata['scorer']== hsic_method]['gamma'].values),
            cmap='Spectral',
#             norm=matplotlib.colors.LogNorm()
        )
        ax.set_xlabel('Mutual Information')
        ax.set_ylabel('Gamma')
        ax.set_zlabel( hsic_method.upper() )
        fig.colorbar(surf, ax=ax, label='Gamma')

    # ax[0].get_legend().remove()
    # ax[1].get_legend().remove()
    # ax[2].get_legend().remove()
#     ax.set_scale('log')
    ax.view_init(30, 35)
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

    fig.savefig(f"{save_path}trialv1_{function}_{hsic_method}_3d.png")

    return None
plot_res_gamma_3D(results_df, function='line', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='line', hsic_method='tka')
plot_res_gamma_3D(results_df, function='line', hsic_method='ctka')
def plot_res_gamma(results_df, function='line', hsic_method='hsic'):

    save_path = f'{cwd}/../../results/hsic/figures/'

    # Set Title stuff
    if function == 'line':
        title = 'Linear Function'
    elif function == 'sine':
        title = 'Sine Function'
    elif function == 'circ':
        title = 'Circle Function'
    elif function == 'rand':
        title = 'Random Function'
    else:
        raise ValueError(f'Unrecognized function: {line}')

    sub_results_df = results_df[results_df['function'] == function]
    free_params = [
    #     'gamma',
        'function'
    ]

    fixed_params = [
        'gamma',
        'value',
        'method',
        'mi'
    ]

    groups = sub_results_df.groupby(free_params)

    hue = 'gamma'

    fig, ax = plt.subplots(nrows=1, figsize=(7, 3))


    for iparams, idata in groups:

        # Plot I - HSIC
        pts = ax.scatter(
            x=idata[idata['scorer']== hsic_method]['value'],
            y=idata[idata['scorer']== hsic_method]['mi'],
            c=idata[idata['scorer']== hsic_method]['gamma'],
            s=20, cmap='Spectral',
            norm=matplotlib.colors.LogNorm()
        )
        ax.set_xlabel( hsic_method.upper() )
        ax.set_ylabel('Mutual Information')
        fig.colorbar(pts, ax=ax, label='Gamma')

    # ax[0].get_legend().remove()
    # ax[1].get_legend().remove()
    # ax[2].get_legend().remove()
    ax.set_yscale('log')
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

    fig.savefig(f"{save_path}trialv1_{function}_{hsic_method}.png")

    return None

Figure II - MI vs Gamma vs HSIC (Linear Function)

plot_res_gamma(results_df, function='line', hsic_method='hsic')
plot_res_gamma(results_df, function='line', hsic_method='tka')
plot_res_gamma(results_df, function='line', hsic_method='ctka')
plot_res_gamma_3D(results_df, function='line', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='line', hsic_method='tka')
plot_res_gamma_3D(results_df, function='line', hsic_method='ctka')

Case II - Sine

plot_res_gamma(results_df, function='sine', hsic_method='hsic')
plot_res_gamma(results_df, function='sine', hsic_method='tka')
plot_res_gamma(results_df, function='sine', hsic_method='ctka')
plot_res_gamma_3D(results_df, function='sine', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='sine', hsic_method='tka')
plot_res_gamma_3D(results_df, function='sine', hsic_method='ctka')

Case III - Circle

plot_res_gamma(results_df, function='circ', hsic_method='hsic')
plot_res_gamma(results_df, function='circ', hsic_method='tka')
plot_res_gamma(results_df, function='circ', hsic_method='ctka')
plot_res_gamma_3D(results_df, function='circ', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='circ', hsic_method='tka')
plot_res_gamma_3D(results_df, function='circ', hsic_method='ctka')

Case IV - Random

plot_res_gamma(results_df, function='rand', hsic_method='hsic')
plot_res_gamma(results_df, function='rand', hsic_method='tka')
plot_res_gamma(results_df, function='rand', hsic_method='ctka')
plot_res_gamma_3D(results_df, function='rand', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='rand', hsic_method='tka')
plot_res_gamma_3D(results_df, function='rand', hsic_method='ctka')