In [1]:
import pandas as pd
import plotly.express as px
In [8]:
mse_method_df = pd.read_csv("mse_method_df.csv")
mse_method_subject_df = pd.read_csv("mse_method_subject_df.csv")
In [9]:
# Decode config_key column into nsteps, clusters, and subjects
df = mse_method_df.copy()
df[["nsteps", "clusters", "subjects", "seed"]] = df["config_key"].str.extract(r"nsteps=(\d+)_clusters=(\d+)_subjects=(\d+)_seed=(\d+)").astype(int)
df = df.query("model != 'naive'")  # Exclude naive model for analysis
df
Out[9]:
model mean median std min max config_key base_config_key seed nsteps clusters subjects
0 VAR_grp_w/o 2.529396 2.312774 0.984564 0.762429 5.035862 nsteps=150_clusters=12_subjects=2_seed=42 nsteps=150_clusters=12_subjects=2 42 150 12 2
1 VAR_grp_w/ 2.341685 1.882104 1.891267 0.833683 17.527476 nsteps=150_clusters=12_subjects=2_seed=42 nsteps=150_clusters=12_subjects=2 42 150 12 2
2 VAR_ind 2.126768 1.789570 1.259586 0.874318 10.998972 nsteps=150_clusters=12_subjects=2_seed=42 nsteps=150_clusters=12_subjects=2 42 150 12 2
3 LSTM_grp_w/o 2.512106 2.310239 1.117340 0.930896 7.331552 nsteps=150_clusters=12_subjects=2_seed=42 nsteps=150_clusters=12_subjects=2 42 150 12 2
4 LSTM_grp_w/ sp 2.172803 2.023987 0.832167 0.970385 6.571903 nsteps=150_clusters=12_subjects=2_seed=42 nsteps=150_clusters=12_subjects=2 42 150 12 2
... ... ... ... ... ... ... ... ... ... ... ... ...
955 LSTM_grp_w/o 1.528850 1.458457 0.267890 1.004945 2.369555 nsteps=1500_clusters=2_subjects=12_seed=46 nsteps=1500_clusters=2_subjects=12 46 1500 2 12
956 LSTM_grp_w/ sp 1.526233 1.455376 0.269169 1.019702 2.414470 nsteps=1500_clusters=2_subjects=12_seed=46 nsteps=1500_clusters=2_subjects=12 46 1500 2 12
957 LSTM_grp_w/ oh 1.554944 1.480679 0.272828 1.056589 2.483530 nsteps=1500_clusters=2_subjects=12_seed=46 nsteps=1500_clusters=2_subjects=12 46 1500 2 12
958 LSTM_grp_w/ enc 1.543440 1.492471 0.272766 1.003309 2.490883 nsteps=1500_clusters=2_subjects=12_seed=46 nsteps=1500_clusters=2_subjects=12 46 1500 2 12
959 LSTM_ind 1.667045 1.552357 0.369190 1.113011 2.869350 nsteps=1500_clusters=2_subjects=12_seed=46 nsteps=1500_clusters=2_subjects=12 46 1500 2 12

960 rows × 12 columns

In [10]:
df_mean_by_seed = df.groupby(["nsteps", "clusters", "subjects", "model"]).agg(
    mean_mse=("mean", "mean"),
    std_mse=("mean", "std"),
    count=("mean", "count")
).reset_index()
df_mean_by_seed
Out[10]:
nsteps clusters subjects model mean_mse std_mse count
0 150 2 12 LSTM_grp_w/ enc 1.809632 0.113523 5
1 150 2 12 LSTM_grp_w/ oh 1.783808 0.139527 5
2 150 2 12 LSTM_grp_w/ sp 1.727778 0.147628 5
3 150 2 12 LSTM_grp_w/o 1.829636 0.153214 5
4 150 2 12 LSTM_ind 2.529305 0.323055 5
... ... ... ... ... ... ... ...
187 1500 12 2 LSTM_grp_w/o 2.047121 0.040751 5
188 1500 12 2 LSTM_ind 1.800304 0.068679 5
189 1500 12 2 VAR_grp_w/ 1.844024 0.077039 5
190 1500 12 2 VAR_grp_w/o 2.551336 0.145536 5
191 1500 12 2 VAR_ind 1.830428 0.073168 5

192 rows × 7 columns

In [11]:
# Prepare means for plotting
df_grouped_nsteps = df_mean_by_seed.groupby(["model", "nsteps"], as_index=False)["mean_mse"].mean()
df_grouped_clusters = df_mean_by_seed.groupby(["model", "clusters"], as_index=False)["mean_mse"].mean()
df_grouped_facet = df_mean_by_seed.groupby(["model", "nsteps", "clusters"], as_index=False)["mean_mse"].mean()
In [12]:
color_discrete_map = {
    'VAR_grp_w/o': '#D7263D',  # vivid red
    'VAR_grp_w/': '#D7263D',
    'VAR_ind': '#D7263D',
    'LSTM_grp_w/o': '#21A179',  # nice green
    'LSTM_grp_w/ sp': "#25D79F",
    'LSTM_grp_w/ oh': "#498BE0",
    'LSTM_grp_w/ enc': "#773FEF",
    'LSTM_ind': '#21A179',
}
line_dash_map = {
    'VAR_grp_w/o': 'solid',
    'VAR_grp_w/': 'dash',
    'VAR_ind': 'dot',
    'LSTM_grp_w/o': 'solid',
    'LSTM_grp_w/': 'dash',
    'LSTM_ind': 'dot',
}

mean mse by nsteps for each model¶

In [14]:
fig = px.line(
    df_grouped_nsteps, x="nsteps", y="mean_mse", color="model", markers=True, title="Mean vs. Nsteps by Model",
    color_discrete_map=color_discrete_map, line_dash="model", line_dash_map=line_dash_map
)
fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")

mean mse by clusters for each model¶

In [15]:
fig = px.line(
    df_grouped_clusters, x="clusters", y="mean_mse", color="model", markers=True, title="Mean vs. Clusters by Model",
    color_discrete_map=color_discrete_map, line_dash="model", line_dash_map=line_dash_map
)
fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")

mean mse by clusters, faceted by nsteps¶

In [8]:
fig = px.line(
    df_grouped_facet, 
    x="clusters", y="mean", color="model", markers=True,
    facet_col="nsteps", facet_col_wrap=3, 
    title="Mean vs. Clusters by Model, for each nsteps",
    color_discrete_map=color_discrete_map, 
    line_dash="model", line_dash_map=line_dash_map,
    height=900   # increase this value to make the plot taller
)

fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")

Bad subject count analysis¶

Pick subject whose lstm_w performance is worse than var_w

In [9]:
mse_method_subject_df[['nsteps', 'clusters', 'subjects']] = mse_method_subject_df['config_key'].str.extract(r'nsteps=(\d+)_clusters=(\d+)_subjects=(\d+)').astype(int)
bad_subjects = mse_method_subject_df[~mse_method_subject_df["is_lstm_better_than_var (w)"]].copy()
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/research/simulations/.venv/lib/python3.12/site-packages/pandas/core/indexes/base.py:3812, in Index.get_loc(self, key)
   3811 try:
-> 3812     return self._engine.get_loc(casted_key)
   3813 except KeyError as err:

File pandas/_libs/index.pyx:167, in pandas._libs.index.IndexEngine.get_loc()

File pandas/_libs/index.pyx:196, in pandas._libs.index.IndexEngine.get_loc()

File pandas/_libs/hashtable_class_helper.pxi:7088, in pandas._libs.hashtable.PyObjectHashTable.get_item()

File pandas/_libs/hashtable_class_helper.pxi:7096, in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 'is_lstm_better_than_var (w)'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Cell In[9], line 2
      1 mse_method_subject_df[['nsteps', 'clusters', 'subjects']] = mse_method_subject_df['config_key'].str.extract(r'nsteps=(\d+)_clusters=(\d+)_subjects=(\d+)').astype(int)
----> 2 bad_subjects = mse_method_subject_df[~mse_method_subject_df["is_lstm_better_than_var (w)"]].copy()

File ~/research/simulations/.venv/lib/python3.12/site-packages/pandas/core/frame.py:4107, in DataFrame.__getitem__(self, key)
   4105 if self.columns.nlevels > 1:
   4106     return self._getitem_multilevel(key)
-> 4107 indexer = self.columns.get_loc(key)
   4108 if is_integer(indexer):
   4109     indexer = [indexer]

File ~/research/simulations/.venv/lib/python3.12/site-packages/pandas/core/indexes/base.py:3819, in Index.get_loc(self, key)
   3814     if isinstance(casted_key, slice) or (
   3815         isinstance(casted_key, abc.Iterable)
   3816         and any(isinstance(x, slice) for x in casted_key)
   3817     ):
   3818         raise InvalidIndexError(key)
-> 3819     raise KeyError(key) from err
   3820 except TypeError:
   3821     # If we have a listlike key, _check_indexing_error will raise
   3822     #  InvalidIndexError. Otherwise we fall through and re-raise
   3823     #  the TypeError.
   3824     self._check_indexing_error(key)

KeyError: 'is_lstm_better_than_var (w)'
In [ ]:
bad_subjects_count = bad_subjects.groupby(['nsteps', 'clusters']).size().reset_index(name='bad_subject_count')
all_combinations = mse_method_subject_df[['nsteps', 'clusters']].drop_duplicates()
bad_subjects_count = pd.merge(all_combinations, bad_subjects_count, on=['nsteps', 'clusters'], how='left').fillna(0)
bad_subjects_count['bad_subject_count'] = bad_subjects_count['bad_subject_count'].astype(int)
bad_subjects_count
Out[ ]:
nsteps clusters bad_subject_count
0 150 12 17
1 150 6 16
2 150 4 14
3 150 2 10
4 300 12 22
5 300 6 17
6 300 4 14
7 300 2 20
8 600 12 22
9 600 6 16
10 600 4 15
11 600 2 13
12 900 12 13
13 900 6 13
14 900 4 7
15 900 2 14
16 1200 12 19
17 1200 6 12
18 1200 4 6
19 1200 2 16
20 1500 12 19
21 1500 6 7
22 1500 4 12
23 1500 2 6
In [ ]:
fig = px.line(bad_subjects_count, x='clusters', y='bad_subject_count', color='nsteps', markers=True, title='Bad Subject Count vs. Clusters by Nsteps')
fig.show(renderer="notebook")
In [ ]:
fig = px.line(bad_subjects_count, x='nsteps', y='bad_subject_count', color='clusters', markers=True, title='Bad Subject Count vs. Clusters by Nsteps')
fig.show(renderer="notebook")
In [ ]: