Simple Temporal Attention (SimTA) module is an attention module to analyze asynchronous time series.
Predicting clinical outcome is remarkably important but challenging. Research efforts have been paid on seeking significant biomarkers associated with the therapy response or/and patient survival. However, these biomarkers are generally costly and invasive, and possibly dissatifactory for novel therapy. On the other hand, multi-modal, heterogeneous, unaligned temporal data is continuously generated in clinical practice. This paper aims at a unified deep learning approach to predict patient prognosis and therapy response, with easily accessible data, e.g., radiographics, laboratory and clinical information. Prior arts focus on modeling single data modality, or ignore the temporal changes. Importantly, the clinical time series is asynchronous in practice, i.e., recorded with irregular intervals. In this study, we formalize the prognosis modeling as a multi-modal asynchronous time series classification task, and propose a MIA-Prognosis framework with Measurement, Intervention and Assessment (MIA) information to predict therapy response, where a Simple Temporal Attention (SimTA) module is developed to process the asynchronous time series. Experiments on synthetic dataset validate the superiory of SimTA over standard RNN-based approaches. Furthermore, we experiment the proposed method on an in-house, retrospective dataset of real-world non-small cell lung cancer patients under anti-PD-1 immunotherapy. The proposed method achieves promising performance on predicting the immunotherapy response. Notably, our predictive model could further stratify low-risk and high-risk patients in terms of long-term survival.
For more details, please refer to our paper:
MIA-Prognosis: A Deep Learning Framework to Predict Therapy Response
Jiancheng Yang, Jiajun Chen, Kaiming Kuang, Tiancheng Lin, Junjun He, Bingbing Ni
- SimTA/
config/
: Training configurationsdataset/
: Toy data generator, PyTorch dataset and toy data pickles used in experiments.engine/
: Keras-like training and evaluation APIetc/
: images for README.mdmodels/
: PyTorch Implementation of SimTA modules and LSTM model as comparison.run/
: Training and evaluation scriptsutils/
: Utility functions
pytorch>=1.3.0
numpy>=1.18.0
scikit-learn>=0.23.2
tqdm>=4.38.0
To train the SimTA model on the toy dataset, run:
python -m run.main_simta --logdir <logging_directory>
Tensorboard logs, configuration files and trained PyTorch models are saved under the logging_directory
that you specify.
To train LSTM-based approaches for comparison, run:
python -m run.main_lstm --logdir <logging_directory>
python -m run.main_lstm_time_interval --logdir <logging_directory>
python -m run.main_lstm_time_stamp --logdir <logging_directory>
To evaluate the SimTA model on the toy dataset, run:
python -m run.main_simta --modelpath <pytorch_model_weights_path>
You can load your trained model weights by specifying the modelpath
argument. A pytorch .pth file path should be provided.
To evaluate LSTM-based approaches for comparison, run:
python -m run.evaluate_lstm --modelpath <pytorch_model_weights_path>
python -m run.evaluate_lstm_time_interval --modelpath <pytorch_model_weights_path>
python -m run.evaluate_lstm_time_stamp --modelpath <pytorch_model_weights_path>
Since we use in-house clinical dataset that cannot be made public in our paper, we hereby open-source our toy data for proof-of-concept experiments(.pickle files in /dataset) and the code we use to generate it. To generate toy data like we did in our paper, run:
python -m dataset.generate_toy_data
You may find the configuration of toy data in /dataset/toy_data_cfg.py and play with it a little. For concrete details of the toy data settings, please refer to our paper.
Here is a screenshot of the experiment tensorboard logs of SimTA and three LSTM-based approaches used in our paper: On the toy dataset, SimTA has a much better performance compared with LSTM-based approaches in terms of mean squared error in the regression task.