InĀ [1]:
# Import required libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats
import os
from pathlib import Path
# Set plotly as default renderer
import plotly.io as pio
pio.templates.default = "plotly_white"
MSE Comparison: VAR vs LSTM Models¶
This notebook compares the Mean Squared Error (MSE) performance between different models on the Fisher dataset. Currently comparing:
- VAR (Vector Autoregression) model
- LSTM (Long Short-Term Memory) model
The notebook is designed to be flexible for adding more models in the future.
InĀ [2]:
# Define the paths to the MSE CSV files
BASE_PATH = Path("/home/ytli/research/lstm/study5-realdata/fisher")
# Define model paths in a dictionary for easy expansion later
MODEL_PATHS = {
    'VAR': BASE_PATH / 'var' / 'mse.csv',
    'LSTM': BASE_PATH / 'lstm' / 'mse.csv',
    'MLVAR': BASE_PATH / 'mlvar' / 'mse.csv',
    'GIMME': BASE_PATH / 'gimme' / 'mse.csv',
    # Future models can be added here, e.g.:
    # 'GRU': BASE_PATH / 'gru' / 'mse.csv',
}
# Function to load and preprocess MSE data
def load_mse_data(model_paths):
    """Load MSE data from multiple models and combine into a single DataFrame."""
    all_data = []
    
    for model_name, file_path in model_paths.items():
        if file_path.exists():
            # Load the data
            df = pd.read_csv(file_path)
            
            # Add model column
            df['model'] = model_name
            
            # Append to list
            all_data.append(df)
        else:
            print(f"Warning: {file_path} does not exist.")
    
    # Combine all data
    if all_data:
        combined_df = pd.concat(all_data, ignore_index=True)
        return combined_df
    else:
        raise ValueError("No valid MSE data files found.")
# Load the data
mse_data = load_mse_data(MODEL_PATHS)
# Display basic info
print(f"Loaded MSE data for {mse_data['model'].nunique()} models.")
print(f"Total subjects: {mse_data['subject_id'].nunique()}")
print(f"Features available: {mse_data['feature_name'].nunique()}")
# Display the first few rows
mse_data.head()
Loaded MSE data for 4 models. Total subjects: 40 Features available: 9
Out[2]:
| subject_id | feature_name | mse | model | |
|---|---|---|---|---|
| 0 | p111 | down | 0.259425 | VAR | 
| 1 | p111 | positive | 0.130215 | VAR | 
| 2 | p111 | content | 0.470464 | VAR | 
| 3 | p111 | enthusiastic | 0.265476 | VAR | 
| 4 | p111 | energetic | 0.349234 | VAR | 
InĀ [3]:
# Basic statistical analysis
# Overall MSE statistics by model
model_stats = mse_data.groupby('model')['mse'].agg(['mean', 'median', 'std', 'min', 'max']).reset_index()
print("Overall MSE statistics by model:")
display(model_stats)
# Feature-wise MSE comparison
feature_stats = mse_data.groupby(['model', 'feature_name'])['mse'].agg(['mean', 'median', 'std']).reset_index()
# Sort by mean MSE to see which features have highest error
feature_stats_sorted = feature_stats.sort_values(by=['feature_name', 'mean'])
display(feature_stats_sorted)
Overall MSE statistics by model:
| model | mean | median | std | min | max | |
|---|---|---|---|---|---|---|
| 0 | GIMME | 1.269682 | 0.685026 | 2.337437 | 0.000355 | 29.193570 | 
| 1 | LSTM | 1.247971 | 0.692644 | 2.404799 | 0.018620 | 29.134912 | 
| 2 | MLVAR | 1.230239 | 0.694659 | 2.231106 | 0.039559 | 28.002925 | 
| 3 | VAR | 1.401815 | 0.787430 | 2.359657 | 0.037601 | 28.096260 | 
| model | feature_name | mean | median | std | |
|---|---|---|---|---|---|
| 9 | LSTM | angry | 1.192598 | 0.666441 | 1.571567 | 
| 18 | MLVAR | angry | 1.296812 | 0.675638 | 1.929048 | 
| 0 | GIMME | angry | 1.386666 | 0.604615 | 2.018156 | 
| 27 | VAR | angry | 1.526484 | 0.794921 | 2.076638 | 
| 19 | MLVAR | content | 1.143187 | 0.872580 | 1.144393 | 
| 10 | LSTM | content | 1.160596 | 0.852058 | 1.320677 | 
| 1 | GIMME | content | 1.190719 | 0.944220 | 1.313963 | 
| 28 | VAR | content | 1.290244 | 0.960834 | 1.484529 | 
| 20 | MLVAR | down | 1.095771 | 0.564518 | 1.996695 | 
| 2 | GIMME | down | 1.170654 | 0.646352 | 2.269943 | 
| 29 | VAR | down | 1.290572 | 0.650436 | 2.388121 | 
| 11 | LSTM | down | 1.323277 | 0.654343 | 2.994909 | 
| 12 | LSTM | energetic | 1.817534 | 0.822456 | 4.528124 | 
| 21 | MLVAR | energetic | 1.821117 | 0.763334 | 4.349153 | 
| 3 | GIMME | energetic | 1.849814 | 0.816260 | 4.532235 | 
| 30 | VAR | energetic | 1.965851 | 0.792798 | 4.373482 | 
| 13 | LSTM | enthusiastic | 1.255544 | 0.660126 | 1.614619 | 
| 22 | MLVAR | enthusiastic | 1.259551 | 0.697563 | 1.751597 | 
| 4 | GIMME | enthusiastic | 1.299732 | 0.758397 | 1.791810 | 
| 31 | VAR | enthusiastic | 1.470635 | 0.809426 | 2.185973 | 
| 23 | MLVAR | hopeless | 0.967087 | 0.473701 | 1.159483 | 
| 14 | LSTM | hopeless | 0.967758 | 0.571769 | 1.113841 | 
| 5 | GIMME | hopeless | 0.995098 | 0.458697 | 1.231998 | 
| 32 | VAR | hopeless | 1.116391 | 0.624417 | 1.262774 | 
| 15 | LSTM | irritable | 0.932096 | 0.669937 | 0.754931 | 
| 24 | MLVAR | irritable | 0.949391 | 0.645403 | 0.903208 | 
| 6 | GIMME | irritable | 1.004854 | 0.653048 | 0.996148 | 
| 33 | VAR | irritable | 1.132890 | 0.825173 | 0.961552 | 
| 25 | MLVAR | positive | 1.241801 | 0.608365 | 2.059933 | 
| 7 | GIMME | positive | 1.263742 | 0.616985 | 1.998281 | 
| 16 | LSTM | positive | 1.272362 | 0.586639 | 2.192872 | 
| 34 | VAR | positive | 1.333475 | 0.736462 | 2.099234 | 
| 8 | GIMME | reassure | 1.265857 | 0.690035 | 2.925788 | 
| 26 | MLVAR | reassure | 1.297434 | 0.805556 | 2.805472 | 
| 17 | LSTM | reassure | 1.309972 | 0.738551 | 3.127807 | 
| 35 | VAR | reassure | 1.489792 | 0.896860 | 2.767637 | 
InĀ [4]:
# Visualization 1: Overall model comparison (boxplot)
# Create boxplot using Plotly
fig = px.box(mse_data, x='model', y='mse', points='all', 
             title='Overall MSE Comparison Between Models',
             labels={'mse': 'Mean Squared Error (MSE)', 'model': 'Model'},
             color='model',
             height=600, width=900)
# Customize layout
fig.update_layout(
    title_font_size=20,
    xaxis_title_font_size=14,
    yaxis_title_font_size=14,
    legend_title_font_size=14,
    template='plotly_white',
    boxmode='group'
)
# Show the plot
fig.show(renderer="notebook")
# Perform paired t-test to check if differences are statistically significant
# This assumes same subjects and features across models
models = mse_data['model'].unique()
if len(models) > 1:
    print("\nStatistical tests between models:")
    for i in range(len(models)):
        for j in range(i+1, len(models)):
            model1, model2 = models[i], models[j]
            
            # Create a pivot table for paired comparison
            pivot_df = mse_data.pivot_table(
                index=['subject_id', 'feature_name'], 
                columns='model', 
                values='mse'
            ).reset_index()
            
            # Check if both models exist in the pivot
            if model1 in pivot_df.columns and model2 in pivot_df.columns:
                # Remove rows with NaN values
                valid_rows = pivot_df[[model1, model2]].dropna()
                
                if len(valid_rows) > 0:
                    # Perform paired t-test
                    t_stat, p_value = stats.ttest_rel(valid_rows[model1], valid_rows[model2])
                    print(f"\n{model1} vs {model2}:")
                    print(f"Paired t-test: t={t_stat:.4f}, p={p_value:.4f}")
                    print(f"Mean difference: {valid_rows[model1].mean() - valid_rows[model2].mean():.4f}")
                    
                    if p_value < 0.05:
                        if valid_rows[model1].mean() < valid_rows[model2].mean():
                            print(f"Conclusion: {model1} performs significantly better than {model2}")
                        else:
                            print(f"Conclusion: {model2} performs significantly better than {model1}")
                    else:
                        print(f"Conclusion: No significant difference between {model1} and {model2}")
            else:
                print(f"\nCannot compare {model1} vs {model2}: missing data")
Statistical tests between models: VAR vs LSTM: Paired t-test: t=4.9267, p=0.0000 Mean difference: 0.1538 Conclusion: LSTM performs significantly better than VAR VAR vs MLVAR: Paired t-test: t=8.1430, p=0.0000 Mean difference: 0.1716 Conclusion: MLVAR performs significantly better than VAR VAR vs GIMME: Paired t-test: t=6.2769, p=0.0000 Mean difference: 0.1321 Conclusion: GIMME performs significantly better than VAR LSTM vs MLVAR: Paired t-test: t=0.6572, p=0.5115 Mean difference: 0.0177 Conclusion: No significant difference between LSTM and MLVAR LSTM vs GIMME: Paired t-test: t=-0.8774, p=0.3808 Mean difference: -0.0217 Conclusion: No significant difference between LSTM and GIMME MLVAR vs GIMME: Paired t-test: t=-2.6284, p=0.0089 Mean difference: -0.0394 Conclusion: MLVAR performs significantly better than GIMME
InĀ [Ā ]:
# Visualization 2: Feature-wise comparison
# Calculate mean MSE by feature and model
feature_means = mse_data.groupby(['feature_name', 'model'])['mse'].mean().reset_index()
# Create bar chart with Plotly
fig = px.bar(feature_means, x='feature_name', y='mse', color='model', barmode='group',
             title='Average MSE by Feature Across Models',
             labels={'mse': 'Mean Squared Error (MSE)', 'feature_name': 'Feature', 'model': 'Model'},
             height=600, width=1000)
# Customize layout
fig.update_layout(
    title_font_size=20,
    xaxis_title_font_size=14,
    yaxis_title_font_size=14,
    legend_title_font_size=14,
    xaxis_tickangle=-45,
    template='plotly_white'
)
# Add hover data for better interaction
fig.update_traces(hovertemplate='Feature: %{x}<br>MSE: %{y:.4f}<br>Model: %{legendgroup}')
# Show the plot
fig.show(renderer="notebook")
InĀ [Ā ]:
# Visualization 3: Subject-wise comparison
# Calculate average MSE per subject across all features
subject_means = mse_data.groupby(['subject_id', 'model'])['mse'].mean().reset_index()
# Find subjects with highest average MSE (potential outliers)
top_subjects = subject_means.groupby('subject_id')['mse'].mean().nlargest(5).index.tolist()
print(f"Subjects with highest average MSE: {', '.join(top_subjects)}")
# Create bar chart for all subjects
fig = px.bar(subject_means, x='subject_id', y='mse', color='model', barmode='group',
             title='Average MSE by Subject Across Models',
             labels={'mse': 'Mean Squared Error (MSE)', 'subject_id': 'Subject ID', 'model': 'Model'},
             height=600, width=1100)
# Customize layout
fig.update_layout(
    title_font_size=20,
    xaxis_title_font_size=14,
    yaxis_title_font_size=14,
    legend_title_font_size=14,
    xaxis_tickangle=-90,
    template='plotly_white'
)
# Add hover data
fig.update_traces(hovertemplate='Subject: %{x}<br>MSE: %{y:.4f}<br>Model: %{legendgroup}')
# Show the plot
fig.show(renderer="notebook")
# If too many subjects, show only top 20 with highest average MSE
if subject_means['subject_id'].nunique() > 20:
    top_subjects = subject_means.groupby('subject_id')['mse'].mean().nlargest(20).index.tolist()
    top_subjects_data = subject_means[subject_means['subject_id'].isin(top_subjects)]
    
    # Create bar chart for top 20 subjects
    fig2 = px.bar(top_subjects_data, x='subject_id', y='mse', color='model', barmode='group',
                 title='Top 20 Subjects with Highest Average MSE',
                 labels={'mse': 'Mean Squared Error (MSE)', 'subject_id': 'Subject ID', 'model': 'Model'},
                 height=600, width=1100)
    
    # Customize layout
    fig2.update_layout(
        title_font_size=20,
        xaxis_title_font_size=14,
        yaxis_title_font_size=14,
        legend_title_font_size=14,
        xaxis_tickangle=-90,
        template='plotly_white'
    )
    
    # Add hover data
    fig2.update_traces(hovertemplate='Subject: %{x}<br>MSE: %{y:.4f}<br>Model: %{legendgroup}')
    
    # Show the plot
    fig2.show()
Subjects with highest average MSE: p013, p204, p160, p202, p139
InĀ [7]:
# Visualization 4: Correlation between model performances
# This shows if models tend to perform similarly on the same subjects/features
# Create wide-format data with one column per model
pivot_data = mse_data.pivot_table(
    index=['subject_id', 'feature_name'],
    columns='model',
    values='mse'
).reset_index()
# Calculate correlation between model performances
model_cols = [col for col in pivot_data.columns if col not in ['subject_id', 'feature_name']]
if len(model_cols) > 1:
    correlation = pivot_data[model_cols].corr()
    
    # Create heatmap with Plotly
    fig = go.Figure(data=go.Heatmap(
        z=correlation.values,
        x=correlation.columns,
        y=correlation.index,
        colorscale='RdBu_r',
        zmin=-1, zmax=1,
        text=correlation.values.round(3),
        texttemplate='%{text}',
        colorbar=dict(title='Correlation')
    ))
    
    # Update layout
    fig.update_layout(
        title='Correlation Between Model Performances',
        title_font_size=20,
        height=600, width=700,
        template='plotly_white'
    )
    
    # Show the plot
    fig.show(renderer="notebook")
    
    # Scatterplot of model performances (only for 2 models)
    if len(model_cols) == 2:  
        # Create the scatter plot with Plotly
        fig2 = px.scatter(pivot_data, x=model_cols[0], y=model_cols[1],
                         hover_data=['subject_id', 'feature_name'],
                         title=f'Comparison of {model_cols[0]} vs {model_cols[1]} MSE',
                         labels={model_cols[0]: f'{model_cols[0]} MSE', 
                                model_cols[1]: f'{model_cols[1]} MSE'},
                         height=700, width=700)
        
        # Add diagonal line (y=x)
        min_val = min(pivot_data[model_cols].min().min(), 0)
        max_val = pivot_data[model_cols].max().max() * 1.1
        
        fig2.add_trace(go.Scatter(
            x=[min_val, max_val],
            y=[min_val, max_val],
            mode='lines',
            line=dict(color='black', width=2, dash='dash'),
            name='y=x line'
        ))
        
        # Update layout
        fig2.update_layout(
            title_font_size=20,
            xaxis_title_font_size=14,
            yaxis_title_font_size=14,
            legend_title_font_size=14,
            template='plotly_white'
        )
        
        # Show the plot
        fig2.show()
else:
    print("Need at least two models for correlation analysis.")
InĀ [Ā ]:
# Visualization 5: Detailed feature analysis
# Compare each feature across models using boxplots
# Get list of unique features
features = mse_data['feature_name'].unique()
n_features = len(features)
# Calculate number of rows and columns for subplots
n_cols = 3  # Number of columns in the grid
n_rows = (n_features + n_cols - 1) // n_cols  # Calculate needed rows
# Create subplots with Plotly
fig = make_subplots(
    rows=n_rows, cols=n_cols,
    subplot_titles=[f'Feature: {feature}' for feature in features],
    vertical_spacing=0.1
)
# Add a boxplot for each feature
for i, feature in enumerate(features):
    # Calculate row and column for current subplot
    row = i // n_cols + 1
    col = i % n_cols + 1
    
    # Filter data for current feature
    feature_data = mse_data[mse_data['feature_name'] == feature]
    
    # For each model, add a box trace
    for model in feature_data['model'].unique():
        model_data = feature_data[feature_data['model'] == model]
        
        # Add box plot
        fig.add_trace(
            go.Box(
                y=model_data['mse'],
                name=model,
                boxpoints='all',  # Show all points
                jitter=0.3,  # Add jitter to points for better visualization
                pointpos=-1.8,  # Position of points relative to box
                marker=dict(size=3),  # Make points smaller
                showlegend=(i == 0),  # Only show legend for the first feature
                hovertemplate=f'Model: {model}<br>MSE: %{{y:.4f}}'
            ),
            row=row, col=col
        )
# Update layout
fig.update_layout(
    title_text='Feature-wise MSE Comparison Across Models',
    title_font_size=20,
    showlegend=True,
    legend_title_text='Model',
    height=300 * n_rows,  # Adjust height based on number of rows
    width=1000,
    template='plotly_white'
)
# Update y-axes
fig.update_yaxes(title_text='MSE')
# Show the figure
fig.show(renderer="notebook")
InĀ [9]:
# Function to add a new model to the comparison
def add_new_model(model_name, file_path):
    """Add a new model to the comparison.
    
    Args:
        model_name (str): Name of the new model
        file_path (str or Path): Path to the model's MSE CSV file
    
    Returns:
        DataFrame: Updated combined MSE data
    """
    global MODEL_PATHS, mse_data
    
    # Add to model paths
    MODEL_PATHS[model_name] = Path(file_path)
    
    # Reload all data
    try:
        mse_data = load_mse_data(MODEL_PATHS)
        print(f"Successfully added {model_name} to the comparison.")
        print(f"Now comparing {mse_data['model'].nunique()} models.")
        return mse_data
    except Exception as e:
        print(f"Error adding new model: {e}")
        return None
# Example usage:
# new_mse_data = add_new_model('GRU', '/home/ytli/research/lstm/study5-realdata/fisher/gru/mse.csv')