Skip to content

Source code for the differential saliency method used in "Re-understanding Finite-State Representations of Recurrent Policy Networks"

Notifications You must be signed in to change notification settings

modanesh/Differential_IG

Repository files navigation

Differential Integrated Gradient

This is the implementation of the differential saliency method used in "Re-understanding Finite-State Representations of Recurrent Policy Networks", accepted to the International Conference on Machine Learning (ICML) 2021.

Installation

  • Python 3.5+
  • To install dependencies:
    pip install -r requirements.txt

Usage

You can use main_IG.py or main_IG_control.py for experimenting with Atari and Control Tasks from OpenAI Gym.

To begin, you need to load and use models trained here: MMN. Once you took all the steps, you end up with a MMN model, and that's what is needed in this repo. Trained models should be put into the inputs directory with a proper name.

Having the models, it's time to run the code. To do that, just run the following command to get the results for Atari games:

python main_IG.py --env_type=atari --input_index=43 --baseline_index=103 --env PongDeterministic-v4 --qbn_sizes 64 100 --gru_size 32

Values of the input arguments can be changed according to your interest.

And the following command to get the results for control tasks:

python main_IG_control.py --env_type=classic_control --input_index=10 --baseline_index=106 --env CartPole-v1 --qbn_sizes 4 4 --gru_size 32

Results will be saved into the results folder. In the repo, we have already provided sample results. For example, in the case of CartPole, an output will look like the following:

Citation

If you find it useful in your research, please cite it with:

@inproceedings{danesh2021re,
  title={Re-understanding Finite-State Representations of Recurrent Policy Networks},
  author={Danesh, Mohamad H and Koul, Anurag and Fern, Alan and Khorram, Saeed},
  booktitle={International Conference on Machine Learning},
  pages={2388--2397},
  year={2021},
  organization={PMLR}
}

About

Source code for the differential saliency method used in "Re-understanding Finite-State Representations of Recurrent Policy Networks"

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages