In [1]:
import pandas as pd
import plotly.express as px
In [8]:
mse_method_df = pd.read_csv("mse_method_df.csv")
mse_method_subject_df = pd.read_csv("mse_method_subject_df.csv")
In [9]:
# Decode config_key column into nsteps, clusters, and subjects
df = mse_method_df.copy()
df[["nsteps", "clusters", "subjects", "seed"]] = df["config_key"].str.extract(r"nsteps=(\d+)_clusters=(\d+)_subjects=(\d+)_seed=(\d+)").astype(int)
df = df.query("model != 'naive'") # Exclude naive model for analysis
df
Out[9]:
| model | mean | median | std | min | max | config_key | base_config_key | seed | nsteps | clusters | subjects | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | VAR_grp_w/o | 2.529396 | 2.312774 | 0.984564 | 0.762429 | 5.035862 | nsteps=150_clusters=12_subjects=2_seed=42 | nsteps=150_clusters=12_subjects=2 | 42 | 150 | 12 | 2 |
| 1 | VAR_grp_w/ | 2.341685 | 1.882104 | 1.891267 | 0.833683 | 17.527476 | nsteps=150_clusters=12_subjects=2_seed=42 | nsteps=150_clusters=12_subjects=2 | 42 | 150 | 12 | 2 |
| 2 | VAR_ind | 2.126768 | 1.789570 | 1.259586 | 0.874318 | 10.998972 | nsteps=150_clusters=12_subjects=2_seed=42 | nsteps=150_clusters=12_subjects=2 | 42 | 150 | 12 | 2 |
| 3 | LSTM_grp_w/o | 2.512106 | 2.310239 | 1.117340 | 0.930896 | 7.331552 | nsteps=150_clusters=12_subjects=2_seed=42 | nsteps=150_clusters=12_subjects=2 | 42 | 150 | 12 | 2 |
| 4 | LSTM_grp_w/ sp | 2.172803 | 2.023987 | 0.832167 | 0.970385 | 6.571903 | nsteps=150_clusters=12_subjects=2_seed=42 | nsteps=150_clusters=12_subjects=2 | 42 | 150 | 12 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 955 | LSTM_grp_w/o | 1.528850 | 1.458457 | 0.267890 | 1.004945 | 2.369555 | nsteps=1500_clusters=2_subjects=12_seed=46 | nsteps=1500_clusters=2_subjects=12 | 46 | 1500 | 2 | 12 |
| 956 | LSTM_grp_w/ sp | 1.526233 | 1.455376 | 0.269169 | 1.019702 | 2.414470 | nsteps=1500_clusters=2_subjects=12_seed=46 | nsteps=1500_clusters=2_subjects=12 | 46 | 1500 | 2 | 12 |
| 957 | LSTM_grp_w/ oh | 1.554944 | 1.480679 | 0.272828 | 1.056589 | 2.483530 | nsteps=1500_clusters=2_subjects=12_seed=46 | nsteps=1500_clusters=2_subjects=12 | 46 | 1500 | 2 | 12 |
| 958 | LSTM_grp_w/ enc | 1.543440 | 1.492471 | 0.272766 | 1.003309 | 2.490883 | nsteps=1500_clusters=2_subjects=12_seed=46 | nsteps=1500_clusters=2_subjects=12 | 46 | 1500 | 2 | 12 |
| 959 | LSTM_ind | 1.667045 | 1.552357 | 0.369190 | 1.113011 | 2.869350 | nsteps=1500_clusters=2_subjects=12_seed=46 | nsteps=1500_clusters=2_subjects=12 | 46 | 1500 | 2 | 12 |
960 rows × 12 columns
In [10]:
df_mean_by_seed = df.groupby(["nsteps", "clusters", "subjects", "model"]).agg(
mean_mse=("mean", "mean"),
std_mse=("mean", "std"),
count=("mean", "count")
).reset_index()
df_mean_by_seed
Out[10]:
| nsteps | clusters | subjects | model | mean_mse | std_mse | count | |
|---|---|---|---|---|---|---|---|
| 0 | 150 | 2 | 12 | LSTM_grp_w/ enc | 1.809632 | 0.113523 | 5 |
| 1 | 150 | 2 | 12 | LSTM_grp_w/ oh | 1.783808 | 0.139527 | 5 |
| 2 | 150 | 2 | 12 | LSTM_grp_w/ sp | 1.727778 | 0.147628 | 5 |
| 3 | 150 | 2 | 12 | LSTM_grp_w/o | 1.829636 | 0.153214 | 5 |
| 4 | 150 | 2 | 12 | LSTM_ind | 2.529305 | 0.323055 | 5 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 187 | 1500 | 12 | 2 | LSTM_grp_w/o | 2.047121 | 0.040751 | 5 |
| 188 | 1500 | 12 | 2 | LSTM_ind | 1.800304 | 0.068679 | 5 |
| 189 | 1500 | 12 | 2 | VAR_grp_w/ | 1.844024 | 0.077039 | 5 |
| 190 | 1500 | 12 | 2 | VAR_grp_w/o | 2.551336 | 0.145536 | 5 |
| 191 | 1500 | 12 | 2 | VAR_ind | 1.830428 | 0.073168 | 5 |
192 rows × 7 columns
In [11]:
# Prepare means for plotting
df_grouped_nsteps = df_mean_by_seed.groupby(["model", "nsteps"], as_index=False)["mean_mse"].mean()
df_grouped_clusters = df_mean_by_seed.groupby(["model", "clusters"], as_index=False)["mean_mse"].mean()
df_grouped_facet = df_mean_by_seed.groupby(["model", "nsteps", "clusters"], as_index=False)["mean_mse"].mean()
In [12]:
color_discrete_map = {
'VAR_grp_w/o': '#D7263D', # vivid red
'VAR_grp_w/': '#D7263D',
'VAR_ind': '#D7263D',
'LSTM_grp_w/o': '#21A179', # nice green
'LSTM_grp_w/ sp': "#25D79F",
'LSTM_grp_w/ oh': "#498BE0",
'LSTM_grp_w/ enc': "#773FEF",
'LSTM_ind': '#21A179',
}
line_dash_map = {
'VAR_grp_w/o': 'solid',
'VAR_grp_w/': 'dash',
'VAR_ind': 'dot',
'LSTM_grp_w/o': 'solid',
'LSTM_grp_w/': 'dash',
'LSTM_ind': 'dot',
}
mean mse by nsteps for each model¶
In [14]:
fig = px.line(
df_grouped_nsteps, x="nsteps", y="mean_mse", color="model", markers=True, title="Mean vs. Nsteps by Model",
color_discrete_map=color_discrete_map, line_dash="model", line_dash_map=line_dash_map
)
fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")
mean mse by clusters for each model¶
In [15]:
fig = px.line(
df_grouped_clusters, x="clusters", y="mean_mse", color="model", markers=True, title="Mean vs. Clusters by Model",
color_discrete_map=color_discrete_map, line_dash="model", line_dash_map=line_dash_map
)
fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")
mean mse by clusters, faceted by nsteps¶
In [8]:
fig = px.line(
df_grouped_facet,
x="clusters", y="mean", color="model", markers=True,
facet_col="nsteps", facet_col_wrap=3,
title="Mean vs. Clusters by Model, for each nsteps",
color_discrete_map=color_discrete_map,
line_dash="model", line_dash_map=line_dash_map,
height=900 # increase this value to make the plot taller
)
fig.update_traces(line=dict(width=1.5))
fig.show(renderer="notebook")
Bad subject count analysis¶
Pick subject whose lstm_w performance is worse than var_w
In [9]:
mse_method_subject_df[['nsteps', 'clusters', 'subjects']] = mse_method_subject_df['config_key'].str.extract(r'nsteps=(\d+)_clusters=(\d+)_subjects=(\d+)').astype(int)
bad_subjects = mse_method_subject_df[~mse_method_subject_df["is_lstm_better_than_var (w)"]].copy()
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) File ~/research/simulations/.venv/lib/python3.12/site-packages/pandas/core/indexes/base.py:3812, in Index.get_loc(self, key) 3811 try: -> 3812 return self._engine.get_loc(casted_key) 3813 except KeyError as err: File pandas/_libs/index.pyx:167, in pandas._libs.index.IndexEngine.get_loc() File pandas/_libs/index.pyx:196, in pandas._libs.index.IndexEngine.get_loc() File pandas/_libs/hashtable_class_helper.pxi:7088, in pandas._libs.hashtable.PyObjectHashTable.get_item() File pandas/_libs/hashtable_class_helper.pxi:7096, in pandas._libs.hashtable.PyObjectHashTable.get_item() KeyError: 'is_lstm_better_than_var (w)' The above exception was the direct cause of the following exception: KeyError Traceback (most recent call last) Cell In[9], line 2 1 mse_method_subject_df[['nsteps', 'clusters', 'subjects']] = mse_method_subject_df['config_key'].str.extract(r'nsteps=(\d+)_clusters=(\d+)_subjects=(\d+)').astype(int) ----> 2 bad_subjects = mse_method_subject_df[~mse_method_subject_df["is_lstm_better_than_var (w)"]].copy() File ~/research/simulations/.venv/lib/python3.12/site-packages/pandas/core/frame.py:4107, in DataFrame.__getitem__(self, key) 4105 if self.columns.nlevels > 1: 4106 return self._getitem_multilevel(key) -> 4107 indexer = self.columns.get_loc(key) 4108 if is_integer(indexer): 4109 indexer = [indexer] File ~/research/simulations/.venv/lib/python3.12/site-packages/pandas/core/indexes/base.py:3819, in Index.get_loc(self, key) 3814 if isinstance(casted_key, slice) or ( 3815 isinstance(casted_key, abc.Iterable) 3816 and any(isinstance(x, slice) for x in casted_key) 3817 ): 3818 raise InvalidIndexError(key) -> 3819 raise KeyError(key) from err 3820 except TypeError: 3821 # If we have a listlike key, _check_indexing_error will raise 3822 # InvalidIndexError. Otherwise we fall through and re-raise 3823 # the TypeError. 3824 self._check_indexing_error(key) KeyError: 'is_lstm_better_than_var (w)'
In [ ]:
bad_subjects_count = bad_subjects.groupby(['nsteps', 'clusters']).size().reset_index(name='bad_subject_count')
all_combinations = mse_method_subject_df[['nsteps', 'clusters']].drop_duplicates()
bad_subjects_count = pd.merge(all_combinations, bad_subjects_count, on=['nsteps', 'clusters'], how='left').fillna(0)
bad_subjects_count['bad_subject_count'] = bad_subjects_count['bad_subject_count'].astype(int)
bad_subjects_count
Out[ ]:
| nsteps | clusters | bad_subject_count | |
|---|---|---|---|
| 0 | 150 | 12 | 17 |
| 1 | 150 | 6 | 16 |
| 2 | 150 | 4 | 14 |
| 3 | 150 | 2 | 10 |
| 4 | 300 | 12 | 22 |
| 5 | 300 | 6 | 17 |
| 6 | 300 | 4 | 14 |
| 7 | 300 | 2 | 20 |
| 8 | 600 | 12 | 22 |
| 9 | 600 | 6 | 16 |
| 10 | 600 | 4 | 15 |
| 11 | 600 | 2 | 13 |
| 12 | 900 | 12 | 13 |
| 13 | 900 | 6 | 13 |
| 14 | 900 | 4 | 7 |
| 15 | 900 | 2 | 14 |
| 16 | 1200 | 12 | 19 |
| 17 | 1200 | 6 | 12 |
| 18 | 1200 | 4 | 6 |
| 19 | 1200 | 2 | 16 |
| 20 | 1500 | 12 | 19 |
| 21 | 1500 | 6 | 7 |
| 22 | 1500 | 4 | 12 |
| 23 | 1500 | 2 | 6 |
In [ ]:
fig = px.line(bad_subjects_count, x='clusters', y='bad_subject_count', color='nsteps', markers=True, title='Bad Subject Count vs. Clusters by Nsteps')
fig.show(renderer="notebook")
In [ ]:
fig = px.line(bad_subjects_count, x='nsteps', y='bad_subject_count', color='clusters', markers=True, title='Bad Subject Count vs. Clusters by Nsteps')
fig.show(renderer="notebook")
In [ ]: