InĀ [1]:
# Import required libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats
import os
from pathlib import Path

# Set plotly as default renderer
import plotly.io as pio
pio.templates.default = "plotly_white"

MSE Comparison: VAR vs LSTM Models¶

This notebook compares the Mean Squared Error (MSE) performance between different models on the Fisher dataset. Currently comparing:

  • VAR (Vector Autoregression) model
  • LSTM (Long Short-Term Memory) model

The notebook is designed to be flexible for adding more models in the future.

InĀ [2]:
# Define the paths to the MSE CSV files
BASE_PATH = Path("/home/ytli/research/lstm/study5-realdata/fisher")

# Define model paths in a dictionary for easy expansion later
MODEL_PATHS = {
    'VAR': BASE_PATH / 'var' / 'mse.csv',
    'LSTM': BASE_PATH / 'lstm' / 'mse.csv',
    'MLVAR': BASE_PATH / 'mlvar' / 'mse.csv',
    'GIMME': BASE_PATH / 'gimme' / 'mse.csv',
    # Future models can be added here, e.g.:
    # 'GRU': BASE_PATH / 'gru' / 'mse.csv',
}

# Function to load and preprocess MSE data
def load_mse_data(model_paths):
    """Load MSE data from multiple models and combine into a single DataFrame."""
    all_data = []
    
    for model_name, file_path in model_paths.items():
        if file_path.exists():
            # Load the data
            df = pd.read_csv(file_path)
            
            # Add model column
            df['model'] = model_name
            
            # Append to list
            all_data.append(df)
        else:
            print(f"Warning: {file_path} does not exist.")
    
    # Combine all data
    if all_data:
        combined_df = pd.concat(all_data, ignore_index=True)
        return combined_df
    else:
        raise ValueError("No valid MSE data files found.")

# Load the data
mse_data = load_mse_data(MODEL_PATHS)

# Display basic info
print(f"Loaded MSE data for {mse_data['model'].nunique()} models.")
print(f"Total subjects: {mse_data['subject_id'].nunique()}")
print(f"Features available: {mse_data['feature_name'].nunique()}")

# Display the first few rows
mse_data.head()
Loaded MSE data for 4 models.
Total subjects: 40
Features available: 9
Out[2]:
subject_id feature_name mse model
0 p111 down 0.259425 VAR
1 p111 positive 0.130215 VAR
2 p111 content 0.470464 VAR
3 p111 enthusiastic 0.265476 VAR
4 p111 energetic 0.349234 VAR
InĀ [3]:
# Basic statistical analysis

# Overall MSE statistics by model
model_stats = mse_data.groupby('model')['mse'].agg(['mean', 'median', 'std', 'min', 'max']).reset_index()
print("Overall MSE statistics by model:")
display(model_stats)

# Feature-wise MSE comparison
feature_stats = mse_data.groupby(['model', 'feature_name'])['mse'].agg(['mean', 'median', 'std']).reset_index()

# Sort by mean MSE to see which features have highest error
feature_stats_sorted = feature_stats.sort_values(by=['feature_name', 'mean'])
display(feature_stats_sorted)
Overall MSE statistics by model:
model mean median std min max
0 GIMME 1.269682 0.685026 2.337437 0.000355 29.193570
1 LSTM 1.247971 0.692644 2.404799 0.018620 29.134912
2 MLVAR 1.230239 0.694659 2.231106 0.039559 28.002925
3 VAR 1.401815 0.787430 2.359657 0.037601 28.096260
model feature_name mean median std
9 LSTM angry 1.192598 0.666441 1.571567
18 MLVAR angry 1.296812 0.675638 1.929048
0 GIMME angry 1.386666 0.604615 2.018156
27 VAR angry 1.526484 0.794921 2.076638
19 MLVAR content 1.143187 0.872580 1.144393
10 LSTM content 1.160596 0.852058 1.320677
1 GIMME content 1.190719 0.944220 1.313963
28 VAR content 1.290244 0.960834 1.484529
20 MLVAR down 1.095771 0.564518 1.996695
2 GIMME down 1.170654 0.646352 2.269943
29 VAR down 1.290572 0.650436 2.388121
11 LSTM down 1.323277 0.654343 2.994909
12 LSTM energetic 1.817534 0.822456 4.528124
21 MLVAR energetic 1.821117 0.763334 4.349153
3 GIMME energetic 1.849814 0.816260 4.532235
30 VAR energetic 1.965851 0.792798 4.373482
13 LSTM enthusiastic 1.255544 0.660126 1.614619
22 MLVAR enthusiastic 1.259551 0.697563 1.751597
4 GIMME enthusiastic 1.299732 0.758397 1.791810
31 VAR enthusiastic 1.470635 0.809426 2.185973
23 MLVAR hopeless 0.967087 0.473701 1.159483
14 LSTM hopeless 0.967758 0.571769 1.113841
5 GIMME hopeless 0.995098 0.458697 1.231998
32 VAR hopeless 1.116391 0.624417 1.262774
15 LSTM irritable 0.932096 0.669937 0.754931
24 MLVAR irritable 0.949391 0.645403 0.903208
6 GIMME irritable 1.004854 0.653048 0.996148
33 VAR irritable 1.132890 0.825173 0.961552
25 MLVAR positive 1.241801 0.608365 2.059933
7 GIMME positive 1.263742 0.616985 1.998281
16 LSTM positive 1.272362 0.586639 2.192872
34 VAR positive 1.333475 0.736462 2.099234
8 GIMME reassure 1.265857 0.690035 2.925788
26 MLVAR reassure 1.297434 0.805556 2.805472
17 LSTM reassure 1.309972 0.738551 3.127807
35 VAR reassure 1.489792 0.896860 2.767637
InĀ [4]:
# Visualization 1: Overall model comparison (boxplot)
# Create boxplot using Plotly
fig = px.box(mse_data, x='model', y='mse', points='all', 
             title='Overall MSE Comparison Between Models',
             labels={'mse': 'Mean Squared Error (MSE)', 'model': 'Model'},
             color='model',
             height=600, width=900)

# Customize layout
fig.update_layout(
    title_font_size=20,
    xaxis_title_font_size=14,
    yaxis_title_font_size=14,
    legend_title_font_size=14,
    template='plotly_white',
    boxmode='group'
)

# Show the plot
fig.show(renderer="notebook")

# Perform paired t-test to check if differences are statistically significant
# This assumes same subjects and features across models
models = mse_data['model'].unique()
if len(models) > 1:
    print("\nStatistical tests between models:")
    for i in range(len(models)):
        for j in range(i+1, len(models)):
            model1, model2 = models[i], models[j]
            
            # Create a pivot table for paired comparison
            pivot_df = mse_data.pivot_table(
                index=['subject_id', 'feature_name'], 
                columns='model', 
                values='mse'
            ).reset_index()
            
            # Check if both models exist in the pivot
            if model1 in pivot_df.columns and model2 in pivot_df.columns:
                # Remove rows with NaN values
                valid_rows = pivot_df[[model1, model2]].dropna()
                
                if len(valid_rows) > 0:
                    # Perform paired t-test
                    t_stat, p_value = stats.ttest_rel(valid_rows[model1], valid_rows[model2])
                    print(f"\n{model1} vs {model2}:")
                    print(f"Paired t-test: t={t_stat:.4f}, p={p_value:.4f}")
                    print(f"Mean difference: {valid_rows[model1].mean() - valid_rows[model2].mean():.4f}")
                    
                    if p_value < 0.05:
                        if valid_rows[model1].mean() < valid_rows[model2].mean():
                            print(f"Conclusion: {model1} performs significantly better than {model2}")
                        else:
                            print(f"Conclusion: {model2} performs significantly better than {model1}")
                    else:
                        print(f"Conclusion: No significant difference between {model1} and {model2}")
            else:
                print(f"\nCannot compare {model1} vs {model2}: missing data")
Statistical tests between models:

VAR vs LSTM:
Paired t-test: t=4.9267, p=0.0000
Mean difference: 0.1538
Conclusion: LSTM performs significantly better than VAR

VAR vs MLVAR:
Paired t-test: t=8.1430, p=0.0000
Mean difference: 0.1716
Conclusion: MLVAR performs significantly better than VAR

VAR vs GIMME:
Paired t-test: t=6.2769, p=0.0000
Mean difference: 0.1321
Conclusion: GIMME performs significantly better than VAR

LSTM vs MLVAR:
Paired t-test: t=0.6572, p=0.5115
Mean difference: 0.0177
Conclusion: No significant difference between LSTM and MLVAR

LSTM vs GIMME:
Paired t-test: t=-0.8774, p=0.3808
Mean difference: -0.0217
Conclusion: No significant difference between LSTM and GIMME

MLVAR vs GIMME:
Paired t-test: t=-2.6284, p=0.0089
Mean difference: -0.0394
Conclusion: MLVAR performs significantly better than GIMME
InĀ [Ā ]:
# Visualization 2: Feature-wise comparison

# Calculate mean MSE by feature and model
feature_means = mse_data.groupby(['feature_name', 'model'])['mse'].mean().reset_index()

# Create bar chart with Plotly
fig = px.bar(feature_means, x='feature_name', y='mse', color='model', barmode='group',
             title='Average MSE by Feature Across Models',
             labels={'mse': 'Mean Squared Error (MSE)', 'feature_name': 'Feature', 'model': 'Model'},
             height=600, width=1000)

# Customize layout
fig.update_layout(
    title_font_size=20,
    xaxis_title_font_size=14,
    yaxis_title_font_size=14,
    legend_title_font_size=14,
    xaxis_tickangle=-45,
    template='plotly_white'
)

# Add hover data for better interaction
fig.update_traces(hovertemplate='Feature: %{x}<br>MSE: %{y:.4f}<br>Model: %{legendgroup}')

# Show the plot
fig.show(renderer="notebook")
InĀ [Ā ]:
# Visualization 3: Subject-wise comparison
# Calculate average MSE per subject across all features
subject_means = mse_data.groupby(['subject_id', 'model'])['mse'].mean().reset_index()

# Find subjects with highest average MSE (potential outliers)
top_subjects = subject_means.groupby('subject_id')['mse'].mean().nlargest(5).index.tolist()
print(f"Subjects with highest average MSE: {', '.join(top_subjects)}")

# Create bar chart for all subjects
fig = px.bar(subject_means, x='subject_id', y='mse', color='model', barmode='group',
             title='Average MSE by Subject Across Models',
             labels={'mse': 'Mean Squared Error (MSE)', 'subject_id': 'Subject ID', 'model': 'Model'},
             height=600, width=1100)

# Customize layout
fig.update_layout(
    title_font_size=20,
    xaxis_title_font_size=14,
    yaxis_title_font_size=14,
    legend_title_font_size=14,
    xaxis_tickangle=-90,
    template='plotly_white'
)

# Add hover data
fig.update_traces(hovertemplate='Subject: %{x}<br>MSE: %{y:.4f}<br>Model: %{legendgroup}')

# Show the plot
fig.show(renderer="notebook")

# If too many subjects, show only top 20 with highest average MSE
if subject_means['subject_id'].nunique() > 20:
    top_subjects = subject_means.groupby('subject_id')['mse'].mean().nlargest(20).index.tolist()
    top_subjects_data = subject_means[subject_means['subject_id'].isin(top_subjects)]
    
    # Create bar chart for top 20 subjects
    fig2 = px.bar(top_subjects_data, x='subject_id', y='mse', color='model', barmode='group',
                 title='Top 20 Subjects with Highest Average MSE',
                 labels={'mse': 'Mean Squared Error (MSE)', 'subject_id': 'Subject ID', 'model': 'Model'},
                 height=600, width=1100)
    
    # Customize layout
    fig2.update_layout(
        title_font_size=20,
        xaxis_title_font_size=14,
        yaxis_title_font_size=14,
        legend_title_font_size=14,
        xaxis_tickangle=-90,
        template='plotly_white'
    )
    
    # Add hover data
    fig2.update_traces(hovertemplate='Subject: %{x}<br>MSE: %{y:.4f}<br>Model: %{legendgroup}')
    
    # Show the plot
    fig2.show()
Subjects with highest average MSE: p013, p204, p160, p202, p139
InĀ [7]:
# Visualization 4: Correlation between model performances
# This shows if models tend to perform similarly on the same subjects/features

# Create wide-format data with one column per model
pivot_data = mse_data.pivot_table(
    index=['subject_id', 'feature_name'],
    columns='model',
    values='mse'
).reset_index()

# Calculate correlation between model performances
model_cols = [col for col in pivot_data.columns if col not in ['subject_id', 'feature_name']]
if len(model_cols) > 1:
    correlation = pivot_data[model_cols].corr()
    
    # Create heatmap with Plotly
    fig = go.Figure(data=go.Heatmap(
        z=correlation.values,
        x=correlation.columns,
        y=correlation.index,
        colorscale='RdBu_r',
        zmin=-1, zmax=1,
        text=correlation.values.round(3),
        texttemplate='%{text}',
        colorbar=dict(title='Correlation')
    ))
    
    # Update layout
    fig.update_layout(
        title='Correlation Between Model Performances',
        title_font_size=20,
        height=600, width=700,
        template='plotly_white'
    )
    
    # Show the plot
    fig.show(renderer="notebook")
    
    # Scatterplot of model performances (only for 2 models)
    if len(model_cols) == 2:  
        # Create the scatter plot with Plotly
        fig2 = px.scatter(pivot_data, x=model_cols[0], y=model_cols[1],
                         hover_data=['subject_id', 'feature_name'],
                         title=f'Comparison of {model_cols[0]} vs {model_cols[1]} MSE',
                         labels={model_cols[0]: f'{model_cols[0]} MSE', 
                                model_cols[1]: f'{model_cols[1]} MSE'},
                         height=700, width=700)
        
        # Add diagonal line (y=x)
        min_val = min(pivot_data[model_cols].min().min(), 0)
        max_val = pivot_data[model_cols].max().max() * 1.1
        
        fig2.add_trace(go.Scatter(
            x=[min_val, max_val],
            y=[min_val, max_val],
            mode='lines',
            line=dict(color='black', width=2, dash='dash'),
            name='y=x line'
        ))
        
        # Update layout
        fig2.update_layout(
            title_font_size=20,
            xaxis_title_font_size=14,
            yaxis_title_font_size=14,
            legend_title_font_size=14,
            template='plotly_white'
        )
        
        # Show the plot
        fig2.show()
else:
    print("Need at least two models for correlation analysis.")
InĀ [Ā ]:
# Visualization 5: Detailed feature analysis
# Compare each feature across models using boxplots

# Get list of unique features
features = mse_data['feature_name'].unique()
n_features = len(features)

# Calculate number of rows and columns for subplots
n_cols = 3  # Number of columns in the grid
n_rows = (n_features + n_cols - 1) // n_cols  # Calculate needed rows

# Create subplots with Plotly
fig = make_subplots(
    rows=n_rows, cols=n_cols,
    subplot_titles=[f'Feature: {feature}' for feature in features],
    vertical_spacing=0.1
)

# Add a boxplot for each feature
for i, feature in enumerate(features):
    # Calculate row and column for current subplot
    row = i // n_cols + 1
    col = i % n_cols + 1
    
    # Filter data for current feature
    feature_data = mse_data[mse_data['feature_name'] == feature]
    
    # For each model, add a box trace
    for model in feature_data['model'].unique():
        model_data = feature_data[feature_data['model'] == model]
        
        # Add box plot
        fig.add_trace(
            go.Box(
                y=model_data['mse'],
                name=model,
                boxpoints='all',  # Show all points
                jitter=0.3,  # Add jitter to points for better visualization
                pointpos=-1.8,  # Position of points relative to box
                marker=dict(size=3),  # Make points smaller
                showlegend=(i == 0),  # Only show legend for the first feature
                hovertemplate=f'Model: {model}<br>MSE: %{{y:.4f}}'
            ),
            row=row, col=col
        )

# Update layout
fig.update_layout(
    title_text='Feature-wise MSE Comparison Across Models',
    title_font_size=20,
    showlegend=True,
    legend_title_text='Model',
    height=300 * n_rows,  # Adjust height based on number of rows
    width=1000,
    template='plotly_white'
)

# Update y-axes
fig.update_yaxes(title_text='MSE')

# Show the figure
fig.show(renderer="notebook")
InĀ [9]:
# Function to add a new model to the comparison
def add_new_model(model_name, file_path):
    """Add a new model to the comparison.
    
    Args:
        model_name (str): Name of the new model
        file_path (str or Path): Path to the model's MSE CSV file
    
    Returns:
        DataFrame: Updated combined MSE data
    """
    global MODEL_PATHS, mse_data
    
    # Add to model paths
    MODEL_PATHS[model_name] = Path(file_path)
    
    # Reload all data
    try:
        mse_data = load_mse_data(MODEL_PATHS)
        print(f"Successfully added {model_name} to the comparison.")
        print(f"Now comparing {mse_data['model'].nunique()} models.")
        return mse_data
    except Exception as e:
        print(f"Error adding new model: {e}")
        return None

# Example usage:
# new_mse_data = add_new_model('GRU', '/home/ytli/research/lstm/study5-realdata/fisher/gru/mse.csv')