In [1]:
import pandas as pd
import plotly.express as px
In [2]:
mse_method_df = pd.read_csv("mse_method_df.csv")
mse_method_subject_df = pd.read_csv("mse_method_subject_df.csv")
In [3]:
# Decode config_key column into nsteps, clusters, and subjects
df = mse_method_df.copy()
df[["nsteps", "clusters", "subjects"]] = df["config_key"].str.extract(r"nsteps=(\d+)_clusters=(\d+)_subjects=(\d+)").astype(int)
df = df.query("model != 'naive'")  # Exclude naive model for analysis
df
Out[3]:
model mean median std min max config_key nsteps clusters subjects
0 VAR_grp_w/o 2.516169 2.337215 0.964860 0.759451 5.054440 nsteps=150_clusters=12_subjects=2 150 12 2
1 VAR_grp_w/ 2.341685 1.882104 1.891267 0.833683 17.527476 nsteps=150_clusters=12_subjects=2 150 12 2
2 VAR_ind 2.126768 1.789570 1.259586 0.874318 10.998972 nsteps=150_clusters=12_subjects=2 150 12 2
3 LSTM_grp_w/o 2.395542 2.273671 0.895071 0.850160 6.454545 nsteps=150_clusters=12_subjects=2 150 12 2
4 LSTM_grp_w/ 2.198785 2.048785 0.779989 1.118295 5.547787 nsteps=150_clusters=12_subjects=2 150 12 2
... ... ... ... ... ... ... ... ... ... ...
139 VAR_grp_w/ 1.512782 1.470542 0.254903 1.070859 2.302964 nsteps=1500_clusters=2_subjects=12 1500 2 12
140 VAR_ind 1.507700 1.472241 0.249388 1.076034 2.123898 nsteps=1500_clusters=2_subjects=12 1500 2 12
141 LSTM_grp_w/o 1.418524 1.377471 0.179482 1.032010 1.986018 nsteps=1500_clusters=2_subjects=12 1500 2 12
142 LSTM_grp_w/ 1.405119 1.380497 0.173763 1.054613 1.929285 nsteps=1500_clusters=2_subjects=12 1500 2 12
143 LSTM_ind 1.518143 1.471425 0.226905 1.055793 2.091106 nsteps=1500_clusters=2_subjects=12 1500 2 12

144 rows × 10 columns

In [4]:
# Prepare means for plotting
df_grouped_nsteps = df.groupby(["model", "nsteps"], as_index=False)["mean"].mean()
df_grouped_clusters = df.groupby(["model", "clusters"], as_index=False)["mean"].mean()
df_grouped_facet = df.groupby(["model", "nsteps", "clusters"], as_index=False)["mean"].mean()
In [5]:
color_discrete_map = {
    'VAR_grp_w/o': '#D7263D',  # vivid red
    'VAR_grp_w/': '#D7263D',
    'VAR_ind': '#D7263D',
    'LSTM_grp_w/o': '#21A179',  # nice green
    'LSTM_grp_w/': '#21A179',
    'LSTM_ind': '#21A179',
}
line_dash_map = {
    'VAR_grp_w/o': 'solid',
    'VAR_grp_w/': 'dash',
    'VAR_ind': 'dot',
    'LSTM_grp_w/o': 'solid',
    'LSTM_grp_w/': 'dash',
    'LSTM_ind': 'dot',
}

mean mse by nsteps for each model¶

In [6]:
fig = px.line(
    df_grouped_nsteps, x="nsteps", y="mean", color="model", markers=True, title="Mean vs. Nsteps by Model",
    color_discrete_map=color_discrete_map, line_dash="model", line_dash_map=line_dash_map
)
fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")

mean mse by clusters for each model¶

In [7]:
fig = px.line(
    df_grouped_clusters, x="clusters", y="mean", color="model", markers=True, title="Mean vs. Clusters by Model",
    color_discrete_map=color_discrete_map, line_dash="model", line_dash_map=line_dash_map
)
fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")

mean mse by clusters, faceted by nsteps¶

In [8]:
fig = px.line(
    df_grouped_facet, x="clusters", y="mean", color="model", markers=True, facet_col="nsteps", facet_col_wrap=3, title="Mean vs. Clusters by Model, for each nsteps",
    color_discrete_map=color_discrete_map, line_dash="model", line_dash_map=line_dash_map
)
fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")

Bad subject count analysis¶

Pick subject whose lstm_w performance is worse than var_w

In [9]:
mse_method_subject_df[['nsteps', 'clusters', 'subjects']] = mse_method_subject_df['config_key'].str.extract(r'nsteps=(\d+)_clusters=(\d+)_subjects=(\d+)').astype(int)
bad_subjects = mse_method_subject_df[~mse_method_subject_df["is_lstm_better_than_var (w)"]].copy()
In [10]:
bad_subjects_count = bad_subjects.groupby(['nsteps', 'clusters']).size().reset_index(name='bad_subject_count')
all_combinations = mse_method_subject_df[['nsteps', 'clusters']].drop_duplicates()
bad_subjects_count = pd.merge(all_combinations, bad_subjects_count, on=['nsteps', 'clusters'], how='left').fillna(0)
bad_subjects_count['bad_subject_count'] = bad_subjects_count['bad_subject_count'].astype(int)
bad_subjects_count
Out[10]:
nsteps clusters bad_subject_count
0 150 12 17
1 150 8 22
2 150 4 15
3 150 2 10
4 300 12 18
5 300 8 16
6 300 4 16
7 300 2 10
8 600 12 14
9 600 8 12
10 600 4 6
11 600 2 2
12 900 12 14
13 900 8 6
14 900 4 4
15 900 2 0
16 1200 12 11
17 1200 8 4
18 1200 4 1
19 1200 2 0
20 1500 12 9
21 1500 8 5
22 1500 4 2
23 1500 2 0
In [11]:
fig = px.line(bad_subjects_count, x='clusters', y='bad_subject_count', color='nsteps', markers=True, title='Bad Subject Count vs. Clusters by Nsteps')
fig.show(renderer="notebook")
In [12]:
fig = px.line(bad_subjects_count, x='nsteps', y='bad_subject_count', color='clusters', markers=True, title='Bad Subject Count vs. Clusters by Nsteps')
fig.show(renderer="notebook")
In [ ]: