Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dcph): enable federated learning with DeepCoxPH #134

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Ignore __pycache__
**/__pycache__

# Poetry lock file
poetry.lock
8 changes: 6 additions & 2 deletions auton_survival/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _load_pbc_dataset(sequential):
e.append(event[data['id'] == id_])
return x, t, e

def load_support():
def load_support(return_features=False):

"""Helper function to load and preprocess the SUPPORT dataset.
The SUPPORT Dataset comes from the Vanderbilt University study
Expand Down Expand Up @@ -182,6 +182,9 @@ def load_support():
'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph',
'glucose', 'bun', 'urine', 'adlp', 'adls']

if return_features:
return outcomes, data[cat_feats+num_feats], {'cat': cat_feats, 'num': num_feats}

return outcomes, data[cat_feats+num_feats]


Expand Down Expand Up @@ -303,9 +306,10 @@ def load_dataset(dataset='SUPPORT', **kwargs):
\( e \) the censoring indicators.
"""
sequential = kwargs.get('sequential', False)
return_features = kwargs.get('return_features', True)

if dataset == 'SUPPORT':
return load_support()
return load_support(return_features=return_features)
if dataset == 'PBC':
return _load_pbc_dataset(sequential)
if dataset == 'FRAMINGHAM':
Expand Down
60 changes: 51 additions & 9 deletions auton_survival/models/cph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@

r""" Deep Cox Proportional Hazards Model"""

from collections import namedtuple
import torch
import numpy as np
import pandas as pd

from .dcph_torch import DeepCoxPHTorch, DeepRecurrentCoxPHTorch
from .dcph_utilities import train_dcph, predict_survival
Expand Down Expand Up @@ -74,8 +76,19 @@ def __init__(self, layers=None, random_seed=0):

self.layers = layers
self.fitted = False
self.initialized = False
self.breslow = None
self.random_seed = random_seed

@property
def torch_module(self):
if self.initialized:
return self.torch_model[0]
else:
raise Exception("Torch module not initialized. " +
"Please call `fit` or `init_torch_model` " +
"before accessing `torch_module`.")

def __call__(self):
if self.fitted:
print("A fitted instance of the Deep Cox PH model")
Expand Down Expand Up @@ -136,10 +149,18 @@ def _gen_torch_model(self, inputdim, optimizer):

return DeepCoxPHTorch(inputdim, layers=self.layers,
optimizer=optimizer)

def init_torch_model(self, inputdim, optimizer):
if not self.initialized:
self.torch_model = (
self._gen_torch_model(inputdim, optimizer),
None,
)
self.initialized = True

def fit(self, x, t, e, vsize=0.15, val_data=None,
iters=1, learning_rate=1e-3, batch_size=100,
optimizer="Adam"):
optimizer="Adam", breslow=True, patience=3):

r"""This method is used to train an instance of the DSM model.

Expand All @@ -166,6 +187,9 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,
optimizer: str
The choice of the gradient based optimization method. One of
'Adam', 'RMSProp' or 'SGD'.
breslow: bool
If breslow is set to False the Breslow Estimator will not be fitted.
Default value is True.

"""

Expand All @@ -179,25 +203,31 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,

inputdim = x_train.shape[-1]

model = self._gen_torch_model(inputdim, optimizer)

model, _ = train_dcph(model,
self.init_torch_model(inputdim, optimizer)

torch_module = self.torch_model[0]

model, losses = train_dcph(torch_module,
(x_train, t_train, e_train),
(x_val, t_val, e_val),
epochs=iters,
lr=learning_rate,
bs=batch_size,
return_losses=True,
random_seed=self.random_seed)
random_seed=self.random_seed,
breslow=breslow,
patience=patience)

self.torch_model = (model[0].eval(), model[1])
DcphModel = namedtuple('DcphModel', 'module breslow')
self.torch_model = DcphModel(model[0].eval(), model[1])
self.losses = losses
self.fitted = True

self.breslow = True if self.torch_model.breslow is not None else False

return self

def predict_risk(self, x, t=None):

if self.fitted:
if self.breslow and self.fitted:
return 1-self.predict_survival(x, t)
else:
raise Exception("The model has not been fitted yet. Please fit the " +
Expand Down Expand Up @@ -233,6 +263,18 @@ def predict_survival(self, x, t=None):
scores = predict_survival(self.torch_model, x, t)
return scores

@torch.inference_mode()
def predict_time_independent_risk(self, x: torch.Tensor) -> torch.Tensor:
if self.fitted:
x = self._preprocess_test_data(x)
self.torch_module.eval()
return self.torch_module(x)
else:
raise Exception(
"The model has not been fitted yet. Please fit the "
+ "model using the `fit` method on some training data "
+ "before calling `predict_time_independent_risk`."
)

class DeepRecurrentCoxPH(DeepCoxPH):
r"""A deep recurrent Cox PH model.
Expand Down
18 changes: 6 additions & 12 deletions auton_survival/models/cph/dcph_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def train_step(model, x, t, e, optimizer, bs=256, seed=100):
batches = (n // bs) + 1

epoch_loss = 0

model.train()

for i in range(batches):

Expand All @@ -67,6 +69,7 @@ def train_step(model, x, t, e, optimizer, bs=256, seed=100):

return epoch_loss/n

@torch.inference_mode()
def test_step(model, x, t, e):

with torch.no_grad():
Expand All @@ -77,7 +80,7 @@ def test_step(model, x, t, e):

def train_dcph(model, train_data, val_data, epochs=50,
patience=3, bs=256, lr=1e-3, debug=False,
random_seed=0, return_losses=False):
random_seed=0, return_losses=False, breslow=True):

torch.manual_seed(random_seed)
np.random.seed(random_seed)
Expand Down Expand Up @@ -126,23 +129,14 @@ def train_dcph(model, train_data, val_data, epochs=50,
patience_ = 0

if patience_ == patience:

minm = np.argmin(losses)
model.load_state_dict(dics[minm])

breslow_spline = fit_breslow(model, xt, tt_, et_)

if return_losses:
return (model, breslow_spline), losses
else:
return (model, breslow_spline)
break

valc = valcn

minm = np.argmin(losses)
model.load_state_dict(dics[minm])

breslow_spline = fit_breslow(model, xt, tt_, et_)
breslow_spline = fit_breslow(model, xt, tt_, et_) if breslow else None

if return_losses:
return (model, breslow_spline), losses
Expand Down
164 changes: 164 additions & 0 deletions tests/test_dcph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""This module contains test functions to
test the DeepCoxPH
models on certain standard datasets.
"""
import unittest
from auton_survival.metrics import survival_regression_metric
from auton_survival.models.cph import DeepCoxPH, DeepCoxPHTorch
from auton_survival import datasets, preprocessing
from sksurv import metrics
from sklearn.model_selection import train_test_split
from sksurv.linear_model.coxph import BreslowEstimator

import numpy as np
import pandas as pd



class TestDCPH(unittest.TestCase):
"""Base Class for all test functions"""
def _get_support_dataset(self):
return datasets.load_dataset(
"SUPPORT",
return_features=True
)

def _preprocess_data(self, features, feat_dict):
return preprocessing.Preprocessor().fit_transform(
features, feat_dict['cat'], feat_dict['num']
)

def _init_and_validate_dataset_preprocessing(self):
outcomes, features, feat_dict = self._get_support_dataset()

self.assertIsInstance(outcomes, pd.DataFrame)
self.assertIsInstance(features, pd.DataFrame)
self.assertIsInstance(feat_dict, dict)

# Preprocess (Impute and Scale) the features
features = self._preprocess_data(features, feat_dict)

x = features
t = outcomes.time.values
e = outcomes.event.values

self.assertIsInstance(x, pd.DataFrame)
self.assertIsInstance(t, np.ndarray)
self.assertIsInstance(e, np.ndarray)

self.assertEqual(x.shape, (9105, 38))
self.assertEqual(t.shape, (9105,))
self.assertEqual(e.shape, (9105,))

(
features_train,
features_test,
outcomes_train,
outcomes_test,
) = train_test_split(features, outcomes, test_size=0.25, random_state=42)

return features_train, features_test, outcomes_train, outcomes_test

def setUp(self):
self.data = self._init_and_validate_dataset_preprocessing()

def test_dcph_support_e2e(self):
"""E2E for DCPH with the SUPPORT dataset"""
(
features_train,
features_test,
outcomes_train,
outcomes_test,
) = self.data

# Train a Deep Cox Proportional Hazards (DCPH) model
model = DeepCoxPH(layers=[128, 64, 32])

self.assertIsInstance(model, DeepCoxPH)

model.fit(
features_train,
outcomes_train.time.values,
outcomes_train.event.values,
iters=30,
patience=5,
vsize=0.1,
)

self.assertIsInstance(model.torch_model, tuple)
self.assertIsInstance(model.torch_model[0], DeepCoxPHTorch)

self.assertIs(model.torch_model[0], model.torch_module)
self.assertIs(model.torch_model[0], model.torch_model.module)

self.assertIs(model.torch_model[1], model.torch_model.breslow)
self.assertIsInstance(model.torch_model.breslow, BreslowEstimator)

# Predict risk at specific time horizons.
times = [365, 365 * 2, 365 * 4]

survival_probability = model.predict_survival(features_test, t=times)
risk_score = model.predict_risk(features_test, t=times)

np.testing.assert_equal((risk_score+survival_probability).all(), 1.0)

ctds = survival_regression_metric(
"ctd",
outcomes_test,
survival_probability,
times,
outcomes_train=outcomes_train,
)

self.assertIsInstance(ctds, list)

for ctd in ctds:
self.assertIsInstance(ctd, float)

boolean_outcomes = list(
map(lambda i: True if i == 1 else False, outcomes_test.event.values)
)

cic = metrics.concordance_index_censored(
boolean_outcomes,
outcomes_test.time.values,
model.predict_time_independent_risk(features_test).squeeze(),
)

self.assertIsInstance(cic, tuple)
self.assertIsInstance(cic[0], float)

def test_dcph_should_not_fit_breslow_when_breslow_is_false(self):
"""
Verify BreslowEstimator is not fitted if breslow=false
"""
(
features_train,
features_test,
outcomes_train,
outcomes_test,
) = self.data

# Train a Deep Cox Proportional Hazards (DCPH) model
model = DeepCoxPH(layers=[128, 64, 32])

model.fit(
features_train,
outcomes_train.time.values,
outcomes_train.event.values,
iters=30,
patience=5,
vsize=0.1,
breslow=False
)

times = [365, 365 * 2, 365 * 4]

self.assertIsNone(model.torch_model[1])
self.assertIsNone(model.torch_model.breslow)

with self.assertRaises(Exception) as cm:
model.predict_survival(features_test, t=times)

with self.assertRaises(Exception) as cm:
model.predict_risk(features_test, t=times)
Loading