Visualizing HSIC Measures¶
import sys, os
import warnings
import tqdm
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# Insert path to model directory,.
cwd = os.getcwd()
path = f"{cwd}/../../src"
sys.path.insert(0, path)
# toy datasets
from data.toy import generate_dependence_data
# Kernel Dependency measure
from models.dependence import HSIC
from models.kernel import estimate_sigma, sigma_to_gamma, gamma_to_sigma, get_param_grid
# RBIG IT measures
from models.ite_algorithms import run_rbig_models
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
%matplotlib inline
warnings.filterwarnings('ignore') # get rid of annoying warnings
%load_ext autoreload
%autoreload 2
save_path = f'{cwd}/../../results/hsic/'
save_name = 'trial_v3'
results_df = pd.read_csv(f"{save_path}{save_name}.csv")
res_noise = results_df['noise'].unique().tolist()
res_gammas = results_df['gamma'].unique().tolist()
res_lines = results_df['function'].unique().tolist()
Figure I - Mutual Information & Noise¶
This first figure is to demonstrate how the mutual information compares with the amount of noise for each of the functions Linear, Sinusoidal, Circular, and Random.
fig, ax = plt.subplots()
ax.scatter(
results_df[results_df['function'] == 'line']['noise'],
results_df[results_df['function'] == 'line']['mi'],
label='Line'
)
plt.xscale('log')
plt.yscale('log')
ax.scatter(
results_df[results_df['function'] == 'sine']['noise'],
results_df[results_df['function'] == 'sine']['mi'],
label='Sine'
)
plt.xscale('log')
plt.yscale('log')
ax.scatter(
results_df[results_df['function'] == 'circ']['noise'],
results_df[results_df['function'] == 'circ']['mi'],
label='Circle'
)
plt.xscale('log')
plt.yscale('log')
ax.scatter(
results_df[results_df['function'] == 'rand']['noise'],
results_df[results_df['function'] == 'rand']['mi'],
label='Random'
)
plt.xscale('log')
plt.yscale('log')
ax.set_xlabel('Noise, $\sigma_y$')
ax.set_ylabel('Mutual Information, MI$(X,Y)$')
# ax.set_xscale('log')
plt.legend()
ax.set_title('Experimental Parameter Space')
plt.show()
save_path = f'{cwd}/../../results/hsic/figures/'
fig.savefig(f"{save_path}trialv1_parameters.png")
def plot_res_gamma_3D(results_df, function='line', hsic_method='hsic'):
save_path = f'{cwd}/../../results/hsic/figures/'
# Set Title stuff
if function == 'line':
title = 'Linear Function'
elif function == 'sine':
title = 'Sine Function'
elif function == 'circ':
title = 'Circle Function'
elif function == 'rand':
title = 'Random Function'
else:
raise ValueError(f'Unrecognized function: {line}')
sub_results_df = results_df[results_df['function'] == function]
free_params = [
# 'gamma',
'function'
]
fixed_params = [
'gamma',
'value',
'method',
'mi'
]
groups = sub_results_df.groupby(free_params)
hue = 'gamma'
fig = plt.figure()
ax = fig.gca(projection='3d')
for iparams, idata in groups:
surf = ax.scatter(
np.log(idata[idata['scorer']== hsic_method]['mi'].values),
np.log(idata[idata['scorer']== hsic_method]['gamma'].values),
idata[idata['scorer']== hsic_method]['value'].values,
# s=20,
c=np.log(idata[idata['scorer']== hsic_method]['gamma'].values),
cmap='Spectral',
# norm=matplotlib.colors.LogNorm()
)
ax.set_xlabel('Mutual Information')
ax.set_ylabel('Gamma')
ax.set_zlabel( hsic_method.upper() )
fig.colorbar(surf, ax=ax, label='Gamma')
# ax[0].get_legend().remove()
# ax[1].get_legend().remove()
# ax[2].get_legend().remove()
# ax.set_scale('log')
ax.view_init(30, 35)
ax.set_title(title)
plt.tight_layout()
plt.show()
fig.savefig(f"{save_path}trialv1_{function}_{hsic_method}_3d.png")
return None
plot_res_gamma_3D(results_df, function='line', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='line', hsic_method='tka')
plot_res_gamma_3D(results_df, function='line', hsic_method='ctka')
def plot_res_gamma(results_df, function='line', hsic_method='hsic'):
save_path = f'{cwd}/../../results/hsic/figures/'
# Set Title stuff
if function == 'line':
title = 'Linear Function'
elif function == 'sine':
title = 'Sine Function'
elif function == 'circ':
title = 'Circle Function'
elif function == 'rand':
title = 'Random Function'
else:
raise ValueError(f'Unrecognized function: {line}')
sub_results_df = results_df[results_df['function'] == function]
free_params = [
# 'gamma',
'function'
]
fixed_params = [
'gamma',
'value',
'method',
'mi'
]
groups = sub_results_df.groupby(free_params)
hue = 'gamma'
fig, ax = plt.subplots(nrows=1, figsize=(7, 3))
for iparams, idata in groups:
# Plot I - HSIC
pts = ax.scatter(
x=idata[idata['scorer']== hsic_method]['value'],
y=idata[idata['scorer']== hsic_method]['mi'],
c=idata[idata['scorer']== hsic_method]['gamma'],
s=20, cmap='Spectral',
norm=matplotlib.colors.LogNorm()
)
ax.set_xlabel( hsic_method.upper() )
ax.set_ylabel('Mutual Information')
fig.colorbar(pts, ax=ax, label='Gamma')
# ax[0].get_legend().remove()
# ax[1].get_legend().remove()
# ax[2].get_legend().remove()
ax.set_yscale('log')
ax.set_title(title)
plt.tight_layout()
plt.show()
fig.savefig(f"{save_path}trialv1_{function}_{hsic_method}.png")
return None
Figure II - MI vs Gamma vs HSIC (Linear Function)¶
plot_res_gamma(results_df, function='line', hsic_method='hsic')
plot_res_gamma(results_df, function='line', hsic_method='tka')
plot_res_gamma(results_df, function='line', hsic_method='ctka')
plot_res_gamma_3D(results_df, function='line', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='line', hsic_method='tka')
plot_res_gamma_3D(results_df, function='line', hsic_method='ctka')
Case II - Sine¶
plot_res_gamma(results_df, function='sine', hsic_method='hsic')
plot_res_gamma(results_df, function='sine', hsic_method='tka')
plot_res_gamma(results_df, function='sine', hsic_method='ctka')
plot_res_gamma_3D(results_df, function='sine', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='sine', hsic_method='tka')
plot_res_gamma_3D(results_df, function='sine', hsic_method='ctka')
Case III - Circle¶
plot_res_gamma(results_df, function='circ', hsic_method='hsic')
plot_res_gamma(results_df, function='circ', hsic_method='tka')
plot_res_gamma(results_df, function='circ', hsic_method='ctka')
plot_res_gamma_3D(results_df, function='circ', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='circ', hsic_method='tka')
plot_res_gamma_3D(results_df, function='circ', hsic_method='ctka')
Case IV - Random¶
plot_res_gamma(results_df, function='rand', hsic_method='hsic')
plot_res_gamma(results_df, function='rand', hsic_method='tka')
plot_res_gamma(results_df, function='rand', hsic_method='ctka')
plot_res_gamma_3D(results_df, function='rand', hsic_method='hsic')
plot_res_gamma_3D(results_df, function='rand', hsic_method='tka')
plot_res_gamma_3D(results_df, function='rand', hsic_method='ctka')