diff --git a/simglucose/envs/simglucose_gym_env.py b/simglucose/envs/simglucose_gym_env.py index 59410055..58acc741 100644 --- a/simglucose/envs/simglucose_gym_env.py +++ b/simglucose/envs/simglucose_gym_env.py @@ -21,24 +21,21 @@ class T1DSimEnv(gym.Env): ''' metadata = {'render.modes': ['human']} + SENSOR_HARDWARE = 'Dexcom' + INSULIN_PUMP_HARDWARE = 'Insulet' + def __init__(self, patient_name=None, reward_fun=None): ''' patient_name must be 'adolescent#001' to 'adolescent#010', or 'adult#001' to 'adult#010', or 'child#001' to 'child#010' ''' - seeds = self._seed() # have to hard code the patient_name, gym has some interesting # error when choosing the patient if patient_name is None: patient_name = 'adolescent#001' - patient = T1DPatient.withName(patient_name) - sensor = CGMSensor.withName('Dexcom', seed=seeds[1]) - hour = self.np_random.randint(low=0.0, high=24.0) - start_time = datetime(2018, 1, 1, hour, 0, 0) - scenario = RandomScenario(start_time=start_time, seed=seeds[2]) - pump = InsulinPump.withName('Insulet') - self.env = _T1DSimEnv(patient, sensor, pump, scenario) + self.patient_name = patient_name self.reward_fun = reward_fun + self.seed() def _step(self, action): # This gym only controls basal insulin @@ -59,6 +56,14 @@ def _seed(self, seed=None): # 2**31. seed2 = seeding.hash_seed(seed1 + 1) % 2**31 seed3 = seeding.hash_seed(seed2 + 1) % 2**31 + + hour = self.np_random.randint(low=0.0, high=24.0) + start_time = datetime(2018, 1, 1, hour, 0, 0) + patient = T1DPatient.withName(self.patient_name) + sensor = CGMSensor.withName(self.SENSOR_HARDWARE, seed=seed2) + scenario = RandomScenario(start_time=start_time, seed=seed3) + pump = InsulinPump.withName(self.INSULIN_PUMP_HARDWARE) + self.env = _T1DSimEnv(patient, sensor, pump, scenario) return [seed1, seed2, seed3] def _render(self, mode='human', close=False): diff --git a/tests/test_seed.py b/tests/test_seed.py new file mode 100644 index 00000000..3691da28 --- /dev/null +++ b/tests/test_seed.py @@ -0,0 +1,30 @@ +import gym +import unittest +from simglucose.controller.basal_bolus_ctrller import BBController +from datetime import datetime + + +class TestSeed(unittest.TestCase): + def test_changing_seed_generates_different_results(self): + from gym.envs.registration import register + register( + id='simglucose-adolescent2-v0', + entry_point='simglucose.envs:T1DSimEnv', + kwargs={'patient_name': 'adolescent#002'} + ) + + env = gym.make('simglucose-adolescent2-v0') + + env.seed(0) + observation_seed0 = env.reset() + self.assertEqual(env.env.scenario.start_time, datetime(2018, 1, 1, 16, 0, 0)) + + env.seed(1000) + observation_seed1 = env.reset() + self.assertEqual(env.env.scenario.start_time, datetime(2018, 1, 1, 10, 0, 0)) + + self.assertNotEqual(observation_seed0, observation_seed1) + + +if __name__ == '__main__': + unittest.main()