LSTM (test-all)¶
In [1]:
import pandas as pd
import os, sys
from pathlib import Path
import torch
from datetime import datetime
import logging
import mlflow
import warnings
# Set up paths and logging
root_folder_path = Path('/home/ytli/research')
experiment_folder_path = Path('/home/ytli/research/lstm')
sys.path.append(str(root_folder_path))
sys.path.append(str(experiment_folder_path))
from modules.study_multivar import RealDataTimeSeriesAnalysis, calculate_mse_by_subject_feature
from modules.plot import plot_subject_feature_html_interactive
In [2]:
study_name = "study5-realdata"
folder_path = "fisher"
method_name = "lstm"
datafile_path = "data/fisher_all.csv"
# Initialize the analysis
analysis = RealDataTimeSeriesAnalysis(study_name, folder_path, datafile_path)
# Load and prepare data
# feature_cols = ['energetic','enthusiastic','content','irritable','restless','worried','guilty','afraid','anhedonia','angry','hopeless','down','positive','fatigue','tension','concentrate','ruminate','avoid_act','reassure','procrast','avoid_people']
feature_cols = [
'down', 'positive', 'content', 'enthusiastic', 'energetic',
'hopeless', 'angry', 'irritable', 'reassure'
]
df = analysis.load_data()
sliding_windows_dict = analysis.create_sliding_windows(df, feature_cols, window_size=5, stride=1)
train_loader, val_loader, test_loader = analysis.prepare_datasets(sliding_windows_dict)
# Train model
print("\033[1m\033[95mTraining model...\033[0m")
best_model, best_model_checkpoint_metrics, test_results = analysis.train_model(train_loader, val_loader, test_loader)
print(f"\033[1m\033[96mBEST MODEL VAL LOSS:\033[0m {best_model_checkpoint_metrics['val_loss']:.5f}")
print(f"\033[1m\033[96mTEST RESULTS:\033[0m loss = {test_results[0]['test_loss']:.5f}")
# Evaluate and visualize
print("\033[1m\033[95mEvaluating model on test data...\033[0m")
test_eval_data = analysis.evaluate_model(best_model, test_loader)
print("="*80)
2025/05/22 10:47:05 WARNING mlflow.utils.autologging_utils: MLflow pytorch autologging is known to be compatible with 1.9.0 <= torch <= 2.6.0, but the installed version is 2.6.0+cu124. If you encounter errors during autologging, try upgrading / downgrading torch to a compatible version, or try upgrading MLflow. INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Training model...
2025/05/22 10:47:24 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.
Downloading artifacts: 0%| | 0/1 [00:00<?, ?it/s]
Downloading artifacts: 0%| | 0/1 [00:00<?, ?it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
🏃 View run fisher at: http://localhost:8093/#/experiments/92/runs/5d9db9f36acf40978c1508ea3458daf0 🧪 View experiment at: http://localhost:8093/#/experiments/92 BEST MODEL VAL LOSS: 0.74562 TEST RESULTS: loss = 1.24797 Evaluating model on test data... ================================================================================
In [3]:
plot_subject_feature_html_interactive(test_eval_data, feature_cols)
In [4]:
mse_df = calculate_mse_by_subject_feature(test_eval_data, feature_cols)
mse_df.to_csv(f'mse.csv', index=False)
mse_df.head(20)
Out[4]:
| subject_id | feature_name | mse | |
|---|---|---|---|
| 0 | p111 | down | 0.180548 |
| 1 | p111 | positive | 0.117444 |
| 2 | p111 | content | 0.503352 |
| 3 | p111 | enthusiastic | 0.282189 |
| 4 | p111 | energetic | 0.375488 |
| 5 | p111 | hopeless | 0.400014 |
| 6 | p111 | angry | 0.666034 |
| 7 | p111 | irritable | 0.580033 |
| 8 | p111 | reassure | 0.696139 |
| 9 | p025 | down | 0.662425 |
| 10 | p025 | positive | 2.178204 |
| 11 | p025 | content | 2.336298 |
| 12 | p025 | enthusiastic | 1.888744 |
| 13 | p025 | energetic | 1.006653 |
| 14 | p025 | hopeless | 1.221417 |
| 15 | p025 | angry | 0.869682 |
| 16 | p025 | irritable | 0.819950 |
| 17 | p025 | reassure | 1.556690 |
| 18 | p023 | down | 0.245138 |
| 19 | p023 | positive | 0.570401 |
In [ ]: