diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..14832e3 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,14 @@ +## Summary + +This pull request introduces the following changes + +- + +Relates to the following issues + +- + +## Conformity + +- [ ] [Changelog entry](https://github.com/joschkabirk/gabbro/blob/main/changelog.md) +- [ ] [Documentation](https://joschkabirk.github.io/gabbro/) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml new file mode 100644 index 0000000..6b156ea --- /dev/null +++ b/.github/workflows/docker.yaml @@ -0,0 +1,63 @@ +name: Docker build + +on: + pull_request: + branches: + - '*' + paths: + - 'docker/**' + push: + branches: + - 'main' + tags: + - '*' + paths: + - 'docker/**' + +env: + CONTAINER_REGISTRY: henningrose + IMAGE_NAME: gabbro + +jobs: + build_latest: + runs-on: ubuntu-latest + steps: + + - name: Checkout + uses: actions/checkout@master + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Build image + run: docker build -f docker/Dockerfile --tag ${CONTAINER_REGISTRY}/${IMAGE_NAME}:latest docker + + # Login and push this image if this is on the main branch + + - name: Login to Dockerhub container registry + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: docker login -u ${{ vars.DOCKERHUB_USERNAME }} -p ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Push image + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: docker push ${CONTAINER_REGISTRY}/${IMAGE_NAME}:latest + + build_release: + # Build an extra image for tagged commits + runs-on: ubuntu-latest + if: startsWith(github.event.ref, 'refs/tags') + steps: + - name: Checkout + uses: actions/checkout@master + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Build image + run: docker build -f docker/Dockerfile --tag ${CONTAINER_REGISTRY}/${IMAGE_NAME}:${{ github.ref_name }} docker + + - name: Login to Dockerhub container registry + run: docker login -u ${{ vars.DOCKERHUB_USERNAME }} -p ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Push image + run: docker push ${CONTAINER_REGISTRY}/${IMAGE_NAME}:${{ github.ref_name }} diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 0000000..b3fa8d1 --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,28 @@ +name: Docs +on: + push: + branches: + - main +permissions: + contents: write +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Configure Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email 41898282+github-actions[bot]@users.noreply.github.com + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV + - uses: actions/cache@v4 + with: + key: mkdocs-material-${{ env.cache_id }} + path: .cache + restore-keys: | + mkdocs-material- + - run: pip install -r docs/requirements.txt + - run: mkdocs gh-deploy --force diff --git a/.github/workflows/pre_commit.yaml b/.github/workflows/pre_commit.yaml new file mode 100644 index 0000000..09531e8 --- /dev/null +++ b/.github/workflows/pre_commit.yaml @@ -0,0 +1,22 @@ +name: Pre-commit + +on: [push] + +jobs: + pre-commit: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.10' + + - name: Install dependencies and run pre-commit + run: | + python -m pip install --upgrade pip + pip install pre-commit + pre-commit install + pre-commit run --all-files diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml new file mode 100644 index 0000000..7c16580 --- /dev/null +++ b/.github/workflows/testing.yaml @@ -0,0 +1,17 @@ +name: Unit tests + +on: + push: + branches: [ 'main' ] + pull_request: + branches: [ '*' ] + +jobs: + pytest: + runs-on: ubuntu-latest + container: + image: jobirk/pytorch-image:latest + steps: + - uses: actions/checkout@v3 + - name: Test with pytest + run: bash -c "source /opt/conda/bin/activate && export PYTHONPATH=$PWD:$PYTHONPATH && pytest" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..56ead48 --- /dev/null +++ b/.gitignore @@ -0,0 +1,174 @@ +# Mac-specific +.DS_Store + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +ignored* +lightning_logs +*disco.py +automation/.law/index +notebooks/*testing.ipynb +notebooks/plots +outputs/ +compare/ +notebooks/ +job_scripts/*tmp* +*.pt + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..667e9c4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,89 @@ +default_language_version: + python: python3 +exclude: 'ach_model.py' +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-docstring-first + - id: check-yaml + - id: debug-statements + - id: detect-private-key + - id: check-toml + - id: check-case-conflict + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + - repo: https://github.com/asottile/pyupgrade + rev: v2.32.1 + hooks: + - id: pyupgrade + args: [--py38-plus] + + - repo: https://github.com/PyCQA/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + args: + [ + "--extend-ignore", + "E203,E402,E501,F401,F841", + "--exclude", + "logs/*,data/*", + ] + + - repo: https://github.com/PyCQA/bandit + rev: "1.7.1" + hooks: + - id: bandit + args: ["-s", "B101"] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.16 + hooks: + - id: mdformat + additional_dependencies: + - mdformat-gfm==0.3.5 + - linkify-it-py==1.0.3 + - markdown-it-py==2.2.0 + + - repo: https://github.com/codespell-project/codespell + rev: v2.1.0 + hooks: + - id: codespell + args: + - --skip=logs/**,data/**,*.ipynb + - --ignore-words-list=hist,circularly,ot + + - repo: https://github.com/kynan/nbstripout + rev: 0.5.0 + hooks: + - id: nbstripout + args: ["--strip-empty-cells"] + + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.4.0 + hooks: + - id: nbqa-black + args: ["--line-length=99"] + - id: nbqa-isort + args: ["--profile=black"] + - id: nbqa-flake8 + args: + [ + "--extend-ignore=E203,E402,E501,F401,F841", + "--exclude=logs/*,data/*", + ] diff --git a/.project-root b/.project-root new file mode 100644 index 0000000..63eab77 --- /dev/null +++ b/.project-root @@ -0,0 +1,2 @@ +# this file is required for inferring the project root directory +# do not delete diff --git a/README.md b/README.md new file mode 100644 index 0000000..b35ee57 --- /dev/null +++ b/README.md @@ -0,0 +1,117 @@ +# OmniJet-α Calo: Applying OJA to calorimeter data + +
+Joschka Birk, Frank Gaede, Anna Hallin, Gregor Kasieczka, Martina Mozzanica, Henning Rose + +[![arXiv](https://img.shields.io/badge/arXiv-2501.05534-red)](https://arxiv.org/abs/2501.05534) +[![PyTorch](https://img.shields.io/badge/PyTorch-2.2-orange)](https://pytorch.org) +[![Lightning](https://img.shields.io/badge/Lightning-2.2.1-purple)](https://lightning.ai) +[![Hydra](https://img.shields.io/badge/Hydra-1.3-blue)](https://hydra.cc) +
+ +This repository contains the code for the results presented in the paper [`OmniJet-α_C: Learning point cloud calorimeter simulations using generative transformers`](https://arxiv.org/abs/2501.05534) +The documentation of the original OmniJet-α can be found at [uhh-pd-ml/omnijet_alpha](https://github.com/uhh-pd-ml/omnijet_alpha) +**Abstract:** + +``` +We show the first use of generative transformers for generating calorimeter showers as point clouds +in a high-granularity calorimeter. Using the tokenizer and generative part of the OmniJet-α model, +we represent the hits in the detector as sequences of integers. This model allows variable-length +sequences, which means that it supports realistic shower development and does not need to be +conditioned on the number of hits. Since the tokenization represents the showers as point clouds, +the model learns the geometry of the showers without being restricted to any particular voxel grid. +``` + +## Table of Contents + +- [How to run the code](#how-to-run-the-code) +- [Dataset](#dataset) +- [Installation](#installation) +- [Tokenization](#tokenization) +- [Generative training](#generative-training) +- [Transfer learning / Classifier training](#transfer-learning--classifier-training) +- [Citation](#citation) + +## How to run the code + +### Dataset + +Instructions on how to download the dataset can be found in the repository [jet-universe/particle_transformer.](https://github.com/FLC-QU-hep/getting_high) + +### Installation + +The recommended (and by us tested) way of running the code is to use the provided docker image at jobirk/omnijet on [DockerHub](https://hub.docker.com/r/jobirk/omnijet). The requirements listed in docker/requirements.txt are installed in the conda environment base of the base image (official pytorch image). Thus, you have to make sure that the conda environment is activated when running the code, which can be done with `source /opt/conda/bin/activate`. + +An interactive session inside a container can be started by running the following command: + +#### On a machine with Singularity + +```sh +singularity shell docker://jobirk/omnijet:latest # start a shell in the container +source /opt/conda/bin/activate # activate the conda environment in the container +``` + +#### On a machine with Docker + +```sh +docker run -it --rm jobirk/omnijet:latest bash # start a shell in the container +source /opt/conda/bin/activate # activate the conda environment in the container +``` + +Alternatively, you can install the requirements from the `docker/requirements.txt` file, but you'll have to add pytorch to the list of requirements, since this is not included in the `requirements.txt` file (we use the official pytorch image as base image). + +Furthermore, you'll have to add/create a `.env` file in the root of the project with the following content: + +```sh +JETCLASS_DIR="" +JETCLASS_DIR_TOKENIZED="" + +# stuff for hydra +LOG_DIR="" +COMET_API_TOKEN="" +HYDRA_FULL_ERROR=1 +``` + +## Tokenization / Reconstruction + +To play around with the already-trained VQ-VAE model, you can download the checkpoint (see `checkpoints/README.md` for instructions) and then have a look at the notebook `examples/notebooks/example_tokenize_and_reconstruct_jets.ipynb`. + +You can run the training of the VQ-VAE model by running the following command: + +```sh +python gabbro/train.py experiment=example_experiment_tokenization +``` + +To create the tokenized dataset, you can run the following command: + +```sh +python python scripts/tokenize_shower_pipeline.py scripts/tokenize_shower.yaml +``` + +Make sure to adjust the settings in the `tokenize_shower.yaml` to your needs and declare the correct folder for your showers. + +## Generative training + +To play around with the already-trained generative model, you can download the checkpoint (see `checkpoints/README.md` for instructions) and then have a look at the notebook `examples/notebooks/example_generate_jets.ipynb`. + +You can run the generative training of the backbone model by running the following command: + +```sh +python gabbro/train.py experiment=example_experiment_backbone_generative +``` + +## Citation + +If you use this code in your research, please cite our paper: + +```bibtex +@misc{birk2025omnijetalphaclearningpoint, + title = {OmniJet-${\alpha_{ C}}$: Learning point cloud calorimeter simulations using generative transformers}, + author = {Joschka Birk and Frank Gaede and Anna Hallin and Gregor Kasieczka and Martina Mozzanica and Henning Rose}, + year = {2025}, + eprint = {2501.05534}, + archivePrefix = {arXiv}, + primaryClass = {hep-ph}, + url = {https://arxiv.org/abs/2501.05534}, +} +``` diff --git a/checkpoints/README.md b/checkpoints/README.md new file mode 100644 index 0000000..8126702 --- /dev/null +++ b/checkpoints/README.md @@ -0,0 +1,8 @@ +# Checkpoints + +In order to keep the repo small, the checkpoints are not included in the repo directly and are instead stored in a separate location. You can download the checkpoints with the following command: + +```sh +# run this in the `checkpoints` directory +./download_checkpoints.sh +``` diff --git a/checkpoints/download_checkpoints.sh b/checkpoints/download_checkpoints.sh new file mode 100644 index 0000000..946f7d2 --- /dev/null +++ b/checkpoints/download_checkpoints.sh @@ -0,0 +1,4 @@ +#!/bin/bash +curl --output checkpoints.tar #TODO: Add link to checkpoints.tar +tar -xvf checkpoints.tar +rm -rf checkpoints.tar diff --git a/configs/callbacks/callbacks_for_generative_training.yaml b/configs/callbacks/callbacks_for_generative_training.yaml new file mode 100644 index 0000000..393ab2f --- /dev/null +++ b/configs/callbacks/callbacks_for_generative_training.yaml @@ -0,0 +1,34 @@ +defaults: + - model_checkpoint.yaml + - model_checkpoint_best.yaml + - model_summary.yaml + - lr_monitor.yaml + - generative_callback.yaml + - early_stopping.yaml + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}_loss_{val_loss:.5f}" + monitor: "val_loss" + mode: "min" + every_n_epochs: 1 + save_last: True + auto_insert_metric_name: False + +model_checkpoint_best: + dirpath: ${paths.output_dir}/checkpoints + filename: "best" + monitor: "val_loss" + mode: "min" + every_n_epochs: 1 + save_last: false + +early_stopping: + monitor: "val_loss" + patience: 100 + mode: "min" + verbose: true + +model_summary: + max_depth: -1 diff --git a/configs/callbacks/classifier_callback.yaml b/configs/callbacks/classifier_callback.yaml new file mode 100644 index 0000000..df207e1 --- /dev/null +++ b/configs/callbacks/classifier_callback.yaml @@ -0,0 +1,4 @@ +# Generate data, calculate plots and metrics and log them to the logger +tokenization_callback: + _target_: gabbro.callbacks.classifier_callback.ClassifierEvaluationCallback + every_n_epochs: 1 # evaluate every n epochs diff --git a/configs/callbacks/classifier_callbacks.yaml b/configs/callbacks/classifier_callbacks.yaml new file mode 100644 index 0000000..b078006 --- /dev/null +++ b/configs/callbacks/classifier_callbacks.yaml @@ -0,0 +1,34 @@ +defaults: + - model_checkpoint.yaml + - model_checkpoint_best.yaml + - model_summary.yaml + - lr_monitor.yaml + - classifier_callback.yaml + - early_stopping.yaml + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}_loss_{val_loss:.5f}" + monitor: "val_loss" + mode: "min" + every_n_epochs: 1 + save_last: True + auto_insert_metric_name: False + +model_checkpoint_best: + dirpath: ${paths.output_dir}/checkpoints + filename: "best" + monitor: "val_loss" + mode: "min" + every_n_epochs: 1 + save_last: false + +early_stopping: + monitor: "val_loss" + patience: 10 + mode: "min" + verbose: true + +model_summary: + max_depth: -1 diff --git a/configs/callbacks/device_stats_monitor.yaml b/configs/callbacks/device_stats_monitor.yaml new file mode 100644 index 0000000..820af37 --- /dev/null +++ b/configs/callbacks/device_stats_monitor.yaml @@ -0,0 +1,8 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.DeviceStatsMonitor.html#lightning.pytorch.callbacks.DeviceStatsMonitor + +# Automatically monitors and logs device stats during training, validation and testing stage. +# DeviceStatsMonitor is a special callback as it requires a logger to passed as argument to the Trainer. +# Look at the above link for more detailed information. +device_stats_monitor: + _target_: pytorch_lightning.callbacks.DeviceStatsMonitor + cpu_stats: null # if None, it will log CPU stats only if the accelerator is CPU. If True, it will log CPU stats regardless of the accelerator. If False, it will not log CPU stats regardless of the accelerator. diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml new file mode 100644 index 0000000..1c0781e --- /dev/null +++ b/configs/callbacks/early_stopping.yaml @@ -0,0 +1,17 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html + +# Monitor a metric and stop training when it stops improving. +# Look at the above link for more detailed information. +early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: val_loss # quantity to be monitored, must be specified !!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 3 # number of checks with no improvement after which training will be stopped + verbose: False # verbosity mode + mode: "min" # "max" means higher metric value is better, can be also "min" + strict: True # whether to crash the training if monitor is not found in the validation metrics + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version diff --git a/configs/callbacks/generative_callback.yaml b/configs/callbacks/generative_callback.yaml new file mode 100644 index 0000000..8d5226e --- /dev/null +++ b/configs/callbacks/generative_callback.yaml @@ -0,0 +1,7 @@ +# Generate data, calculate plots and metrics and log them to the logger +generative_callback: + _target_: gabbro.callbacks.generative_callback.GenEvalCallback + n_val_gen_jets: 2000 + every_n_epochs: 1 + starting_at_epoch: 1 + batch_size_for_generation: 256 diff --git a/configs/callbacks/lr_monitor.yaml b/configs/callbacks/lr_monitor.yaml new file mode 100644 index 0000000..8de6705 --- /dev/null +++ b/configs/callbacks/lr_monitor.yaml @@ -0,0 +1,8 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#lightning.pytorch.callbacks.LearningRateMonitor + +# Automatically monitor and logs learning rate for learning rate schedulers during training. +# Look at the above link for more detailed information. +lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: "epoch" # "step" or "epoch" or None + log_momentum: False # if True, will also log the momentum value of the optimizer at each step diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000..1fb951c --- /dev/null +++ b/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,19 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html + +# Save the model periodically by monitoring a quantity. +# Look at the above link for more detailed information. +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: null # directory to save the model file + filename: null # checkpoint filename + monitor: null # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 500 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: null # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/model_checkpoint_best.yaml b/configs/callbacks/model_checkpoint_best.yaml new file mode 100644 index 0000000..9a4b84f --- /dev/null +++ b/configs/callbacks/model_checkpoint_best.yaml @@ -0,0 +1,19 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html + +# Save the model periodically by monitoring a quantity. +# Look at the above link for more detailed information. +model_checkpoint_best: + _target_: gabbro.callbacks.best_checkpoint_callback.CustomModelCheckpoint + dirpath: null # directory to save the model file + filename: best # checkpoint filename + monitor: "val_acc" # name of the logged metric which determines when model is improving + verbose: true # verbosity mode + save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 10 # save k best models (determined by above metric) + mode: "max" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: false # when True, the checkpoints filenames will contain the metric name + save_weights_only: false # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: null # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000..70a060b --- /dev/null +++ b/configs/callbacks/model_summary.yaml @@ -0,0 +1,7 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html + +# Generates a summary of all layers in a LightningModule with rich text formatting. +# Look at the above link for more detailed information. +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000..82d2f89 --- /dev/null +++ b/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,6 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html + +# Create a progress bar with rich text formatting. +# Look at the above link for more detailed information. +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/configs/callbacks/tokenization_callback.yaml b/configs/callbacks/tokenization_callback.yaml new file mode 100644 index 0000000..cfdac1e --- /dev/null +++ b/configs/callbacks/tokenization_callback.yaml @@ -0,0 +1,3 @@ +# Generate data, calculate plots and metrics and log them to the logger +tokenization_callback: + _target_: gabbro.callbacks.tokenization_callback.TokenizationEvalCallback diff --git a/configs/callbacks/tokenization_callbacks.yaml b/configs/callbacks/tokenization_callbacks.yaml new file mode 100644 index 0000000..771d3c2 --- /dev/null +++ b/configs/callbacks/tokenization_callbacks.yaml @@ -0,0 +1,23 @@ +defaults: + - model_checkpoint.yaml + - model_summary.yaml + - lr_monitor.yaml + - tokenization_callback.yaml + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}_loss_{val_loss:.5f}" + monitor: "val_loss" + mode: "min" + every_n_epochs: 1 + save_last: True + auto_insert_metric_name: False + +#early_stopping: +# monitor: "val_loss" +# patience: 100 +# mode: "min" + +model_summary: + max_depth: -1 diff --git a/configs/callbacks/tokenization_dummy_callback.yaml b/configs/callbacks/tokenization_dummy_callback.yaml new file mode 100644 index 0000000..ef313a8 --- /dev/null +++ b/configs/callbacks/tokenization_dummy_callback.yaml @@ -0,0 +1,3 @@ +# Generate data, calculate plots and metrics and log them to the logger +tokenization_callback: + _target_: gabbro.callbacks.tokenization_shower_callback.TokenizationEvalCallback diff --git a/configs/callbacks/tokenization_dummy_callbacks.yaml b/configs/callbacks/tokenization_dummy_callbacks.yaml new file mode 100644 index 0000000..35d8efa --- /dev/null +++ b/configs/callbacks/tokenization_dummy_callbacks.yaml @@ -0,0 +1,23 @@ +defaults: + - model_checkpoint.yaml + - model_summary.yaml + - lr_monitor.yaml + - tokenization_dummy_callback.yaml + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}_loss_{val_loss:.5f}" + monitor: "val_loss" + mode: "min" + every_n_epochs: 1 + save_last: True + auto_insert_metric_name: False + +#early_stopping: +# monitor: "val_loss" +# patience: 100 +# mode: "min" + +model_summary: + max_depth: -1 diff --git a/configs/data/data_generative_e_sorted.yaml b/configs/data/data_generative_e_sorted.yaml new file mode 100644 index 0000000..cea0af4 --- /dev/null +++ b/configs/data/data_generative_e_sorted.yaml @@ -0,0 +1,43 @@ +_target_: gabbro.data.iterable_dataset_shower.IterableCaloDatamodule + + +defaults: + - defaults.yaml + + +h5file: True +batch_size: 64 +n_files_at_once: 1 + + + +dataset_kwargs_common: + feature_dict: + part_token_id_without_last: None + part_token_id_without_first: None + pad_length: 1700 + n_files_at_once: 1 + h5file: True + # n_shower_per_file: 760000 + labels_to_load: + - label_Shower + +dataset_kwargs_train: + max_n_files_per_type: 1 + n_jets_per_file: null + files_dict: + shower: + - ${data.data_dir}/tokenized_test_e_sorted_remapped.parquet + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: + shower: + - ${data.data_dir}/tokenized_test_e_sorted_remapped.parquet + + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: + shower: + - ${data.data_dir}/tokenized_test_e_sorted.parquet diff --git a/configs/data/data_generative_layer_sorted.yaml b/configs/data/data_generative_layer_sorted.yaml new file mode 100644 index 0000000..7e339df --- /dev/null +++ b/configs/data/data_generative_layer_sorted.yaml @@ -0,0 +1,43 @@ +_target_: gabbro.data.iterable_dataset_shower.IterableCaloDatamodule + + +defaults: + - defaults.yaml + + +h5file: True +batch_size: 64 +n_files_at_once: 1 + + + +dataset_kwargs_common: + feature_dict: + part_token_id_without_last: None + part_token_id_without_first: None + pad_length: 1700 + n_files_at_once: 1 + h5file: True + # n_shower_per_file: 760000 + labels_to_load: + - label_Shower + +dataset_kwargs_train: + max_n_files_per_type: 1 + n_jets_per_file: null + files_dict: + shower: + - ${data.data_dir}/tokenized_train_layer_sorted.parquet + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: + shower: + - ${data.data_dir}/tokenized_val_layer_sorted.parquet + + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: + shower: + - ${data.data_dir}/tokenized_test_layer_sorted.parquet diff --git a/configs/data/data_generative_unsorted.yaml b/configs/data/data_generative_unsorted.yaml new file mode 100644 index 0000000..bc6b780 --- /dev/null +++ b/configs/data/data_generative_unsorted.yaml @@ -0,0 +1,43 @@ +_target_: gabbro.data.iterable_dataset_shower.IterableCaloDatamodule + + +defaults: + - defaults.yaml + + +h5file: True +batch_size: 64 +n_files_at_once: 1 + + + +dataset_kwargs_common: + feature_dict: + part_token_id_without_last: None + part_token_id_without_first: None + pad_length: 1700 + n_files_at_once: 1 + h5file: True + # n_shower_per_file: 760000 + labels_to_load: + - label_Shower + +dataset_kwargs_train: + max_n_files_per_type: 1 + n_jets_per_file: null + files_dict: + shower: + - ${data.data_dir}/tokenized_train.parquet + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: + shower: + - ${data.data_dir}/tokenized_val.parquet + + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: + shower: + - ${data.data_dir}/tokenized_test.parquet diff --git a/configs/data/data_taus.yaml b/configs/data/data_taus.yaml new file mode 100644 index 0000000..36c6039 --- /dev/null +++ b/configs/data/data_taus.yaml @@ -0,0 +1,39 @@ +_target_: gabbro.data.dataset_taus.TauDataModule +# file_train: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/train_file.parquet +# file_val: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/val_file.parquet +# file_test: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/test_file.parquet +# file_train: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/train_file_1000.parquet +# file_val: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/val_file_1000.parquet +# file_test: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/test_file_1000.parquet +# n_jets_train: null +# n_jets_val: null +# n_jets_test: null +batch_size: 512 +pad_length: 128 +num_workers: 32 +pin_memory: False + +# dummy keys to make compatible with other dataset structure +data_dir: null +dataset_kwargs_common: + pad_length: 128 + n_files_at_once: null + labels_to_load: null + feature_dict: + part_pt: {multiply_by: 1, subtract_by: 1.8, func: "np.log", inv_func: "np.exp"} + # part_pt: {multiply_by: 0.05, subtract_by: 12} + part_etarel: {multiply_by: 3, larger_than: -0.8, smaller_than: 0.8} + part_phirel: {multiply_by: 3, larger_than: -0.8, smaller_than: 0.8} + +dataset_kwargs_train: + max_n_files_per_type: null + n_jets_per_file: null + files_dict: null + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: null + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: null diff --git a/configs/data/data_tokenization.yaml b/configs/data/data_tokenization.yaml new file mode 100644 index 0000000..45eae23 --- /dev/null +++ b/configs/data/data_tokenization.yaml @@ -0,0 +1,100 @@ +_target_: gabbro.data.iterable_dataset_jetclass.IterableDatamodule + + +defaults: + - defaults.yaml + +data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass/JetClass/ +#/beegfs/desy/user/korcariw/multiRes_CaloClouds/dataset/showers/photons_10_100GeV_10bins_showers_train.h5 +batch_size: 512 + +dataset_kwargs_common: + # feature_dict: + # part_pt: {multiply_by: 0.3, subtract_by: 2.7, func: np.log, inv_func: np.exp} + # part_eta: {multiply_by: 4} + # part_phi: {multiply_by: 4} + pad_length: 128 + n_files_at_once: 10 + labels_to_load: + - label_QCD + - label_Hbb + - label_Hcc + - label_Hgg + - label_H4q + - label_Hqql + - label_Zqq + - label_Wqq + - label_Tbqq + - label_Tbl + +dataset_kwargs_train: + max_n_files_per_type: null + files_dict: + QCD: + - ${data.data_dir}/train_100M/ZJetsToNuNu_* + Hbb: + - ${data.data_dir}/train_100M/HToBB_* + Hcc: + - ${data.data_dir}/train_100M/HToCC_* + Hgg: + - ${data.data_dir}/train_100M/HToGG_* + H4q: + - ${data.data_dir}/train_100M/HToWW4Q_* + Hqql: + - ${data.data_dir}/train_100M/HToWW2Q1L_* + Zqq: + - ${data.data_dir}/train_100M/ZToQQ_* + Wqq: + - ${data.data_dir}/train_100M/WToQQ_* + Tbqq: + - ${data.data_dir}/train_100M/TTBar_* + Tbl: + - ${data.data_dir}/train_100M/TTBarLep_* + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/val_5M/ZJetsToNuNu_* + Hbb: + - ${data.data_dir}/val_5M/HToBB_* + Hcc: + - ${data.data_dir}/val_5M/HToCC_* + Hgg: + - ${data.data_dir}/val_5M/HToGG_* + H4q: + - ${data.data_dir}/val_5M/HToWW4Q_* + Hqql: + - ${data.data_dir}/val_5M/HToWW2Q1L_* + Zqq: + - ${data.data_dir}/val_5M/ZToQQ_* + Wqq: + - ${data.data_dir}/val_5M/WToQQ_* + Tbqq: + - ${data.data_dir}/val_5M/TTBar_* + Tbl: + - ${data.data_dir}/val_5M/TTBarLep_* + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/test_20M/ZJetsToNuNu_* + Hbb: + - ${data.data_dir}/test_20M/HToBB_* + Hcc: + - ${data.data_dir}/test_20M/HToCC_* + Hgg: + - ${data.data_dir}/test_20M/HToGG_* + H4q: + - ${data.data_dir}/test_20M/HToWW4Q_* + Hqql: + - ${data.data_dir}/test_20M/HToWW2Q1L_* + Zqq: + - ${data.data_dir}/test_20M/ZToQQ_* + Wqq: + - ${data.data_dir}/test_20M/WToQQ_* + Tbqq: + - ${data.data_dir}/test_20M/TTBar_* + Tbl: + - ${data.data_dir}/test_20M/TTBarLep_* diff --git a/configs/data/data_tokenization_dev.yaml b/configs/data/data_tokenization_dev.yaml new file mode 100644 index 0000000..9deeb62 --- /dev/null +++ b/configs/data/data_tokenization_dev.yaml @@ -0,0 +1,41 @@ +_target_: gabbro.data.iterable_dataset_shower.IterableCaloDatamodule + + +defaults: + - defaults.yaml +data_dir: /data/dust/user/korcariw/maxwell.merged/CaloClouds/dataset/ +h5file: True +batch_size: 64 +n_files_at_once: 1 + + + +dataset_kwargs_common: + feature_dict: + x: None + y: None + z: None + energy: None + pad_length: 1700 + n_files_at_once: 1 + h5file: True + labels_to_load: + - label_Shower + +dataset_kwargs_train: + max_n_files_per_type: null + files_dict: + QCD: + - ${data.data_dir}/showers/photons_10_100GeV_float32_sorted_train.h5 +dataset_kwargs_val: + max_n_files_per_type: null + files_dict: + QCD: + - ${data.data_dir}/showers/photons_10_100GeV_float32_sorted_val.h5 + + +dataset_kwargs_test: + max_n_files_per_type: null + files_dict: + QCD: + - ${data.data_dir}/showers/photons_10_100GeV_float32_sorted_test.h5 diff --git a/configs/data/defaults.yaml b/configs/data/defaults.yaml new file mode 100644 index 0000000..4e1f3c9 --- /dev/null +++ b/configs/data/defaults.yaml @@ -0,0 +1,17 @@ +dataset_kwargs_train: + logger_name: IterDataset-Train + shuffle_files: true + shuffle_data: true + +dataset_kwargs_val: + logger_name: IterDataset-Validation + shuffle_files: false + shuffle_data: true + +dataset_kwargs_test: + logger_name: IterDataset-Test + shuffle_files: false + shuffle_data: true + +batch_size: 128 +data_dir: beegfs/desy/user/korcariw/multiRes_CaloClouds/dataset/ diff --git a/configs/data/iter_dataset_jetclass.yaml b/configs/data/iter_dataset_jetclass.yaml new file mode 100644 index 0000000..f8a57a0 --- /dev/null +++ b/configs/data/iter_dataset_jetclass.yaml @@ -0,0 +1,43 @@ +_target_: gabbro.data.iterable_dataset_jetclass.IterableDatamodule + +# --------- + +defaults: + - defaults.yaml + +# --------- +# +data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass/JetClass/ + +dataset_kwargs_common: + feature_dict: null + # part_pt: {multiply_by: 0.3, subtract_by: 2.7, func: np.log, inv_func: np.exp} + # part_eta: {multiply_by: 4} + # part_phi: {multiply_by: 4} + n_files_at_once: 2 + labels_to_load: + - label_QCD + - label_Hbb + - label_Hcc + - label_Hgg + - label_H4q + - label_Hqql + - label_Zqq + - label_Wqq + - label_Tbqq + - label_Tbl + +dataset_kwargs_train: + files_dict: + QCD: + - ${data.data_dir}/train_100M/ZJetsToNuNu_* + +dataset_kwargs_val: + files_dict: + QCD: + - ${data.data_dir}/val_5M/ZJetsToNuNu_* + +dataset_kwargs_test: + files_dict: + QCD: + - ${data.data_dir}/test_20M/ZJetsToNuNu_* diff --git a/configs/data/iter_dataset_jetclass_classification.yaml b/configs/data/iter_dataset_jetclass_classification.yaml new file mode 100644 index 0000000..0def268 --- /dev/null +++ b/configs/data/iter_dataset_jetclass_classification.yaml @@ -0,0 +1,97 @@ +_target_: gabbro.data.iterable_dataset_jetclass.IterableDatamodule + + +defaults: + - defaults.yaml + +data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass/JetClass/ + +batch_size: 512 + +dataset_kwargs_common: + pad_length: 128 + n_files_at_once: 10 + labels_to_load: + - label_QCD + - label_Hbb + - label_Hcc + - label_Hgg + - label_H4q + - label_Hqql + - label_Zqq + - label_Wqq + - label_Tbqq + - label_Tbl + +dataset_kwargs_train: + max_n_files_per_type: null + n_jets_per_file: null + files_dict: + QCD: + - ${data.data_dir}/train_100M/ZJetsToNuNu_* + Hbb: + - ${data.data_dir}/train_100M/HToBB_* + Hcc: + - ${data.data_dir}/train_100M/HToCC_* + Hgg: + - ${data.data_dir}/train_100M/HToGG_* + H4q: + - ${data.data_dir}/train_100M/HToWW4Q_* + Hqql: + - ${data.data_dir}/train_100M/HToWW2Q1L_* + Zqq: + - ${data.data_dir}/train_100M/ZToQQ_* + Wqq: + - ${data.data_dir}/train_100M/WToQQ_* + Tbqq: + - ${data.data_dir}/train_100M/TTBar_* + Tbl: + - ${data.data_dir}/train_100M/TTBarLep_* + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/val_5M/ZJetsToNuNu_* + Hbb: + - ${data.data_dir}/val_5M/HToBB_* + Hcc: + - ${data.data_dir}/val_5M/HToCC_* + Hgg: + - ${data.data_dir}/val_5M/HToGG_* + H4q: + - ${data.data_dir}/val_5M/HToWW4Q_* + Hqql: + - ${data.data_dir}/val_5M/HToWW2Q1L_* + Zqq: + - ${data.data_dir}/val_5M/ZToQQ_* + Wqq: + - ${data.data_dir}/val_5M/WToQQ_* + Tbqq: + - ${data.data_dir}/val_5M/TTBar_* + Tbl: + - ${data.data_dir}/val_5M/TTBarLep_* + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/test_20M/ZJetsToNuNu_* + Hbb: + - ${data.data_dir}/test_20M/HToBB_* + Hcc: + - ${data.data_dir}/test_20M/HToCC_* + Hgg: + - ${data.data_dir}/test_20M/HToGG_* + H4q: + - ${data.data_dir}/test_20M/HToWW4Q_* + Hqql: + - ${data.data_dir}/test_20M/HToWW2Q1L_* + Zqq: + - ${data.data_dir}/test_20M/ZToQQ_* + Wqq: + - ${data.data_dir}/test_20M/WToQQ_* + Tbqq: + - ${data.data_dir}/test_20M/TTBar_* + Tbl: + - ${data.data_dir}/test_20M/TTBarLep_* diff --git a/configs/data/iter_dataset_jetclass_classification_dev.yaml b/configs/data/iter_dataset_jetclass_classification_dev.yaml new file mode 100644 index 0000000..df1178e --- /dev/null +++ b/configs/data/iter_dataset_jetclass_classification_dev.yaml @@ -0,0 +1,97 @@ +_target_: gabbro.data.iterable_dataset_jetclass.IterableDatamodule + + +defaults: + - defaults.yaml + +data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass/JetClass/ + +batch_size: 512 + +dataset_kwargs_common: + pad_length: 128 + n_files_at_once: 10 + labels_to_load: + - label_QCD + # - label_Hbb + # - label_Hcc + # - label_Hgg + # - label_H4q + # - label_Hqql + # - label_Zqq + # - label_Wqq + - label_Tbqq + # - label_Tbl + +dataset_kwargs_train: + max_n_files_per_type: null + n_jets_per_file: null + files_dict: + QCD: + - ${data.data_dir}/train_100M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/train_100M/HToBB_* + # Hcc: + # - ${data.data_dir}/train_100M/HToCC_* + # Hgg: + # - ${data.data_dir}/train_100M/HToGG_* + # H4q: + # - ${data.data_dir}/train_100M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/train_100M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/train_100M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/train_100M/WToQQ_* + Tbqq: + - ${data.data_dir}/train_100M/TTBar_* + # Tbl: + # - ${data.data_dir}/train_100M/TTBarLep_* + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/val_5M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/val_5M/HToBB_* + # Hcc: + # - ${data.data_dir}/val_5M/HToCC_* + # Hgg: + # - ${data.data_dir}/val_5M/HToGG_* + # H4q: + # - ${data.data_dir}/val_5M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/val_5M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/val_5M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/val_5M/WToQQ_* + Tbqq: + - ${data.data_dir}/val_5M/TTBar_* + # Tbl: + # - ${data.data_dir}/val_5M/TTBarLep_* + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/test_20M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/test_20M/HToBB_* + # Hcc: + # - ${data.data_dir}/test_20M/HToCC_* + # Hgg: + # - ${data.data_dir}/test_20M/HToGG_* + # H4q: + # - ${data.data_dir}/test_20M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/test_20M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/test_20M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/test_20M/WToQQ_* + Tbqq: + - ${data.data_dir}/test_20M/TTBar_* + # Tbl: + # - ${data.data_dir}/test_20M/TTBarLep_* diff --git a/configs/data/iter_dataset_jetclass_classification_top_vs_qcd.yaml b/configs/data/iter_dataset_jetclass_classification_top_vs_qcd.yaml new file mode 100644 index 0000000..df1178e --- /dev/null +++ b/configs/data/iter_dataset_jetclass_classification_top_vs_qcd.yaml @@ -0,0 +1,97 @@ +_target_: gabbro.data.iterable_dataset_jetclass.IterableDatamodule + + +defaults: + - defaults.yaml + +data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass/JetClass/ + +batch_size: 512 + +dataset_kwargs_common: + pad_length: 128 + n_files_at_once: 10 + labels_to_load: + - label_QCD + # - label_Hbb + # - label_Hcc + # - label_Hgg + # - label_H4q + # - label_Hqql + # - label_Zqq + # - label_Wqq + - label_Tbqq + # - label_Tbl + +dataset_kwargs_train: + max_n_files_per_type: null + n_jets_per_file: null + files_dict: + QCD: + - ${data.data_dir}/train_100M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/train_100M/HToBB_* + # Hcc: + # - ${data.data_dir}/train_100M/HToCC_* + # Hgg: + # - ${data.data_dir}/train_100M/HToGG_* + # H4q: + # - ${data.data_dir}/train_100M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/train_100M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/train_100M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/train_100M/WToQQ_* + Tbqq: + - ${data.data_dir}/train_100M/TTBar_* + # Tbl: + # - ${data.data_dir}/train_100M/TTBarLep_* + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/val_5M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/val_5M/HToBB_* + # Hcc: + # - ${data.data_dir}/val_5M/HToCC_* + # Hgg: + # - ${data.data_dir}/val_5M/HToGG_* + # H4q: + # - ${data.data_dir}/val_5M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/val_5M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/val_5M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/val_5M/WToQQ_* + Tbqq: + - ${data.data_dir}/val_5M/TTBar_* + # Tbl: + # - ${data.data_dir}/val_5M/TTBarLep_* + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/test_20M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/test_20M/HToBB_* + # Hcc: + # - ${data.data_dir}/test_20M/HToCC_* + # Hgg: + # - ${data.data_dir}/test_20M/HToGG_* + # H4q: + # - ${data.data_dir}/test_20M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/test_20M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/test_20M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/test_20M/WToQQ_* + Tbqq: + - ${data.data_dir}/test_20M/TTBar_* + # Tbl: + # - ${data.data_dir}/test_20M/TTBarLep_* diff --git a/configs/data/iter_dataset_jetclass_classification_top_vs_qcd_transfer_learning.yaml b/configs/data/iter_dataset_jetclass_classification_top_vs_qcd_transfer_learning.yaml new file mode 100644 index 0000000..3be0821 --- /dev/null +++ b/configs/data/iter_dataset_jetclass_classification_top_vs_qcd_transfer_learning.yaml @@ -0,0 +1,101 @@ +_target_: gabbro.data.iterable_dataset_jetclass.IterableDatamodule + + +defaults: + - defaults.yaml + +data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_tokenized/2024-02-19_20-54-01_nonfissile_defect_a56f_TTBar_ZJetsToNuNu_test_split_to_trainvaltest + +batch_size: 512 + +dataset_kwargs_common: + pad_length: 128 + n_files_at_once: 10 + labels_to_load: + - label_QCD + # - label_Hbb + # - label_Hcc + # - label_Hgg + # - label_H4q + # - label_Hqql + # - label_Zqq + # - label_Wqq + - label_Tbqq + # - label_Tbl + +dataset_kwargs_train: + max_n_files_per_type: null + n_jets_per_file: null + files_dict: + QCD: + # - ${data.data_dir}/train_100M/ZJetsToNuNu_* + - ${data.data_dir}/train_2M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/train_100M/HToBB_* + # Hcc: + # - ${data.data_dir}/train_100M/HToCC_* + # Hgg: + # - ${data.data_dir}/train_100M/HToGG_* + # H4q: + # - ${data.data_dir}/train_100M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/train_100M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/train_100M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/train_100M/WToQQ_* + Tbqq: + # - ${data.data_dir}/train_100M/TTBar_* + - ${data.data_dir}/train_2M/TTBar_* + # Tbl: + # - ${data.data_dir}/train_100M/TTBarLep_* + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: + QCD: + # - ${data.data_dir}/val_5M/ZJetsToNuNu_* + - ${data.data_dir}/val_1M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/val_5M/HToBB_* + # Hcc: + # - ${data.data_dir}/val_5M/HToCC_* + # Hgg: + # - ${data.data_dir}/val_5M/HToGG_* + # H4q: + # - ${data.data_dir}/val_5M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/val_5M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/val_5M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/val_5M/WToQQ_* + Tbqq: + # - ${data.data_dir}/val_5M/TTBar_* + - ${data.data_dir}/val_1M/TTBar_* + # Tbl: + # - ${data.data_dir}/val_5M/TTBarLep_* + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: + QCD: + - ${data.data_dir}/test_1M/ZJetsToNuNu_* + # Hbb: + # - ${data.data_dir}/test_20M/HToBB_* + # Hcc: + # - ${data.data_dir}/test_20M/HToCC_* + # Hgg: + # - ${data.data_dir}/test_20M/HToGG_* + # H4q: + # - ${data.data_dir}/test_20M/HToWW4Q_* + # Hqql: + # - ${data.data_dir}/test_20M/HToWW2Q1L_* + # Zqq: + # - ${data.data_dir}/test_20M/ZToQQ_* + # Wqq: + # - ${data.data_dir}/test_20M/WToQQ_* + Tbqq: + - ${data.data_dir}/test_1M/TTBar_* + # Tbl: + # - ${data.data_dir}/test_20M/TTBarLep_* diff --git a/configs/data/landscape.yaml b/configs/data/landscape.yaml new file mode 100644 index 0000000..d6e7812 --- /dev/null +++ b/configs/data/landscape.yaml @@ -0,0 +1,34 @@ +_target_: gabbro.data.dataset_landscape.LandscapeDataModule +file_train: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/train_file.parquet +file_val: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/val_file.parquet +file_test: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/test_file.parquet +# file_train: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/train_file_1000.parquet +# file_val: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/val_file_1000.parquet +# file_test: /beegfs/desy/user/birkjosc/datasets/landscape/TopLandscape/test_file_1000.parquet +n_jets_train: null +n_jets_val: null +n_jets_test: null +batch_size: 512 +pad_length: 128 +num_workers: 32 +pin_memory: False + +# dummy keys to make compatible with other dataset structure +data_dir: null +dataset_kwargs_common: + pad_length: 128 + n_files_at_once: null + labels_to_load: null + +dataset_kwargs_train: + max_n_files_per_type: null + n_jets_per_file: null + files_dict: null + +dataset_kwargs_val: + max_n_files_per_type: 1 + files_dict: null + +dataset_kwargs_test: + max_n_files_per_type: 1 + files_dict: null diff --git a/configs/data/tokenized_classification.yaml b/configs/data/tokenized_classification.yaml new file mode 100644 index 0000000..0654a2b --- /dev/null +++ b/configs/data/tokenized_classification.yaml @@ -0,0 +1,5 @@ +_target_: gabbro.data.datamodule_tokenized_jetclass.TokenJetClassDataModule + + +defaults: + - defaults.yaml diff --git a/configs/experiment/example_experiment_backbone_generative.yaml b/configs/experiment/example_experiment_backbone_generative.yaml new file mode 100644 index 0000000..039c0f0 --- /dev/null +++ b/configs/experiment/example_experiment_backbone_generative.yaml @@ -0,0 +1,134 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example_experiment_backbone_generative + +defaults: + # - override /data: tokenized_classification.yaml + # - override /data: iter_dataset_jetclass_classification_top_vs_qcd + - override /data: data_generative_e_sorted + - override /model: backbone_generative.yaml + - override /callbacks: callbacks_for_generative_training.yaml + - override /trainer: ddp.yaml +# - override /trainer: gpu.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +# add here checkpoint to continue training +#ckpt_path: "path/to/your/checkpoint.ckpt" + +project_name: "example_backbone" +tags: ["generative"] + +run_note: "" + +seed: 44 + +load_weights_from: False +load_weights_strict: False + + +data: + batch_size: 16 # NOTE: adapt the limit_train_batches accordingly + + data_dir: /data/dust/user/rosehenn/gabbro/compare/2024-09-21_16-54-39_max-wng062_CerousLocknut/ # this is the path to the tokenized dataset + + dataset_kwargs_train: + max_n_files_per_type: 1 + n_shower_per_file: 1000 + dataset_kwargs_val: + n_shower_per_file: 100 + dataset_kwargs_test: + n_shower_per_file: 100 + dataset_kwargs_common: + load_only_once: true + pad_length: 1700 + n_files_at_once: 1 + h5file: True + random_seed_for_per_file_shuffling: 42 + feature_dict: + # part_token_id: {} + part_token_id_without_last: {} # <-- this will be the input for the gen. model + part_token_id_without_first: {} # <-- this will be the target for the gen. model + token_id_cfg: + remove_start_token: false + remove_end_token: false + shift_tokens_minus_one: false + +callbacks: + generative_callback: + n_val_gen_jets: 10 # increased again for better comparison + starting_at_epoch: 300 + every_n_epochs: 1 + batch_size_for_generation: 8 + data_dir: ${data.data_dir} + seed_shuffle_val_data: 2 # loads the validation data (not tokeneized) for plotting + plot_best_checkpoint: False #If you want to plot the best Checkpoint. False plots the current checkpoint. + early_stopping: + patience: 300 # number of checks with no improvement after which training will be stopped +trainer: + max_steps: 10000000 + gradient_clip_val: 1 + log_every_n_steps: 400 + limit_train_batches: 1.0 # 11.875 with batch size 64, to have 760000 samples per epoch # increased again + limit_val_batches: 1.0 # 2.900 with batch size 64, to have 185k samples per epoch # increased again + # precision: "bf16-true" + # num_sanity_val_steps: 10 + +# setting load_weights_from will load the weights from the given checkpoint, but start training from scratch +# load_weights_from: + +model: + # --- model architecture configuration --- + # model_class_name: "BackboneWithClasshead" + # model_kwargs_loaded: null + token_dir: ${data.data_dir} + exclude_padded_values_from_loss: True + model_kwargs: + # keep_backbone_fixed: false + # --- + return_embeddings: True # meaning that the new head structure is used instead of the old one + # n_out_nodes: 2 + # if you want to transfer the weights from a backbone model, you can specify the path here + # backbone_weights_path: "path/to/your/backbone_weights.ckpt" + embedding_dim: 256 + attention_dropout: 0.0 + vocab_size: 65538 #adjust to your codebook size + stop_token_weight: 1.0 + max_sequence_len: 1700 + temperature: 1.0 + stop_token_threshold: 0.0 + n_GPT_blocks: 3 + n_heads: 8 + verbosity: true + # --- optimizer configuration --- + optimizer: + _target_: gabbro.utils.optimizer.ranger.Ranger + _partial_: true + lr: 0.001 + weight_decay: 1e-2 + betas: [0.95, 0.999] + eps: 1e-5 + alpha: 0.5 + k: 6 + + # --- learning rate scheduler --- + scheduler: + _target_: torch.optim.lr_scheduler.ConstantLR + _partial_: true + total_iters: 1 + factor: 1.0 + # --- learning rate scheduler --- + +task_name: "omnijet_backbone" + +logger: + wandb: + project: ${project_name} + tags: ${tags} + # group: ${project_name} + name: ${task_name} + comet: + experiment_name: null + project_name: ${project_name} diff --git a/configs/experiment/example_experiment_tokenization.yaml b/configs/experiment/example_experiment_tokenization.yaml new file mode 100644 index 0000000..17084f1 --- /dev/null +++ b/configs/experiment/example_experiment_tokenization.yaml @@ -0,0 +1,114 @@ +# @package _global_ + +# to execute this experiment run: +# python gabbro/train.py experiment=example_experiment_tokenization + +defaults: + - override /data: data_tokenization_dev.yaml + - override /model: model_vqvae_transformer.yaml + - override /callbacks: tokenization_dummy_callbacks.yaml + - override /trainer: ddp.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +# add here checkpoint to continue training +#ckpt_path: "path/to/your/checkpoint.ckpt" + +project_name: "tokenization_example" +tags: ["vqvae_tokenization"] + +run_note: "" # "here you can add a note to the run which will be displayed in the logging service" + +seed: 1603 +load_weights_from: false + +data: + data_dir: /data/dust/user/korcariw/maxwell.merged/CaloClouds/dataset/ # this is the path to the dataset + batch_size: 16 # NOTE: adapt the limit_train_batches accordingly + h5file: True + dataset_kwargs_train: + max_n_files_per_type: 1 # for this example we only load one file per type + n_shower_per_file: 10000 # for this example we only load 10000 showers per file + shuffle_only_once: true # shuffle the training dataset only once + dataset_kwargs_val: + shuffle_only_once: true # shuffle the validation dataset only once + n_shower_per_file: 1000 # for this example we only load 1000 showers per file + seed_shuffle_data: 42 + dataset_kwargs_test: + shuffle_only_once: true # shuffle the test dataset only once + n_shower_per_file: 1000 # for this example we only load 1000 showers per file + seed_shuffle_data: 42 + dataset_kwargs_common: + n_files_at_once: 1 # load 10 files at once (which are all in this case) + load_only_once: true # load the files only once and keep them in memory + pad_length: 1700 # pad the showers to a length of 1700 hits + energy_threshold: 0 # ignore hits with energy below this threshold + energy_sorting: false # sort the hits by energy (starting with the highest energy hit) (for the VQVAE this is irrelevant) + feature_dict: + x: {"multiply_by": 0.3, "subtract_by": 14.5} + y: {"multiply_by": 0.3, "subtract_by": 14.5} + z: {"multiply_by": 0.2, "subtract_by": 15.5} + energy: {"multiply_by": 1, "subtract_by": -1,"func": "np.log","inv_func": "np.exp"} + +trainer: + max_epochs: 600 + gradient_clip_val: 1 + log_every_n_steps: 60 + limit_train_batches: 1.0 # 1.0 means all batches, 0.1 means 10% of all batches and e.g. 2700 means 2700 batches (to define the "epoch", which we might want to be smaller than the whole dataset to get faster feedback on the training process) + limit_val_batches: 1.0 # --> using 200*512 = 102400 validation samples, around 10k per type + +model: + model_kwargs_loaded: null + # --- optimizer configuration --- + optimizer: + _target_: gabbro.utils.optimizer.ranger.Ranger + _partial_: true + lr: 0.001 + weight_decay: 1e-2 + betas: [0.95, 0.999] + eps: 1e-5 + alpha: 0.5 + k: 6 + + # --- learning rate scheduler --- + scheduler: + _target_: gabbro.schedulers.lr_scheduler.OneCycleCooldown + _partial_: true + warmup: 10 # epochs until max_lr is reached + cooldown: 20 # epochs to decrease to initial_lr after max_lr is reached + cooldown_final: 50 # epochs to decrease to final_lr after max_lr is reached + max_lr: 3e-4 + initial_lr: 3e-4 + final_lr: 3e-4 # final_lr is used after the second cooldown + + # --- model architecture configuration --- + model_type: VQVAENormFormer + model_kwargs: + input_dim: 4 + hidden_dim: 128 + latent_dim: 12 + num_blocks: 4 + num_heads: 8 + alpha: 10 + vq_kwargs: + num_codes: 65536 #32768 + beta: 0.9 + kmeans_init: false + norm: null + cb_norm: null + affine_lr: 2 + sync_nu: 1 + replace_freq: 100 + +task_name: "tokenization" + +logger: + wandb: + project: ${project_name} + tags: ${tags} + # group: ${project_name} + name: ${task_name} + comet: + experiment_name: null + project_name: ${project_name} diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml new file mode 100644 index 0000000..b9c6b62 --- /dev/null +++ b/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000..e42b6e5 --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,13 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${project_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}_${nodename_bigram:} +sweep: + dir: ${paths.log_dir}/${project_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml new file mode 100644 index 0000000..865a7e9 --- /dev/null +++ b/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "deep-learning" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml new file mode 100644 index 0000000..844ec67 --- /dev/null +++ b/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: pytorch_lightning.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml new file mode 100644 index 0000000..29110aa --- /dev/null +++ b/configs/logger/many_loggers.yaml @@ -0,0 +1,6 @@ +# train with many loggers at once + +defaults: + - comet.yaml + - csv.yaml + - wandb.yaml diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml new file mode 100644 index 0000000..b30f7a2 --- /dev/null +++ b/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "deep-learning" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/configs/model/backbone_generative.yaml b/configs/model/backbone_generative.yaml new file mode 100644 index 0000000..42ea00a --- /dev/null +++ b/configs/model/backbone_generative.yaml @@ -0,0 +1,20 @@ +_target_: gabbro.models.backbone.BackboneNextTokenPredictionLightning + +model_kwargs: + embedding_dim: 256 + attention_dropout: 0.1 + vocab_size: 8194 + max_sequence_len: 128 + n_GPT_blocks: 3 + n_heads: 8 + verbosity: false + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 # ParT paper uses RAdam optimizer with initial lr of 0.001 + weight_decay: 0 + +scheduler: + _target_: torch.optim.lr_scheduler.ConstantLR + _partial_: true diff --git a/configs/model/model_vqvae_mlp.yaml b/configs/model/model_vqvae_mlp.yaml new file mode 100644 index 0000000..e68dfde --- /dev/null +++ b/configs/model/model_vqvae_mlp.yaml @@ -0,0 +1,42 @@ +_target_: gabbro.models.vqvae.VQVAELightning + +model_type: "MLP" + +model_kwargs: + input_dim: 3 + encoder_layers: [128, 128] + decoder_layers: [128, 128] + latent_dim: 3 + alpha: 5 + vq_kwargs: + num_codes: 10000 + beta: 0.9 + kmeans_init: true + norm: null + cb_norm: null + affine_lr: 0.0 + sync_nu: 2 + replace_freq: 20 + dim: -1 + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + # weight_decay: 0.05 + +scheduler: + _target_: torch.optim.lr_scheduler.ConstantLR + _partial_: true + +# using the method listed in the paper https://arxiv.org/abs/1902.08570, but with other parameters +# scheduler: +# _target_: src.schedulers.lr_scheduler.OneCycleCooldown +# _partial_: true +# warmup: 4 +# cooldown: 10 +# cooldown_final: 10 +# max_lr: 0.0002 +# initial_lr: 0.00003 +# final_lr: 0.00002 +# max_iters: 200 diff --git a/configs/model/model_vqvae_transformer.yaml b/configs/model/model_vqvae_transformer.yaml new file mode 100644 index 0000000..eb9152e --- /dev/null +++ b/configs/model/model_vqvae_transformer.yaml @@ -0,0 +1,43 @@ +_target_: gabbro.models.vqvae.VQVAELightning + +model_type: "Transformer" + +model_kwargs: + input_dim: 3 + hidden_dim: 128 + latent_dim: 16 + num_blocks: 3 + num_heads: 8 + alpha: 5 + vq_kwargs: + num_codes: 2048 + beta: 0.9 + kmeans_init: true + norm: null + cb_norm: null + affine_lr: 0.0 + sync_nu: 2 + replace_freq: 20 + dim: -1 + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 + # weight_decay: 0.05 + +scheduler: + _target_: torch.optim.lr_scheduler.ConstantLR + _partial_: true + +# using the method listed in the paper https://arxiv.org/abs/1902.08570, but with other parameters +# scheduler: +# _target_: src.schedulers.lr_scheduler.OneCycleCooldown +# _partial_: true +# warmup: 4 +# cooldown: 10 +# cooldown_final: 10 +# max_lr: 0.0002 +# initial_lr: 0.00003 +# final_lr: 0.00002 +# max_iters: 200 diff --git a/configs/model/particleflow.yaml b/configs/model/particleflow.yaml new file mode 100644 index 0000000..92d4b15 --- /dev/null +++ b/configs/model/particleflow.yaml @@ -0,0 +1,28 @@ +_target_: gabbro.models.classifiers.ClassifierPL + +model_class_name: "ParticleFlow" + +model_kwargs: + n_tokens: 513 + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.001 # ParT paper uses RAdam optimizer with initial lr of 0.001 + weight_decay: 0 + +scheduler: + _target_: torch.optim.lr_scheduler.ConstantLR + _partial_: true + +# using the method listed in the paper https://arxiv.org/abs/1902.08570, but with other parameters +# scheduler: +# _target_: src.schedulers.lr_scheduler.OneCycleCooldown +# _partial_: true +# warmup: 4 +# cooldown: 10 +# cooldown_final: 10 +# max_lr: 0.0002 +# initial_lr: 0.00003 +# final_lr: 0.00002 +# max_iters: 200 diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml new file mode 100644 index 0000000..e3fecf9 --- /dev/null +++ b/configs/paths/default.yaml @@ -0,0 +1,18 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +# root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +# data_dir: ${oc.env:DATA_DIR} + +# path to logging directory +log_dir: ${oc.env:LOG_DIR} + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:run.dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000..8cfc1a0 --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: iter_dataset.yaml + - model: treeformer.yaml + - callbacks: default_callbacks.yaml + - logger: many_loggers.yaml # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - trainer: gpu.yaml + - paths: default.yaml + - extras: default.yaml + - hydra: default.yaml + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: treeformer_dev.yaml + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default.yaml + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path +task_name: "train" + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +# appending lists from command line is currently not supported :( +# https://github.com/facebookresearch/hydra/issues/1547 +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: True + +# simply provide checkpoint path to resume training +ckpt_path: null + +# seed for random number generators in pytorch, numpy and python.random +seed: 12345 diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml new file mode 100644 index 0000000..640f71d --- /dev/null +++ b/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default.yaml + +accelerator: cpu +devices: 1 diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml new file mode 100644 index 0000000..b577d41 --- /dev/null +++ b/configs/trainer/ddp.yaml @@ -0,0 +1,13 @@ +defaults: + - default.yaml + +# use "ddp_spawn" instead of "ddp", +# it's slower but normal "ddp" currently doesn't work ideally with hydra +# https://github.com/facebookresearch/hydra/issues/2070 +# https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn +strategy: ddp + +accelerator: gpu +devices: -1 +num_nodes: 1 +sync_batchnorm: True diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100644 index 0000000..78384f2 --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,20 @@ +_target_: lightning.Trainer + +default_root_dir: ${paths.output_dir} + +min_epochs: 1 # prevents early stopping +# max_epochs: 10 + +accelerator: cpu +devices: 1 +enable_progress_bar: False + +# mixed precision for extra speed-up +# precision: 16 + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml new file mode 100644 index 0000000..d5e5773 --- /dev/null +++ b/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default.yaml + +accelerator: gpu +devices: 1 diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..5bd9aac --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,55 @@ +FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-runtime + +RUN apt-get update && apt-get install -y \ + vim \ + neovim \ + wget \ + curl \ + git \ + zsh \ + libxcb-xinerama0 \ + lsb-release \ + bat \ + ripgrep + +RUN curl -sL https://raw.githubusercontent.com/wimpysworld/deb-get/main/deb-get | bash -s install deb-get +RUN deb-get install \ + zenith \ + git-delta + +# install just +RUN wget -qO - 'https://proget.makedeb.org/debian-feeds/prebuilt-mpr.pub' | gpg --dearmor | sudo tee /usr/share/keyrings/prebuilt-mpr-archive-keyring.gpg 1> /dev/null +RUN echo "deb [arch=all,$(dpkg --print-architecture) signed-by=/usr/share/keyrings/prebuilt-mpr-archive-keyring.gpg] https://proget.makedeb.org prebuilt-mpr $(lsb_release -cs)" | sudo tee /etc/apt/sources.list.d/prebuilt-mpr.list +RUN apt update && apt install -y just + +# Install GitHub cli +RUN mkdir -p -m 755 /etc/apt/keyrings && wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \ + && chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \ + && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \ + && apt update \ + && apt install gh -y + +# add the 'slurm' user in order to make slurm work from within container +# (if the corresponding libraries are mounted to the container) +RUN adduser --disabled-password --gecos "" slurm + +# allow pip install of "sklearn", which should be replaced by "scikit-learn" +# (if this is not used, the github CI pipeline fails from time to time since +# some packages still have "sklearn" in their dependencies) +ENV SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True + +COPY requirements.txt . + +RUN pip install -r requirements.txt + +SHELL ["/bin/bash", "--login", "-c"] + +RUN conda install -y jupyter +RUN git clone https://github.com/minyoungg/vqtorch && cd vqtorch && pip install -e . + +# move anaconda binary path to the end, otherwise the "clear" command in the +# terminal is broken +ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/anaconda/bin + +# activate conda environment by default +RUN ["/bin/bash", "-c", "source /opt/conda/bin/activate"] diff --git a/docker/requirements.txt b/docker/requirements.txt new file mode 100644 index 0000000..83d4755 --- /dev/null +++ b/docker/requirements.txt @@ -0,0 +1,51 @@ +cupy-cuda117==10.6.0 +dcor==0.6 # calculate distance correlation +fsspec==2023.12.2 +gif==23.3.0 # simplifies animations +h5py==3.10.0 +MarkupSafe==2.1.3 +matplotlib==3.8.4 +mnist==0.2.2 +nltk==3.8.1 # natural-language toolkit to generate unique but readable run-ids +pandas==2.2.1 +pyarrow==15.0.2 +scikit-learn==1.4.1.post1 +seaborn==0.13.0 +tables==3.9.2 +zuko==1.1.0 + +# -------- hep stuff ------- +awkward==2.6.3 +energyflow +fastjet==3.4.1.3 +jetnet==0.2.5 +law==0.1.18 +ragged==0.1.0 +uproot==5.1.2 +weaver-core==0.4.15 + +# --------- pytorch --------- # +lightning==2.2.1 +pytorch-lightning==2.2.1 +torchdyn==1.0.6 +torchmetrics==1.3.2 +torchvision==0.17.0 + +# --------- hydra --------- # +hydra-colorlog==1.2.0 +hydra-core==1.3.2 +hydra-optuna-sweeper==1.2.0 + +# --------- loggers --------- # +comet-ml==3.39.2 +wandb==0.16.6 + +# --------- dev tools & others --------- # +black==24.3.0 +isort==5.13.2 +pre-commit==3.7.0 # hooks for applying linters on commit +pylint==3.1.0 +pyrootutils==1.0.4 # standardizing the project root setup +pytest==8.1.1 # tests +rich==13.7.1 # beautiful text formatting in terminal +ruff==0.3.5 # fast linter/formatter diff --git a/examples/example_notebooks/example_generate_showers.ipynb b/examples/example_notebooks/example_generate_showers.ipynb new file mode 100644 index 0000000..118b34d --- /dev/null +++ b/examples/example_notebooks/example_generate_showers.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "import awkward as ak\n", + "import numpy as np\n", + "import vector\n", + "from omegaconf import OmegaConf\n", + "\n", + "sys.path.append(\"/data/dust/user/rosehenn/gabbro\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Shower generation with a trained OmniJet model\n", + "\n", + "This notebook provides a short example on how to load a trained OmniJet model with the next-token-prediction head and generate jets with it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gabbro.models.backbone import BackboneNextTokenPredictionLightning\n", + "\n", + "# this checkpoint is the checkpoint from a backbone training with the nex-token-prediction head\n", + "# make sure you have downloaded the checkpoint in advance\n", + "# if not, run the script `checkpoints/download_checkpoints.sh`\n", + "ckpt_path = \"/data/dust/user/rosehenn/gabbro_output/full_resolution/runs/2024-11-21_13-49-55_max-wng060_TerminativeCirculation/checkpoints/epoch_032_loss_4.10881.ckpt\"\n", + "gen_model = BackboneNextTokenPredictionLightning.load_from_checkpoint(ckpt_path)\n", + "gen_model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generating Showers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "generated_showers = gen_model.generate_n_showers_batched(\n", + " n_showers=2,\n", + " batch_size=2,\n", + " # saveas=save_path, # use this option if you want to save the awkward array as a parquet file\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "generated_showers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --- Load the tokenizer model from checkpoint, and also get the feature_dict from the config ---\n", + "from gabbro.models.vqvae import VQVAELightning\n", + "\n", + "ckpt_path = \"/data/dust/user/rosehenn/gabbro_output/TokTrain/runs/2024-09-21_16-54-39_max-wng062_CerousLocknut/checkpoints/epoch_231_loss_0.17179.ckpt\"\n", + "\n", + "vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path)\n", + "vqvae_model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = OmegaConf.load(Path(ckpt_path).parent.parent / \"config.yaml\")\n", + "pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)\n", + "print(\"\\npp_dict:\")\n", + "for item in pp_dict:\n", + " print(item, pp_dict[item])\n", + "\n", + "# get the cuts from the pp_dict (since this leads to particles being removed during\n", + "# preprocessing/tokenization), thus we also have to remove them from the original jets\n", + "# when we compare the tokenized+reconstructed particles to the original ones)\n", + "pp_dict_cuts = {\n", + " feat_name: {\n", + " criterion: pp_dict[feat_name].get(criterion)\n", + " for criterion in [\"larger_than\", \"smaller_than\"]\n", + " }\n", + " for feat_name in pp_dict\n", + "}\n", + "\n", + "print(\"\\npp_dict_cuts:\")\n", + "for item in pp_dict_cuts:\n", + " print(item, pp_dict_cuts[item])\n", + "\n", + "print(\"\\nModel:\")\n", + "print(vqvae_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# reconstruct the generated tokens to physical features\n", + "\n", + "# note that if you want to reconstruct tokens from the generative model, you'll have\n", + "# to remove the start token from the tokenized array, and subtract 1 from the tokens\n", + "# (since we chose the convention to use 0 as the start token, so the tokens from the\n", + "# generative model are shifted by 1 compared to the ones from the VQ-VAE)\n", + "showers_reconstructed = vqvae_model.reconstruct_ak_tokens(\n", + " tokens_ak=generated_showers[:, 1:] - 1,\n", + " pp_dict=pp_dict,\n", + " batch_size=512,\n", + " pad_length=128,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "showers_reconstructed" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/example_notebooks/example_tokenize_and_reconstruct_showers.ipynb b/examples/example_notebooks/example_tokenize_and_reconstruct_showers.ipynb new file mode 100644 index 0000000..52482bf --- /dev/null +++ b/examples/example_notebooks/example_tokenize_and_reconstruct_showers.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "import awkward as ak\n", + "import numpy as np\n", + "import vector\n", + "from omegaconf import OmegaConf\n", + "\n", + "sys.path.append(\"/data/dust/user/rosehenn/gabbro\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tokenization with the VQ-VAE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --- Load the tokenizer model from checkpoint, and also get the feature_dict from the config ---\n", + "from gabbro.models.vqvae import VQVAELightning\n", + "\n", + "ckpt_path = \"/data/dust/user/rosehenn/gabbro_output/TokTrain/runs/2024-09-21_16-54-39_max-wng062_CerousLocknut/checkpoints/epoch_231_loss_0.17179.ckpt\"\n", + "\n", + "vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path)\n", + "vqvae_model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = OmegaConf.load(Path(ckpt_path).parent.parent / \"config.yaml\")\n", + "pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)\n", + "print(\"\\npp_dict:\")\n", + "for item in pp_dict:\n", + " print(item, pp_dict[item])\n", + "\n", + "# get the cuts from the pp_dict (since this leads to particles being removed during\n", + "# preprocessing/tokenization), thus we also have to remove them from the original jets\n", + "# when we compare the tokenized+reconstructed particles to the original ones)\n", + "pp_dict_cuts = {\n", + " feat_name: {\n", + " criterion: pp_dict[feat_name].get(criterion)\n", + " for criterion in [\"larger_than\", \"smaller_than\"]\n", + " }\n", + " for feat_name in pp_dict\n", + "}\n", + "\n", + "print(\"\\npp_dict_cuts:\")\n", + "for item in pp_dict_cuts:\n", + " print(item, pp_dict_cuts[item])\n", + "\n", + "print(\"\\nModel:\")\n", + "print(vqvae_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load shower file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gabbro.data.loading import read_shower_file\n", + "\n", + "filename_in = \"/data/dust/user/rosehenn/gabbro/notebooks/array_real.parquet\"\n", + "showers = ak.from_parquet(filename_in)\n", + "showers = showers[:5000]\n", + "# part_features_ak = ak_select_and_preprocess(data_showers, pp_dict_cuts)[:, :128]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tokenize and reconstruct showers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# tokenization and reconstruction\n", + "\n", + "part_features_ak_tokenized = vqvae_model.tokenize_ak_array(\n", + " ak_arr=showers,\n", + " pp_dict=pp_dict,\n", + " batch_size=4,\n", + " pad_length=1700,\n", + ")\n", + "# note that if you want to reconstruct tokens from the generative model, you'll have\n", + "# to remove the start token from the tokenized array, and subtract 1 from the tokens\n", + "# (since we chose the convention to use 0 as the start token, so the tokens from the\n", + "# generative model are shifted by 1 compared to the ones from the VQ-VAE)\n", + "part_features_ak_reco = vqvae_model.reconstruct_ak_tokens(\n", + " tokens_ak=part_features_ak_tokenized,\n", + " pp_dict=pp_dict,\n", + " batch_size=4,\n", + " pad_length=1700,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# inspect the tokenized and reconstructed Showers\n", + "print(\"First 5 tokenized Showers:\")\n", + "for i in range(5):\n", + " print(part_features_ak_tokenized[i])\n", + "\n", + "print(\"\\nFirst 5 reconstructed Showers:\")\n", + "for i in range(5):\n", + " print(part_features_ak_reco[i])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot the reconstructed showers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gabbro.plotting.feature_plotting import plot_paper_plots\n", + "\n", + "fig = plot_paper_plots(\n", + " feature_sets=[showers[: len(part_features_ak_reco)], part_features_ak_reco],\n", + " labels=[\"Geant4\", \"Tokenized\"], # \"OmniJet-$\\\\alpha_C$\" \"BIB-AE\", \"L2L Flows\"\n", + " colors=[\"lightgrey\", \"#1a80bb\", \"#ea801c\", \"#4CAF50\", \"#1a80bb\"],\n", + ")\n", + "fig.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/job_scripts/example_job_train_generative.sh b/examples/job_scripts/example_job_train_generative.sh new file mode 100755 index 0000000..7c01c85 --- /dev/null +++ b/examples/job_scripts/example_job_train_generative.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# NOTE: This script is an example and should be adjusted to your needs. +# The fields which need to be adjusted are marked with "ADJUST THIS". + +######################### +## SLURM JOB COMMANDS ### +######################### +#SBATCH --partition=maxgpu +#SBATCH --constraint="A100&GPUx1" +#SBATCH --time=100:00:00 +#SBATCH --exclude= +#SBATCH --job-name TokenGen_example +#SBATCH --output /data/dust/user/rosehenn/gabbro_output/logs/slurm_logs/%x_%j.log # ADJUST THIS to your log path +#SBATCH --mail-user your.email + +echo "Starting job $SLURM_JOB_ID with the following script:" +echo "----------------------------------------------------------------------------" +echo +cat $0 + +source ~/.bashrc +cd /data/dust/user/rosehenn/gabbro/ # ADJUST THIS to your repository path + +LOGFILE="/data/dust/user/rosehenn/gabbro_output/logs/slurm_logs/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" # ADJUST THIS to your log path +PYTHON_COMMAND="python gabbro/train.py experiment=example_experiment_backbone_generative.yaml" + +# run the python command in the singularity container +# ADJUST THIS to your singularity image path +singularity exec --nv --bind /data:/data \ + --env SLURM_JOB_ID="$SLURM_JOB_ID" --env SLURM_LOGFILE="$LOGFILE" \ + docker://jobirk/omnijet:latest \ + bash -c "source /opt/conda/bin/activate && $PYTHON_COMMAND" +## ---------------------- End of job script ----------------------------------- +################################################################################ diff --git a/examples/job_scripts/example_job_train_tokenization.sh b/examples/job_scripts/example_job_train_tokenization.sh new file mode 100755 index 0000000..25cadcf --- /dev/null +++ b/examples/job_scripts/example_job_train_tokenization.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# NOTE: This script is an example and should be adjusted to your needs. +# The fields which need to be adjusted are marked with "ADJUST THIS". + +######################### +## SLURM JOB COMMANDS ### +######################### +#SBATCH --partition=maxgpu +#SBATCH --constraint=GPU +#SBATCH --time=48:00:00 +#SBATCH --job-name TokTrain_example # ADJUST THIS to your job name +#SBATCH --output /data/dust/user/rosehenn/gabbro_output/logs/slurm_logs/%x_%j.log # ADJUST THIS to your log path +#SBATCH --mail-user your.email + +echo "Starting job $SLURM_JOB_ID with the following script:" +echo "----------------------------------------------------------------------------" +echo +cat $0 + +source ~/.bashrc +cd /data/dust/user/rosehenn/gabbro/ # ADJUST THIS to your repository path + +LOGFILE="/data/dust/user/rosehenn/gabbro_output/logs/slurm_logs/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" # ADJUST THIS to your log path +PYTHON_COMMAND="python gabbro/train.py experiment=example_experiment_tokenization.yaml \ +" + +# run the python command in the singularity container +# ADJUST THIS to your singularity image path +srun singularity exec --nv --bind /data:/data \ + --env SLURM_JOB_ID="$SLURM_JOB_ID" --env SLURM_LOGFILE="$LOGFILE" \ + docker://jobirk/omnijet:latest \ + bash -c "source /opt/conda/bin/activate && $PYTHON_COMMAND" + +## ---------------------- End of job script ----------------------------------- +################################################################################ diff --git a/gabbro/callbacks/best_checkpoint_callback.py b/gabbro/callbacks/best_checkpoint_callback.py new file mode 100644 index 0000000..3b82145 --- /dev/null +++ b/gabbro/callbacks/best_checkpoint_callback.py @@ -0,0 +1,18 @@ +from lightning.pytorch.callbacks import ModelCheckpoint + + +class CustomModelCheckpoint(ModelCheckpoint): + """Custom ModelCheckpoint callback that allows to specify the state_key to be used for the best + checkpoint. + + This workaround is needed because it's not allowed to have two ModelCheckpoint callbacks with + the same state_key in the same Trainer. + """ + + def __init__(self, state_key="best_checkpoint", **kwargs): + super().__init__(**kwargs) + self._state_key = state_key + + @property + def state_key(self) -> str: + return self._state_key diff --git a/gabbro/callbacks/generative_callback.py b/gabbro/callbacks/generative_callback.py new file mode 100644 index 0000000..1939552 --- /dev/null +++ b/gabbro/callbacks/generative_callback.py @@ -0,0 +1,357 @@ +"""Callback for evaluating the generative token model.""" + +import os +from pathlib import Path + +import awkward as ak +import lightning as L +import numpy as np +import torch.distributed as dist +import vector + +import gabbro.plotting.utils as plot_utils +from gabbro.models.backbone import BackboneNextTokenPredictionLightning +from gabbro.plotting.feature_plotting import plot_paper_plots +from gabbro.utils.arrays import np_to_ak +from gabbro.utils.pylogger import get_pylogger +from gabbro.utils.utils import analyze_first_10_tokens + +pylogger = get_pylogger("GenEvalCallback") +vector.register_awkward() + + +class GenEvalCallback(L.Callback): + def __init__( + self, + image_path: str = None, + image_filetype: str = "png", + no_trainer_info_in_filename: bool = False, + save_result_arrays: bool = None, + n_val_gen_jets: int = 10, + starting_at_epoch: int = 0, + every_n_epochs: int = 1, + batch_size_for_generation: int = 2, + plot_best_checkpoint: bool = False, + data_dir: str = "data", + seed_shuffle_val_data: int = None, + n_data: int = 1000, + weights: str = "some_weights", + file_path: str = "some_file_path", + ): + """Callback for evaluating the tokenization of particles. + + Parameters + ---------- + image_path : str + Path to save the images to. If None, the images are saved to the + default_root_dir of the trainer. + image_filetype : str + Filetype to save the images as. Default is "png". + no_trainer_info_in_filename : bool + If True, the filename of the images will not contain the epoch and + global step information. Default is False. + save_result_arrays : bool + If True, the results are saved as parquet file. Default is None. + n_val_gen_jets : int + Number of validation jets to generate. Default is 10. + starting_at_epoch : int + Start evaluating the model at this epoch. Default is 0. + every_n_epochs : int + Evaluate the model every n epochs. Default is 1. + batch_size_for_generation : int + Batch size for generating the jets. Default is 512. + plot_best_checkpoint : bool + If True, the best checkpoint is used for generating the showers. + n_data : int + This is just the number of showers used for training + weights : str + This is the path to the weights used for the backbone + """ + super().__init__() + self.comet_logger = None + self.image_path = image_path + self.n_val_gen_jets = n_val_gen_jets + self.image_filetype = image_filetype + self.no_trainer_info_in_filename = no_trainer_info_in_filename + self.save_results_arrays = save_result_arrays + self.every_n_epochs = every_n_epochs + self.starting_at_epoch = starting_at_epoch + self.batch_size_for_generation = batch_size_for_generation + self.best_checkpoint = plot_best_checkpoint + self.data_dir = data_dir + self.seed_shuffle_val_data = seed_shuffle_val_data + self.n_data = str(n_data) + if weights is None: + self.weights = "no_weights" + elif weights == "some_weights": + self.weights = "no_weights" + elif weights == "": + self.weights = "no_weights" + else: + self.weights = weights + pylogger.info(f"the label used for the saving of the data: {self.weights}") + pylogger.info(f"the actual input for the weights: {weights}") + self.filepath = file_path + + def on_validation_epoch_end(self, trainer, pl_module): + if trainer.current_epoch < self.starting_at_epoch: + pylogger.info( + "Skipping generation. Starting evaluating with this callback" + f" at epoch {self.starting_at_epoch}." + ) + return None + if trainer.current_epoch % self.every_n_epochs != 0: + pylogger.info( + f"Skipping generation. Only evaluating every {self.every_n_epochs} epochs." + ) + return None + if len(pl_module.val_token_ids_list) == 0: + pylogger.warning( + "No validation data available. Skipping generation in validation end." + ) + return None + self.plot_real_vs_gen_jets(trainer, pl_module) + + def on_test_epoch_end(self, trainer, pl_module): + pass + + def on_train_end(self, trainer, pl_module): + """Called at the end of fit (training + optional testing).""" + if len(pl_module.val_token_ids_list) == 0: + pylogger.warning("No validation data available. Skipping generation in train end.") + return None + + self.plot_real_vs_gen_jets(trainer, pl_module) + + def plot_real_vs_gen_jets(self, trainer, pl_module): + plot_utils.set_mpl_style() + + # get loggers + for logger in trainer.loggers: + if isinstance(logger, L.pytorch.loggers.CometLogger): + self.comet_logger = logger.experiment + elif isinstance(logger, L.pytorch.loggers.WandbLogger): + self.wandb_logger = logger.experiment + # convert the numpy arrays and masks of the real jets to ak arrays of token + # ids + pylogger.info(f"Starting generation of {self.n_val_gen_jets} jets...") + np_real_token_ids = np.concatenate(pl_module.val_token_ids_list) + np_real_token_masks = np.concatenate(pl_module.val_token_masks_list) + + pylogger.info(f"np_real_token_ids.shape: {np_real_token_ids.shape}") + pylogger.info(f"np_real_token_masks.shape: {np_real_token_masks.shape}") + real_token_ids = np_to_ak( + x=np_real_token_ids, + names=["part_token_id"], + mask=np_real_token_masks, + ) + + if self.best_checkpoint: + # This Part is new and updates to the best / not the current Checkpoint + best_checkpoint_path = trainer.checkpoint_callback.best_model_path + pylogger.info(f"Loading best checkpoint from {best_checkpoint_path}") + pl_module = BackboneNextTokenPredictionLightning.load_from_checkpoint( + best_checkpoint_path + ) # Call on the class + pl_module.eval() + pl_module.load_backbone_weights(ckpt_path=best_checkpoint_path) + ######## + + self.real_token_ids = ak.values_astype(real_token_ids["part_token_id"], "int64") + self.gen_token_ids = pl_module.generate_n_jets_batched( + self.n_val_gen_jets, batch_size=self.batch_size_for_generation + ) + data_dir = self.data_dir + data_dir = Path(data_dir) + data_dir = data_dir / "reconstructed_test.parquet" + + pylogger.info(f"real_token_ids: {self.real_token_ids}") + pylogger.info(f"gen_token_ids: {self.gen_token_ids}") + + pylogger.info(f"Length of generated shower: {len(self.gen_token_ids)}") + pylogger.info(f"Length of real shower: {len(self.real_token_ids)}") + + plot_dir = ( + self.image_path + if self.image_path is not None + else trainer.default_root_dir + "/plots/" + ) + os.makedirs(plot_dir, exist_ok=True) + filename_real = f"{plot_dir}/epoch{trainer.current_epoch}_gstep{trainer.global_step}_real_shower.parquet" + # Get the rank of the GPU + rank = dist.get_rank() if dist.is_initialized() else 0 + filename_gen = filename_real.replace("real_shower", f"gen_shower_rank{rank}") + + # log min max values of the token ids and of the number of constituents + multiplicity_real = ak.num(self.real_token_ids) + multiplicity_gen = ak.num(self.gen_token_ids) + pylogger.info( + f"Real shower: min multiplicity: {ak.min(multiplicity_real)}, " + f"max multiplicity: {ak.max(multiplicity_real)}" + ) + pylogger.info( + f"Gen shower: min multiplicity: {ak.min(multiplicity_gen)}, " + f"max multiplicity: {ak.max(multiplicity_gen)}" + ) + pylogger.info( + f"Real shower: min token id: {ak.min(self.real_token_ids)}, " + f"max token id: {ak.max(self.real_token_ids)}" + ) + pylogger.info( + f"Gen shower: min token id: {ak.min(self.gen_token_ids)}, " + f"max token id: {ak.max(self.gen_token_ids)}" + ) + + # check if there are nan values in the token ids + if np.sum(np.isnan(ak.flatten(self.real_token_ids))) > 0: + pylogger.warning("Real token ids contain NaN values.") + if np.sum(np.isnan(ak.flatten(self.gen_token_ids))) > 0: + pylogger.warning("Generated token ids contain NaN values.") + + # ak.to_parquet(self.real_token_ids, filename_real) + ak.to_parquet(self.gen_token_ids, filename_gen) + pylogger.info(f"Real shower saved to {filename_real}") + pylogger.info(f"Generated shower saved to {filename_gen}") + + def reconstruct_ak_array( + ak_array_filepath, start_token_included, end_token_included, shift_tokens_by_minus_one + ): + token_dir = Path(pl_module.token_dir) + config_file = token_dir / "config.yaml" + ckpt_file = token_dir / "model.ckpt" + input_file = ak_array_filepath + output_file = ak_array_filepath.replace(".parquet", "_reco.parquet") + + REPO_DIR = Path(__file__).resolve().parent.parent.parent + PYTHON_COMMAND = [ + "python", + f"{REPO_DIR}/scripts/reconstruct_shower_tokens.py", + f"--tokens_file={input_file}", + f"--output_file={output_file}", + f"--ckpt_file={ckpt_file}", + f"--config_file={config_file}", + f"--start_token_included={start_token_included}", + f"--end_token_included={end_token_included}", + f"--shift_tokens_by_minus_one={shift_tokens_by_minus_one}", + ] + os.system(" ".join(PYTHON_COMMAND)) # nosec + + return output_file + + # self.real_reco_file = reconstruct_ak_array(filename_real, 1, 1, 1) #Dont use this, because we dont need to reconstruct the data again and rather want to compare to the original data + self.gen_reco_file = reconstruct_ak_array(filename_gen, 1, 0, 1) + # TODO: make this adjustable + p4s_real = ak.from_parquet( + "/beegfs/desy/user/rosehenn/gabbro/notebooks/array_test.parquet" + ) + p4s_real_token = ak.from_parquet( + "/beegfs/desy/user/rosehenn/gabbro/compare/2024-09-21_16-54-39_max-wng062_CerousLocknut/tokenized_test_e_sorted.parquet" + ) + + # Barrier to ensure all ranks have finished processing before proceeding + if dist.is_initialized(): + dist.barrier() + + if rank == 0: + base_filename = filename_gen.replace("_rank0.parquet", "") + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + gen_data_list = [] + gen_data_list_token = [] + for i in range(world_size): + filename = base_filename + f"_rank{i}_reco.parquet" + filename_token = base_filename + f"_rank{i}.parquet" + if os.path.exists(filename): + gen_data = ak.from_parquet(filename) + gen_data_list.append(gen_data) + else: + print(f"Warning: File {filename_token} does not exist.") + if os.path.exists(filename_token): + gen_data_token = ak.from_parquet(filename_token) + gen_data_list_token.append(gen_data_token) + + # Combine all the data + if gen_data_list: + p4s_gen = ak.concatenate(gen_data_list) + else: + p4s_gen = ak.Array([]) # Handle the case where no files are found + if gen_data_list_token: + p4s_gen_token = ak.concatenate(gen_data_list_token) + else: + p4s_gen_token = ak.Array([]) + # Now p4s_gen contains the combined data from all ranks + + min_length = min(len(p4s_real), len(p4s_gen)) + min_length_token = min(len(p4s_real_token), len(p4s_gen_token)) + + # Truncate to the shorter length + p4s_real = p4s_real[:min_length] + p4s_gen = p4s_gen[:min_length] + + p4s_real_token = p4s_real_token[:min_length_token] + p4s_gen_token = p4s_gen_token[:min_length_token] + + real_token_counts = analyze_first_10_tokens(p4s_real_token) + gen_token_counts = analyze_first_10_tokens(p4s_gen_token) + + pylogger.info(f"Real token counts: {real_token_counts}") + pylogger.info(f"Generated token counts: {gen_token_counts}") + + # Analysis of the first 10 tokens + mean_real_token_count_10 = np.mean(real_token_counts[:10]) + mean_gen_token_count_10 = np.mean(gen_token_counts[:10]) + + mean_gen_token_count = np.mean(gen_token_counts) + mean_real_token_count = np.mean(real_token_counts) + + diversity_real_tokens = len(real_token_counts) + diversity_gen_tokens = len(gen_token_counts) + + if self.comet_logger is not None: + self.comet_logger.log_metrics( + { + "mean_real_token_count_10": mean_real_token_count_10, + "mean_gen_token_count_10": mean_gen_token_count_10, + "mean_gen_token_count": mean_gen_token_count, + "mean_real_token_count": mean_real_token_count, + "diversity_real_tokens": diversity_real_tokens, + "diversity_gen_tokens": diversity_gen_tokens, + }, + step=trainer.global_step, + ) + + # Plot the real vs. generated showers + pylogger.info(f"Real shower: {p4s_real}") + pylogger.info(f"Generated shower: {p4s_gen}") + pylogger.info( + f"Plotting {len(p4s_real)} real showers and {len(p4s_gen)} generated showers..." + ) + if self.best_checkpoint: + plot_kwargs = { + "filepath": self.filepath, + "weights": self.weights, + "n_data": self.n_data, + "transfer_learning": True, + } + else: + plot_kwargs = {} + + fig = plot_paper_plots( + feature_sets=[p4s_real, p4s_gen], + colors=["lightgrey", "cornflowerblue"], + labels=["Geant4", "Generated"], + **plot_kwargs, + ) + + image_filename = f"{plot_dir}/epoch{trainer.current_epoch}_gstep{trainer.global_step}_real_vs_gen_shower.{self.image_filetype}" + # image_filename_COG = f"{plot_dir}/epoch{trainer.current_epoch}_gstep{trainer.global_step}_real_vs_gen_shower_COG.{self.image_filetype}" + fig.savefig(image_filename) + # fig_COG.savefig(image_filename_COG) + + if self.comet_logger is not None: + for fname in [image_filename]: + self.comet_logger.log_image( + fname, name=fname.split("/")[-1], step=trainer.global_step + ) + if dist.is_initialized(): + dist.barrier() diff --git a/gabbro/callbacks/tokenization_shower_callback.py b/gabbro/callbacks/tokenization_shower_callback.py new file mode 100644 index 0000000..20700ea --- /dev/null +++ b/gabbro/callbacks/tokenization_shower_callback.py @@ -0,0 +1,914 @@ +"""Callback for evaluating the tokenization of particles.""" + +import math +import os + +import awkward as ak +import lightning as L +import matplotlib.pyplot as plt +import numpy as np +import torch.distributed as dist +import vector +from matplotlib.gridspec import GridSpec +from scipy.stats import wasserstein_distance + +from gabbro.utils.arrays import ak_preprocess, np_to_akward + +# from gabbro.plotting.plotting_functions import plot_p4s +from gabbro.utils.pylogger import get_pylogger +from gabbro.utils.utils import KL, get0Momentum, get_diff_construct + +default_labels = {"x": "$x$", "y": "$y$", "z": "$z$", "energy": "$E$"} + +pylogger = get_pylogger("TokenizationEvalCallback") +vector.register_awkward() + + +class TokenizationEvalCallback(L.Callback): + def __init__( + self, + image_path: str = None, + image_filetype: str = "png", + no_trainer_info_in_filename: bool = False, + save_result_arrays: bool = None, + ): + """Callback for evaluating the tokenization of particles. + + Parameters + ---------- + image_path : str + Path to save the images to. If None, the images are saved to the + default_root_dir of the trainer. + image_filetype : str + Filetype to save the images as. Default is "png". + no_trainer_info_in_filename : bool + If True, the filename of the images will not contain the epoch and + global step information. Default is False. + save_result_arrays : bool + If True, the results are saved as parquet file. Default is None. + """ + super().__init__() + self.comet_logger = None + self.image_path = image_path + self.image_filetype = image_filetype + self.no_trainer_info_in_filename = no_trainer_info_in_filename + self.save_results_arrays = save_result_arrays + + def on_validation_epoch_end(self, trainer, pl_module): + if dist.is_initialized() and dist.get_rank() != 0: + return + pl_module.concat_validation_loop_predictions() + self.plot(trainer, pl_module, stage="val") + + def on_test_epoch_end(self, trainer, pl_module): + if dist.is_initialized() and dist.get_rank() != 0: + return + pl_module.concat_test_loop_predictions() + self.plot(trainer, pl_module, stage="test") + + def plot(self, trainer, pl_module, stage="val"): + if stage == "val" and not hasattr(pl_module, "val_x_original_concat"): + pylogger.info("No validation predictions found. Skipping plotting.") + return + + pylogger.info( + f"Running TokenizationEvalCallback epoch: {trainer.current_epoch} step:" + f" {trainer.global_step}" + ) + # get loggers + for logger in trainer.loggers: + if isinstance(logger, L.pytorch.loggers.CometLogger): + self.comet_logger = logger.experiment + elif isinstance(logger, L.pytorch.loggers.WandbLogger): + self.wandb_logger = logger.experiment + + plot_dir = ( + self.image_path if self.image_path is not None else trainer.default_root_dir + "/plots" + ) + os.makedirs(plot_dir, exist_ok=True) + if self.no_trainer_info_in_filename: + plot_filename = f"{plot_dir}/evaluation_overview.{self.image_filetype}" + else: + plot_filename = f"{plot_dir}/epoch{trainer.current_epoch}_gstep{trainer.global_step}_overview.{self.image_filetype}" + + if stage == "val": + x_recos = pl_module.val_x_reco_concat + x_originals = pl_module.val_x_original_concat + masks = pl_module.val_mask_concat + pylogger.info(f"x_recos.shape: {x_recos.shape}") + pylogger.info(f"x_originals.shape: {x_originals.shape}") + pylogger.info(f"masks.shape: {masks.shape}") + # labels = pl_module.val_labels_concat + code_idx = pl_module.val_code_idx_concat + elif stage == "test": + # return and print that there are no test predictions if there are none + if not hasattr(pl_module, "test_x_original_concat"): + pylogger.info("No test predictions found. Skipping plotting.") + return + x_recos = pl_module.test_x_reco_concat + x_originals = pl_module.test_x_original_concat + masks = pl_module.test_mask_concat + # labels = pl_module.test_labels_concat + code_idx = pl_module.test_code_idx_concat + else: + raise ValueError(f"stage {stage} not recognized") + + if stage == "test": + pylogger.info(f"x_original_concat.shape: {x_originals.shape}") + pylogger.info(f"x_reco_concat.shape: {x_recos.shape}") + pylogger.info(f"masks_concat.shape: {masks.shape}") + # pylogger.info(f"labels_concat.shape: {labels.shape}") + + pp_dict = trainer.datamodule.hparams.dataset_kwargs_common.feature_dict + + x_reco_ak_pp = np_to_akward(x_recos, pp_dict) + x_original_ak_pp = np_to_akward(x_originals, pp_dict) + + pylogger.info(f"x rocusntructed before preprocess {x_reco_ak_pp}") + pylogger.info(f"x_original before preprocess: {x_original_ak_pp}") + x_reco_ak = ak_preprocess(x_reco_ak_pp, pp_dict=pp_dict, inverse=True) + x_original_ak = ak_preprocess(x_original_ak_pp, pp_dict=pp_dict, inverse=True) + pylogger.info(f"x_reconstructed after preprocess {x_reco_ak}") + pylogger.info(f"x_original after preprocess: {x_original_ak}") + x_complete = ak.to_numpy(x_original_ak["x"]) + y_complete = ak.to_numpy(x_original_ak["y"]) + z_complete = ak.to_numpy(x_original_ak["z"]) + energy_xyz_complete = ak.to_numpy(x_original_ak["energy"]) + energy_xyz_complete = energy_xyz_complete * masks + pylogger.info(f"energy: {energy_xyz_complete}") + + x_complete_reco = ak.to_numpy(x_reco_ak["x"]) + y_complete_reco = ak.to_numpy(x_reco_ak["y"]) + z_complete_reco = ak.to_numpy(x_reco_ak["z"]) + energy_xyz_complete_reco = ak.to_numpy(x_reco_ak["energy"]) + energy_xyz_complete_reco = energy_xyz_complete_reco * masks + + x = x_complete[masks].flatten() + y = y_complete[masks].flatten() + z = z_complete[masks].flatten() + energy_xyz = energy_xyz_complete[masks].flatten() + + x_reco = x_complete_reco[masks].flatten() + y_reco = y_complete_reco[masks].flatten() + z_reco = z_complete_reco[masks].flatten() + energy_xyz_reco = energy_xyz_complete_reco[masks].flatten() + + pylogger.info(f"energy: {energy_xyz}") + pylogger.info(f"energy_reco: {energy_xyz_reco}") + x_bin_min = min(x) - 0.5 + x_bin_max = max(x) + 1.5 + y_bin_min = min(y) - 0.5 + y_bin_max = max(y) + 1.5 + z_bin_min = min(z) - 0.5 + z_bin_max = max(z) + 1.5 + + fig = plt.figure(figsize=(18, 12), facecolor="white") + fig.suptitle("Projected Showers", fontsize=30) + + gs = GridSpec(2, 3) + ############################################################ + # First Histogram - Energy Plots + ############################################################ + bins = np.logspace(np.log(0.001), np.log(max(energy_xyz)), 150, base=np.e) + ax0 = fig.add_subplot(gs[0]) + ax0.axvline(0.1, linestyle="--", color="black", label="MIP") + ax0.set_title("Visible Energy") + ax0.hist( + energy_xyz, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=0.5, + label="original", + color="silver", + ) + ax0.hist( + energy_xyz_reco, + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + wasserstein_dist = wasserstein_distance(energy_xyz, energy_xyz_reco) + kl_divergence = KL(energy_xyz, energy_xyz_reco, bins) + + ax0.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax0.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax0.set_xlabel("Visible energy (MeV)") + ax0.set_ylabel("a.u.") + ax0.legend(loc="upper right") + ax0.set_xscale("log") + ax0.set_yscale("log") + + # Plot for y non-logarithmic + ax1 = fig.add_subplot(gs[1]) + ax1.set_title("Visible Energy") + ax1.axvline(0.1, linestyle="--", color="black", label="MIP") + bins = np.logspace(np.log(0.001), np.log(max(energy_xyz)), 150, base=np.e) + ax1.hist( + energy_xyz, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=0.5, + label="original", + color="silver", + ) + ax1.hist( + energy_xyz_reco, + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + ax1.set_xlabel("Visible energy (MeV)") + ax1.set_ylabel("a.u.") + ax1.set_xscale("log") + ax1.legend(loc="upper right") + wasserstein_dist = wasserstein_distance(energy_xyz, energy_xyz_reco) + kl_divergence = KL(energy_xyz, energy_xyz_reco, 150) + ax1.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax1.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + + # Energy Sum Histogram + ax2 = fig.add_subplot(gs[2]) + ax2.set_title("Energy Sum") + + data1 = np.sum(energy_xyz_complete, axis=-1) + data2 = np.sum(energy_xyz_complete_reco, axis=-1) + + ax2.hist( + data1, + bins=50, + histtype="stepfilled", + lw=2, + alpha=1.0, + label="original", + color="silver", + ) + ax2.hist( + data2, + bins=50, + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, 50) + ax2.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax2.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax2.set_xlabel("Visible energy sum (MeV)") + ax2.set_ylabel("a.u.") + ax2.legend(loc="upper right") + + """ + # Number of Hits Histogram + ax2 = fig.add_subplot(gs[2]) + ax2.set_title("Number of Hits") + ax2.hist((energy_xyz_complete != 0).reshape(-1, size_of_event).sum(axis=1), bins=50, histtype="stepfilled", lw=2, alpha=1.0, label="original",color="silver") + ax2.hist( + (energy_xyz_complete_reco != 0).reshape(-1, size_of_event).sum(axis=1), + bins=50, + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + ax2.set_xlabel("n_hits") + ax2.set_ylabel("a.u.") + ax2.legend(loc = "upper right") + """ + + # Plot for only y-scale to 0.1 + ax3 = fig.add_subplot(gs[3]) + bins = np.logspace(np.log(0.1), np.log(max(energy_xyz)), 150, base=np.e) + ax3.set_title("Visible Energy") + ax3.hist( + energy_xyz, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=0.5, + label="original", + color="silver", + ) + ax3.hist( + energy_xyz_reco, + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + + wasserstein_dist = wasserstein_distance(energy_xyz, energy_xyz_reco) + kl_divergence = KL(energy_xyz, energy_xyz_reco, 150) + ax3.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax3.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + + ax3.set_xlabel("Visible energy (MeV)") + ax3.set_ylabel("a.u.") + ax3.legend(loc="upper right") + ax3.set_yscale("log") + ax3.set_xscale("log") + + # Plot for only x-scale logarithmic + ax4 = fig.add_subplot(gs[4]) + bins = np.logspace(np.log(0.1), np.log(max(energy_xyz)), 150, base=np.e) + ax4.set_title("Visible Energy") + + ax4.hist( + energy_xyz, + bins, + histtype="stepfilled", + lw=2, + alpha=0.5, + label="original", + color="silver", + ) + ax4.hist( + energy_xyz_reco, + bins, + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + ax4.set_xlabel("Visible energy (MeV)") + ax4.set_ylabel("a.u.") + ax4.legend(loc="upper right") + ax4.set_xscale("log") + wasserstein_dist = wasserstein_distance(energy_xyz, energy_xyz_reco) + kl_divergence = KL(energy_xyz, energy_xyz_reco, bins) + ax4.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax4.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + + # z-start-layer + + max_energy_indices = np.argmax(energy_xyz_complete, axis=1) + max_energy_z_values = z_complete[np.arange(len(z_complete)), max_energy_indices] + + max_energy_indices_reco = np.argmax(energy_xyz_complete_reco, axis=1) + max_energy_z_values_reco = z_complete_reco[ + np.arange(len(z_complete_reco)), max_energy_indices_reco + ] + ax5 = fig.add_subplot(gs[5]) + ax5.set_title("z start layer") + step = math.ceil(z_bin_max / 11) + bins = np.arange(z_bin_min, z_bin_max) + ax5.hist( + max_energy_z_values, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="original", + ) + ax5.hist( + max_energy_z_values_reco, + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + color="red", + label="reconstructed", + ) + wasserstein_dist = wasserstein_distance(max_energy_z_values, max_energy_z_values_reco) + kl_divergence = KL(max_energy_z_values, max_energy_z_values_reco, bins) + ax5.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax5.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + + ax5.set_xlabel("z") + ax5.set_ylabel("a.u.") + ax5.ticklabel_format( + axis="y", style="sci", scilimits=(0, 0), useMathText=True + ) # Set scientific notation for y-axis + + ax5.set_xticks(np.arange(z_bin_min, z_bin_max, step)) + ax5.legend(loc="upper right") + + fig.suptitle("Distributions") + + fig.tight_layout() + + rep = "_overview" + filename = plot_filename.replace(rep, "_visible_energy") + pylogger.info(f"Saving plot to {filename}") + fig.savefig(filename) + + ############################################################ + # Second Histogram --- x,y,z Distribution and 0th Moment + ############################################################ + + fig_0Moment = plt.figure(figsize=(18, 12), facecolor="white") + fig_0Moment.suptitle("0th Moment", fontsize=30) + gs2 = GridSpec(2, 3) + + ax0 = fig_0Moment.add_subplot(gs2[0]) + x0 = get0Momentum(x_complete, energy_xyz_complete) + x0_reco = get0Momentum(x_complete_reco, energy_xyz_complete_reco) + average = sum(x0) / len(x0) + if average < 1: + offset = 0.4 + else: + offset = average * 0.05 + + if average < 0: + bins = np.arange(-average - offset, -average + offset, 0.005) + else: + bins = np.arange(average - offset, average + offset, 0.005) + + ax0.set_title("[X] distribution") + ax0.hist( + x0, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="original", + ) + ax0.hist( + x0_reco, + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + color="red", + label="reconstructed", + ) + data1 = x0 + data2 = x0_reco + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, bins) + ax0.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax0.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax0.set_xlabel("X") + ax0.set_ylabel("a.u.") + ax0.legend(loc="upper right") + + ax1 = fig_0Moment.add_subplot(gs2[1]) + y0 = get0Momentum(y_complete, energy_xyz_complete) + y0_reco = get0Momentum(y_complete_reco, energy_xyz_complete_reco) + average = sum(y0) / len(y0) + if average < 1: + offset = 0.4 + else: + offset = average * 0.05 + + if average < 0: + bins = np.arange(-average - offset, -average + offset, 0.005) + else: + bins = np.arange(average - offset, average + offset, 0.005) + ax1.set_title("[Y] distribution") + ax1.hist( + y0, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="original", + ) + ax1.hist( + y0_reco, + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + color="red", + label="reconstructed", + ) + wasserstein_dist = wasserstein_distance(y0, y0_reco) + kl_divergence = KL(y0, y0_reco, bins) + ax1.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax1.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax1.set_xlabel("Y") + ax1.set_ylabel("a.u.") + ax1.legend(loc="upper right") + + z0 = get0Momentum(z_complete, energy_xyz_complete) + z0_reco = get0Momentum(z_complete_reco, energy_xyz_complete_reco) + average = sum(z0) / len(z0) + if average < 1: + offset = 1.4 + else: + offset = average * 0.45 + + if average < 0: + bins = np.arange(-average - offset, -average + offset, 0.05) + else: + bins = np.arange(average - offset, average + offset, 0.05) + ax2 = fig_0Moment.add_subplot(gs2[2]) + ax2.set_title("[Z] distribution") + ax2.hist( + z0, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="original", + ) + ax2.hist( + z0_reco, + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + color="red", + label="reconstructed", + ) + wasserstein_dist = wasserstein_distance(z0, z0_reco) + kl_divergence = KL(z0, z0_reco, bins) + ax2.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax2.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax2.set_xlabel("Z") + ax2.set_ylabel("a.u.") + ax2.legend(loc="upper right") + + # X Distribution + ax3 = fig_0Moment.add_subplot(gs2[3]) + ax3.set_title("[x] distribution") + ax3.yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True)) + ax3.hist( + x, + bins=np.arange(x_bin_min, x_bin_max), + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="original", + ) + ax3.hist( + x_reco, + bins=np.arange(x_bin_min, x_bin_max), + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + data1 = x + data2 = x_reco + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, np.arange(x_bin_min, x_bin_max)) + ax3.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax3.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax3.set_xlabel("[x]") + ax3.set_ylabel("Number of hits") + ax3.set_xticks(np.arange(x_bin_min, x_bin_max, step)) + ax3.legend(loc="upper right") + + # Y Distribution + ax4 = fig_0Moment.add_subplot(gs2[4]) + ax4.set_title("[y] distribution") + ax4.yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True)) + ax4.hist( + y, + bins=np.arange(y_bin_min, y_bin_max), + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="original", + ) + ax4.hist( + y_reco, + bins=np.arange(y_bin_min, y_bin_max), + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + data1 = y + data2 = y_reco + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, np.arange(y_bin_min, y_bin_max)) + ax4.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax4.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax4.set_xlabel("[y]") + ax4.set_ylabel("Number of hits") + ax4.set_xticks(np.arange(y_bin_min, y_bin_max, step)) + ax4.legend(loc="upper right") + + # Z Distribution + ax5 = fig_0Moment.add_subplot(gs2[5]) + ax5.set_title("[z] distribution") + ax5.yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True)) + ax5.hist( + z, + bins=np.arange(z_bin_min, z_bin_max), + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="original", + ) + ax5.hist( + z_reco, + bins=np.arange(z_bin_min, z_bin_max), + histtype="step", + lw=2, + alpha=1.0, + label="reconstructed", + color="red", + ) + data1 = z + data2 = z_reco + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, np.arange(z_bin_min, z_bin_max)) + ax5.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax5.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax5.set_xlabel("[z]") + ax5.set_ylabel("Number of hits") + ax5.set_xticks(np.arange(z_bin_min, z_bin_max, step)) + ax5.legend(loc="upper right") + + ############################################################ + + fig.suptitle("Distributions") + + fig.tight_layout() + + rep = "_overview" + filename = plot_filename.replace(rep, "_test") + pylogger.info(f"Saving plot to {filename}") + fig.savefig(filename) + + fig_0Moment.suptitle("Distributions") + + fig_0Moment.tight_layout() + rep = "_overview" + filename2 = plot_filename.replace(rep, "_0Moment_test") + pylogger.info(f"Saving plot to {filename2}") + fig_0Moment.savefig(filename2) + + ############################################################ + # Third Histogram ---- Error + ############################################################ + fig_shift = plt.figure(figsize=(24, 12), facecolor="white") + fig_shift.suptitle("shift", fontsize=30) + gs3 = GridSpec(2, 4) + + x_diff = get_diff_construct(x, x_reco) + y_diff = get_diff_construct(y, y_reco) + z_diff = get_diff_construct(z, z_reco) + e_diff = get_diff_construct(energy_xyz, energy_xyz_reco) + + bins = np.arange(-0.05, 0.05, 0.001) + bins_e = np.arange(-1, 1, 0.02) + + ax0 = fig_shift.add_subplot(gs3[0]) + ax0.set_title("$x_{reco} - x_{true}$ Distribution") + ax0.hist( + x_diff, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="error", + ) + ax0.axvline(0, linestyle="--", color="black", label="center") + ax0.set_xlabel("$x_{diff}$") + ax0.set_ylabel("a.u.") + ax0.legend(loc="upper right") + + ax1 = fig_shift.add_subplot(gs3[1]) + ax1.set_title("$y_{reco} - y_{true}$ Distribution") + ax1.hist( + y_diff, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="error", + ) + ax1.axvline(0, linestyle="--", color="black", label="center") + ax1.set_xlabel("$y_{diff}$") + ax1.set_ylabel("a.u.") + ax1.legend(loc="upper right") + + ax2 = fig_shift.add_subplot(gs3[2]) + ax2.set_title("$z_{reco} - z_{true}$ Distribution") + ax2.hist( + z_diff, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="error", + ) + ax2.axvline(0, linestyle="--", color="black", label="center") + ax2.set_xlabel("$z_{diff}$") + ax2.set_ylabel("a.u.") + ax2.legend(loc="upper right") + + ax3 = fig_shift.add_subplot(gs3[3]) + ax3.set_title("$energy_{reco} - energy_{true}$ Distribution") + ax3.hist( + e_diff, + bins=bins_e, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="error", + ) + ax3.axvline(0, linestyle="--", color="black", label="center") + ax3.set_xlabel("$energy_{diff}$") + ax3.set_ylabel("a.u.") + ax3.legend(loc="upper right") + + ############################################################ + # Create the second row of the 3th plot with the larger errors + ############################################################ + + bins = np.arange(-10, 10, 0.25) + bins_e = np.arange(-10, 10, 0.2) + + ax4 = fig_shift.add_subplot(gs3[4]) + ax4.set_title("$x_{reco} - x_{true}$ Distribution") + ax4.hist( + x_diff, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="error", + ) + ax4.axvline(0, linestyle="--", color="black", label="center") + ax4.set_xlabel("$x_{diff}$") + ax4.set_ylabel("a.u.") + ax4.set_yscale("log") + ax4.legend(loc="upper right") + + ax5 = fig_shift.add_subplot(gs3[5]) + ax5.set_title("$y_{reco} - y_{true}$ Distribution") + ax5.hist( + y_diff, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="error", + ) + ax5.axvline(0, linestyle="--", color="black", label="center") + ax5.set_xlabel("$y_{diff}$") + ax5.set_ylabel("a.u.") + ax5.set_yscale("log") + ax5.legend(loc="upper right") + + ax6 = fig_shift.add_subplot(gs3[6]) + ax6.set_title("$z_{reco} - z_{true}$ Distribution") + ax6.hist( + z_diff, + bins=bins, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="error", + ) + ax6.axvline(0, linestyle="--", color="black", label="center") + ax6.set_xlabel("$z_{diff}$") + ax6.set_ylabel("a.u.") + ax6.legend(loc="upper right") + ax6.set_yscale("log") + + ax7 = fig_shift.add_subplot(gs3[7]) + ax7.set_title("$energy_{reco} - energy_{true}$ Distribution") + ax7.hist( + e_diff, + bins=bins_e, + histtype="stepfilled", + lw=2, + alpha=1.0, + color="silver", + label="error", + ) + ax7.set_yscale("log") + ax7.axvline(0, linestyle="--", color="black", label="center") + ax7.set_xlabel("$energy_{diff}$") + ax7.set_ylabel("a.u.") + ax7.legend(loc="upper right") + + fig_shift.suptitle("Distributions of the error") + + fig_shift.tight_layout() + rep = "_overview" + filename3 = plot_filename.replace(rep, "_shift_test") + pylogger.info(f"Saving plot to {filename3}") + fig_shift.savefig(filename3) + + # log the plots + if self.comet_logger is not None: + for fname in [filename, filename2, filename3]: + self.comet_logger.log_image( + fname, name=fname.split("/")[-1], step=trainer.global_step + ) + + # calculate per-feature mean abs error + shape = x_recos.shape + x_recos_reshaped = x_recos.reshape(-1, shape[-1]) + x_originals_reshaped = x_originals.reshape(-1, shape[-1]) + particle_feature_mean_absolute_error = np.mean( + np.abs(x_recos_reshaped - x_originals_reshaped), axis=1 + ) + particle_feature_mean_error = np.mean(x_recos_reshaped - x_originals_reshaped, axis=1) + # calculate codebook utilization + n_codes = pl_module.model.vq_kwargs["num_codes"] + codebook_utilization = len(np.unique(code_idx)) / n_codes + + # log the mean squared error + if self.comet_logger is not None: + self.comet_logger.log_metric( + f"{stage}_codebook_utilization", codebook_utilization, step=trainer.global_step + ) + for i, feature in enumerate(pp_dict.keys()): + self.comet_logger.log_metric( + f"{stage}_mean_abserr_{feature}", + particle_feature_mean_absolute_error[i], + step=trainer.global_step, + ) + self.comet_logger.log_metric( + f"{stage}_mean_err_{feature}", + particle_feature_mean_error[i], + step=trainer.global_step, + ) diff --git a/gabbro/data/__init__.py b/gabbro/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gabbro/data/data_tokenization.py b/gabbro/data/data_tokenization.py new file mode 100644 index 0000000..67dd314 --- /dev/null +++ b/gabbro/data/data_tokenization.py @@ -0,0 +1,239 @@ +import logging +import os +from pathlib import Path + +import awkward as ak +import vector +from omegaconf import OmegaConf + +from gabbro.data.loading import read_shower_file +from gabbro.models.vqvae import VQVAELightning +from gabbro.utils.mapping import merge_duplicates_numpy + +# import gabbro.plotting as jplt + +vector.register_awkward() + +logger = logging.getLogger(__name__) + + +def tokenize_shower_file( + filename_in: str, + model_ckpt_path: str, + filename_out: str = None, + add_start_end_tokens: bool = False, + print_model: bool = False, + energy_sorting: bool = False, + layer_sorting: bool = False, + n_load: int = None, +): + """Tokenize a single file using a trained model. + + Parameters + ---------- + filename : str + Path to the file to be tokenized. + model_ckpt_path : str + Path to the model checkpoint. + filename_out : str, optional + Path to the output file. + add_start_end_tokens : bool, optional + Whether to add start and end tokens to the tokenized sequence. + print_model : bool, optional + Whether to print the model architecture. + n_load : int, optional + Number of events to load from the file. If None, all events are loaded. + + Returns + ------- + tokens_int : ak.Array + Array of tokens. + p4s_original : ak.Array + Momentum4D array of the original particles. + x_ak_original : ak.Array + Array of the original particles. + """ + data_showers = read_shower_file( + filename_in, + n_load=n_load, + ) + if energy_sorting: + print("Sorting by energy") + # Sort the showers by energy + sorted_energy = ak.argsort(data_showers.energy, axis=1, ascending=False) + # Update data_showers with sorted energy + data_showers = data_showers[sorted_energy] + if layer_sorting: + print("Sorting by layer") + # Sort the showers by layer + sorted_layer = ak.argsort(data_showers.z, axis=1, ascending=True) + # Update data_showers with sorted layer + data_showers = data_showers[sorted_layer] + + # --- Model and config loading --- + ckpt_path = Path(model_ckpt_path) + config_path = ckpt_path.parent.parent / "config.yaml" + cfg = OmegaConf.load(config_path) + logger.info(f"Loaded config from {config_path}") + model = VQVAELightning.load_from_checkpoint(ckpt_path) + if print_model: + print(model) + pp_dict = cfg.data.dataset_kwargs_common["feature_dict"] + logger.info("Preprocessing dictionary:") + for key, value in pp_dict.items(): + logger.info(f" | {key}: {value}") + + model = model.to("cuda") + model.eval() + + # -------------------------------- + + p4s_original = ak.zip( + { + "x": data_showers["x"], + "y": data_showers["y"], + "z": data_showers["z"], + "energy": data_showers["energy"], + }, + with_name="Momentum4D", + ) + # ak_data_shower = np_to_akward(data_showers, pp_dict) + + tokens = model.tokenize_shower_ak_array(data_showers, pp_dict) + print("tokens", tokens) + + if add_start_end_tokens: + n_tokens = model.model.vqlayer.num_codes + tokens = ak.concatenate( + [ + ak.zeros_like(tokens[:, :1]), # start token is 0 + tokens + 1, + ak.ones_like(tokens[:, :1]) + n_tokens, # end token is n_tokens + 1 + ], + axis=1, + ) + + tokens_int = ak.values_astype(tokens, int) + + if filename_out is not None: + os.makedirs(os.path.dirname(filename_out), exist_ok=True) + logger.info(f"Saving tokenized file to {filename_out}") + ak.to_parquet(tokens_int, filename_out) + + print("p4s_original z", p4s_original.z) + print("p4s_original energy", p4s_original.energy) + + return tokens_int, p4s_original, data_showers + + +def reconstruct_shower_file( + filename_in: str, + model_ckpt_path: str, + config_path: str, + start_token_included: bool = False, + end_token_included: bool = False, + shift_tokens_by_minus_one: bool = False, + filename_out: str = None, + print_model: bool = False, + device: str = "cuda", + merge_duplicates: bool = True, +): + """Reconstruct a single file using a trained model and the tokenized file. + + Parameters + ---------- + filename_in : str + Path to the file to be tokenized. + model_ckpt_path : strgemini + Path to the model checkpoint. + config_path : str + Path to the config file. + filename_out : str, optional + Path to the output file. + start_token_included : bool, optional + Whether the start token is included in the tokenized sequence. + end_token_included : bool, optional + Whether the end token is included in the tokenized sequence. + shift_tokens_by_minus_one : bool, optional + Whether to shift the tokens by -1. + print_model : bool, optional + Whether to print the model architecture. + device : str, optional + Device to use for the model. + merge_duplicates : bool, optional + Whether to merge the duplicate voxels. + + Returns + ------- + p4s_reco : ak.Array + Momentum4D array of the reconstructed particles. + x_reco_ak : ak.Array + Array of the reconstructed particles. + labels_onehot : np.ndarray + One-hot encoded labels of the shower type. Only returned if return_labels is True. + """ + + # --- Model and config loading --- + ckpt_path = Path(model_ckpt_path) + cfg = OmegaConf.load(config_path) + logger.info(f"Loaded config from {config_path}") + model = VQVAELightning.load_from_checkpoint(ckpt_path) + if print_model: + logger.info(model) + pp_dict = cfg.data.dataset_kwargs_common["feature_dict"] + # logger.info("Preprocessing dictionary:") + # for key, value in pp_dict.items(): + # logger.info(f" | {key}: {value}") + + model = model.to(device) + model.eval() + # -------------------------------- + + tokens = ak.from_parquet(filename_in) + logger.info(f"tokens: {tokens}") + + if end_token_included: + logger.info("Removing end token") + tokens = tokens[:, :-1] + logger.info(f"Tokens with end token removed: {tokens}") + if start_token_included: + logger.info("Removing start token and shifting tokens by -1") + tokens = tokens[:, 1:] + logger.info(f"Tokens with start token removed: {tokens}") + if shift_tokens_by_minus_one: + logger.info("Shifting tokens by -1") + tokens = tokens - 1 + logger.info(f"Tokens shifted by -1: {tokens}") + + logger.info(f"Smallest token in file: {ak.min(tokens)}") + logger.info(f"Largest token in file: {ak.max(tokens)}") + + x_reco_ak = model.reconstruct_shower_ak_tokens(tokens, pp_dict, hide_pbar=False, batch_size=2) + + logger.info(f"x_reco_ak x: {x_reco_ak.x}") + logger.info(f"x_reco_ak energy: {x_reco_ak.energy}") + + logger.info(f"reconstructed file: {x_reco_ak}") + + if merge_duplicates: + x_reco_ak = merge_duplicates_numpy( + x_reco_ak + ) # this maps the duplicate voxels on the true grid and sums the energy + logger.info(f"reconstructed file after merge: {x_reco_ak}") + + p4s_reco = ak.zip( + { + "x": x_reco_ak.x if "x" in x_reco_ak.fields else x_reco_ak.x, + "y": x_reco_ak.y if "y" in x_reco_ak.fields else x_reco_ak.y, + "z": x_reco_ak.z if "z" in x_reco_ak.fields else x_reco_ak.z, + "energy": x_reco_ak.energy if "energy" in x_reco_ak.fields else x_reco_ak.energy, + }, + with_name="Momentum4D", + ) + logger.info(f"p4s_reco energy: {p4s_reco.energy}") + if filename_out is not None: + os.makedirs(os.path.dirname(filename_out), exist_ok=True) + logger.info(f"Saving tokenized file to {filename_out}") + ak.to_parquet(p4s_reco, filename_out) + + return p4s_reco, x_reco_ak diff --git a/gabbro/data/iterable_dataset_shower.py b/gabbro/data/iterable_dataset_shower.py new file mode 100644 index 0000000..038825b --- /dev/null +++ b/gabbro/data/iterable_dataset_shower.py @@ -0,0 +1,537 @@ +import gc +import glob +import random +from typing import Optional + +import awkward as ak +import lightning as L +import numpy as np +import torch +import torch.distributed as dist +import vector +from torch.distributed import get_rank, get_world_size +from torch.utils.data import DataLoader, IterableDataset, get_worker_info + +from gabbro.data.loading import read_shower_file, read_tokenized_shower_file +from gabbro.utils.arrays import ak_pad, ak_padding, ak_preprocess, ak_to_np_stack +from gabbro.utils.pylogger import get_pylogger + +vector.register_awkward() + + +class CustomIterableDataset(IterableDataset): + """Custom IterableDataset that loads data from multiple files.""" + + def __init__( + self, + files_dict: dict, + n_files_at_once: int = None, + n_shower_per_file: int = None, + max_n_files_per_type: int = None, + shuffle_files: bool = True, + shuffle_data: bool = True, + seed: int = 4697, + seed_shuffle_data: int = 3838, + pad_length: int = 1700, + logger_name: str = "CustomIterableDataset", + feature_dict: dict = None, + labels_to_load: list = None, + token_reco_cfg: dict = None, + token_id_cfg: dict = None, + load_only_once: bool = False, + shuffle_only_once: bool = False, + random_seed_for_per_file_shuffling: int = 4350, + h5file: bool = False, + energy_threshold: float = 0, + energy_sorting: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + files_dict : dict + Dict with the file names for each type. Can be e.g. a dict like + {"tbqq": ["tbqq_0.root", ...], "qcd": ["qcd_0.root", ...], ...}. + n_files_at_once : int, optional + Number of files to load at once. If None, one file per files_dict key + is loaded. + n_shower_per_file : int, optional + Number of showers loaded from each individual file. Defaults to None, which + means that all showers are loaded. + max_n_files_per_type : int, optional + Maximum number of files to use per type. If None, all files are used. + Can be used to use e.g. always the first file from the sorted list of files + in validation. + shuffle_files : bool, optional + Whether to shuffle the list of files. + shuffle_data : bool, optional + Whether to shuffle the data after loading. + seed : int, optional + Random seed. + seed_shuffle_data : int, optional + Random seed for shuffling the data. This is useful if you want to shuffle + the data in the same way for different datasets (e.g. train and val). + The default value is 3838. + pad_length : int, optional + Maximum number of particles per shower. If a shower has more particles, the + first pad_length particles are used, the rest is discarded. + logger_name : str, optional + Name of the logger. + feature_dict : dict, optional + Dictionary with the features to load. The keys are the names of the features + and the values are the preprocessing parameters passed to the + `ak_select_and_preprocess` function. + labels_to_load : list, optional + List with the shower_type labels to load. + token_reco_cfg : dict, optional + Dictionary with the configuration to reconstruct the tokenized showerclass files. + If None, this is not used. + token_id_cfg : dict, optional + Dictionary with the tokenization configuration, this is to be used when the + token-id data is to be loaded. If None, this is ignored. + load_only_once : bool, optional + If True, the data is loaded only once and then returned in the same order + in each iteration. NOTE: this is only useful if the whole dataset fits into + memory. If the dataset is too large, this will lead to a memory error. + shuffle_only_once : bool, optional + If True, the data is shuffled only once and then returned in the same order + in each iteration. NOTE: this should only be used for val/test. + random_seed_for_per_file_shuffling : int, optional + Random seed for shuffling the showers within a file. This is useful if you want + to only load a subset of the showers from a file and want to choose different + showers in different training runs. + If load_only_once is False, this is ignored. + h5file : bool, optional + If True, the data is loaded from an h5 file. If False, the data is loaded from a root file. + energy_threshhold : float, optional + Is the minimum energy for which shower events are considered as non-zero. + energy_sorting : bool, optional + If True, the showers are sorted by energy in descending order. + **kwargs + Additional keyword arguments. + + """ + if feature_dict is None: + raise ValueError("feature_dict must be provided.") + if labels_to_load is None: + raise ValueError("labels_to_load must be provided.") + + worker_info = get_worker_info() + rank = get_rank() if dist.is_initialized() else 0 + world_size = get_world_size() if dist.is_initialized() else 1 + + self.multi_gpu_info = { + "num_gpus": torch.cuda.device_count(), + "process_rank": rank, + "world_size": world_size, + "device": f"cuda:{rank}" if torch.cuda.is_available() else "cpu", + "worker_id": worker_info.id if worker_info is not None else 0, + "num_workers": worker_info.num_workers if worker_info is not None else 1, + } + + self.logger_name = logger_name + self.setup_logger(rank=None) + + self.logger.info(f"{[f'{key}={value}' for key, value in self.multi_gpu_info.items()]}") + + self.logger.info(f"Using seed {seed}") + self.pad_length = pad_length + self.shuffle_data = shuffle_data + self.shuffle_files = shuffle_files + self.processed_files_counter = 0 + self.max_n_files_per_type = max_n_files_per_type + self.n_shower_per_file = n_shower_per_file + self.feature_dict = feature_dict + self.labels_to_load = labels_to_load + self.particle_features_list = [feat for feat in self.feature_dict.keys()] + self.seed_shuffle_data = seed_shuffle_data + self.load_only_once = load_only_once + self.shuffle_only_once = shuffle_only_once + self.data_shuffled = False + self.random_seed_for_per_file_shuffling = random_seed_for_per_file_shuffling + self.h5file = h5file + self.energy_threshold = energy_threshold + self.energy_sorting = energy_sorting + + if self.random_seed_for_per_file_shuffling is not None: + if not self.load_only_once: + self.logger.warning( + "random_seed_for_per_file_shuffling is only used if load_only_once is True." + ) + self.random_seed_for_per_file_shuffling = None + else: + self.logger.info( + f"Using random seed {self.random_seed_for_per_file_shuffling} for per-file shuffling." + ) + + self.logger.info(f"Using the following labels: {self.labels_to_load}") + self.logger.info(f"Using the following particle features: {self.particle_features_list}") + self.logger.info(f"pad_length {self.pad_length} for the number of hits per shower.") + self.logger.info(f"energy_threshold {self.energy_threshold}") + self.logger.info(f"shuffle_data={self.shuffle_data}") + self.logger.info(f"shuffle_files={self.shuffle_files}") + self.logger.info( + "Number of showers loaded per file: " + f"{self.n_shower_per_file if self.n_shower_per_file is not None else 'all'}" + ) + self.logger.info("Using the following features:") + for feat, params in self.feature_dict.items(): + self.logger.info(f"- {feat}: {params}") + self.files_dict = {} + for shower_type, files in files_dict.items(): + expanded_files = [] + for file in files: + expanded_files.extend(sorted(list(glob.glob(file)))) + self.files_dict[shower_type] = ( + expanded_files + if max_n_files_per_type is None + else expanded_files[:max_n_files_per_type] + ) + + self.logger.info(f"Files for shower_type {shower_type}:") + for file in self.files_dict[shower_type]: + self.logger.info(f" - {file}") + + if self.load_only_once: + self.logger.warning( + "load_only_once is True. This means that there will only be the initial data loading." + ) + + # add all files from the dict to a list (the values are lists of files) + self.file_list = [] + for files in self.files_dict.values(): + self.file_list.extend(files) + + # if not specified how many files to use at once, use one file per shower_type + if n_files_at_once is None: + self.n_files_at_once = len(self.files_dict) + else: + if n_files_at_once > len(self.file_list): + self.logger.warning( + f"n_files_at_once={n_files_at_once} is larger than the number of files in the" + f" dataset ({len(self.file_list)})." + ) + self.logger.warning(f"Setting n_files_at_once to {len(self.file_list)}.") + self.n_files_at_once = len(self.file_list) + else: + self.n_files_at_once = n_files_at_once + + self.logger.info(f"Will load {self.n_files_at_once} files at a time and combine them.") + + self.file_indices = np.array([0, self.n_files_at_once]) + self.file_iterations = len(self.file_list) // self.n_files_at_once + if self.load_only_once: + self.file_iterations = 1 + + self.current_part_data = None + self.current_part_mask = None + self.token_reco_cfg = token_reco_cfg + self.token_id_cfg = token_id_cfg + + def setup_logger(self, rank: int = None) -> None: + self.logger = get_pylogger(f"{__name__}-{self.logger_name}", rank=rank) + self.logger.info("Logger set up (potentially with new rank information).") + + def get_data(self): + """Returns a generator (i.e. iterator) that goes over the current files list and returns + batches of the corresponding data.""" + # Iterate over shower_type + self.logger.debug("\n>>> __iter__ called\n") + self.file_indices = np.array([0, self.n_files_at_once]) + + # shuffle the file list + if self.shuffle_files: + self.logger.info(">>> Shuffling files") + random.shuffle(self.file_list) + # self.logger.info(">>> self.file_list:") + # for filename in self.file_list: + # self.logger.info(f" - {filename}") + + # Iterate over files + for j in range(self.file_iterations): + self.logger.debug(20 * "-") + # Increment file index if not first iteration + if j > 0: + self.logger.info(">>> Incrementing file index") + self.file_indices += self.n_files_at_once + + # stop the iteration if self.file_indices[1] is larger than the number of files + # FIXME: this means that the last batch of files (in case the number of files is not + # divisible by self.n_files_at_once) is not used --> fix this + # but if shuffling is used, this should not be a problem + if self.file_indices[1] <= len(self.file_list): + self.load_next_files() + + # loop over the current data + for i in range(self.start_idx_this_gpu, self.end_idx_this_gpu): + yield { + "part_features": self.current_part_data[i], + "part_mask": self.current_part_mask[i], + } + + def __iter__(self): + """Returns an iterable which represents an iterator that iterates over the dataset.""" + # get current global rank to make sure the logger is set up correctly and displays + # the rank in the logs + self.multi_gpu_info["process_rank"] = get_rank() if dist.is_initialized() else 0 + self.setup_logger(rank=self.multi_gpu_info["process_rank"]) + self.logger.info(">>> __iter__(self.get_data()) called") + return iter(self.get_data()) + + def set_indices_for_this_rank(self): + """Set the start and end indices to load for this rank.""" + # set the indices to load for each gpu + if self.multi_gpu_info["world_size"] > 1: + # split the self.current_part_data over the gpus + n_shower = len(self.current_part_data) + n_shower_per_gpu = n_shower // self.multi_gpu_info["world_size"] + self.start_idx_this_gpu = n_shower_per_gpu * self.multi_gpu_info["process_rank"] + self.end_idx_this_gpu = n_shower_per_gpu * (self.multi_gpu_info["process_rank"] + 1) + else: + self.start_idx_this_gpu = 0 + self.end_idx_this_gpu = len(self.current_part_data) + + self.logger.info( + f"Rank {self.multi_gpu_info['process_rank']} will load data from index " + f"{self.start_idx_this_gpu} to {self.end_idx_this_gpu}" + ) + + def load_next_files(self): + if self.load_only_once: + if self.current_part_data is not None: + self.logger.warning("Data has already been loaded. Will not load again.") + self.shuffle_current_data() + return + if self.processed_files_counter > 0: + self.logger.info( + f"self.processed_files_counter={self.processed_files_counter} is larger than 0 " + f"and smaller than the total number of files in the dataset ({len(self.file_list)})." + " This means that the files list was not fully traversed in the previous " + "iteration. Will continue with the current files list." + ) + self.part_data_list = [] + self.mask_data_list = [] + self.shower_type_labels_list = [] + + self.current_files = self.file_list[self.file_indices[0] : self.file_indices[1]] + self.logger.info(f">>> Loading next files - self.file_indices={self.file_indices}") + if self.load_only_once: + self.logger.warning( + "Loading data only once. Will not load again.\n" + "--> This will be the data for all iterations." + ) + for i_file, filename in enumerate(self.current_files): + self.logger.info(f"{i_file+1} / {len(self.current_files)} : {filename}") + self.logger.info(f"Loading data from file: {filename}") + self.logger.info(f"Is the Data loaded from h5 file?: {self.h5file}") + + # This Part will be used if you want to load an h5 file: + if self.h5file: + # this part will be used if you want to load a parquet file of tokens + if self.token_id_cfg is not None: + self.logger.info("Loading tokenized shower file") + tokens = read_tokenized_shower_file( + filename, + particle_features=["part_token_id"], + remove_start_token=self.token_id_cfg.get("remove_start_token", False), + remove_end_token=self.token_id_cfg.get("remove_end_token", False), + shift_tokens_minus_one=self.token_id_cfg.get( + "shift_tokens_minus_one", False + ), + n_load=self.n_shower_per_file, + random_seed=self.random_seed_for_per_file_shuffling, + ) + self.logger.info(f"the tokens are: {tokens}") + ak_x_particles = ak.Array( + { + "part_token_id": tokens["part_token_id"], + "part_token_id_without_last": tokens["part_token_id"][:, :-1], + "part_token_id_without_first": tokens["part_token_id"][:, 1:], + } + ) + self.logger.info(f"the ak_x_particles is: {ak_x_particles}") + ak_x_particles = ak_preprocess(ak_x_particles, self.feature_dict) + self.logger.info("the Data was successfully preprocessed") + ak_x_particles_padded, ak_mask_particles = ak_pad( + ak_x_particles, self.pad_length, return_mask=True + ) + self.logger.info("the Data was successfully padded") + np_x_particles_padded = ak_to_np_stack( + ak_x_particles_padded, names=self.particle_features_list + ) + self.logger.info("the Data was successfully stacked to numpy") + # mask to numpy + np_mask_particles = ak.to_numpy(ak_mask_particles) + # add the data to the lists + self.part_data_list.append(torch.tensor(np_x_particles_padded)) + self.mask_data_list.append(torch.tensor(np_mask_particles, dtype=torch.bool)) + + else: + # this part is for loading an h5 file in the shower_format + # load data (only the amount of showers defined if n_shower_per_file is set) + data_showers = read_shower_file( + filename, + n_load=self.n_shower_per_file, + ) + if self.energy_sorting: + # Sort the showers by energy + sorted_energy = ak.argsort(data_showers.energy, axis=1, ascending=False) + # Update data_showers with sorted energy + data_showers = data_showers[sorted_energy] + + # Applying the Padding and getting the mask + ak_x_particles_padded, ak_mask_particles = ak_padding( + data_showers, self.pad_length, self.energy_threshold + ) + # Apply the mask to the energy field using element-wise multiplication + ak_x_particles_padded["energy"] = ak.where( + ~ak_mask_particles, + 1, + ak_x_particles_padded["energy"], + ) + # shape = ak.to_numpy(ak_x_particles_padded["energy"]).shape + # shape_mask = ak.to_numpy(ak_mask_particles).shape + # self.logger.info( + # f"Shape {shape} of the padded file and mask {shape_mask} with the pad_length of {self.pad_length}" + # ) + + # preprocessing the data + ak_x_particles = ak_preprocess(ak_x_particles_padded, self.feature_dict) + self.logger.info("ak_preprocess ran successfully") + + # Define a constant for the large negative value + + self.logger.info("ak.where ran successfully") + + np_x_particles_padded = ak_to_np_stack( + ak_x_particles, names=self.particle_features_list + ) + # mask to numpy + np_mask_particles = ak.to_numpy(ak_mask_particles) + # add the data to the lists + self.part_data_list.append(torch.tensor(np_x_particles_padded)) + self.mask_data_list.append(torch.tensor(np_mask_particles, dtype=torch.bool)) + gc.collect() + # self.shower_type_labels_list.append(torch.tensor(np_shower_type_labels)) + + # mask = torch.any(torch.tensor(data_showers) != 0, dim=-1) + # self.logger.info("this is the mask shape:", mask.shape) + # Add the data to the lists (adjust based on data structure) + # self.part_data_list.append(torch.tensor(data_showers)) + # self.mask_data_list.append(torch.tensor(mask)) + + # This part would be used if you want to load a root file: + else: + pass + + # concatenate the data from all files + self.current_part_data = torch.cat(self.part_data_list, dim=0) + self.current_part_mask = torch.cat(self.mask_data_list, dim=0) + if not self.h5file: + self.current_shower_type_labels_one_hot = torch.cat( + self.shower_type_labels_list, dim=0 + ) + + self.shuffle_current_data() + + self.logger.info( + f">>> Data loaded. (self.current_part_data.shape = {self.current_part_data.shape})" + ) + self.set_indices_for_this_rank() + + self.processed_files_counter += self.n_files_at_once + + self.logger.info( + "Updating self.processed_files_counter. The new value is " + f"self.processed_files_counter = {self.processed_files_counter}." + ) + self.logger.info( + "Checking if all files in the current files list have been processed. " + "If so, the file list will be shuffled (unless `shuffle_files=False`)" + "such that the next iteration will proceed with a new file list." + ) + + def shuffle_current_data(self): + # shuffle the data + if self.shuffle_only_once and self.data_shuffled: + self.logger.info("Data has already been shuffled. Will not shuffle again.") + return + if self.shuffle_data: + if self.seed_shuffle_data is not None: + self.logger.info(f"Shuffling data with seed {self.seed_shuffle_data}") + rng = np.random.default_rng(self.seed_shuffle_data) + else: + self.logger.info("Shuffling data without seed") + rng = np.random.default_rng() + perm = rng.permutation(len(self.current_part_data)) + self.current_part_data = self.current_part_data[perm] + self.current_part_mask = self.current_part_mask[perm] + if not self.h5file: + self.current_shower_type_labels_one_hot = self.current_shower_type_labels_one_hot[ + perm + ] + self.data_shuffled = True + self.logger.info("Data shuffled.") + + +class IterableCaloDatamodule(L.LightningDataModule): + def __init__( + self, + dataset_kwargs_train: dict, + dataset_kwargs_val: dict, + dataset_kwargs_test: dict, + dataset_kwargs_common: dict, + batch_size: int = 256, + **kwargs, + ): + super().__init__() + + # save the parameters as attributes + self.save_hyperparameters() + + def prepare_data(self) -> None: + """Prepare the data.""" + pass + + def setup(self, stage: Optional[str] = None) -> None: + if stage == "fit": + self.train_dataset = CustomIterableDataset( + **self.hparams.dataset_kwargs_train, + **self.hparams.dataset_kwargs_common, + ) + self.val_dataset = CustomIterableDataset( + **self.hparams.dataset_kwargs_val, + **self.hparams.dataset_kwargs_common, + ) + elif stage == "test": + self.test_dataset = CustomIterableDataset( + **self.hparams.dataset_kwargs_test, + **self.hparams.dataset_kwargs_common, + ) + + def train_dataloader(self): + # Use a DistributedSampler for multi-GPU training + return DataLoader( + self.train_dataset, + batch_size=self.hparams.batch_size, + # pin_memory=True, # Pre-transfer data to pinned memory + # num_workers=2, + ) + + def val_dataloader(self): + # Optionally use a sampler for validation + return DataLoader( + self.val_dataset, + batch_size=self.hparams.batch_size, + # pin_memory=True, + # num_workers=2, + ) + + def test_dataloader(self): + # Optionally use a sampler for testing + return DataLoader( + self.test_dataset, + batch_size=self.hparams.batch_size, + # pin_memory=True, + # num_workers=2, + ) diff --git a/gabbro/data/loading.py b/gabbro/data/loading.py new file mode 100644 index 0000000..1e63ce2 --- /dev/null +++ b/gabbro/data/loading.py @@ -0,0 +1,121 @@ +import logging + +import awkward as ak +import h5py +import numpy as np +import vector + +logger = logging.getLogger(__name__) + +vector.register_awkward() + + +def read_tokenized_shower_file( + filepath, + particle_features=["part_token_id"], + remove_start_token=False, + remove_end_token=False, + shift_tokens_minus_one=False, + n_load=None, + random_seed=None, +): + """Reads a file containing a list of file paths. + + Parameters: + ---------- + filepath : str + Path to the file. + particle_features : List[str], optional + A list of particle-level features to be loaded. Should only contain "part_token_id". + labels : List[str], optional + A list of truth labels to be loaded. + remove_start_token : bool, optional + Whether to remove the start token from the tokenized sequence. + remove_end_token : bool, optional + Whether to remove the end token from the tokenized sequence. + shift_tokens_minus_one : bool, optional + Whether to shift the token values by -1. + n_load : int, optional + Number of events to load. If None, all events are loaded. + random_seed : int, optional + Random seed for shuffling the data. If None, no shuffling is performed. + + + Returns: + ------- + tokens : List[str] + A list of file paths. + """ + + ak_tokens = ak.from_parquet(filepath) + + if random_seed is not None: + print(f"shuffling with random seed {random_seed}") + rng = np.random.default_rng(random_seed) + permutation = rng.permutation(len(ak_tokens)) + print("ak_tokens", ak_tokens) + ak_tokens = ak_tokens[permutation] + print("ak_tokens after permutation", ak_tokens) + + if n_load is not None: + print(f"will only load tokens of {n_load} events") + ak_tokens = ak_tokens[:n_load] + + # one-hot encode the shower type + + if remove_start_token: + ak_tokens = ak_tokens[:, 1:] + if remove_end_token: + ak_tokens = ak_tokens[:, :-1] + if shift_tokens_minus_one: + ak_tokens = ak_tokens - 1 + + x_ak = ak.Array({"part_token_id": ak_tokens}) + + return x_ak + + +def read_shower_file(filepath, n_load=None, chunk_size=1000): + """Loads a single file from the showerClass dataset. + + Parameters: + ---------- + filepath : str + Path to the h5 data file. + n_load : int, optional + Number of showers to load. If None, load all data. + chunk_size : int, optional + Size of chunks to load at a time. Default is 1000. + + Returns: + ------- + shower : ak.Array + An awkward array of the shower features. + """ + with h5py.File(filepath, "r") as h5file: + dataset_showers = h5file["showers"] + total_showers = dataset_showers.shape[0] + + if n_load is None: + n_load = total_showers + + data_dict = { + "x": [], + "y": [], + "z": [], + "energy": [], + } + + for start in range(0, n_load, chunk_size): + end = min(start + chunk_size, n_load) + table = dataset_showers[start:end] + + data_dict["x"].append(table[:, :, 0]) + data_dict["y"].append(table[:, :, 1]) + data_dict["z"].append(table[:, :, 2]) + data_dict["energy"].append(table[:, :, 3]) + + # Concatenate chunks + data_dict = {key: ak.concatenate(value) for key, value in data_dict.items()} + ak_array = ak.Array(data_dict) + return ak_array diff --git a/gabbro/metrics/jet_substructure.py b/gabbro/metrics/jet_substructure.py new file mode 100644 index 0000000..bb1eda6 --- /dev/null +++ b/gabbro/metrics/jet_substructure.py @@ -0,0 +1,279 @@ +"""Module with functions related to calculating jet substructure.""" + +import logging +import os + +import awkward as ak +import fastjet +import h5py +import numpy as np +import vector + +vector.register_awkward() + +pylogger = logging.getLogger("jet_substructure") +logging.basicConfig(level=logging.INFO) + + +def calc_deltaR(particles, jet): + jet = ak.unflatten(ak.flatten(jet), counts=1) + return particles.deltaR(jet) + + +class JetSubstructure: + """Class to calculate and store the jet substructure variables. + + Definitions as in slide 7 here: + https://indico.cern.ch/event/760557/contributions/3262382/attachments/1796645/2929179/lltalk.pdf + """ + + def __init__( + self, + particles, + R=0.8, + beta=1.0, + use_wta_pt_scheme=False, + ): + """Run the jet clustering and calculate the substructure variables. The clustering is + performed with the kt algorithm and the WTA pt scheme. + + Parameters + ---------- + particles : awkward array + The particles that are clustered into jets. Have to be vector Momentum4D objects + R : float, optional + The jet radius, by default 0.8 + beta : float, optional + The beta parameter for N-subjettiness, by default 1.0 + use_wta_pt_scheme : bool, optional + Whether to use the WTA pt scheme for the clustering, by default False + """ + + print(f"Calculating substructure for {len(particles)} jets") + mask_too_few_particles = ak.num(particles) < 3 + n_jets_with_nparticles_too_small = ak.sum(mask_too_few_particles) + if n_jets_with_nparticles_too_small > 0: + print(f"There are {n_jets_with_nparticles_too_small} jets with less than 3 particles.") + raise ValueError("Jets with too few particles are not allowed.") + + self.R = R + self.beta = beta + self.particles = particles + self.particles_sum = ak.sum(particles, axis=1) + self.jet_mass = self.particles_sum.mass + self.jet_pt = self.particles_sum.pt + self.jet_eta = self.particles_sum.eta + self.jet_phi = self.particles_sum.phi + self.jet_n_constituents = ak.num(particles) + + if use_wta_pt_scheme: + jetdef = fastjet.JetDefinition(fastjet.kt_algorithm, self.R, fastjet.WTA_pt_scheme) + else: + jetdef = fastjet.JetDefinition(fastjet.kt_algorithm, self.R) + print("Clustering jets with fastjet") + print("Jet definition:", jetdef) + self.cluster = fastjet.ClusterSequence(particles, jetdef) + self.inclusive_jets = self.cluster.inclusive_jets() + self.exclusive_jets_1 = self.cluster.exclusive_jets(n_jets=1) + self.exclusive_jets_2 = self.cluster.exclusive_jets(n_jets=2) + self.exclusive_jets_3 = self.cluster.exclusive_jets(n_jets=3) + + print("Calculating N-subjettiness") + self._calc_d0() + self._calc_tau1() + self._calc_tau2() + self._calc_tau3() + self.tau21 = self.tau2 / self.tau1 + self.tau32 = self.tau3 / self.tau2 + print("Calculating D2") + # D2 as defined in https://arxiv.org/pdf/1409.6298.pdf + self.d2 = self.cluster.exclusive_jets_energy_correlator(njets=1, func="d2") + + def _calc_d0(self): + """Calculate the d0 values.""" + self.d0 = ak.sum(self.particles.pt * self.R**self.beta, axis=1) + + def _calc_tau1(self): + """Calculate the tau1 values.""" + self.delta_r_1i = calc_deltaR(self.particles, self.exclusive_jets_1[:, :1]) + self.pt_i = self.particles.pt + # calculate the tau1 values + self.tau1 = ak.sum(self.pt_i * self.delta_r_1i**self.beta, axis=1) / self.d0 + + def _calc_tau2(self): + """Calculate the tau2 values.""" + delta_r_1i = calc_deltaR(self.particles, self.exclusive_jets_2[:, :1]) + delta_r_2i = calc_deltaR(self.particles, self.exclusive_jets_2[:, 1:2]) + self.pt_i = self.particles.pt + # add new axis to make it broadcastable + min_delta_r = ak.min( + ak.concatenate( + [ + delta_r_1i[..., np.newaxis] ** self.beta, + delta_r_2i[..., np.newaxis] ** self.beta, + ], + axis=-1, + ), + axis=-1, + ) + self.tau2 = ak.sum(self.pt_i * min_delta_r, axis=1) / self.d0 + + def _calc_tau3(self): + """Calculate the tau3 values.""" + delta_r_1i = calc_deltaR(self.particles, self.exclusive_jets_3[:, :1]) + delta_r_2i = calc_deltaR(self.particles, self.exclusive_jets_3[:, 1:2]) + delta_r_3i = calc_deltaR(self.particles, self.exclusive_jets_3[:, 2:3]) + self.pt_i = self.particles.pt + min_delta_r = ak.min( + ak.concatenate( + [ + delta_r_1i[..., np.newaxis] ** self.beta, + delta_r_2i[..., np.newaxis] ** self.beta, + delta_r_3i[..., np.newaxis] ** self.beta, + ], + axis=-1, + ), + axis=-1, + ) + self.tau3 = ak.sum(self.pt_i * min_delta_r, axis=1) / self.d0 + + def get_substructure_as_ak_array(self): + """Return the substructure variables as a dictionary.""" + return ak.Array( + { + "tau1": self.tau1, + "tau2": self.tau2, + "tau3": self.tau3, + "tau21": self.tau21, + "tau32": self.tau32, + "d2": self.d2, + "jet_mass": self.jet_mass, + "jet_pt": self.jet_pt, + "jet_eta": self.jet_eta, + "jet_phi": self.jet_phi, + "jet_n_constituents": self.jet_n_constituents, + } + ) + + +def calc_substructure( + particles_sim, + particles_gen, + R=0.8, + filename=None, +): + """Calculate the substructure variables for the given particles and save them to a file. + + Parameters + ---------- + particles_sim : awkward array + The particles of the simulated jets. + particles_gen : awkward array + The particles of the generated jets. + R : float, optional + The jet radius, by default 0.8 + filename : str, optional + The filename to save the results to, by default None (don't save) + """ + if filename is None: + print("No filename given, won't save the results.") + else: + if os.path.exists(filename): + print(f"File {filename} already exists, won't overwrite.") + return + print(f"Saving results to {filename}") + + substructure_sim = JetSubstructure(particles_sim, R=R) + substructure_gen = JetSubstructure(particles_gen, R=R) + names = [ + "tau1", + "tau2", + "tau3", + "tau21", + "tau32", + "d2", + "jet_mass", + "jet_pt", + ] + with h5py.File(filename, "w") as f: + for name in names: + f[f"{name}_sim"] = substructure_sim.__dict__[name] + f[f"{name}_gen"] = substructure_gen.__dict__[name] + + +def load_substructure_data( + h5_file_path, + keys=["tau1", "tau2", "tau3", "tau21", "tau32", "d2", "jet_mass", "jet_pt"], +): + """Load the substructure data from the h5 file. + + Args: + h5_file_path (str): Path to the h5 file + keys (list, optional): List of keys to load from the h5 file. Defaults to ["tau1", "tau2", "tau3", "tau21", "tau32", "d2", "jet_mass", "jet_pt"]. + + Returns: + data_gen: Array of shape (n_features, n_jets) with the substructure data for the generated jets + data_jetclass: Array of shape (n_features, n_jets) with the substructure data for the JetClass jets + """ + + # load substructure for model generated data + data_substructure = [] + data_substructure_jetclass = [] + with h5py.File(h5_file_path) as f: + tau21 = np.array(f["tau21_gen"]) + tau32 = np.array(f["tau32_gen"]) + d2 = np.array(f["d2_gen"]) + # jet_mass = np.array(f["jet_mass_gen"]) + # jet_pt = np.array(f["jet_pt_gen"]) + tau21_isnan = np.isnan(tau21) + tau32_isnan = np.isnan(tau32) + d2_isnan = np.isnan(d2) + if np.sum(tau21_isnan) > 0 or np.sum(tau32_isnan) > 0 or np.sum(d2_isnan) > 0: + pylogger.warning(f"Found {np.sum(tau21_isnan)} nan values in tau21") + pylogger.warning(f"Found {np.sum(tau32_isnan)} nan values in tau32") + pylogger.warning(f"Found {np.sum(d2_isnan)} nan values in d2") + pylogger.warning("Setting nan values to zero.") + tau21[tau21_isnan] = 0 + tau32[tau32_isnan] = 0 + d2[d2_isnan] = 0 + for key in keys: + data_substructure.append(np.array(f[key + "_gen"])) + # set nan values in tau21, tau32 and d2 to zero + data_substructure[keys.index("tau21")][tau21_isnan] = 0 + data_substructure[keys.index("tau32")][tau32_isnan] = 0 + data_substructure[keys.index("d2")][d2_isnan] = 0 + + # load substructure for JetClass data + tau21_jetclass = np.array(f["tau21_sim"]) + tau32_jetclass = np.array(f["tau32_sim"]) + d2_jetclass = np.array(f["d2_sim"]) + # jet_mass_jetclass = np.array(f["jet_mass_sim"]) + # jet_pt_jetclass = np.array(f["jet_pt_sim"]) + tau21_jetclass_isnan = np.isnan(tau21_jetclass) + tau32_jetclass_isnan = np.isnan(tau32_jetclass) + d2_jetclass_isnan = np.isnan(d2_jetclass) + if ( + np.sum(tau21_jetclass_isnan) > 0 + or np.sum(tau32_jetclass_isnan) > 0 + or np.sum(d2_jetclass_isnan) > 0 + ): + pylogger.warning(f"Found {np.sum(tau21_jetclass_isnan)} nan values in tau21") + pylogger.warning(f"Found {np.sum(tau32_jetclass_isnan)} nan values in tau32") + pylogger.warning(f"Found {np.sum(d2_jetclass_isnan)} nan values in d2") + pylogger.warning("Setting nan values to zero.") + tau21_jetclass[tau21_jetclass_isnan] = 0 + tau32_jetclass[tau32_jetclass_isnan] = 0 + d2_jetclass[d2_jetclass_isnan] = 0 + for key in keys: + data_substructure_jetclass.append(np.array(f[key + "_sim"])) + # set nan values in tau21, tau32 and d2 to zero + data_substructure_jetclass[keys.index("tau21")][tau21_jetclass_isnan] = 0 + data_substructure_jetclass[keys.index("tau32")][tau32_jetclass_isnan] = 0 + data_substructure_jetclass[keys.index("d2")][d2_jetclass_isnan] = 0 + + data_substructure = np.array(data_substructure) + data_substructure_jetclass = np.array(data_substructure_jetclass) + return ( + data_substructure, + data_substructure_jetclass, + ) diff --git a/gabbro/metrics/utils.py b/gabbro/metrics/utils.py new file mode 100644 index 0000000..91fb98d --- /dev/null +++ b/gabbro/metrics/utils.py @@ -0,0 +1,115 @@ +"""Utility functions for metrics.""" + +import numpy as np +import scipy + + +def quantiled_kl_divergence( + sample_ref: np.ndarray, + sample_approx: np.ndarray, + n_bins: int = 30, + return_bin_edges=False, +): + """Calculate the KL divergence using quantiles on sample_ref to define the bounds. + + Parameters + ---------- + sample_ref : np.ndarray + The first sample to compare (this is the reference, so in the context of + jet generation, those are the real jets). + sample_approx : np.ndarray + The second sample to compare (this is the model/approximation, so in the + context of jet generation, those are the generated jets). + n_bins : int + The number of bins to use for the histogram. Those bins are defined by + equiprobably quantiles of sample_ref. + return_bin_edges : bool, optional + If True, return the bins used to calculate the KL divergence. + """ + bin_edges = np.quantile(sample_ref, np.linspace(0.0, 1.0, n_bins + 1)) + bin_edges[0] = float("-inf") + bin_edges[-1] = float("inf") + pk = np.histogram(sample_ref, bin_edges)[0] / len(sample_ref) + qk = np.histogram(sample_approx, bin_edges)[0] / len(sample_approx) + qk += 1e-10 # Add small constant to avoid 0 count + kl = scipy.stats.entropy(pk, qk) + if return_bin_edges: + return kl, bin_edges + return kl + + +def calc_quantiled_kl_divergence_for_dict( + dict_reference: dict, + dict_approx: dict, + names: list, + n_bins: int = 30, +): + """Calculate the quantiled KL divergence for two dictionaries of samples. + + Parameters + ---------- + dict_reference : dict + The first dictionary of samples. + dict_approx : dict + The second dictionary of samples. + names : list + The names of the samples to compare. All names must be included in both dicts. + """ + # loop over the names and calculate the quantiled kld for each name + klds = {} + for name in names: + klds[name] = quantiled_kl_divergence( + sample_ref=np.array(dict_reference[name]), + sample_approx=np.array(dict_approx[name]), + n_bins=n_bins, + ) + return klds + + +def calc_accuracy(preds, labels, verbose=False): + """Calculates accuracy and AUC. + + Parameters + ---------- + preds : array-like + Classifier scores. Tensor of shape (n_samples, n_classes). + labels : array-like + Array with the true labels (one-hot encoded). Tensor of shape (n_samples, n_classes). + + Returns + ------- + accuracy : float + Accuracy. + """ + accuracy = (np.argmax(preds, axis=1) == np.argmax(labels, axis=1)).mean() + + return accuracy + + +def calc_rejection(scores, labels, verbose=False, sig_eff=0.3): + """Calculates the R30 metric. + + Parameters + ---------- + scores : array-like + Classifier scores (probability of being signal). Array of shape (n_samples,). + labels : array-like + Array with the true labels (0 or 1). Array of shape (n_samples,). + sig_eff : float, optional + Signal efficiency at which to calculate the rejection. + + Returns + ------- + rejection : float + Rejection metric value. + cut_value : float + Cut value for this rejection. + """ + is_signal = labels == 1 + cut_value = np.percentile(scores[is_signal], 100 - sig_eff * 100) + background_efficiency = np.sum(scores[~is_signal] > cut_value) / np.sum(~is_signal) + if verbose: + print(f"cut_value = {cut_value}") + print(f"background_efficiency = {background_efficiency}") + rejection = 1 / background_efficiency + return rejection, cut_value diff --git a/gabbro/models/backbone.py b/gabbro/models/backbone.py new file mode 100644 index 0000000..679ebd8 --- /dev/null +++ b/gabbro/models/backbone.py @@ -0,0 +1,1144 @@ +"""Backbone model with different heads.""" + +import time +from typing import Any, Dict, Tuple + +import awkward as ak +import lightning as L +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import vector +from tqdm import tqdm + +from gabbro.metrics.utils import calc_accuracy +from gabbro.models.gpt_model import BackboneModel +from gabbro.utils.arrays import fix_padded_logits +from gabbro.utils.pylogger import get_pylogger + +vector.register_awkward() + +logger = get_pylogger(__name__) + +# ------------------------------------------------------------------------- +# ------------ BACKBONE + Generative (next-token-prediction) head --------- +# ------------------------------------------------------------------------- + + +class NextTokenPredictionHead(nn.Module): + """Head for predicting the next token in a sequence.""" + + def __init__(self, embedding_dim, vocab_size): + super().__init__() + self.fc1 = nn.Linear(embedding_dim, vocab_size) + + def forward(self, x): + return self.fc1(x) + + +class BackboneNextTokenPredictionLightning(L.LightningModule): + """PyTorch Lightning module for training the backbone model.""" + + def __init__( + self, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler = None, + model_kwargs={}, + token_dir=None, + verbose=False, + exclude_padded_values_from_loss: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.save_hyperparameters(logger=False) + + # initialize the backbone + self.module = BackboneModel(**model_kwargs) + + # initialize the model head + self.head = NextTokenPredictionHead( + embedding_dim=model_kwargs["embedding_dim"], + vocab_size=model_kwargs["vocab_size"], + ) + + # initialize the loss function + # self.criterion = nn.CrossEntropyLoss() + self.criterion = nn.CrossEntropyLoss( + weight=torch.tensor( + [1.0] * (model_kwargs["vocab_size"] - 1) + [model_kwargs["stop_token_weight"]] + ).to(self.device) + ) + + self.token_dir = token_dir + self.verbose = verbose + + self.train_loss_history = [] + self.val_loss_list = [] + + self.validation_cnt = 0 + self.validation_output = {} + + self.backbone_weights_path = model_kwargs.get("backbone_weights_path", "None") + + self.pylogger = get_pylogger(__name__) + self.pylogger.info(f"Backbone weights path: {self.backbone_weights_path}") + + if self.backbone_weights_path is not None: + if self.backbone_weights_path != "None": + self.load_backbone_weights(self.backbone_weights_path) + + def load_backbone_weights(self, ckpt_path): + self.pylogger.info(f"Loading backbone weights from {ckpt_path}") + ckpt = torch.load(ckpt_path) + # print(ckpt) + state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt + + # print("state_dict",state_dict) + loaded_state_dict = {k: v for k, v in state_dict.items() if "tril" not in k} + self.load_state_dict(loaded_state_dict, strict=False) + print("Backbone weights loaded") + + def forward(self, x, mask=None): + if self.module.return_embeddings: + backbone_out = self.module(x, mask) + logits = self.head(backbone_out) + else: + logits = self.module(x, mask) + if self.verbose: + self.pylogger.info("Logits shape: ", logits.shape) + return logits + + def model_step(self, batch, return_logits=False): + """Perform a single model step on a batch of data. + + Parameters + ---------- + batch : dict + A batch of data as a dictionary containing the input and target tensors, + as well as the mask. + return_logits : bool, optional + Whether to return the logits or not. (default is False) + """ + + # all token-ids up to the last one are the input, the ones from the second + # to the (including) last one are the target + # this model step uses the convention that the first particle feature + # is the token, with the tokens up to the last one + # the second particle feature is the target token (i.e. the next token) + + X = batch["part_features"] + X = X.squeeze().long() + input = X[:, :, 0] + targets = X[:, :, 1] + mask = batch["part_mask"] + + # Add print statements after data extraction + + # compute the logits (i.e. the predictions for the next token) + logits = self.forward(input, mask) + + if self.hparams.exclude_padded_values_from_loss: + logits = fix_padded_logits(logits, mask, factor=1e6) + + # reshape the logits and targets to work with the loss function + B, T, C = logits.shape + logits = logits.view(B * T, C) + targets = targets.contiguous().view(B * T) + + loss = self.criterion(logits, targets) + + if return_logits: + return loss, X, logits, mask, targets + + return loss + + @torch.no_grad() + def generate_batch(self, batch_size): + """Generate a batch of shower constituents autoregressively, stopping generation for each + sequence individually when the stop token is encountered.""" + device = next(self.module.parameters()).device + idx = torch.zeros(batch_size, 1).long().to(device) + stop_token_id = self.module.vocab_size - 2 + completed_sequences = torch.zeros(batch_size, dtype=torch.bool).to(device) + + for i in range(self.module.max_sequence_len): + if torch.all(completed_sequences): + break # Stop if all sequences have generated the stop token + + logits = self(idx) + self.pylogger.info( + "Logit shape input for generation: ", logits.shape + ) if self.verbose else None + logits = logits[:, -1, :] + + logits = logits / self.module.temperature + probs = F.softmax(logits[:, 1:], dim=-1) + + stop_token_probs = probs[:, stop_token_id] + + stop_token_probs[stop_token_probs < self.module.stop_token_threshold] = 0 + probs[:, stop_token_id] = stop_token_probs + + probs_sum = probs.sum(dim=-1, keepdim=True) + probs = probs / probs_sum + + idx_next = torch.multinomial(probs, num_samples=1) + 1 + idx = torch.cat((idx, idx_next), dim=1) + self.pylogger.info( + "appended idx_next to original idx, shape: ", idx.shape + ) if self.verbose else None + + completed_sequences = completed_sequences | (idx_next.squeeze(-1) == stop_token_id + 1) + + # No need to truncate, sequences stopped naturally + gen_batch_np = idx.detach().cpu().numpy() + gen_batch_ak = ak.from_numpy(gen_batch_np) + gen_batch_until_stop = [] + + # loop over the showers in the batch, and only keep the tokens until the stop token + for shower in gen_batch_ak: + stop_token_position = np.where(shower == self.module.vocab_size - 1) + if len(stop_token_position[0]) > 0: + stop_token_position = stop_token_position[0][0] + else: + stop_token_position = shower.shape[0] + gen_batch_until_stop.append(shower[:stop_token_position]) + + return ak.Array(gen_batch_until_stop) + + def generate_n_showers_batched(self, n_showers, batch_size, saveas=None): + """Generate showers in batches. + + Parameters + ---------- + n_showers : int + Number of showers to generate. + batch_size : int + Batch size to use during generation (use as large as possible with memory.) + saveas : str, optional + Path to save the generated showers to (in parquet format). (default is None) + + Returns + ------- + ak.Array + The generated showers (i.e. their token ids, in the shape (n_showers, ). + """ + n_batches = n_showers // batch_size + generated_showers = [] + + self.pylogger.info( + f"Generating {n_showers} showers in {n_batches} batches of size {batch_size}" + ) + + for i in tqdm(range(n_batches)): + gen_batch_ak = self.generate_batch(batch_size) + generated_showers.append(gen_batch_ak) + + # concatenate the generated batches + generated_showers = ak.concatenate(generated_showers)[:n_showers] + + if saveas is not None: + self.pylogger.info(f"Saving generated showers to {saveas}") + ak.to_parquet(generated_showers, saveas) + + return generated_showers + + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set.""" + loss = self.model_step(batch) + + self.train_loss_history.append(float(loss)) + self.log( + "train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + + return loss + + def on_train_start(self) -> None: + self.pylogger.info("`on_train_start` called.") + self.pylogger.info("Setting up the logger with the correct rank.") + self.pylogger = get_pylogger(__name__, rank=self.trainer.global_rank) + self.pylogger.info("Logger set up.") + + self.preprocessing_dict = ( + self.trainer.datamodule.hparams.dataset_kwargs_common.feature_dict + ) + + def on_train_epoch_start(self): + self.pylogger.info("`on_train_epoch_start` called.") + self.pylogger.info(f"Epoch {self.trainer.current_epoch} starting.") + self.epoch_train_start_time = time.time() # start timing the epoch + + def on_train_epoch_end(self): + self.pylogger.info("`on_train_epoch_end` called.") + self.epoch_train_end_time = time.time() + self.epoch_train_duration_minutes = ( + self.epoch_train_end_time - self.epoch_train_start_time + ) / 60 + self.log( + "epoch_train_duration_minutes", + self.epoch_train_duration_minutes, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + if len(self.train_loss_history) > 0: + self.pylogger.info( + f"Epoch {self.trainer.current_epoch} finished in" + f" {self.epoch_train_duration_minutes:.1f} minutes. " + f"Current step: {self.global_step}. Current loss: {self.train_loss_history[-1]}." + f" rank: {self.global_rank}" + ) + if dist.is_initialized(): + dist.barrier() + self.pylogger.info("Barrier at epoch end.") + + def on_train_end(self) -> None: + self.pylogger.info("`on_train_end` called.") + + def on_validation_epoch_start(self) -> None: + self.pylogger.info("`on_validation_epoch_start` called.") + self.val_token_ids_list = [] + self.val_token_masks_list = [] + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + loss, X, logits, mask, targets = self.model_step(batch, return_logits=True) + + self.val_token_ids_list.append(batch["part_features"].float().detach().cpu().numpy()) + self.val_token_masks_list.append(batch["part_mask"].float().detach().cpu().numpy()) + self.log( + "val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + # self.log("batch_idx", batch["part_features"], on_step=True, on_epoch=True, prog_bar=True) + # self.pylogger.info(f"first_batch {batch['part_features'][0]}") + # self.pylogger.info("val_token_ids_list", self.val_token_ids_list, on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def on_test_epoch_start(self) -> None: + self.pylogger.info("`on_test_epoch_start` called.") + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + loss, X, logits, mask, targets = self.model_step(batch, return_logits=True) + self.log("test_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + + def on_validation_epoch_end(self) -> None: + """Lightning hook that is called when a validation epoch ends.""" + self.pylogger.info("`on_validation_epoch_end` called.") + + def on_test_epoch_end(self): + self.pylogger.info("`on_test_epoch_end` called.") + + def configure_optimizers(self) -> Dict[str, Any]: + """Configures optimizers and learning-rate schedulers to be used for training.""" + self.pylogger.info("`configure_optimizers` called.") + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + }, + } + + return {"optimizer": optimizer} + + +# ------------------------------------------------------------------------- +# ------------------ BACKBONE + Classification head ----------------------- +# ------------------------------------------------------------------------- + + +class NormformerCrossBlock(nn.Module): + def __init__(self, input_dim, mlp_dim, num_heads, dropout_rate=0.0): + super().__init__() + self.input_dim = input_dim + self.num_heads = num_heads + self.dropout_rate = dropout_rate + + # define the MultiheadAttention layer with layer normalization + self.norm1 = nn.LayerNorm(input_dim) + self.attn = nn.MultiheadAttention( + input_dim, num_heads, batch_first=True, dropout=self.dropout_rate + ) + self.norm2 = nn.LayerNorm(input_dim) + + # define the MLP with layer normalization + self.mlp = nn.Sequential( + nn.LayerNorm(input_dim), # Add layer normalization + nn.Linear(input_dim, mlp_dim), + nn.SiLU(), + nn.Dropout(self.dropout_rate), + nn.Linear(mlp_dim, input_dim), + ) + + # initialize weights of mlp[-1] and layer norm after attn block to 0 + # such that the residual connection is the identity when the block is + # initialized + nn.init.zeros_(self.mlp[-1].weight) + nn.init.zeros_(self.mlp[-1].bias) + nn.init.zeros_(self.norm1.weight) + + def forward(self, x, class_token, mask=None, return_attn_weights=False): + # x: (B, S, F) + # mask: (B, S) + x = x * mask.unsqueeze(-1) + + # calculate cross-attention + x_norm = self.norm1(x) + attn_output, attn_weights = self.attn( + query=class_token, key=x_norm, value=x_norm, key_padding_mask=mask != 1 + ) + return attn_output + + +class ClassifierNormformer(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + num_heads=4, + dropout_rate=0.0, + num_class_blocks=3, + model_kwargs={"n_out_nodes": 2, "fc_params": [(100, 0.1), (100, 0.1)]}, + **kwargs, + ): + super().__init__() + + self.model_kwargs = model_kwargs + self.n_out_nodes = model_kwargs["n_out_nodes"] + self.dropout_rate = dropout_rate + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.num_class_blocks = num_class_blocks + self.class_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) + + self.input_projection = nn.Linear(self.input_dim, self.hidden_dim) + self.class_attention_blocks = nn.ModuleList( + [ + NormformerCrossBlock( + input_dim=self.hidden_dim, + num_heads=self.num_heads, + dropout_rate=self.dropout_rate, + mlp_dim=self.hidden_dim, + ) + for _ in range(self.num_class_blocks) + ] + ) + self.final_mlp = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.SiLU(), + nn.Dropout(self.dropout_rate), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.SiLU(), + nn.Dropout(self.dropout_rate), + nn.Linear(self.hidden_dim, self.model_kwargs["n_out_nodes"]), + ) + + self.loss_history = [] + self.lr_history = [] + + def forward(self, x, mask): + # expand class token and add to mask + class_token = self.class_token.expand(x.size(0), -1, -1) + mask_with_token = torch.cat([torch.ones(x.size(0), 1).to(x.device), mask], dim=1) + + # pass through class attention blocks, always use the updated class token + for block in self.class_attention_blocks: + x_class_token_and_x_encoded = torch.cat([class_token, x], dim=1) + # class_token = block(x_class_token_and_x_encoded, mask=mask_with_token)[:, :1, :] + class_token = block(x_class_token_and_x_encoded, class_token, mask=mask_with_token) + + # pass through final mlp + class_token = self.final_mlp(class_token).squeeze(1) + return class_token + + +class ClassificationHead(torch.nn.Module): + """Classification head for the backbone model.""" + + def __init__(self, model_kwargs={"n_out_nodes": 2}): + super().__init__() + self.backbone_weights_path = None + + if "n_out_nodes" not in model_kwargs: + model_kwargs["n_out_nodes"] = 2 + if "return_embeddings" not in model_kwargs: + model_kwargs["return_embeddings"] = True + + self.n_out_nodes = model_kwargs["n_out_nodes"] + model_kwargs.pop("n_out_nodes") + + self.classification_head_linear_embed = nn.Linear( + model_kwargs["embedding_dim"], + model_kwargs["embedding_dim"], + ) + self.classification_head_linear_class = nn.Linear( + model_kwargs["embedding_dim"], + self.n_out_nodes, + ) + + def forward(self, x, mask): + embeddings = F.relu(self.classification_head_linear_embed(x)) + embeddings_sum = torch.sum(embeddings * mask.unsqueeze(-1), dim=1) + logits = self.classification_head_linear_class(embeddings_sum) + return logits + + +class BackboneClassificationLightning(L.LightningModule): + """Backbone with classification head.""" + + def __init__( + self, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler = None, + class_head_type: str = "summation", + model_kwargs: dict = {}, + **kwargs, + ) -> None: + super().__init__() + self.save_hyperparameters(logger=False) + + # initialize the backbone + self.module = BackboneModel(**model_kwargs) + + # initialize the model head + if class_head_type == "summation": + self.head = ClassificationHead( + model_kwargs={ + "n_out_nodes": model_kwargs["n_out_nodes"], + "embedding_dim": model_kwargs["embedding_dim"], + } + ) + elif class_head_type == "class_attention": + self.head = ClassifierNormformer( + input_dim=model_kwargs["embedding_dim"], + hidden_dim=model_kwargs["embedding_dim"], + model_kwargs={"n_out_nodes": model_kwargs["n_out_nodes"]}, + num_heads=2, + num_class_blocks=3, + dropout_rate=0.0, + ) + else: + raise ValueError(f"Invalid class_head_type: {class_head_type}") + + self.criterion = torch.nn.CrossEntropyLoss() + + self.train_loss_history = [] + self.val_loss_history = [] + + self.backbone_weights_path = model_kwargs.get("backbone_weights_path", "None") + logger.info(f"Backbone weights path: {self.backbone_weights_path}") + + if self.backbone_weights_path is not None: + if self.backbone_weights_path != "None": + self.load_backbone_weights(self.backbone_weights_path) + + def load_backbone_weights(self, ckpt_path=None): + logger.info(f"Loading backbone weights from {ckpt_path}") + if ckpt_path is None: + ckpt_path = ( + self.hparams.model_kwargs.backbone_weights_path + ) # Or wherever your default path is stored + logger.info(f"Loading backbone weights from {ckpt_path}") + ckpt = torch.load(ckpt_path) + state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt + self.load_state_dict(state_dict, strict=False) + + def forward(self, X, mask): + embeddings = self.module(X, mask) + logits = self.head(embeddings, mask) + return logits + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + logger.info("`on_train_start` called.") + + def on_train_epoch_start(self) -> None: + logger.info("`on_train_epoch_start` called.") + self.train_preds_list = [] + self.train_labels_list = [] + logger.info(f"Epoch {self.trainer.current_epoch} started.") + + def model_step(self, batch) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + """ + X = batch["part_features"].to("cuda") + mask = batch["part_mask"].to("cuda") + shower_labels = batch["shower_type_labels"] + if len(X.size()) == 2: + X = X.unsqueeze(-1) + X = X.squeeze().long() + # one-hot encode the labels + logits = self.forward(X, mask) + labels = F.one_hot(shower_labels.squeeze(), num_classes=self.head.n_out_nodes).float() + loss = self.criterion(logits.to("cuda"), labels.to("cuda")) + return loss, logits, labels + + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + loss, logits, targets = self.model_step(batch) + + preds = torch.softmax(logits, dim=1) + self.train_preds_list.append(preds.float().detach().cpu().numpy()) + self.train_labels_list.append(targets.float().detach().cpu().numpy()) + self.train_loss_history.append(loss.float().detach().cpu().numpy()) + + acc = calc_accuracy( + preds=preds.float().detach().cpu().numpy(), + labels=targets.float().detach().cpu().numpy(), + ) + + self.log( + "train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + return loss + + def on_train_epoch_end(self): + logger.info("`on_train_epoch_end` called.") + self.train_preds = np.concatenate(self.train_preds_list) + self.train_labels = np.concatenate(self.train_labels_list) + logger.info(f"Epoch {self.trainer.current_epoch} finished.") + dist.barrier() + plt.plot(self.train_loss_history) + + def on_validation_epoch_start(self) -> None: + logger.info("`on_validation_epoch_start` called.") + self.val_preds_list = [] + self.val_labels_list = [] + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + loss, logits, targets = self.model_step(batch) + preds = torch.softmax(logits, dim=1) + self.val_preds_list.append(preds.float().detach().cpu().numpy()) + self.val_labels_list.append(targets.float().detach().cpu().numpy()) + # update and log metrics + acc = calc_accuracy( + preds=preds.float().detach().cpu().numpy(), + labels=targets.float().detach().cpu().numpy(), + ) + self.log( + "val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + self.log("val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + def on_validation_epoch_end(self) -> None: + """Lightning hook that is called when a validation epoch ends.""" + logger.info("`on_validation_epoch_end` called.") + self.val_preds = np.concatenate(self.val_preds_list) + self.val_labels = np.concatenate(self.val_labels_list) + + def on_test_start(self): + logger.info("`on_test_start` called.") + self.test_loop_preds_list = [] + self.test_loop_labels_list = [] + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set.""" + loss, logits, targets = self.model_step(batch) + preds = torch.softmax(logits, dim=1) + self.test_loop_preds_list.append(preds.float().detach().cpu().numpy()) + self.test_loop_labels_list.append(targets.float().detach().cpu().numpy()) + + acc = calc_accuracy( + preds=preds.float().detach().cpu().numpy(), + labels=targets.float().detach().cpu().numpy(), + ) + self.log( + "test_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + self.log("test_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + def on_test_epoch_end(self): + logger.info("`on_test_epoch_end` called.") + self.test_preds = np.concatenate(self.test_loop_preds_list) + self.test_labels = np.concatenate(self.test_loop_labels_list) + + def configure_optimizers(self) -> Dict[str, Any]: + """Configures optimizers and learning-rate schedulers to be used for training.""" + logger.info("`configure_optimizers` called.") + if self.hparams.model_kwargs.keep_backbone_fixed: + logger.info("--- Keeping backbone fixed. ---") + optimizer = self.hparams.optimizer( + [ + {"params": self.module.parameters(), "lr": 0.0}, + {"params": self.head.parameters()}, + ] + ) + else: + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +class BackboneMPMLightning(L.LightningModule): + """Backbone model with NextTokenPredictionHead used for predicting a masked particle.""" + + def __init__( + self, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler = None, + model_kwargs={}, + token_dir=None, + **kwargs, + ) -> None: + super().__init__() + self.save_hyperparameters(logger=False) + + # --------------- load pretrained model --------------- # + # if kwargs.get("load_pretrained", False): + self.module = BackboneModel(**model_kwargs) + + self.head = NextTokenPredictionHead( + embedding_dim=model_kwargs["embedding_dim"], + vocab_size=model_kwargs["vocab_size"], + ) + self.backbone_weights_path = model_kwargs.get("backbone_weights_path", None) + self.token_dir = token_dir + + self.train_loss_history = [] + self.val_loss_list = [] + + self.validation_cnt = 0 + self.validation_output = {} + + self.criterion = nn.CrossEntropyLoss() + + if self.backbone_weights_path is not None: + if self.backbone_weights_path != "None": + self.load_backbone_weights(self.backbone_weights_path) + + def load_backbone_weights(self, ckpt_path=None): # Add ckpt_path parameter + if ckpt_path is None: + ckpt_path = ( + self.hparams.model_kwargs.backbone_weights_path + ) # Or wherever your default path is stored + logger.info(f"Loading backbone weights from {ckpt_path}") + logger.info(f"Loading backbone weights from {ckpt_path}") + ckpt = torch.load(ckpt_path) + state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt + self.load_state_dict(state_dict, strict=False) + print("Backbone weights loaded") + + def forward(self, x, mask): + embedding = self.module(x, mask) + logits = self.head(embedding) + return logits + + def multi_masking(self, showerlenghts, mask_percent, mm=True): + showerlenghts = showerlenghts.cpu().numpy() + if mm: + to_mask = showerlenghts // (100 / mask_percent) + else: + to_mask = np.ones_like(showerlenghts) + + batch_mask = [] + + for shower, mask_amount in enumerate(to_mask): + mask_values = np.random.choice( + showerlenghts[shower], size=int(mask_amount), replace=False + ) + mask_part = np.arange(128) > 130 + for index in mask_values: + mask_part[index] = True + + batch_mask.append(mask_part) + mask = torch.from_numpy(np.asarray(batch_mask)) + return mask + + def model_step(self, batch, return_logits=False, return_output=False): + """Perform a single model step on a batch of data.""" + # preparing the data + X = batch["part_features"] + X = X.squeeze().long() + mask = batch["part_mask"] + showerlen = mask.sum(axis=1) + lenmask = showerlen > 1 + mask_to_fill = self.multi_masking(showerlen[lenmask], 10, mm=False).to("cuda") + X_len_masked = X[lenmask] + + X_masked = X_len_masked.masked_fill(mask_to_fill, 8192) + targets = X[lenmask][:, :] + + X = X_masked[:, :] + mask = mask[lenmask][:, :] + # forward pass + logits = self.forward(X, mask) + + # calculating accuracy and output metrics + B, T, C = logits.shape + argmax = torch.argmax(logits, axis=2) + + masked_particle_mask = mask_to_fill[:, :] & mask + total_masked_particles = masked_particle_mask.sum().item() + correct_masked_particles = ( + (argmax[masked_particle_mask] == targets[masked_particle_mask]).sum().item() + ) + accuracy_masked_particles = correct_masked_particles / total_masked_particles + + masked_logits = logits[masked_particle_mask].to("cpu") + masked_targets = targets[masked_particle_mask].to("cpu") + + loss = self.criterion(masked_logits, masked_targets) + + if return_logits: + return loss, X, masked_logits, mask, masked_targets + + return loss, accuracy_masked_particles + + @torch.no_grad() + def predict_tokens(self, batch): + """Mask and predict tokens on a batch of data.""" + self.to("cuda") + X = batch["part_features"].to("cuda") # .to("cpu") + X_orig = X + X = X.squeeze().long() + mask = batch["part_mask"].to("cuda") # .to("cpu") + showerlen = mask.sum(axis=1) + lenmask = showerlen > 1 + mask_to_fill = self.multi_masking(showerlen[lenmask], 10, mm=False).to("cuda") + X_len_masked = X[lenmask] + X_masked = X_len_masked.masked_fill(mask_to_fill, 8192) + # all token-ids up to the last one are the input, the ones from the second + # to the (including) last one are the target + targets = X[lenmask][:, :] + X = X_masked[:, :] + mask = mask[lenmask][:, :] + + logits = self.forward(X, mask) + B, T, C = logits.shape + argmax = torch.argmax(logits, axis=2) + masked_particle_mask = mask_to_fill[:, :] & mask + total_masked_particles = masked_particle_mask.sum().item() + correct_masked_particles = ( + (argmax[masked_particle_mask] == targets[masked_particle_mask]).sum().item() + ) + accuracy_masked_particles = correct_masked_particles / total_masked_particles + + masked_logits = logits[masked_particle_mask].to("cpu") + masked_targets = targets[masked_particle_mask].to("cpu") + + self.log( + "masked particle accuracy", + accuracy_masked_particles, + on_step=True, + on_epoch=False, + prog_bar=True, + ) + print("masked particle accuracy:", accuracy_masked_particles) + return ( + X_orig.cpu(), + X_len_masked.cpu(), + mask_to_fill.cpu(), + masked_targets, + masked_logits, + ) + + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set.""" + + loss, accuracy_masked_particles = self.model_step(batch) + self.log( + "masked particle accuracy", + accuracy_masked_particles, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.train_loss_history.append(float(loss)) + self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def on_train_start(self) -> None: + self.preprocessing_dict = ( + self.trainer.datamodule.hparams.dataset_kwargs_common.feature_dict + ) + + def on_train_epoch_start(self): + logger.info(f"Epoch {self.trainer.current_epoch} starting.") + self.epoch_train_start_time = time.time() # start timing the epoch + + def on_train_epoch_end(self): + self.epoch_train_end_time = time.time() + self.epoch_train_duration_minutes = ( + self.epoch_train_end_time - self.epoch_train_start_time + ) / 60 + self.log( + "epoch_train_duration_minutes", + self.epoch_train_duration_minutes, + on_epoch=True, + prog_bar=False, + ) + if len(self.train_loss_history) > 0: + logger.info( + f"Epoch {self.trainer.current_epoch} finished in" + f" {self.epoch_train_duration_minutes:.1f} minutes. " + f"Current step: {self.global_step}. Current loss FULL: {self.train_loss_history}." + f"Current step: {self.global_step}. Current loss: {self.train_loss_history[-1]}." + ) + + def on_train_end(self): + pass + + def on_validation_epoch_start(self) -> None: + self.val_token_preds_list = [] + self.val_token_target_list = [] + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + loss, X, logits, mask, targets = self.model_step(batch, return_logits=True) + preds = torch.softmax(logits, dim=1) + self.val_token_preds_list.append(preds.detach().cpu().numpy()) + self.val_token_target_list.append(targets.detach().cpu().numpy()) + acc = (np.argmax(preds.detach().cpu().numpy()) == targets.detach().cpu().numpy()).mean() + self.log("val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + self.log("val_acc", acc, on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def on_validation_epoch_end(self) -> None: + """Lightning hook that is called when a validation epoch ends.""" + self.val_preds = np.concatenate(self.val_token_preds_list) + self.val_labels = np.concatenate(self.val_token_target_list) + + def on_test_epoch_start(self) -> None: + pass + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + loss, x_original, x_reco, mask, labels, code_idx = self.model_step(batch, return_x=True) + self.log( + "test_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + + def on_test_epoch_end(self): + pass + + def configure_optimizers(self) -> Dict[str, Any]: + """Configures optimizers and learning-rate schedulers to be used for training. + + Normally you'd need one, but in the case of GANs or similar you might need multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.parameters()) + + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +class RegressionHead(nn.Module): + """Head for predicting the shower features using regression..""" + + def __init__(self, embedding_dim, num_shower_features=4): + super().__init__() + self.fc1 = nn.Linear(embedding_dim, num_shower_features) + + def forward(self, x): + x_dim_red = x.mean(axis=1) + return self.fc1(x_dim_red) + + +class BackboneRegressionLightning(L.LightningModule): + """Backbone model with NextTokenPredictionHead used for predicting a masked particle.""" + + def __init__( + self, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler = None, + model_kwargs={}, + token_dir=None, + **kwargs, + ) -> None: + super().__init__() + self.save_hyperparameters(logger=False) + + # --------------- load pretrained model --------------- # + # if kwargs.get("load_pretrained", False): + self.module = BackboneModel(**model_kwargs) + + self.head = RegressionHead( + embedding_dim=model_kwargs["embedding_dim"], + num_shower_features=model_kwargs["num_shower_features"], # needs to be implemented + ) + self.backbone_weights_path = model_kwargs.get("backbone_weights_path", None) + self.token_dir = token_dir + + self.train_loss_history = [] + self.val_loss_list = [] + + self.validation_cnt = 0 + self.validation_output = {} + + self.criterion = nn.MSELoss() + + if self.backbone_weights_path is not None: + if self.backbone_weights_path != "None": + self.load_backbone_weights(self.backbone_weights_path) + + def load_backbone_weights(self, ckpt_path): + logger.info(f"Loading backbone weights from {ckpt_path}") + ckpt = torch.load(ckpt_path) + state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt + self.load_state_dict(state_dict, strict=False) + print("Backbone weights loaded") + + def forward(self, x, mask): + embedding = self.module(x, mask) + pred_shower_features = self.head(embedding) + return pred_shower_features + + def model_step(self, batch, return_logits=False, return_output=False): + """Perform a single model step on a batch of data.""" + # preparing the data + X = batch["part_features"].to("cuda") + targets = X.mean(axis=1).squeeze() # needs to be changed to the shower feature arrays + + X = X.squeeze().long() + mask = batch["part_mask"].to("cuda") + + # forward pass + pred_shower_features = self.forward(X, mask) + + targets = batch[ + "shower_features" + ] # needs to be deleted after testing!!! This is just to make the test run. + + print("shape of output: ", pred_shower_features.shape) + # calculating loss and output metrics + loss = self.criterion(pred_shower_features, targets) + if return_logits: + return loss, X, mask, pred_shower_features + + return loss + + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set.""" + + loss = self.model_step(batch) + + self.train_loss_history.append(float(loss)) + self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def on_train_start(self) -> None: + self.preprocessing_dict = ( + self.trainer.datamodule.hparams.dataset_kwargs_common.feature_dict + ) + + def on_train_epoch_start(self): + logger.info(f"Epoch {self.trainer.current_epoch} starting.") + self.epoch_train_start_time = time.time() # start timing the epoch + + def on_train_epoch_end(self): + self.epoch_train_end_time = time.time() + self.epoch_train_duration_minutes = ( + self.epoch_train_end_time - self.epoch_train_start_time + ) / 60 + self.log( + "epoch_train_duration_minutes", + self.epoch_train_duration_minutes, + on_epoch=True, + prog_bar=False, + ) + logger.info( + f"Epoch {self.trainer.current_epoch} finished in" + f" {self.epoch_train_duration_minutes:.1f} minutes." + ) + + def on_train_end(self): + pass + + def on_validation_epoch_start(self) -> None: + pass + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + loss, X, mask, pred_shower_features = self.model_step(batch, return_logits=True) + self.log("val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def on_validation_epoch_end(self) -> None: + """Lightning hook that is called when a validation epoch ends.""" + pass + + def on_test_epoch_start(self) -> None: + pass + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + loss, X, mask, pred_shower_features = self.model_step(batch, return_logits=True) + self.log("test_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + + def on_test_epoch_end(self): + pass + + def configure_optimizers(self) -> Dict[str, Any]: + """Configures optimizers and learning-rate schedulers to be used for training. + + Normally you'd need one, but in the case of GANs or similar you might need multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.parameters()) + + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} diff --git a/gabbro/models/classifiers.py b/gabbro/models/classifiers.py new file mode 100644 index 0000000..0119a86 --- /dev/null +++ b/gabbro/models/classifiers.py @@ -0,0 +1,533 @@ +import sys +from typing import Any, Dict, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning import LightningModule +from weaver.nn.model.ParticleTransformer import ParticleTransformer # noqa: E402 + +# from gabbro.models.gpt_model import FullModel +from gabbro.models.vqvae import NormformerStack + +sys.path.append("/home//home/birkjosc/repositories") + +from gabbro.models.gpt_model_sequential import FullModel # noqa: E402 + + +def calc_accuracy(preds, labels, verbose=False): + """Calculates accuracy and AUC. + + Parameters + ---------- + preds : array-like + Classifier scores. Tensor of shape (n_samples, n_classes). + labels : array-like + Array with the true labels (one-hot encoded). Tensor of shape (n_samples, n_classes). + + Returns + ------- + accuracy : float + Accuracy. + """ + accuracy = (np.argmax(preds, axis=1) == np.argmax(labels, axis=1)).mean() + + return accuracy + + +def calc_rejection(scores, labels, verbose=False, sig_eff=0.3): + """Calculates the R30 metric. + + Parameters + ---------- + scores : array-like + Classifier scores (probability of being signal). Array of shape (n_samples,). + labels : array-like + Array with the true labels (0 or 1). Array of shape (n_samples,). + sig_eff : float, optional + Signal efficiency at which to calculate the rejection. + + Returns + ------- + rejection : float + Rejection metric value. + cut_value : float + Cut value for this rejection. + """ + is_signal = labels == 1 + cut_value = np.percentile(scores[is_signal], 100 - sig_eff * 100) + background_efficiency = np.sum(scores[~is_signal] > cut_value) / np.sum(~is_signal) + if verbose: + print(f"cut_value = {cut_value}") + print(f"background_efficiency = {background_efficiency}") + rejection = 1 / background_efficiency + return rejection, cut_value + + +class NormformerCrossBlock(nn.Module): + def __init__(self, input_dim, mlp_dim, num_heads, dropout_rate=0.1): + super().__init__() + self.input_dim = input_dim + self.num_heads = num_heads + self.dropout_rate = dropout_rate + + # define the MultiheadAttention layer with layer normalization + self.norm1 = nn.LayerNorm(input_dim) + self.attn = nn.MultiheadAttention(input_dim, num_heads, batch_first=True, dropout=0.1) + self.norm2 = nn.LayerNorm(input_dim) + + # define the MLP with layer normalization + self.mlp = nn.Sequential( + nn.LayerNorm(input_dim), # Add layer normalization + nn.Linear(input_dim, mlp_dim), + nn.SiLU(), + nn.Dropout(self.dropout_rate), + nn.Linear(mlp_dim, input_dim), + ) + + # initialize weights of mlp[-1] and layer norm after attn block to 0 + # such that the residual connection is the identity when the block is + # initialized + nn.init.zeros_(self.mlp[-1].weight) + nn.init.zeros_(self.mlp[-1].bias) + nn.init.zeros_(self.norm1.weight) + + def forward(self, x, class_token, mask=None, return_attn_weights=False): + # x: (B, S, F) + # mask: (B, S) + x = x * mask.unsqueeze(-1) + + # calculate cross-attention + x_norm = self.norm1(x) + attn_output, attn_weights = self.attn( + query=class_token, key=x_norm, value=x_norm, key_padding_mask=mask != 1 + ) + return attn_output + + +# --------------------------- Particle Flow Network --------------------------- +class ParticleFlow(nn.Module): + """Definition of the Particle Flow Network.""" + + def __init__( + self, + input_dim=None, + n_out_nodes=2, + n_embed=16, + n_tokens=None, + **kwargs, + ): + """Initialise Particle Flow Network. + + Parameters + ---------- + input_dim : int, optional + Number of features per point. + n_out_nodes : int, optional + Number of output nodes. + n_embed : int, optional + Number of embedding dimensions, only used if n_tokens is not None. + n_tokens : int, optional + Number of codebook entries (i.e. number of different tokens), only + used if input_dim is None. + """ + + super().__init__() + + if input_dim is None and n_tokens is None: + raise ValueError("Either input_dim or n_tokens must be specified") + + self.n_out_nodes = n_out_nodes + self.n_tokens = n_tokens + self.n_embed = n_embed + + if n_tokens is None: + self.phi_1 = nn.Linear(input_dim, 100) + else: + self.embedding = nn.Embedding(n_tokens, n_embed) + self.phi_1 = nn.Linear(n_embed, 100) + + self.phi_2 = nn.Linear(100, 100) + self.phi_3 = nn.Linear(100, 256) + self.F_1 = nn.Linear(256, 100) + self.F_2 = nn.Linear(100, 100) + self.F_3 = nn.Linear(100, 100) + self.output_layer = nn.Linear(100, self.n_out_nodes) + + def forward(self, x, mask): + """Definition of the ParticleFlow forward pass. + + Parameters + ---------- + + Returns + ------- + """ + batch_size, n_points, n_features = x.size() + + # propagate through phi + if self.n_tokens is not None: + x = self.embedding(x).squeeze() + x = F.relu(self.phi_1(x)) + x = F.relu(self.phi_2(x)) + x = F.relu(self.phi_3(x)) + + # sum over points dim. + # todo: check if sum_scale helps * self.sum_scale + x_sum = torch.sum(x * mask[..., None], dim=1) + + # propagate through F + x = F.relu(self.F_1(x_sum)) + x = F.relu(self.F_2(x)) + x = F.relu(self.F_3(x)) + + # output layer - no activation used here, since both nn.BCEWithLogitsLoss + # and nn.CrossEntropyLoss expect the logits as input + x_out = self.output_layer(x) + + return x_out + + +class ClassifierPL(LightningModule): + """Pytorch-lightning module for jet classification.""" + + def __init__( + self, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler = None, + model_class_name: str = "ParticleFlow", + model_kwargs: dict = {}, + **kwargs, + ) -> None: + super().__init__() + self.save_hyperparameters(logger=False) + + self.model_class_name = model_class_name + if "keep_backbone_fixed" in model_kwargs: + self.keep_backbone_fixed = model_kwargs["keep_backbone_fixed"] + model_kwargs.pop("keep_backbone_fixed") + else: + self.keep_backbone_fixed = False + + if self.model_class_name == "ParticleFlow": + self.model = ParticleFlow(**model_kwargs) + elif self.model_class_name == "ClassifierNormformer": + self.model = ClassifierNormformer(**model_kwargs) + elif self.model_class_name == "ParT": + self.model = ParticleTransformer(**model_kwargs) + elif self.model_class_name == "BackboneWithClasshead": + self.model = BackboneWithClasshead(model_kwargs) + else: + raise ValueError(f"Model class {model_class_name} not supported.") + + self.criterion = torch.nn.CrossEntropyLoss() + + self.train_loss_history = [] + self.val_loss_history = [] + + def forward(self, features, mask): + return self.model(features, mask) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + pass + + def on_train_epoch_start(self) -> None: + self.train_preds_list = [] + self.train_labels_list = [] + print(f"Epoch {self.trainer.current_epoch} started.", end="\r") + + def model_step(self, batch) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + """ + X = batch["part_features"] + mask = batch["part_mask"] + jet_labels = batch["jet_type_labels"] + if len(X.size()) == 2: + X = X.unsqueeze(-1) + if self.model_class_name == "BackboneWithClasshead": + X = X.squeeze().long() + # one-hot encode the labels + labels = F.one_hot(jet_labels.squeeze(), num_classes=self.model.n_out_nodes).float() + logits = self.forward(X, mask) + loss = self.criterion(logits.to("cuda"), labels.to("cuda")) + return loss, logits, labels + + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + # X = batch["part_features"] + # # save X to file with torch.save(X, "X.pt") and epoch and step inside the filename + # filename_base = "/beegfs/desy/user/birkjosc/testing/test_same_batch_1000jets" + # # filename_base = "/beegfs/desy/user/birkjosc/testing/test_same_batch_without_load_only_once_1000jets" + # filename = f"{filename_base}/x_epoch_{self.trainer.current_epoch}_step_{self.trainer.global_step}.pt" + # print(f"Saving X to {filename}") + # torch.save(X, filename) + loss, logits, targets = self.model_step(batch) + + preds = torch.softmax(logits, dim=1) + self.train_preds_list.append(preds.detach().cpu().numpy()) + self.train_labels_list.append(targets.detach().cpu().numpy()) + self.train_loss_history.append(loss.detach().cpu().numpy()) + + acc = calc_accuracy( + preds=preds.detach().cpu().numpy(), labels=targets.detach().cpu().numpy() + ) + + self.log( + "train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + return loss + + def on_train_epoch_end(self): + self.train_preds = np.concatenate(self.train_preds_list) + self.train_labels = np.concatenate(self.train_labels_list) + print(f"Epoch {self.trainer.current_epoch} finished.", end="\r") + plt.plot(self.train_loss_history) + + def on_validation_epoch_start(self) -> None: + self.val_preds_list = [] + self.val_labels_list = [] + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + loss, logits, targets = self.model_step(batch) + preds = torch.softmax(logits, dim=1) + self.val_preds_list.append(preds.detach().cpu().numpy()) + self.val_labels_list.append(targets.detach().cpu().numpy()) + # update and log metrics + acc = calc_accuracy( + preds=preds.detach().cpu().numpy(), labels=targets.detach().cpu().numpy() + ) + self.log( + "val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + self.log("val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + def on_validation_epoch_end(self) -> None: + """Lightning hook that is called when a validation epoch ends.""" + self.val_preds = np.concatenate(self.val_preds_list) + self.val_labels = np.concatenate(self.val_labels_list) + + def on_test_start(self): + self.test_loop_preds_list = [] + self.test_loop_labels_list = [] + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set.""" + loss, logits, targets = self.model_step(batch) + preds = torch.softmax(logits, dim=1) + self.test_loop_preds_list.append(preds.detach().cpu().numpy()) + self.test_loop_labels_list.append(targets.detach().cpu().numpy()) + + acc = calc_accuracy( + preds=preds.detach().cpu().numpy(), labels=targets.detach().cpu().numpy() + ) + self.log( + "test_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True + ) + self.log("test_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + def on_test_epoch_end(self): + self.test_preds = np.concatenate(self.test_loop_preds_list) + self.test_labels = np.concatenate(self.test_loop_labels_list) + + def configure_optimizers(self) -> Dict[str, Any]: + """Configures optimizers and learning-rate schedulers to be used for training. + + Normally you'd need one, but in the case of GANs or similar you might need multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + if self.keep_backbone_fixed: + print("--- Keeping backbone fixed. ---") + optimizer = self.hparams.optimizer( + [ + {"params": self.model.module.parameters(), "lr": 0.0}, + {"params": self.model.classification_head_linear_embed.parameters()}, + {"params": self.model.classification_head_linear_class.parameters()}, + ] + ) + else: + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +class ClassifierNormformer(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + num_heads=1, + num_enc_blocks=2, + class_head_kwargs={"n_out_nodes": 2, "fc_params": [(100, 0.1), (100, 0.1)]}, + dropout_rate=0.1, + num_class_blocks=3, + **kwargs, + ): + super().__init__() + + self.dropout_rate = dropout_rate + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.num_enc_blocks = num_enc_blocks + self.num_class_blocks = num_class_blocks + self.class_head_kwargs = class_head_kwargs + self.class_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) + + self.input_projection = nn.Linear(self.input_dim, self.hidden_dim) + self.encoder_normformer = NormformerStack( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_blocks=self.num_enc_blocks, + dropout_rate=self.dropout_rate, + ) + self.class_attention_blocks = nn.ModuleList( + [ + NormformerCrossBlock( + input_dim=self.hidden_dim, + num_heads=self.num_heads, + dropout_rate=self.dropout_rate, + mlp_dim=self.hidden_dim, + ) + for _ in range(self.num_class_blocks) + ] + ) + self.initialize_classification_head() + + self.loss_history = [] + self.lr_history = [] + + def forward(self, x, mask): + # encode + x = self.input_projection(x) + x_encoded = self.encoder_normformer(x, mask=mask) + # concatenate class token and x + class_token = self.class_token.expand(x.size(0), -1, -1) + mask_with_token = torch.cat([torch.ones(x.size(0), 1).to(x.device), mask], dim=1) + + # pass through class attention blocks, always use the updated class token + for block in self.class_attention_blocks: + x_class_token_and_x_encoded = torch.cat([class_token, x_encoded], dim=1) + # class_token = block(x_class_token_and_x_encoded, mask=mask_with_token)[:, :1, :] + class_token = block(x_class_token_and_x_encoded, class_token, mask=mask_with_token) + + # x = x * mask.unsqueeze(-1) + + # # sum over points dim + # x = torch.sum(x, dim=1) + + # pass class token through classification head + # x = self.classification_head(x[:, 0, :]) + + return self.classification_head(class_token.squeeze(1)) + + def initialize_classification_head(self): + if self.class_head_kwargs is None: + self.class_head_kwargs = { + "fc_params": [ + [128, 0.1], + [128, 0.1], + ], + "n_out_nodes": 2, + } + + fc_params = [[self.hidden_dim, 0]] + self.class_head_kwargs["fc_params"] + self.n_out_nodes = self.class_head_kwargs["n_out_nodes"] + + layers = [] + + for i in range(1, len(fc_params)): + in_dim = fc_params[i - 1][0] + out_dim = fc_params[i][0] + dropout_rate = fc_params[i][1] + layers.extend( + [ + nn.Linear(in_dim, out_dim), + nn.Dropout(dropout_rate), + nn.ReLU(), + ] + ) + # add final layer + layers.extend([nn.Linear(fc_params[-1][0], self.n_out_nodes)]) + + self.classification_head = nn.Sequential(*layers) + + +class BackboneWithClasshead(torch.nn.Module): + def __init__(self, model_kwargs={"n_out_nodes": 2}): + super().__init__() + self.backbone_weights_path = None + + if "n_out_nodes" not in model_kwargs: + model_kwargs["n_out_nodes"] = 2 + if "return_embeddings" not in model_kwargs: + model_kwargs["return_embeddings"] = True + + self.n_out_nodes = model_kwargs["n_out_nodes"] + model_kwargs.pop("n_out_nodes") + # remove backbone weights path from model_kwargs + if "backbone_weights_path" in model_kwargs: + self.backbone_weights_path = model_kwargs["backbone_weights_path"] + model_kwargs.pop("backbone_weights_path") + + # backbone + self.module = FullModel(**model_kwargs) + # load weights if available + if self.backbone_weights_path is not None: + if self.backbone_weights_path != "None": + print(f"Loading weights from {self.backbone_weights_path}") + self.load_weights(self.backbone_weights_path) + self.module.return_embeddings = True + # initialize classification head + self.classification_head_linear_embed = nn.Linear( + model_kwargs["embedding_dim"], + model_kwargs["embedding_dim"], + ) + self.classification_head_linear_class = nn.Linear( + model_kwargs["embedding_dim"], + self.n_out_nodes, + ) + + def load_weights(self, ckpt_path): + print(f"Loading weights from {ckpt_path}") + ckpt = torch.load(ckpt_path) + state_dict = ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt + self.load_state_dict(state_dict, strict=False) + + def forward(self, x, mask): + embeddings = self.module(x, mask) + embeddings = F.relu(self.classification_head_linear_embed(embeddings)) + embeddings_sum = torch.sum(embeddings * mask.unsqueeze(-1), dim=1) + logits = self.classification_head_linear_class(embeddings_sum) + return logits diff --git a/gabbro/models/gpt_model.py b/gabbro/models/gpt_model.py new file mode 100644 index 0000000..e4ed303 --- /dev/null +++ b/gabbro/models/gpt_model.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import vector + +from gabbro.utils.pylogger import get_pylogger + +vector.register_awkward() + +logger = get_pylogger(__name__) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + embedding_dim: int, + n_heads: int, + attention_dropout: float, + max_sequence_len: int = 256, + apply_causal_mask=False, + ): + super().__init__() + assert embedding_dim % n_heads == 0, "Embedding dim must be divisible by number of heads" + + self.head_dim = embedding_dim // n_heads + self.n_heads = n_heads + self.embedding_dim = embedding_dim + self.apply_causal_mask = apply_causal_mask + + self.key = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.query = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.value = nn.Linear(embedding_dim, embedding_dim, bias=False) + + # Create a causal attention mask and store it as self.tril. Being a + # buffer means that it will not be included as parameters in the model. + self.register_buffer("tril", torch.tril(torch.ones(max_sequence_len, max_sequence_len))) + self.dropout = nn.Dropout(attention_dropout) + + self.proj = nn.Linear(embedding_dim, embedding_dim) + + def forward(self, x, padding_mask=None): + B, T, C = x.shape + # input of size (batch, time-step, channels); channels = embedding dimension + # output of size (batch, time-step, embedding_dim) + + k = self.key(x) # (B, T, E) + q = self.query(x) # (B, T, E) + v = self.value(x) # (B, T, E) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (B, T, E) -> (B, T, num_heads, head_dim) + k = k.view(B, T, self.n_heads, self.head_dim) + v = v.view(B, T, self.n_heads, self.head_dim) + q = q.view(B, T, self.n_heads, self.head_dim) + + # Transpose: (B, T, n_heads, head_dim) -> (B, n_heads, T, head_dim) + k = k.transpose(1, 2) + q = q.transpose(1, 2) + v = v.transpose(1, 2) + + # Compute scaled dot-product attention + # (B, n_heads, T, head_dim) @ (B, n_heads, head_dim, T) -> (B, n_heads, T, T) + attn_scores = q @ k.transpose(2, 3) * k.shape[-1] ** -0.5 + + if padding_mask is not None: + padding_mask = padding_mask.unsqueeze(-1).expand(-1, -1, T) # (B, T) -> (B, T, T) + # (B, T, T) -> (B, n_heads, T, T) + padding_mask = padding_mask.unsqueeze(1).expand(B, self.n_heads, T, T) + # Need to set a finite number for the masking, instead of -inf, + # otherwise softmax results in nans. + # (B, n_heads, T, T) + attn_scores = attn_scores.masked_fill(padding_mask == 0, float("-1e9")) + + # Apply the causal mask, cropped to the sequence length + # (B, n_heads, T, T) + if self.apply_causal_mask: + attn_scores = attn_scores.masked_fill(self.tril[:T, :T] == 0, float("-inf")) + + attn_weights = F.softmax(attn_scores, dim=-1) # (B, n_heads, T, T) + attn_weights = self.dropout(attn_weights) + + # attn_weights have shape (B, n_heads, T, T) and v (B, n_heads, T, head_dim) + # (B, n_heads, T, head_dim) -> (B, T, n_heads, head_dim) + context_vec = (attn_weights @ v).transpose(1, 2) + + # Combine heads, where embedding_dim = n_heads * head_dim + context_vec = context_vec.contiguous().view(B, T, self.embedding_dim) + context_vec = self.proj(context_vec) + + return context_vec + + +class FeedForward(nn.Module): + """Simple linear layer followed by a non-linearity to be placed after the attention blocks.""" + + def __init__(self, embedding_dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(embedding_dim, 4 * embedding_dim), + nn.ReLU(), + nn.Linear(4 * embedding_dim, embedding_dim), + # nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class GPT_DecoderBlock(nn.Module): + """The GPT decoder block.""" + + def __init__( + self, + embedding_dim: int, + attention_dropout: int, + n_heads: int, + verbose: bool = False, + apply_causal_mask: bool = True, + max_sequence_len: int = 256, + ): + super().__init__() + self.verbose = verbose + self.apply_causal_mask = apply_causal_mask + self.mha_block = MultiHeadAttention( + embedding_dim, + n_heads, + attention_dropout, + apply_causal_mask=apply_causal_mask, + max_sequence_len=max_sequence_len, + ) + self.ff_block = FeedForward(embedding_dim) + self.layernorm_1 = nn.LayerNorm(embedding_dim) + self.layernorm_2 = nn.LayerNorm(embedding_dim) + + def forward(self, x, padding_mask=None): + x_residual = x + + x = self.mha_block(x, padding_mask=padding_mask) + x += x_residual + + x = self.layernorm_1(x) + x_residual = x + + x = self.ff_block(x) + x += x_residual + + x = self.layernorm_2(x) + + return x + + +class BackboneModel(nn.Module): + """Model that is used as the backbone in our studies. + + Going from integer tokens to embeddings via an embedding table, then through a stack of GPT + blocks. The output is the final embeddings. + """ + + def __init__( + self, + embedding_dim: int, + attention_dropout: float, + vocab_size: int, + max_sequence_len: int, + n_heads: int, + n_GPT_blocks: int, + n_classes: int = 2, + classify: bool = False, + verbosity: bool = True, + n_tokens: int = None, + return_embeddings=False, # only there for now for backwards-compatibility with the old model + temperature: float = 1.0, + stop_token_threshold: float = 0.0, + apply_causal_mask=True, + **kwargs, + ): + super().__init__() + + self.apply_causal_mask = apply_causal_mask + + if not self.apply_causal_mask: + logger.warning( + "NOT applying causal mask in the attention blocks. If you are using " + "this model for an autoregressive generative task, this is probably " + "not what you want." + ) + + self.vocab_size = vocab_size + self.max_sequence_len = max_sequence_len + self.temperature = temperature + self.stop_token_threshold = stop_token_threshold + self.verbose = verbosity + self.embedding_table = nn.Embedding(vocab_size, embedding_dim) + self.return_embeddings = return_embeddings + + GPT_block_stack = [] + for _ in range(n_GPT_blocks): + GPT_block_stack.extend( + [ + GPT_DecoderBlock( + embedding_dim, + attention_dropout, + n_heads=n_heads, + verbose=self.verbose, + apply_causal_mask=self.apply_causal_mask, + max_sequence_len=self.max_sequence_len, + ) + ] + ) + self.GPT_blocks = nn.Sequential(*GPT_block_stack) + + def forward(self, x, padding_mask=None): + x = self.embedding_table(x) + + for block in self.GPT_blocks: + x = block(x, padding_mask=padding_mask) + + return x diff --git a/gabbro/models/vqvae.py b/gabbro/models/vqvae.py new file mode 100644 index 0000000..4cf43fb --- /dev/null +++ b/gabbro/models/vqvae.py @@ -0,0 +1,1202 @@ +import logging +import time +from pathlib import Path +from typing import Any, Dict, Tuple + +import awkward as ak +import lightning as L +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import vector +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm +from vqtorch.nn import VectorQuant + +from gabbro.utils.arrays import ( + ak_pad, + ak_padding, + ak_preprocess, + ak_select_and_preprocess, + ak_to_np_stack, + np_to_ak, +) + +vector.register_awkward() + +logger = logging.getLogger(__name__) + + +class VQVAEMLP(torch.nn.Module): + def __init__( + self, + input_dim=2, + latent_dim=2, + encoder_layers=None, + decoder_layers=None, + vq_kwargs={}, + **kwargs, + ): + """Initializes the VQ-VAE model. + + Parameters + ---------- + codebook_size : int, optional + The size of the codebook. The default is 8. + embed_dim : int, optional + The dimension of the embedding space. The default is 2. + input_dim : int, optional + The dimension of the input data. The default is 2. + encoder_layers : list, optional + List of integers representing the number of units in each encoder layer. + If None, a default encoder with a single linear layer is used. The default is None. + decoder_layers : list, optional + List of integers representing the number of units in each decoder layer. + If None, a default decoder with a single linear layer is used. The default is None. + """ + + super().__init__() + self.vq_kwargs = vq_kwargs + self.embed_dim = latent_dim + self.input_dim = input_dim # for shower constituents, eta and phi + + # --- Encoder --- # + if encoder_layers is None: + self.encoder = torch.nn.Linear(self.input_dim, self.embed_dim) + else: + enc_layers = [] + enc_layers.append(torch.nn.Linear(self.input_dim, encoder_layers[0])) + enc_layers.append(torch.nn.ReLU()) + + for i in range(len(encoder_layers) - 1): + enc_layers.append(torch.nn.Linear(encoder_layers[i], encoder_layers[i + 1])) + enc_layers.append(torch.nn.ReLU()) + enc_layers.append(torch.nn.Linear(encoder_layers[-1], self.embed_dim)) + + self.encoder = torch.nn.Sequential(*enc_layers) + + # --- Vector-quantized layer --- # + self.vqlayer = VectorQuant(feature_size=self.embed_dim, **vq_kwargs) + + # --- Decoder --- # + if decoder_layers is None: + self.decoder = torch.nn.Linear(self.embed_dim, self.input_dim) + else: + dec_layers = [] + dec_layers.append(torch.nn.Linear(self.embed_dim, decoder_layers[0])) + dec_layers.append(torch.nn.ReLU()) + + for i in range(len(decoder_layers) - 1): + dec_layers.append(torch.nn.Linear(decoder_layers[i], decoder_layers[i + 1])) + dec_layers.append(torch.nn.ReLU()) + dec_layers.append(torch.nn.Linear(decoder_layers[-1], self.input_dim)) + + self.decoder = torch.nn.Sequential(*dec_layers) + + self.loss_history = [] + self.lr_history = [] + + def forward(self, samples, mask=None): + # mask is there for compatibility with the transformer model + # encode + z_embed = self.encoder(samples) + # quantize + z_q2, vq_out = self.vqlayer(z_embed) + # decode + x_reco = self.decoder(z_q2) + return x_reco, vq_out + + +class NormformerBlock(nn.Module): + def __init__(self, input_dim, mlp_dim, num_heads, dropout_rate=0.1): + super().__init__() + self.input_dim = input_dim + self.num_heads = num_heads + self.dropout_rate = dropout_rate + + # define the MultiheadAttention layer with layer normalization + self.norm1 = nn.LayerNorm(input_dim) + self.attn = nn.MultiheadAttention(input_dim, num_heads, batch_first=True, dropout=0.1) + self.norm2 = nn.LayerNorm(input_dim) + + # define the MLP with layer normalization + self.mlp = nn.Sequential( + nn.LayerNorm(input_dim), # Add layer normalization + nn.Linear(input_dim, mlp_dim), + nn.SiLU(), + nn.Dropout(self.dropout_rate), + nn.Linear(mlp_dim, input_dim), + ) + + # initialize weights of mlp[-1] and layer norm after attn block to 0 + # such that the residual connection is the identity when the block is + # initialized + nn.init.zeros_(self.mlp[-1].weight) + nn.init.zeros_(self.mlp[-1].bias) + nn.init.zeros_(self.norm1.weight) + + def forward(self, x, mask=None, return_attn_weights=False): + # x: (B, S, F) + # mask: (B, S) + x = x * mask.unsqueeze(-1) + + # calculate self-attention + x_norm = self.norm1(x) + attn_output, attn_weights = self.attn(x_norm, x_norm, x_norm, key_padding_mask=mask != 1) + # Add residual connection and permute back to (B, S, F) + attn_res = self.norm2(attn_output) + x + + output = self.mlp(attn_res) + attn_res + + if return_attn_weights: + return output, attn_weights + + # output shape: (B, S, F) + return output + + +class Transformer(torch.nn.Module): + def __init__( + self, + input_dim, + output_dim, + hidden_dim, + num_heads=1, + num_blocks=2, + skip_out_proj=False, + ): + super().__init__() + + self.project_in = nn.Linear(input_dim, hidden_dim) + + self.num_blocks = num_blocks + self.skip_out_proj = skip_out_proj + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.output_dim = output_dim + + self.blocks = nn.ModuleList( + [ + NormformerBlock(input_dim=hidden_dim, mlp_dim=hidden_dim, num_heads=num_heads) + for _ in range(num_blocks) + ] + ) + self.project_out = nn.Linear(hidden_dim, output_dim) + + def forward(self, x, mask): + x = self.project_in(x) + for i, block in enumerate(self.blocks): + x = block(x, mask=mask) + if self.skip_out_proj: + return x * mask.unsqueeze(-1) + x = self.project_out(x) * mask.unsqueeze(-1) + return x + + +class NormformerStack(torch.nn.Module): + def __init__( + self, + hidden_dim, + num_heads=1, + num_blocks=2, + skip_out_proj=False, + dropout_rate=0.1, + ): + super().__init__() + + self.num_blocks = num_blocks + self.skip_out_proj = skip_out_proj + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.dropout_rate = dropout_rate + + self.blocks = nn.ModuleList( + [ + NormformerBlock( + input_dim=self.hidden_dim, + mlp_dim=self.hidden_dim, + num_heads=self.num_heads, + dropout_rate=self.dropout_rate, + ) + for _ in range(num_blocks) + ] + ) + + def forward(self, x, mask): + for i, block in enumerate(self.blocks): + x = block(x, mask=mask) + return x * mask.unsqueeze(-1) + + +class VQVAETransformer(torch.nn.Module): + def __init__( + self, + input_dim, + latent_dim, + hidden_dim, + num_heads=1, + num_blocks=2, + vq_kwargs={}, + **kwargs, + ): + super().__init__() + + self.vq_kwargs = vq_kwargs + self.latent_dim = latent_dim + + self.encoder = Transformer( + input_dim=input_dim, + output_dim=latent_dim, + hidden_dim=hidden_dim, + num_heads=num_heads, + num_blocks=num_blocks, + ) + self.vqlayer = VectorQuant(feature_size=latent_dim, **vq_kwargs) + self.decoder = Transformer( + input_dim=latent_dim, + output_dim=input_dim, + hidden_dim=hidden_dim, + num_heads=num_heads, + num_blocks=num_blocks, + ) + self.loss_history = [] + self.lr_history = [] + + def forward(self, x, mask): + # encode + x = self.encoder(x, mask=mask) + z_embed = x * mask.unsqueeze(-1) + # quantize + z, vq_out = self.vqlayer(z_embed) + # decode + x_reco = self.decoder(z, mask=mask) + return x_reco, vq_out + + +class VQVAENormFormer(torch.nn.Module): + """This is basically just a re-factor of the VQVAETransformer class, but with more modular + model components, making it easier to use some components in other models.""" + + def __init__( + self, + input_dim, + latent_dim, + hidden_dim, + num_heads=1, + num_blocks=2, + vq_kwargs={}, + **kwargs, + ): + super().__init__() + + self.loss_history = [] + self.lr_history = [] + + self.vq_kwargs = vq_kwargs + self.input_dim = input_dim + self.latent_dim = latent_dim + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.num_blocks = num_blocks + + # Model components: + self.input_projection = nn.Linear(self.input_dim, self.hidden_dim) + self.encoder_normformer = NormformerStack( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_blocks=self.num_blocks, + ) + self.latent_projection_in = nn.Linear(self.hidden_dim, self.latent_dim) + self.vqlayer = VectorQuant(feature_size=self.latent_dim, **vq_kwargs) + self.latent_projection_out = nn.Linear(self.latent_dim, self.hidden_dim) + self.decoder_normformer = NormformerStack( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_blocks=self.num_blocks, + ) + self.output_projection = nn.Linear(hidden_dim, input_dim) + + def forward(self, x, mask): + # encode + x = self.input_projection(x) + x = self.encoder_normformer(x, mask=mask) + z_embed = self.latent_projection_in(x) * mask.unsqueeze(-1) + # quantize + z, vq_out = self.vqlayer(z_embed) + # decode + x_reco = self.latent_projection_out(z) * mask.unsqueeze(-1) + x_reco = self.decoder_normformer(x_reco, mask=mask) + x_reco = self.output_projection(x_reco) * mask.unsqueeze(-1) + return x_reco, vq_out + + +class VQVAELightning(L.LightningModule): + """PyTorch Lightning module for training a VQ-VAE.""" + + def __init__( + self, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler = None, + model_kwargs={}, + model_type="Transformer", + **kwargs, + ) -> None: + super().__init__() + self.save_hyperparameters(logger=False) + + # --------------- load pretrained model --------------- # + # if kwargs.get("load_pretrained", False): + if model_type == "MLP": + self.model = VQVAEMLP(**model_kwargs) + elif model_type == "Transformer": + self.model = VQVAETransformer(**model_kwargs) + elif model_type == "VQVAENormFormer": + self.model = VQVAENormFormer(**model_kwargs) + else: + raise ValueError(f"Unknown model type: {model_type}") + + self.train_loss_history = [] + self.val_loss_list = [] + + self.validation_cnt = 0 + self.validation_output = {} + + # loss function + self.criterion = torch.nn.MSELoss() + + # for tracking best so far validation accuracy + self.val_x_original = [] + self.val_x_reco = [] + self.val_mask = [] + + def forward( + self, + x_particle, + mask_particle, + ): + x_particle_reco, vq_out = self.model(x_particle, mask=mask_particle) + return x_particle_reco, vq_out + + def model_step(self, batch, return_x=False): + """Perform a single model step on a batch of data.""" + + # x_particle, mask_particle, labels = batch + x_particle = batch["part_features"] + mask_particle = batch["part_mask"] + # labels = batch["shower_type_labels"] + + x_particle_reco, vq_out = self.forward(x_particle, mask_particle) + + reco_loss = torch.sum( + ( + x_particle_reco * mask_particle.unsqueeze(-1) + - x_particle * mask_particle.unsqueeze(-1) + ) + ** 2 + ) / torch.sum(mask_particle) + + alpha = self.hparams["model_kwargs"]["alpha"] + cmt_loss = vq_out["loss"] + code_idx = vq_out["q"] + loss = reco_loss + alpha * cmt_loss + self.log("train_cmt_loss", cmt_loss.item(), on_step=True, on_epoch=True, prog_bar=True) + # logger.info(f"loss = reco_loss + alpha * cmt_loss: {loss} = {reco_loss} + {alpha} * {cmt_loss}") + + if return_x: + return loss, x_particle, x_particle_reco, mask_particle, code_idx # labels + + return loss + + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set.""" + loss = self.model_step(batch) + + self.train_loss_history.append(float(loss)) + self.log("train_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def on_train_start(self) -> None: + self.preprocessing_dict = ( + self.trainer.datamodule.hparams.dataset_kwargs_common.feature_dict + ) + logger.info(f"Preprocessing dict: {self.preprocessing_dict}") + + def on_train_epoch_start(self): + logger.info(f"Epoch {self.trainer.current_epoch} starting.") + self.epoch_train_start_time = time.time() # start timing the epoch + + def on_train_epoch_end(self): + self.epoch_train_end_time = time.time() + self.epoch_train_duration_minutes = ( + self.epoch_train_end_time - self.epoch_train_start_time + ) / 60 + self.log( + "epoch_train_duration_minutes", + self.epoch_train_duration_minutes, + on_epoch=True, + prog_bar=False, + ) + logger.info( + f"Epoch {self.trainer.current_epoch} finished in" + f" {self.epoch_train_duration_minutes:.1f} minutes." + ) + + def on_train_end(self): + pass + + def on_validation_epoch_start(self) -> None: + logger.info("on_validation_epoch_start called.") + self.val_x_original = [] + self.val_x_reco = [] + self.val_mask = [] + # self.val_labels = [] + self.val_code_idx = [] + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + # print("on_validation_epoch_step called.") + loss, x_original, x_reco, mask, code_idx = self.model_step(batch, return_x=True) + + # save the original and reconstructed data + self.val_x_original.append(x_original.detach().cpu().numpy()) + self.val_x_reco.append(x_reco.detach().cpu().numpy()) + self.val_mask.append(mask.detach().cpu().numpy()) + # self.val_labels.append(labels.detach().cpu().numpy()) + self.val_code_idx.append(code_idx.detach().cpu().numpy()) + + self.log("val_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + + # for the first validation step, plot the model + if batch_idx == 0: + # get loggers + comet_logger = None + for logger in self.trainer.loggers: + if isinstance(logger, L.pytorch.loggers.CometLogger): + comet_logger = logger.experiment + + curr_epoch, curr_step = self.trainer.current_epoch, self.trainer.global_step + + plot_dir = Path(self.trainer.default_root_dir + "/plots/") + plot_dir.mkdir(exist_ok=True) + plot_filename = f"{plot_dir}/epoch{curr_epoch}_gstep{curr_step}_original_vs_reco.png" + # log the plot + plot_model( + self.model, + samples=batch["part_features"], + masks=batch["part_mask"], + device=self.device, + saveas=plot_filename, + ) + if comet_logger is not None: + comet_logger.log_image( + plot_filename, name=plot_filename.split("/")[-1], step=curr_step + ) + + return loss + + def on_test_epoch_start(self) -> None: + logger.info("on_test_epoch_start called.") + self.test_x_original = [] + self.test_x_reco = [] + self.test_mask = [] + # self.test_labels = [] + self.test_code_idx = [] + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + logger.info("Test step called.") + loss, x_original, x_reco, mask, code_idx = self.model_step(batch, return_x=True) + + # save the original and reconstructed data + self.test_x_original.append(x_original.detach().cpu().numpy()) + self.test_x_reco.append(x_reco.detach().cpu().numpy()) + self.test_mask.append(mask.detach().cpu().numpy()) + # self.test_labels.append(labels.detach().cpu().numpy()) + self.test_code_idx.append(code_idx.detach().cpu().numpy()) + + self.log("test_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True) + + def tokenize_ak_array(self, ak_arr, pp_dict, batch_size=256, pad_length=128, hide_pbar=False): + """Tokenize an awkward array of showers. + + Parameters + ---------- + ak_arr : ak.Array + Awkward array of showers, shape (N_showers, , N_features). + pp_dict : dict + Dictionary with preprocessing information. + batch_size : int, optional + Batch size for the evaluation loop. The default is 256. + pad_length : int, optional + Length to which the tokens are padded. The default is 128. + hide_pbar : bool, optional + Whether to hide the progress bar. The default is False. + + Returns + ------- + ak.Array + Awkward array of tokens, shape (N_showers, ). + """ + + # preprocess the ak_arrary + ak_arr = ak_select_and_preprocess(ak_arr, pp_dict=pp_dict) + ak_arr_padded, mask = ak_pad(ak_arr, maxlen=pad_length, return_mask=True) + # convert to numpy + arr = ak_to_np_stack(ak_arr_padded, names=pp_dict.keys()) + # convert to torch tensor + x = torch.from_numpy(arr).float() + mask = torch.from_numpy(mask.to_numpy()).float() + + codes = [] + dataset = TensorDataset(x, mask) + dataloader = DataLoader(dataset, batch_size=batch_size) + + with torch.no_grad(): + if not hide_pbar: + pbar = tqdm(dataloader) + else: + pbar = dataloader + for i, (x_batch, mask_batch) in enumerate(pbar): + # move to device + x_batch = x_batch.to(self.device) + mask_batch = mask_batch.to(self.device) + x_particle_reco, vq_out = self.forward(x_batch, mask_batch) + code = vq_out["q"] + codes.append(code) + codes = torch.cat(codes, dim=0).detach().cpu().numpy() + mask = mask.detach().cpu().numpy() + tokens = np_to_ak(codes, names=["token"], mask=mask)["token"] + return tokens + + def tokenize_shower_ak_array( + self, ak_arr, pp_dict, batch_size=32, pad_length=1700, hide_pbar=False + ): + """Tokenize an awkward array of showers. + + Parameters + ---------- + ak_arr : ak.Array + Awkward array of showers, shape (N_showers, , N_features). + pp_dict : dict + Dictionary with preprocessing information. + batch_size : int, optional + Batch size for the evaluation loop. The default is 64. + pad_length : int, optional + Length to which the tokens are padded. The default is 500. + hide_pbar : bool, optional + Whether to hide the progress bar. The default is False. + + Returns + ------- + ak.Array + Awkward array of tokens, shape (N_showers, ). + """ + + # preprocess the ak_arrary + ak_arr, mask = ak_padding(ak_arr, maxlen=pad_length, energy_threshold=0) + ak_arr = ak_preprocess(ak_arr, pp_dict=pp_dict) + ak_arr["energy"] = ak.where( + ak.to_numpy(ak_arr["energy"]) == -np.inf, + -10000000000, + ak_arr["energy"], + ) + + # convert to numpy + arr = ak_to_np_stack(ak_arr, names=pp_dict.keys()) + # convert to torch tensor + x = torch.from_numpy(arr).float() + mask = torch.from_numpy(mask.to_numpy()).float() + + codes = [] + dataset = TensorDataset(x, mask) + dataloader = DataLoader(dataset, batch_size=batch_size) + + with torch.no_grad(): + if not hide_pbar: + pbar = tqdm(dataloader) + else: + pbar = dataloader + for i, (x_batch, mask_batch) in enumerate(pbar): + # move to device + # allocated_memory = torch.cuda.memory_allocated() / 1024**2 # in MB + # reserved_memory = torch.cuda.memory_reserved() / 1024**2 # in MB + # print(f"Iteration {i}: Allocated Memory: {allocated_memory:.2f} MB, Reserved Memory: {reserved_memory:.2f} MB") + x_batch = x_batch.to(self.device) + mask_batch = mask_batch.to(self.device) + x_particle_reco, vq_out = self.forward(x_batch, mask_batch) + code = vq_out["q"] + code = code.detach().cpu() + codes.append(code) + del code + + torch.cuda.empty_cache() + codes = torch.cat(codes, dim=0).detach().cpu().numpy() + mask = mask.detach().cpu().numpy() + tokens = np_to_ak(codes, names=["token"], mask=mask)["token"] + return tokens + + def reconstruct_ak_tokens( + self, tokens_ak, pp_dict, batch_size=256, pad_length=128, hide_pbar=False + ): + """Reconstruct tokenized awkward array. + + Parameters + ---------- + tokens_ak : ak.Array + Awkward array of tokens, shape (N_showers, ). + pp_dict : dict + Dictionary with preprocessing information. + batch_size : int, optional + Batch size for the evaluation loop. The default is 256. + pad_length : int, optional + Length to which the tokens are padded. The default is 128. + hide_pbar : bool, optional + Whether to hide the progress bar. The default is False. + + Returns + ------- + ak.Array + Awkward array of reconstructed showers, shape (N_showers, , N_features). + """ + + self.model.eval() + + tokens, mask = ak_pad(tokens_ak, maxlen=pad_length, return_mask=True) + tokens = torch.from_numpy(tokens.to_numpy()).long() + mask = torch.from_numpy(mask.to_numpy()).float() + + x_reco = [] + dataset = TensorDataset(tokens, mask) + dataloader = DataLoader(dataset, batch_size=batch_size) + + codebook = self.model.vqlayer.codebook.weight + + # if the codebook has an affine transform, apply it + # before using it to reconstruct the data + # see https://github.com/minyoungg/vqtorch/blob/main/vqtorch/nn/vq.py#L102-L104 + if hasattr(self.model.vqlayer, "affine_transform"): + codebook = self.model.vqlayer.affine_transform(codebook) + + last_batch = None + with torch.no_grad(): + if not hide_pbar: + pbar = tqdm(dataloader) + else: + pbar = dataloader + for i, (tokens_batch, mask_batch) in enumerate(pbar): + # move to device + tokens_batch = tokens_batch.to(self.device) + mask_batch = mask_batch.to(self.device) + try: + z_q = F.embedding(tokens_batch, codebook) + except Exception as e: # noqa: E722 + print(f"Error in embedding: {e}") + print("batch shape", tokens_batch.shape) + print("batch max", tokens_batch.max()) + print("batch min", tokens_batch.min()) + + if last_batch is not None: + break + + if hasattr(self.model, "latent_projection_out"): + x_reco_batch = self.model.latent_projection_out(z_q) * mask_batch.unsqueeze(-1) + x_reco_batch = self.model.decoder_normformer(x_reco_batch, mask=mask_batch) + x_reco_batch = self.model.output_projection( + x_reco_batch + ) * mask_batch.unsqueeze(-1) + elif hasattr(self.model, "decoder"): + x_reco_batch = self.model.decoder(z_q) + else: + raise ValueError("Unknown model structure. Cannot reconstruct.") + x_reco.append(x_reco_batch) + + x_reco = torch.cat(x_reco, dim=0).detach().cpu().numpy() + x_reco_ak = np_to_ak(x_reco, names=pp_dict.keys(), mask=mask.detach().cpu().numpy()) + x_reco_ak = ak_select_and_preprocess(x_reco_ak, pp_dict, inverse=True) + + return x_reco_ak + + def reconstruct_shower_ak_tokens( + self, tokens_ak, pp_dict, batch_size=32, pad_length=1700, hide_pbar=False + ): + """Reconstruct tokenized awkward array. + + Parameters + ---------- + tokens_ak : ak.Array + Awkward array of tokens, shape (N_showers, ). + pp_dict : dict + Dictionary with preprocessing information. + batch_size : int, optional + Batch size for the evaluation loop. The default is 256. + pad_length : int, optional + Length to which the tokens are padded. The default is 128. + hide_pbar : bool, optional + Whether to hide the progress bar. The default is False. + + Returns + ------- + ak.Array + Awkward array of reconstructed showers, shape (N_showers, , N_features). + """ + + self.model.eval() + + # preprocess the tokens + tokens, mask = ak_pad(tokens_ak, maxlen=pad_length, return_mask=True) + tokens = torch.from_numpy(tokens.to_numpy()).long() + mask = torch.from_numpy(mask.to_numpy()).float() + + x_reco = [] + dataset = TensorDataset(tokens, mask) + dataloader = DataLoader(dataset, batch_size=batch_size) + + codebook = self.model.vqlayer.codebook.weight + + # if the codebook has an affine transform, apply it + # before using it to reconstruct the data + # see https://github.com/minyoungg/vqtorch/blob/main/vqtorch/nn/vq.py#L102-L104 + if hasattr(self.model.vqlayer, "affine_transform"): + print("applying affine transform") + codebook = self.model.vqlayer.affine_transform(codebook) + print("applied affine transform") + + last_batch = None + with torch.no_grad(): + if not hide_pbar: + pbar = tqdm(dataloader) + else: + pbar = dataloader + for i, (tokens_batch, mask_batch) in enumerate(pbar): + # move to device + tokens_batch = tokens_batch.to(self.device) + mask_batch = mask_batch.to(self.device) + # allocated_memory = torch.cuda.memory_allocated() / 1024**2 # in MB + # reserved_memory = torch.cuda.memory_reserved() / 1024**2 # in MB + # print(f"Iteration {i}: Allocated Memory: {allocated_memory:.2f} MB, Reserved Memory: {reserved_memory:.2f} MB") + + try: + z_q = F.embedding(tokens_batch, codebook) + except Exception as e: # noqa: E722 + print(f"Error in embedding: {e}") + print("batch shape", tokens_batch.shape) + print("batch max", tokens_batch.max()) + print("batch min", tokens_batch.min()) + + if last_batch is not None: + break + + if hasattr(self.model, "latent_projection_out"): + x_reco_batch = self.model.latent_projection_out(z_q) * mask_batch.unsqueeze(-1) + x_reco_batch = self.model.decoder_normformer(x_reco_batch, mask=mask_batch) + x_reco_batch = self.model.output_projection( + x_reco_batch + ) * mask_batch.unsqueeze(-1) + elif hasattr(self.model, "decoder"): + x_reco_batch = self.model.decoder(z_q) + else: + raise ValueError("Unknown model structure. Cannot reconstruct.") + x_reco_batch = x_reco_batch.detach().cpu() + x_reco.append(x_reco_batch) + del x_reco_batch + + x_reco = torch.cat(x_reco, dim=0).detach().cpu().numpy() + x_reco_ak = np_to_ak(x_reco, names=pp_dict.keys(), mask=mask.detach().cpu().numpy()) + x_reco_ak = ak_preprocess(x_reco_ak, pp_dict, inverse=True) + + return x_reco_ak + + def on_validation_end(self) -> None: + logger.info("`on_validation_end` called.") + """Lightning hook that is called when a validation epoch ends.""" + # self.concat_validation_loop_predictions() + + def concat_validation_loop_predictions(self) -> None: + self.val_x_original_concat = np.concatenate(self.val_x_original) + self.val_x_reco_concat = np.concatenate(self.val_x_reco) + self.val_mask_concat = np.concatenate(self.val_mask) + self.val_code_idx_concat = np.concatenate(self.val_code_idx) + + def concat_test_loop_predictions(self) -> None: + self.test_x_original_concat = np.concatenate(self.test_x_original) + self.test_x_reco_concat = np.concatenate(self.test_x_reco) + self.test_mask_concat = np.concatenate(self.test_mask) + self.test_code_idx_concat = np.concatenate(self.test_code_idx) + + def on_test_end(self): + logger.info("`on_test_end` called.") + # self.concat_test_loop_predictions() + + def configure_optimizers(self) -> Dict[str, Any]: + """Configures optimizers and learning-rate schedulers to be used for training. + + Normally you'd need one, but in the case of GANs or similar you might need multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + }, + } + + return {"optimizer": optimizer} + + +def sample_shower_constituents(data_constituents, N, data_mask=None): + """Sample N shower constituents from the data. + + Parameters + ---------- + data_constituents : Tensor + Tensor of shape (N_constituents, N_features) containing the shower constituents. + N : int + Number of constituents to sample. + """ + idx = torch.randint(0, len(data_constituents), size=(N,)) + if data_mask is not None: + return data_constituents[idx], data_mask[idx] + return data_constituents[idx] + + +def vqvae_loss(model, x, mask, factor_codebook=0.1, factor_commit=0.05): + """Loss function. + + Parameters + ---------- + model : nn.Module + The model. + x : Tensor + The input data. + factor_codebook : float, optional + Weight for the codebook learning loss. The default is 0.1. + factor_commit : float, optional + Weight for the commitment loss. The default is 0.05. + """ + r, z_e, z_q1, idx = model(x, mask=mask) + # create detached copies of the embeddings and codes to calculate the codebook + # learning loss and the commitment loss + z_e_stopped = z_e.clone().detach() + z_q_stopped = z_q1.clone().detach() + + if mask is not None: + z_e_stopped = z_e_stopped[mask == 1] + z_q_stopped = z_q_stopped[mask == 1] + r = r[mask == 1] + x = x[mask == 1] + z_e = z_e[mask == 1] + z_q1 = z_q1[mask == 1] + idx = idx[mask == 1] + + # fmt: off + # reconstruction loss: MSE between input and reconstructed data + mse_reco = torch.nn.functional.mse_loss(r, x) # get good reco + # codebook learning loss: MSE between embeddings and codes + # --> moves the codebook vectors towards the "fixed" embeddings + mse_cbk = torch.nn.functional.mse_loss(z_e_stopped, z_q1) + # commitment loss: MSE between embeddings and detached codes + # --> moves the embeddings towards the "fixed" codebook vectors + mse_commit = torch.nn.functional.mse_loss(z_e, z_q_stopped) + # fmt: on + + # the total loss is a weighted sum of the three losses + loss = mse_reco + factor_codebook * mse_cbk + factor_commit * mse_commit + # loss = mse_reco # + factor_codebook * mse_cbk + factor_commit * mse_commit + return loss + + +def plot_model(model, samples, device="cuda", n_examples_to_plot=2000, masks=None, saveas=None): + """Visualize the model. + + Parameters + ---------- + model : nn.Module + The model. + samples : Tensor + The input data. + device : str, optional + Device to use. The default is "cuda". + n_examples_to_plot : int, optional + Number of examples to plot. The default is 200. + """ + + samples = samples.to(device) + model = model.to(device) + + # run the model on the input data + with torch.no_grad(): + # print(f"Model device: {next(model.parameters()).device}") + # print(f"Samples device: {samples.device}") + r, vq_out = model(samples, masks) + z_q = vq_out["z_q"] + z_e = vq_out["z"] + idx = vq_out["q"] + + if masks is not None: + r = r[masks == 1] + z_e = z_e[masks == 1] + z_q = z_q[masks == 1] + idx = idx[masks == 1] + + z_e = z_e.squeeze(1) + z_q = z_q.squeeze(1) + idx = idx.squeeze(1) + + # move r, z_e, z_q, idx to cpu for plotting + r = r.detach().cpu() + z_e = z_e.detach().cpu() + z_q = z_q.detach().cpu() + idx = idx.detach().cpu() + + samples = samples.detach().cpu().numpy() + if masks is not None: + masks = masks.detach().cpu().numpy() + samples = samples[masks == 1] + + # create detached copy of the codebook to plot this + fig, axarr = plt.subplots(1, 5, figsize=(15, 3)) + # axarr = axarr.flatten() + + style_tokens = dict(color="forestgreen") + style_true = dict(color="royalblue") + style_tokens_emb = dict(color="darkorange") + style_true_emb = dict(color="darkorchid") + + ax = axarr[0] + ax.scatter( + z_e[:n_examples_to_plot, 0], + z_e[:n_examples_to_plot, 1], + alpha=0.4, + marker="o", + label="Samples", + **style_true_emb, + ) + ax.scatter( + z_q[:n_examples_to_plot, 0], + z_q[:n_examples_to_plot, 1], + alpha=0.6, + marker="x", + label="Closest tokens", + **style_tokens_emb, + ) + ax.set_xlabel("$e_1$") + ax.set_ylabel("$e_2$") + ax.legend(loc="upper right") + ax.set_title("Embeddings \n(samples and closest tokens)") + + ax = axarr[1] + ax.scatter( + z_e[:n_examples_to_plot, 0], + z_e[:n_examples_to_plot, 2], + alpha=0.2, + s=26, + **style_true_emb, + label="Samples", + ) + ax.scatter( + z_q[:n_examples_to_plot, 0], + z_q[:n_examples_to_plot, 2], + alpha=0.7, + s=26, + **style_tokens_emb, + marker="x", + label="Closest tokens", + ) + ax.set_xlabel("$e_1$") + ax.set_ylabel("$e_3$") + ax.set_title("Embeddings \n(samples and closest token)") + ax.legend(loc="upper right") + + # plot the original sample and the reconstructed sample (the first sample in the batch) + # plot original sample + ax = axarr[2] + ax.scatter( + samples[:n_examples_to_plot, 0], + samples[:n_examples_to_plot, 1], + alpha=0.2, + s=26, + **style_true, + label="Original", + ) + ax.set_xlabel("$x$") + ax.set_ylabel("$y$") + ax.set_title("Original constituents \n(first few in batch)") + # plot reconstructed sample + ax.scatter( + r[:n_examples_to_plot, 0], + r[:n_examples_to_plot, 1], + alpha=0.7, + s=26, + marker="x", + **style_tokens, + label="Reco. token", + ) + ax.set_xlabel("$x$") + ax.set_ylabel("$y$") + ax.set_title("Data space \nTrue vs reconstructed") + ax.legend(loc="upper right") + + # plot true vs reconstructed for deltaR and ptrel + ax = axarr[3] + ax.scatter( + samples[:n_examples_to_plot, 0], + samples[:n_examples_to_plot, 2], + s=26, + alpha=0.2, + **style_true, + label="Original", + ) + ax.scatter( + r[:n_examples_to_plot, 0], + r[:n_examples_to_plot, 2], + s=26, + alpha=0.7, + **style_tokens, + marker="x", + label="Reco. tokens", + ) + ax.set_xlabel("$x$") + ax.set_ylabel("$z$") + ax.legend(loc="upper right") + ax.set_title("Data space \nTrue vs reconstructed") + + # plot the histogram of the codebook indices (i.e. a codebook_size x codebook_size + # histogram with each entry in the histogram corresponding to one sample associated + # with the corresponding codebook entry) + ax = axarr[4] + n_codes = model.vq_kwargs["num_codes"] + bins = np.linspace(-0.5, n_codes + 0.5, n_codes + 1) + ax.hist(idx, bins=bins) + ax.set_title( + "Codebook histogram\n(Each entry corresponds to one sample\nbeing associated with that" + " codebook entry)", + fontsize=8, + ) + + # make empty axes invisible + def is_axes_empty(ax): + return not ( + ax.lines + or ax.patches + or ax.collections + or ax.images + or ax.texts + or ax.artists + or ax.tables + ) + + for ax in axarr.flatten(): + if is_axes_empty(ax): + ax.set_visible(False) + + fig.tight_layout() + plt.show() + if saveas is not None: + fig.savefig(saveas) + + +def plot_loss(loss_history, lr_history, moving_average=100): + if len(loss_history) < moving_average: + print("Not enough steps to plot loss history") + return + fig, ax1 = plt.subplots(figsize=(5, 2)) + ax2 = ax1.twinx() + + # Plot loss history + loss_history = np.array(loss_history) + loss_history = np.convolve(loss_history, np.ones(moving_average), "valid") / moving_average + ax1.plot(loss_history, color="blue") + ax1.set_xlabel("Step") + ax1.set_ylabel("Loss") + ax1.set_yscale("log") + ax1.grid(True, which="both", ls="-", alpha=0.5) + ax1.set_title(f"Loss history (moving average over {moving_average} steps)", fontsize=8) + + # Plot lr history + ax2.plot(lr_history, color="red") + ax2.set_ylabel("Learning Rate") + + fig.tight_layout() + plt.show() + + +def train_vae( + model, + data, + masks=None, + n_steps=10_000, + plot_every=100, + batch_size=256, + loss_moving_avg=100, + lr=5e-3, + device="cuda", + reduce_lr_after=500, + lr_decay_rate=0.999, + factor_codebook=0.1, + factor_commit=0.05, +): + # move model and data to device + model = model.to(device) + data = data.to(device) + + # initialize optimizer + opt = optim.Adam(model.parameters(), lr) + # lr scheduler + # scheduler = optim.lr_scheduler.StepLR(opt, reduce_lr_after, gamma) + # scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=lr_decay_rate) + model.loss_history = [] + model.lr_history = [] + pbar = tqdm(range(n_steps)) + + # Automatic Mixed Precision + # scaler = torch.cuda.amp.GradScaler(enabled=device == "cuda") + + # training loop + for i in pbar: + # plot the model every plot_every steps (before the update to also see the + if masks is not None: + samples, masks = sample_shower_constituents(data, batch_size, data_mask=masks) + masks = masks.to(device) + else: + samples = sample_shower_constituents(data, batch_size) + samples.to(device) + + # plot + # if i % plot_every == 0: + # plot_loss(model.loss_history, model.lr_history, loss_moving_avg) + + model.zero_grad() + opt.zero_grad() + + # with torch.cuda.amp.autocast(enabled=device == "cuda"): + loss = vqvae_loss(model, samples, masks, factor_codebook, factor_commit) + loss.backward() + opt.step() + # scaler.scale(loss).backward() + # scaler.step(opt) + # scaler.update() + + # start with lr decay after reduce_lr_after steps + # if i > reduce_lr_after: + # scheduler.step() + # model.loss_history.append(float(loss)) + # model.lr_history.append(scheduler.get_last_lr()[0]) + # if i > loss_moving_avg: + # loss_avg = np.mean(model.loss_history[-loss_moving_avg:]) + # else: + # loss_avg = np.mean(model.loss_history) + # pbar.set_description(f"Step {i}, loss: {loss_avg:.5f}") + + return model diff --git a/gabbro/plotting/__init__.py b/gabbro/plotting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gabbro/plotting/feature_plotting.py b/gabbro/plotting/feature_plotting.py new file mode 100644 index 0000000..eda03ce --- /dev/null +++ b/gabbro/plotting/feature_plotting.py @@ -0,0 +1,1826 @@ +import json +import math + +import awkward as ak +import matplotlib as mpl +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np +import pandas as pd +import seaborn as sns +import vector +from matplotlib.gridspec import GridSpec +from matplotlib.lines import Line2D +from scipy.stats import wasserstein_distance + +import gabbro.plotting.utils as plot_utils +from gabbro.metrics.utils import quantiled_kl_divergence +from gabbro.plotting.utils import plot_ratios +from gabbro.utils.utils import ( + KL, + find_max_energy_z, + get_COG_ak, + sum_energy_per_layer, + sum_energy_per_radial_distance, + write_distances_to_json, +) + +vector.register_awkward() + + +def binclip(x, bins, dropinf=False): + binfirst_center = bins[0] + (bins[1] - bins[0]) / 2 + binlast_center = bins[-2] + (bins[-1] - bins[-2]) / 2 + if dropinf: + print("Dropping inf") + print("len(x) before:", len(x)) + x = x[~np.isinf(x)] + print("len(x) after:", len(x)) + return np.clip(x, binfirst_center, binlast_center) + + +def get_bin_centers_and_bin_heights_from_hist(hist): + """Return the bin centers and bin heights from a histogram. + + Parameters + ---------- + hist : tuple + The output of matplotlib hist. + + Returns + ------- + bin_centers : array-like + The bin centers. + bin_heights : array-like + The bin heights. + """ + bin_centers = (hist[1][:-1] + hist[1][1:]) / 2 + bin_heights = hist[0] + return bin_centers, bin_heights + + +def plot_hist_with_ratios( + comp_dict: dict, + bins: np.ndarray, + ax_upper: plt.Axes, + ax_ratio: plt.Axes = None, + ref_dict: dict = None, + ratio_range: tuple = None, + xlabel: str = None, + logy: bool = False, + leg_loc: str = "best", + underoverflow: bool = True, + leg_title: str = None, + leg_ncols: int = 1, + return_hist_curve: bool = False, +): + """Plot histograms of the reference and comparison arrays, and their ratio. + + Parameters: + ---------- + ax_upper : plt.Axes + Axes for the upper panel. + ax_ratio : plt.Axes + Axes for the ratio panel. + ref_dict : dict + Dict with {id: {"arr": ..., "hist_kwargs": ...}, ...} of the reference array. + comp_dict : dict + Dict with {id: {"arr": ..., "hist_kwargs": ...}, ...} of the comparison arrays. + bins : np.ndarray + Bin edges for the histograms. + ratio_range : tuple, optional + Range of the y-axis for the ratio plot. + xlabel : str, optional + Label for the x-axis. + logy : bool, optional + Whether to plot the y-axis in log scale. + leg_loc : str, optional + Location of the legend. + underoverflow : bool, optional + Whether to include underflow and overflow bins. Default is True. + leg_title : str, optional + Title of the legend. + leg_ncols : int, optional + Number of columns in the legend. Default is 1. + return_hist_curve : bool, optional + Whether to return the histogram curves in a dict. Default is False. + + Returns + ------- + hist_curve_dict : dict + Dict with {id: (bin_centers, bin_heights), ...} of the histogram curves. + Only returned if `return_hist_curve` is True. Both bin_centers and bin_heights + are array-like. + """ + + legend_handles = [] + hist_curve_dict = {} + + if ref_dict is not None: + ref_arr = list(ref_dict.values())[0] + ref_label = list(ref_dict.keys())[0] + kwargs_ref = dict(histtype="stepfilled", color="k", alpha=0.25, label=ref_label) + + if leg_title is not None: + # plot empty array with alpha 0 to create a legend entry + ax_upper.hist([], alpha=0, label=leg_title) + + kwargs_common = dict(bins=bins, density=True) + if ref_dict is not None: + hist_ref = ax_upper.hist(binclip(ref_arr["arr"], bins), **kwargs_common, **kwargs_ref) + + if ax_ratio is not None: + ax_ratio.axhline(1, color="black", linestyle="--", lw=1) + + # loop over entries in comp_dict and plot them + for i, (arr_id, arr_dict) in enumerate(comp_dict.items()): + kwargs_comp = dict(histtype="step") | arr_dict.get("hist_kwargs", {}) + if "linestyle" in kwargs_comp: + if kwargs_comp["linestyle"] == "dotted": + kwargs_comp["linestyle"] = plot_utils.get_good_linestyles("densely dotted") + hist_comp = ax_upper.hist(binclip(arr_dict["arr"], bins), **kwargs_common, **kwargs_comp) + if return_hist_curve: + hist_curve_dict[arr_id] = get_bin_centers_and_bin_heights_from_hist(hist_comp) + legend_handles.append( + Line2D( + [], + [], + color=kwargs_comp.get("color", "C1"), + lw=kwargs_comp.get("lw", 1), + label=kwargs_comp.get("label", arr_id), + linestyle=kwargs_comp.get("linestyle", "-"), + ) + ) + if ax_ratio is not None: + # calculate and plot ratio + ratio = hist_comp[0] / hist_ref[0] + # duplicate the first entry to avoid a gap in the plot (due to step plot) + ratio = np.append(np.array(ratio[0]), np.array(ratio)) + bin_edges = hist_ref[1] + ax_ratio.step(bin_edges, ratio, where="pre", **arr_dict.get("hist_kwargs", {})) + + ax_upper.legend( + # handles=legend_handles, + loc=leg_loc, + frameon=False, + title=leg_title, + ncol=leg_ncols, + ) + # re-do legend, with the first handle kep and the others replaced by the new list + old_handles, old_labels = ax_upper.get_legend_handles_labels() + new_handles = old_handles[:1] + legend_handles if ref_dict is not None else legend_handles + ax_upper.legend( + handles=new_handles, + loc=leg_loc, + frameon=False, + title=leg_title, + ncol=leg_ncols, + ) + ax_upper.set_ylabel("Normalized") + + ax_upper.set_xlim(bins[0], bins[-1]) + + if ax_ratio is not None: + ax_ratio.set_xlim(bins[0], bins[-1]) + ax_upper.set_xticks([]) + + if ratio_range is not None: + ax_ratio.set_ylim(*ratio_range) + if xlabel is not None: + if ax_ratio is not None: + ax_ratio.set_xlabel(xlabel) + else: + ax_upper.set_xlabel(xlabel) + if logy: + ax_upper.set_yscale("log") + return hist_curve_dict if return_hist_curve else None + + +def plot_two_shower_versions(const1, const2, label1="version1", label2="version2", title=None): + """Plot the constituent and shower features for two shower collections. + + Parameters: + ---------- + const1 : awkward array + Constituents of the first shower collection. + const2 : awkward array + Constituents of the second shower collection. + title : str, optional + Title of the plot. + """ + + showers1 = ak.sum(const1, axis=1) + showers2 = ak.sum(const2, axis=1) + + fig, axarr = plt.subplots(4, 4, figsize=(12, 8)) + histkwargs = dict(bins=100, density=True, histtype="step") + + part_feats = ["pt", "eta", "phi", "mass"] + for i, feat in enumerate(part_feats): + axarr[0, i].hist(ak.flatten(const1[feat]), **histkwargs, label=label1) + axarr[0, i].hist(ak.flatten(const2[feat]), **histkwargs, label=label1) + axarr[0, i].set_xlabel(f"Constituent {feat}") + # plot the difference + axarr[1, i].hist( + ak.flatten(const2[feat]) - ak.flatten(const1[feat]), + **histkwargs, + label=f"{label2} - {label1}", + ) + axarr[1, i].set_xlabel(f"Constituent {feat} resolution") + + shower_feats = ["pt", "eta", "phi", "mass"] + for i, feat in enumerate(shower_feats): + axarr[2, i].hist(getattr(showers1, feat), **histkwargs, label=label1) + axarr[2, i].hist(getattr(showers2, feat), **histkwargs, label=label2) + axarr[2, i].set_xlabel(f"shower {feat}") + axarr[3, i].hist( + getattr(showers2, feat) - getattr(showers1, feat), + **histkwargs, + label=f"{label2} - {label1}", + ) + axarr[3, i].set_xlabel(f"shower {feat} resolution") + + axarr[0, 0].legend(frameon=False) + axarr[1, 0].legend(frameon=False) + axarr[2, 0].legend(frameon=False) + axarr[3, 0].legend(frameon=False) + + if title is not None: + fig.suptitle(title) + + fig.tight_layout() + # plt.show() + return fig, axarr + + +def plot_features( + ak_array_dict, + names=None, + label_prefix=None, + flatten=True, + histkwargs=None, + legend_only_on=None, + legend_kwargs={}, + ax_rows=1, + decorate_ax_kwargs={}, + bins_dict=None, + colors=None, +): + """Plot the features of the constituents or showers. + + Parameters: + ---------- + ak_array_dict : dict of awkward array + Dict with {"name": ak.Array, ...} of the constituents or showers to plot. + names : list of str or dict, optional + Names of the features to plot. Either a list of names, or a dict of {"name": "label", ...}. + label_prefix : str, optional + Prefix for the plot x-axis labels. + flatten : bool, optional + Whether to flatten the arrays before plotting. Default is True. + histkwargs : dict, optional + Keyword arguments passed to plt.hist. + legend_only_on : int, optional + Plot the legend only on the i-th subplot. Default is None. + legend_kwargs : dict, optional + Keyword arguments passed to ax.legend. + ax_rows : int, optional + Number of rows of the subplot grid. Default is 1. + decorate_ax_kwargs : dict, optional + Keyword arguments passed to `decorate_ax`. + bins_dict : dict, optional + Dict of {name: bins} for the histograms. `name` has to be the same as the keys in `names`. + colors : list, optional + List of colors for the histograms. Has to have the same length as the number of arrays. + If shorter, the colors will be repeated. + """ + + default_hist_kwargs = {"density": True, "histtype": "step", "bins": 100} + + # setup colors + if colors is not None: + if len(colors) < len(ak_array_dict): + print( + "Warning: colors list is shorter than the number of arrays. " + "Will use default colors for remaining ones." + ) + colors = colors + [f"C{i}" for i in range(len(ak_array_dict) - len(colors))] + + if histkwargs is None: + histkwargs = default_hist_kwargs + else: + histkwargs = default_hist_kwargs | histkwargs + + # create the bins dict + if bins_dict is None: + bins_dict = {} + # loop over all names - if the name is not in the bins_dict, use the default bins + for name in names: + if name not in bins_dict: + bins_dict[name] = histkwargs["bins"] + + # remove default bins from histkwargs + histkwargs.pop("bins") + + if isinstance(names, list): + names = {name: name for name in names} + + ax_cols = len(names) // ax_rows + 1 + + fig, axarr = plt.subplots(ax_rows, ax_cols, figsize=(3 * ax_cols, 2 * ax_rows)) + axarr = axarr.flatten() + + legend_handles = [] + legend_labels = [] + + for i_label, (label, ak_array) in enumerate(ak_array_dict.items()): + color = colors[i_label] if colors is not None else f"C{i_label}" + legend_labels.append(label) + for i, (feat, feat_label) in enumerate(names.items()): + if flatten: + values = ak.flatten(getattr(ak_array, feat)) + else: + values = getattr(ak_array, feat) + + if not isinstance(bins_dict[feat], int): + values = binclip(values, bins_dict[feat]) + + _, _, patches = axarr[i].hist(values, **histkwargs, bins=bins_dict[feat], color=color) + axarr[i].set_xlabel( + feat_label if label_prefix is None else f"{label_prefix} {feat_label}" + ) + if i == 0: + legend_handles.append( + Line2D( + [], + [], + color=patches[0].get_edgecolor(), + lw=patches[0].get_linewidth(), + label=label, + linestyle=patches[0].get_linestyle(), + ) + ) + + legend_kwargs["handles"] = legend_handles + legend_kwargs["labels"] = legend_labels + legend_kwargs["frameon"] = False + for i, _ax in enumerate(axarr): + if legend_only_on is None: + _ax.legend(**legend_kwargs) + else: + if i == legend_only_on: + _ax.legend(**legend_kwargs) + + plot_utils.decorate_ax(_ax, **decorate_ax_kwargs) + + fig.tight_layout() + return fig, axarr + + +def plot_features_pairplot( + arr, + names=None, + pairplot_kwargs={}, + input_type="ak_constituents", +): + """Plot the features of the constituents or showers using a pairplot. + + Parameters: + ---------- + arr : awkward array or numpy array + Constituents or showers. + part_names : list or dict, optional + List of names of the features to plot, or dict of {"name": "label", ...}. + pairplot_kwargs : dict, optional + Keyword arguments passed to sns.pairplot. + input_type : str, optional + Type of the input array. Can be "ak_constituents", "ak_showers", or "np_flat". + "ak_constituents" is an awkward array of shower constituents of shape `(n_showers, , n_features)`. + "ak_showers" is an awkward array of showers of shape `(n_showers, n_features)`. + "np_flat" is a numpy array of shape `(n_entries, n_features)` + + + Returns: + -------- + pairplot : seaborn.axisgrid.PairGrid + Pairplot object of the features. + """ + + if isinstance(names, list): + names = {name: name for name in names} + + sns.set_style("dark") + # create a dataframe from the awkward array + if input_type == "ak_constituents": + df = pd.DataFrame( + {feat_label: ak.flatten(getattr(arr, feat)) for feat, feat_label in names.items()} + ) + elif input_type == "ak_showers": + df = pd.DataFrame({feat_label: getattr(arr, feat) for feat, feat_label in names.items()}) + elif input_type == "np_flat": + df = pd.DataFrame( + {feat_label: arr[:, i] for i, (feat, feat_label) in enumerate(names.items())} + ) + else: + raise ValueError(f"Invalid input_type: {input_type}") + pairplot = sns.pairplot(df, kind="hist", **pairplot_kwargs) + plt.show() + + # reset the style + plt.rcdefaults() + + return pairplot + + +def plot_shower_features( + generated_features: ak = None, + real_features: ak = None, + colours: list = ["cornflowerblue", "darkorange"], + labels: list = ["Real", "Generated"], +): + """Plot the features of the constituents or showers. + + Parameters: + ---------- + generated_features : awkward array + Features of the generated showers. + real_features : awkward array + Features of the real showers. + """ + + voxel = ak.to_numpy(ak.num(real_features["x"])) + voxel_gen = ak.to_numpy(ak.num(generated_features["x"])) + + shower_energy = ak.to_numpy(ak.sum(real_features["energy"], axis=1)) + shower_energy_gen = ak.to_numpy(ak.sum(generated_features["energy"], axis=1)) + + max_z = find_max_energy_z(real_features["energy"], real_features["z"]) + max_z_gen = find_max_energy_z(generated_features["energy"], generated_features["z"]) + + x_zero = ak.to_numpy(get_COG_ak(real_features["x"], real_features["energy"])) + y_zero = ak.to_numpy(get_COG_ak(real_features["y"], real_features["energy"])) + z_zero = ak.to_numpy(get_COG_ak(real_features["z"], real_features["energy"])) + + x_zero_gen = ak.to_numpy(get_COG_ak(generated_features["x"], generated_features["energy"])) + y_zero_gen = ak.to_numpy(get_COG_ak(generated_features["y"], generated_features["energy"])) + z_zero_gen = ak.to_numpy(get_COG_ak(generated_features["z"], generated_features["energy"])) + + x = ak.flatten(real_features["x"]).to_numpy() + y = ak.flatten(real_features["y"]).to_numpy() + z = ak.flatten(real_features["z"]).to_numpy() + energy = ak.flatten(real_features["energy"]).to_numpy() + + x_gen = ak.flatten(generated_features["x"]).to_numpy() + y_gen = ak.flatten(generated_features["y"]).to_numpy() + z_gen = ak.flatten(generated_features["z"]).to_numpy() + energy_gen = ak.flatten(generated_features["energy"]).to_numpy() + + x_bin_min = min(x) - 1.5 + x_bin_max = max(x) + 2.5 + y_bin_min = x_bin_min + y_bin_max = x_bin_max + z_bin_min = x_bin_min + z_bin_max = x_bin_max + + fig = plt.figure(figsize=(18, 12), facecolor="white") + gs = GridSpec(2, 3) + ############################################################ + # First Histogram - Energy Plots + ############################################################ + + bins = np.logspace(np.log(0.1), np.log(max(energy)), 150, base=np.e) + ax0 = fig.add_subplot(gs[0]) + ax0.set_title("Visible Energy") + ax0.hist( + [energy, energy_gen], + bins=bins, + histtype="step", + lw=2, + alpha=0.5, + label=labels, + color=colours, + ) + wasserstein_dist = wasserstein_distance(energy, energy_gen) + kl_divergence = KL(energy, energy_gen, bins) + + ax0.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax0.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax0.set_xlabel("Visible energy (MeV)") + ax0.set_ylabel("a.u.") + ax0.legend(loc="upper right") + ax0.set_xscale("log") + ax0.set_yscale("log") + + # Energy Sum Histogram + ax1 = fig.add_subplot(gs[3]) + ax1.set_title("Energy Sum") + data1 = shower_energy + data2 = shower_energy_gen + ax1.hist( + [data1, data2], + bins=30, + histtype="step", + lw=2, + alpha=1.0, + label=labels, + color=colours, + ) + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, 30) + ax1.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax1.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax1.set_xlabel("Visible energy sum (MeV)") + ax1.set_ylabel("a.u.") + ax1.legend(loc="upper right") + + # z-start-layer + + # Create a new figure + ax2 = fig.add_subplot(gs[4]) + ax2.set_title("z start layer") + step = math.ceil(z_bin_max / 11) + bins = np.arange(z_bin_min, z_bin_max) + ax2.hist( + [max_z, max_z_gen], + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + color=colours, + label=labels, + ) + wasserstein_dist = wasserstein_distance(max_z, max_z_gen) + kl_divergence = KL(max_z, max_z_gen, bins) + ax2.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax2.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax2.set_xlabel("z") + ax2.set_ylabel("a.u.") + ax2.ticklabel_format( + axis="y", style="sci", scilimits=(0, 0), useMathText=True + ) # Set scientific notation for y-axis + + ax2.set_xticks(np.arange(z_bin_min, z_bin_max, step)) + ax2.legend(loc="upper right") + + # Plot for only y-scale logarithmic + ax3 = fig.add_subplot(gs[1]) + ax3.set_title("Visible Energy") + ax3.hist( + [energy, energy_gen], + bins=150, + histtype="step", + lw=2, + alpha=0.5, + label=labels, + color=colours, + ) + + wasserstein_dist = wasserstein_distance(energy, energy_gen) + kl_divergence = KL(energy, energy_gen, 150) + ax3.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax3.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + + ax3.set_xlabel("Visible energy (MeV)") + ax3.set_ylabel("a.u.") + ax3.legend(loc="upper right") + ax3.set_yscale("log") + + # Plot for only x-scale logarithmic + ax4 = fig.add_subplot(gs[2]) + bins = np.logspace(np.log(0.1), np.log(max(energy)), 150, base=np.e) + ax4.set_title("Visible Energy") + ax4.hist( + [energy, energy_gen], + bins, + histtype="step", + lw=2, + alpha=0.5, + label=labels, + color=colours, + ) + + ax4.set_xlabel("Visible energy (MeV)") + ax4.set_ylabel("a.u.") + ax4.legend(loc="upper right") + ax4.set_xscale("log") + wasserstein_dist = wasserstein_distance(energy, energy_gen) + kl_divergence = KL(energy, energy_gen, bins) + ax4.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax4.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + + # Number of Hits Histogram + ax5 = fig.add_subplot(gs[5]) + ax5.set_title("Number of Hits") + ax5.hist( + [voxel, voxel_gen], bins=30, histtype="step", lw=2, alpha=1.0, label=labels, color=colours + ) + ax5.set_xlabel("n_hits") + ax5.set_ylabel("a.u.") + ax5.legend(loc="upper right") + wasserstein_dist = wasserstein_distance(voxel, voxel_gen) + kl_divergence = KL(voxel, voxel_gen, 30) + ax5.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax5.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + + fig.suptitle("Distributions") + + fig.tight_layout() + + ############################################################ + # Second Histogram --- x,y,z Distribution and 0th Moment + ############################################################ + + fig_COG = plt.figure(figsize=(18, 12), facecolor="white") + gs2 = GridSpec(2, 3) + + ax0 = fig_COG.add_subplot(gs2[0]) + + average = sum(x_zero) / len(x_zero) + if average < 1: + offset = 0.4 + else: + offset = average * 0.05 + + if average < 0: + bins = np.arange(-average - offset, -average + offset, 0.005) + else: + bins = np.arange(average - offset, average + offset, 0.005) + + ax0.set_title("[X] distribution") + ax0.hist( + [x_zero, x_zero_gen], + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + color=colours, + label=labels, + ) + data1 = x_zero + data2 = x_zero_gen + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, bins) + ax0.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax0.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax0.set_xlabel("X") + ax0.set_ylabel("a.u.") + ax0.legend(loc="upper right") + + ax1 = fig_COG.add_subplot(gs2[1]) + average = sum(y_zero) / len(y_zero) + if average < 1: + offset = 0.4 + else: + offset = average * 0.05 + + if average < 0: + bins = np.arange(-average - offset, -average + offset, 0.005) + else: + bins = np.arange(average - offset, average + offset, 0.005) + ax1.set_title("[Y] distribution") + ax1.hist( + [y_zero, y_zero_gen], + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + color=colours, + label=labels, + ) + + wasserstein_dist = wasserstein_distance(y_zero, y_zero_gen) + kl_divergence = KL(y_zero, y_zero_gen, bins) + ax1.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax1.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax1.set_xlabel("Y") + ax1.set_ylabel("a.u.") + ax1.legend(loc="upper right") + + average = sum(z_zero) / len(z_zero) + if average < 1: + offset = 1.4 + else: + offset = average * 0.45 + + if average < 0: + bins = np.arange(-average - offset, -average + offset, 0.05) + else: + bins = np.arange(average - offset, average + offset, 0.05) + ax2 = fig_COG.add_subplot(gs2[2]) + ax2.set_title("[Z] distribution") + ax2.hist( + [z_zero, z_zero_gen], + bins=bins, + histtype="step", + lw=2, + alpha=1.0, + color=colours, + label=labels, + ) + + wasserstein_dist = wasserstein_distance(z_zero, z_zero_gen) + kl_divergence = KL(z_zero, z_zero_gen, bins) + ax2.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax2.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax2.set_xlabel("Z") + ax2.set_ylabel("a.u.") + ax2.legend(loc="upper right") + + # X Distribution + ax3 = fig_COG.add_subplot(gs2[3]) + ax3.set_title("[x] distribution") + ax3.yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True)) + ax3.hist( + [x, x_gen], + bins=np.arange(x_bin_min, x_bin_max), + histtype="step", + lw=2, + alpha=1.0, + color=colours, + label=labels, + ) + + data1 = x + data2 = x_gen + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, np.arange(x_bin_min, x_bin_max)) + ax3.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax3.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax3.set_xlabel("[x]") + ax3.set_ylabel("Number of hits") + ax3.set_xticks(np.arange(x_bin_min, x_bin_max, step)) + ax3.legend(loc="upper right") + + # Y Distribution + ax4 = fig_COG.add_subplot(gs2[4]) + ax4.set_title("[y] distribution") + ax4.yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True)) + ax4.hist( + [y, y_gen], + bins=np.arange(y_bin_min, y_bin_max), + histtype="step", + lw=2, + alpha=1.0, + color=colours, + label=labels, + ) + + data1 = y + data2 = y_gen + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, np.arange(y_bin_min, y_bin_max)) + ax4.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax4.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax4.set_xlabel("[y]") + ax4.set_ylabel("Number of hits") + ax4.set_xticks(np.arange(y_bin_min, y_bin_max, step)) + ax4.legend(loc="upper right") + + # Z Distribution + ax5 = fig_COG.add_subplot(gs2[5]) + ax5.set_title("[z] distribution") + ax5.yaxis.set_major_formatter(plt.ScalarFormatter(useMathText=True)) + ax5.hist( + [z, z_gen], + bins=np.arange(z_bin_min, z_bin_max), + histtype="step", + lw=2, + alpha=1.0, + color=colours, + label=labels, + ) + + data1 = z + data2 = z_gen + wasserstein_dist = wasserstein_distance(data1, data2) + kl_divergence = KL(data1, data2, np.arange(z_bin_min, z_bin_max)) + ax5.text( + 0.05, + 0.95, + f"Wasserstein Distance: {wasserstein_dist:.3f}", + transform=plt.gca().transAxes, + ) + ax5.text(0.05, 0.90, f"KL Divergence: {kl_divergence:.3f}", transform=plt.gca().transAxes) + ax5.set_xlabel("[z]") + ax5.set_ylabel("Number of hits") + ax5.set_xticks(np.arange(z_bin_min, z_bin_max, step)) + ax5.legend(loc="upper right") + return fig, fig_COG + + +def plot_compare_gen_showers( + feature_sets: list, labels: list = None, colors: list = None, **kwargs +): + """Plots the features of multiple constituent or shower sets. + + Args: + feature_sets: A list of dictionaries, each containing awkward arrays for "x", "y", "z", and "energy" features. + labels: (Optional) A list of labels for the feature sets (defaults to 'Set 1', 'Set 2', etc.). + colors: (Optional) A list of colors for the feature sets (defaults to a matplotlib colormap). + kwargs: Additional keyword arguments to pass to the plotting functions. + """ + + num_sets = len(feature_sets) + + if labels is None: + labels = [f"Set {i + 1}" for i in range(num_sets)] + if colors is None: + colors = plt.cm.get_cmap("tab10").colors # Use matplotlib's colormap + + # Preprocessing & feature extraction + extracted_features = [] + for features in feature_sets: + # Filter voxels with energy > 0.1 + mask = features["energy"] > 0.1 + filtered_features = { + "x": features["x"][mask], + "y": features["y"][mask], + "z": features["z"][mask], + "energy": features["energy"][mask], + } + + extracted_features.append( + { + "voxel": ak.to_numpy(ak.num(filtered_features["x"])), + "energy": ak.flatten(features["energy"]).to_numpy(), # Keep all energies here + "shower_energy": ak.to_numpy(ak.sum(filtered_features["energy"], axis=1)), + "max_z": find_max_energy_z(filtered_features["energy"], filtered_features["z"]), + "x_zero": ak.to_numpy( + get_COG_ak(filtered_features["x"], filtered_features["energy"]) + ), + "y_zero": ak.to_numpy( + get_COG_ak(filtered_features["y"], filtered_features["energy"]) + ), + "z_zero": ak.to_numpy( + get_COG_ak(filtered_features["z"], filtered_features["energy"]) + ), + "x": ak.flatten(filtered_features["x"]).to_numpy(), + "y": ak.flatten(filtered_features["y"]).to_numpy(), + "z": ak.flatten(filtered_features["z"]).to_numpy(), + "distance": filtered_features["x"].to_numpy(), # TODO maybe delete this function + "energy_filtered": ak.flatten(filtered_features["energy"]).to_numpy(), + } + ) + + # Plotting (two figures) + mpl.rcParams["xtick.labelsize"] = 15 + mpl.rcParams["ytick.labelsize"] = 15 + # mpl.rcParams['font.size'] = 28 + mpl.rcParams["font.size"] = 10 + mpl.rcParams["legend.frameon"] = False + mpl.rcParams["text.usetex"] = False + mpl.rcParams["font.family"] = "sans-serif" + + fig = plt.figure(figsize=(18, 12), facecolor="white") + fig_COG = plt.figure(figsize=(18, 12), facecolor="white") + + # Call the plotting functions, passing the feature sets, labels, and colors + plot_distributions(fig, extracted_features, labels, colors, **kwargs) + plot_cog_and_spatial(fig_COG, extracted_features, labels, colors, **kwargs) + fig_COG.tight_layout() + fig.tight_layout() + + return fig, fig_COG + + +def plot_distributions(fig, features_list, labels, colors, **kwargs): + """Plots the distributions of energy, energy sum, number of hits, and z start layer.""" + gs = fig.add_gridspec( + 5, 3, wspace=0.3, hspace=0.1, height_ratios=[3, 0.8, 0.9, 3, 0.8] + ) # 3 rows for the different distributions + # print("Plotting distributions:max(features_list[z])", max(features_list["z"])) + + # Binning setup (adjust ranges and bins as needed for your data) + fontsize_labels = 18 + + first_features = features_list[0] + x_max = max(first_features["x"]) + + if x_max < 12: # smaller dataset + energy_sum = 2000 + energy = 140 + z = 10.5 + n_hits = 400 + else: + energy_sum = 2000 + energy = 70 + z = 31.5 + n_hits = 1700 + + energy_bins = np.logspace(np.log10(0.01), np.log10(energy), 50) # Logarithmic bins for energy + energy_sum_bins = np.arange(0, energy_sum, 50) + max_z_bins = np.arange(-1.5, z, 1) # Linear bins for z start layer + voxel_bins = np.arange(0, n_hits, 50) # The number of hits + dist_e_bins = np.arange(0, 21, 1) # The distance + + # Energy Distribution + ax5 = fig.add_subplot(gs[0, 0]) # vis cell energy x log + ax0 = fig.add_subplot(gs[0, 1]) # vis cell energy x/y log + ax4 = fig.add_subplot(gs[0, 2]) # energy over distance + ax1 = fig.add_subplot(gs[3, 0]) # energy sum + ax2 = fig.add_subplot(gs[3, 1]) # z start layer + ax3 = fig.add_subplot(gs[3, 2]) # number of hits + + # looping through all input data to be plottet on the different distributions + for features, label, color in zip(features_list, labels, colors): + histtype = "stepfilled" if features is features_list[0] else "step" + edgecolor = "gray" if histtype == "stepfilled" else color + linestyle = ( + "--" + if len(features_list) > 2 + and ( + features is features_list[2] + or len(features_list) > 3 + and (features is features_list[3]) + ) + else "-" + ) + alpha = 0.95 + ax0.hist( + features["energy"], + bins=energy_bins, + linestyle=linestyle, + histtype=histtype, + edgecolor=edgecolor, + lw=2, + alpha=alpha, + label=label, + color=color, + ) + ax1.hist( + features["shower_energy"], + bins=energy_sum_bins, + histtype=histtype, + edgecolor=edgecolor, + linestyle=linestyle, + lw=2, + alpha=alpha, + label=label, + color=color, + ) + ax2.hist( + features["max_z"], + bins=max_z_bins, + histtype=histtype, + edgecolor=edgecolor, + linestyle=linestyle, + lw=2, + alpha=alpha, + label=label, + color=color, + ) + ax3.hist( + features["voxel"], + bins=voxel_bins, + histtype=histtype, + edgecolor=edgecolor, + linestyle=linestyle, + lw=2, + alpha=alpha, + label=label, + color=color, + ) + ax4.hist( + features["distance"], + bins=dist_e_bins, + weights=features["energy_filtered"], + histtype=histtype, + edgecolor=edgecolor, + linestyle=linestyle, + lw=2, + alpha=alpha, + label=label, + color=color, + ) + ax5.hist( + features["energy"], + bins=energy_bins, + histtype=histtype, + edgecolor=edgecolor, + linestyle=linestyle, + lw=2, + alpha=alpha, + label=label, + color=color, + ) + # ax0.set_xlabel("Energy (MeV)") + ax0.set_ylabel("a.u.", fontsize=fontsize_labels) + ax0.set_xscale("log") + ax0.set_yscale("log") + ax0.axvspan(0.01, 0.1, facecolor="lightgray", alpha=0.5, hatch="/") + ax0.tick_params(axis="x", labelbottom=False) + ymin, ymax = ax0.get_ylim() + new_ymax = ymax + 62 * ymax + ax0.set_ylim(ymin, new_ymax) + + # Create twin axis for ratio plot + ax0_twin = fig.add_subplot(gs[1, 1], sharex=ax0) + mask = [0.7, 1.3] + plot_ratios(ax0_twin, features_list, energy_bins, "energy", labels, colors, mask=mask) + # Add horizontal line at y=1 + ax0_twin.axhline(y=1, color="gray", linestyle="--") + ax0_twin.axvspan(0.01, 0.1, facecolor="lightgray", alpha=0.5, hatch="/") + + # Set y-axis limits + ax0_twin.set_ylim(mask) + ax0_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax0_twin.set_xlabel("visible cell energy [MeV]", fontsize=fontsize_labels) + ax0_twin.tick_params(axis="y", labelcolor="black") + + # Energy Sum Distribution + ax1.set_ylabel("a.u.", fontsize=fontsize_labels) + ax1.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) + ax1.tick_params(axis="x", labelbottom=False) + ymin, ymax = ax1.get_ylim() + new_ymax = ymax + 0.35 * ymax + ax1.set_ylim(ymin, new_ymax) + # Create twin axis for ratio plot + ax1_twin = fig.add_subplot(gs[4, 0], sharex=ax1) + plot_ratios( + ax1_twin, features_list, energy_sum_bins, "shower_energy", labels, colors, mask=mask + ) + ax1_twin.axhline(y=1, color="gray", linestyle="--") + # Set y-axis limits + ax1_twin.set_ylim(mask) + ax1_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax1_twin.set_xlabel("energy sum [MeV]", fontsize=fontsize_labels) + ax1_twin.tick_params(axis="y", labelcolor="black") + + # Z Start Layer Distribution + ax2.set_ylabel("a.u.", fontsize=fontsize_labels) + ax2.tick_params(axis="x", labelbottom=False) + ax2.set_yscale("log") + ymin, ymax = ax2.get_ylim() + new_ymax = ymax + 64 * ymax + ax2.set_ylim(ymin, new_ymax) + # Create twin axis for ratio plot + ax2_twin = fig.add_subplot(gs[4, 1], sharex=ax2) + mask = [0.6, 1.4] + plot_ratios(ax2_twin, features_list, max_z_bins, "max_z", labels, colors, mask=mask) + ax2_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + ax2_twin.set_ylim(mask) + ax2_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax2_twin.set_xlabel("shower start layer [layer]", fontsize=fontsize_labels) + ax2_twin.tick_params(axis="y", labelcolor="black") + + # Number of Hits (Voxel) Distribution + ax3.set_ylabel("# showers", fontsize=fontsize_labels) + ax3.tick_params(axis="x", labelbottom=False) + ax3.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax3.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) + ymin, ymax = ax3.get_ylim() + new_ymax = ymax + 0.44 * ymax + ax3.set_ylim(ymin, new_ymax) + + # Create twin axis for ratio plot + ax3_twin = fig.add_subplot(gs[4, 2], sharex=ax3) + plot_ratios(ax3_twin, features_list, voxel_bins, "voxel", labels, colors, mask=mask) + + ax3_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + ax3_twin.set_ylim(mask) + ax3_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax3_twin.set_xlabel("number of hits", fontsize=fontsize_labels) + ax3_twin.tick_params(axis="y", labelcolor="black") + + # Energy Distribution per Layer + # FIXME + ax4.set_ylabel("energy [MeV]", fontsize=fontsize_labels) + ax4.set_yscale("log") + ax4.tick_params(axis="x", labelbottom=False) + ymin, ymax = ax4.get_ylim() + new_ymax = ymax + 0.18 * ymax + ax4.set_ylim(ymin, new_ymax) + + # Create twin axis for ratio plot + ax4_twin = fig.add_subplot(gs[1, 2], sharex=ax4) + mask = [0.7, 1.3] + plot_ratios( + ax4_twin, + features_list, + dist_e_bins, + "distance", + labels, + colors, + mask=mask, + weights="energy_filtered", + ) + + ax4_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + ax4_twin.set_ylim(mask) + ax4_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax4_twin.set_xlabel("radius [pixels]", fontsize=fontsize_labels) + ax4_twin.tick_params(axis="y", labelcolor="black") + + # Energy Distribution only x-logarithmic + ax5.set_ylabel("a.u.", fontsize=fontsize_labels) + ax5.set_xscale("log") + ax5.tick_params(axis="x", labelbottom=False) + ax5.axvspan(0.01, 0.1, facecolor="lightgray", alpha=0.5, hatch="/") + ax5.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax5.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) + ymin, ymax = ax0.get_ylim() + new_ymax = ymax + 0.34 * ymax + ax0.set_ylim(ymin, new_ymax) + # Create twin axis for ratio plot + ax5_twin = fig.add_subplot(gs[1, 0], sharex=ax5) + + plot_ratios(ax5_twin, features_list, energy_bins, "energy", labels, colors, mask=mask) + + ax5_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + ax5_twin.set_ylim(mask) + ax5_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax5_twin.set_xlabel("visible cell energy [MeV]", fontsize=fontsize_labels) + ax5_twin.tick_params(axis="y", labelcolor="black") + ax5_twin.axvspan(0.01, 0.1, facecolor="lightgray", alpha=0.5, hatch="/") + + # Add legend to the first subplot (energy) + legend_elements = [ + Line2D([0], [0], color=color, lw=2, label=label) for color, label in zip(colors, labels) + ] + # Create the figure + ax5.legend(handles=legend_elements, loc="upper right") + ax2.legend(handles=legend_elements, loc="upper right") + ax3.legend(handles=legend_elements, loc="upper right") + ax0.legend(handles=legend_elements, loc="upper right") + ax4.legend(handles=legend_elements, loc="upper right") + ax1.legend(handles=legend_elements, loc="upper right") + + # Add divergence metrics to the plots + if len(features_list) == 2: + for ax in [ax0, ax4, ax5]: + add_divergence_metrics( + ax, + features_list[0]["energy"], + features_list[1]["energy"], + energy_bins, + "energy", + fontsize=fontsize_labels - 2, + **kwargs, + ) + add_divergence_metrics( + ax1, + features_list[0]["shower_energy"], + features_list[1]["shower_energy"], + energy_sum_bins, + "energy_sum", + fontsize=fontsize_labels - 2, + **kwargs, + ) + add_divergence_metrics( + ax2, + features_list[0]["max_z"], + features_list[1]["max_z"], + max_z_bins, + "max_z", + fontsize=fontsize_labels - 2, + **kwargs, + ) + add_divergence_metrics( + ax3, + features_list[0]["voxel"], + features_list[1]["voxel"], + voxel_bins, + "n_hits", + fontsize=fontsize_labels - 2, + **kwargs, + ) + + +def plot_cog_and_spatial(fig_COG, features_list, labels, colors, **kwargs): + """Plots the COG distributions and spatial distributions of x, y, and z.""" + gs2 = fig_COG.add_gridspec( + 5, 3, wspace=0.3, hspace=0.1, height_ratios=[3, 0.8, 0.9, 3, 0.8] + ) # 3 rows for the different distributions + fontsize_labels = 18 + legend_elements = [ + Line2D([0], [0], color=color, lw=2, label=label) for color, label in zip(colors, labels) + ] + + # COG Distribution Plots + for i in range(3): + ax = fig_COG.add_subplot(gs2[0, i]) + ax_twin = fig_COG.add_subplot(gs2[1, i], sharex=ax) + string = "x_zero" if i == 0 else "y_zero" if i == 1 else "z_zero" + for features, label, color in zip(features_list, labels, colors): + histtype = "stepfilled" if features is features_list[0] else "step" + edgecolor = "gray" if histtype == "stepfilled" else color + linestyle = ( + "--" + if len(features_list) > 2 + and ( + features is features_list[2] + or len(features_list) > 3 + and (features is features_list[3]) + ) + else "-" + ) + + data = features[string] + average = np.mean(data) + + if average < 7: # smaller rebinned dataset + average = 4.5 + # for z 1.4, for x and y 0.4 + offset = 1.4 if i == 2 else 0.4 + # for z 0.2, for x and y 0.05 + steps = 0.2 if i == 2 else 0.05 + + else: # Full resolution dataset + average = 14.5 + # for z 1.4, for x and y 0.4 + offset = 8 if i == 2 else 0.4 + # for z 0.25, for x and y 0.05 + steps = 0.5 if i == 2 else 0.025 + + bins = ( + np.arange(average - offset, average + offset, steps) + if average >= 0 + else np.arange(-average - offset, -average + offset, steps) + ) + ax.hist( + data, + bins=bins, + histtype=histtype, + lw=2, + alpha=0.8, + linestyle=linestyle, + label=label, + edgecolor=edgecolor, + color=color, + ) + mask = [0.5, 1.5] + plot_ratios(ax_twin, features_list, bins, string, labels, colors, mask=mask) + ax_twin.set_xlabel( + f"center of gravity {chr(ord('X')+i)} [voxel]", fontsize=fontsize_labels + ) # Extract the dimension (X, Y, or Z) from the title + ax.set_ylabel("# showers", fontsize=fontsize_labels) + ax.tick_params(axis="x", labelbottom=False) + ax.legend(handles=legend_elements, loc="upper right") + ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) + ymin, ymax = ax.get_ylim() + new_ymax = ymax + 0.28 * ymax + ax.set_ylim(ymin, new_ymax) + + ax_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + ax_twin.set_ylim(mask) + ax_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax_twin.set_xlabel( + f"center of gravity {chr(ord('X')+i)} [{'layer' if i == 2 else 'cell'}]", + fontsize=fontsize_labels, + ) + + ax_twin.tick_params(axis="y", labelcolor="black") + if len(features_list) == 2: + add_divergence_metrics( + ax, + features_list[0]["x_zero" if i == 0 else "y_zero" if i == 1 else "z_zero"], + features_list[1]["x_zero" if i == 0 else "y_zero" if i == 1 else "z_zero"], + bins, + "X" if i == 0 else "Y" if i == 1 else "Z", + fontsize=fontsize_labels - 2, + **kwargs, + ) + + # Spatial Distribution Plots + for i in range(3): + ax = fig_COG.add_subplot(gs2[3, i]) + ax_twin = fig_COG.add_subplot(gs2[4, i], sharex=ax) + string = "x" if i == 0 else "y" if i == 1 else "z" + for features, label, color in zip(features_list, labels, colors): + histtype = "stepfilled" if features is features_list[0] else "step" + edgecolor = "gray" if histtype == "stepfilled" else color + linestyle = ( + "--" + if len(features_list) > 2 + and ( + features is features_list[2] + or len(features_list) > 3 + and (features is features_list[3]) + ) + else "-" + ) + bins = np.arange(-0.5, 31.5, 1) + data = features[string] + ax.hist( + data, + bins=bins, + histtype=histtype, + lw=2, + alpha=0.8, + linestyle=linestyle, + label=label, + color=color, + edgecolor=edgecolor, + ) + mask = [0.7, 1.3] + plot_ratios(ax_twin, features_list, bins, string, labels, colors, mask=mask) + ax_twin.set_xlabel( + f"spatial distribution {chr(ord('x')+i)} [{'layer' if i == 2 else 'cell'}]", + fontsize=fontsize_labels, + ) + ax.set_ylabel("a.u.", fontsize=fontsize_labels) + ax.legend(handles=legend_elements, loc="upper right") + ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) + ax.tick_params(axis="x", labelbottom=False) + ymin, ymax = ax.get_ylim() + new_ymax = ymax + 0.28 * ymax + ax.set_ylim(ymin, new_ymax) + ax_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + ax_twin.set_ylim(mask) + ax_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax_twin.tick_params(axis="y", labelcolor="black") + if len(features_list) == 2: + add_divergence_metrics( + ax, + features_list[0]["x" if i == 0 else "y" if i == 1 else "z"], + features_list[1]["x" if i == 0 else "y" if i == 1 else "z"], + bins, + "x" if i == 0 else "y" if i == 1 else "z", + fontsize=fontsize_labels - 2, + **kwargs, + ) + + +def save_metrics_to_file(file_path, metrics): + """Save metrics to a JSON file. + + Parameters: + ---------- + file_path : str + Path to the JSON file. + metrics : dict + Dictionary containing the metrics to save. + """ + try: + with open(file_path) as file: + data = json.load(file) + except FileNotFoundError: + data = [] + + data.append(metrics) + + with open(file_path, "w") as file: + json.dump(data, file, indent=4) + + +def add_divergence_metrics(ax, data1, data2, bins, feature, fontsize, **kwargs): + """Calculates and adds Wasserstein distance and KL divergence to the plot.""" + wasserstein_dist = wasserstein_distance(data1, data2) + bins = int(len(bins)) + kl_divergence = quantiled_kl_divergence(data1, data2, bins, False) + filepath = kwargs.get("filepath", None) + weights = kwargs.get("weights", None) + n_data = kwargs.get("n_data", None) + transfer_learning = kwargs.get("transfer_learning", False) + + if transfer_learning: + write_distances_to_json( + kl_divergence, wasserstein_dist, filepath, weights, n_data, feature + ) + + ax.text( + 1.0, + 1.05, + f"W-distance: {wasserstein_dist:.2e}", + transform=ax.transAxes, + ha="right", + va="top", + fontsize=fontsize, + ) + ax.text( + 1.0, + 1.10, + f"KLD: {kl_divergence:.2e}", + transform=ax.transAxes, + ha="right", + va="top", + fontsize=fontsize, + ) + + +def plot_paper_plots(feature_sets: list, labels: list = None, colors: list = None, **kwargs): + """Plots the features of multiple constituent or shower sets. + + Args: + feature_sets: A list of dictionaries, each containing awkward arrays for "x", "y", "z", and "energy" features. + labels: (Optional) A list of labels for the feature sets (defaults to 'Set 1', 'Set 2', etc.). + colors: (Optional) A list of colors for the feature sets (defaults to a matplotlib colormap). + kwargs: Additional keyword arguments to pass to the plotting functions. + """ + + num_sets = len(feature_sets) + + if labels is None: + labels = [f"Set {i + 1}" for i in range(num_sets)] + if colors is None: + colors = plt.cm.get_cmap("tab10").colors # Use matplotlib's colormap + + # Preprocessing & feature extraction + features_list = [] + for features in feature_sets: + # Filter voxels with energy > 0.1 + mask = features["energy"] > 0.1 + filtered_features = { + "x": features["x"][mask], + "y": features["y"][mask], + "z": features["z"][mask], + "energy": features["energy"][mask], + } + + features_list.append( + { + "voxel": ak.to_numpy(ak.num(filtered_features["x"])), + "energy": ak.flatten(features["energy"]).to_numpy(), # Keep all energies here + "shower_energy": ak.to_numpy(ak.sum(filtered_features["energy"], axis=1)), + # "max_z": find_max_energy_z(filtered_features["energy"], filtered_features["z"]), + "x_zero": ak.to_numpy( + get_COG_ak(filtered_features["x"], filtered_features["energy"]) + ), + "y_zero": ak.to_numpy( + get_COG_ak(filtered_features["y"], filtered_features["energy"]) + ), + "z_zero": ak.to_numpy( + get_COG_ak(filtered_features["z"], filtered_features["energy"]) + ), + "x": ak.flatten(filtered_features["x"]).to_numpy(), + "y": ak.flatten(filtered_features["y"]).to_numpy(), + "z": ak.flatten(filtered_features["z"]).to_numpy(), + "distance": np.mean( + sum_energy_per_radial_distance( + filtered_features["x"], filtered_features["y"], filtered_features["energy"] + ), + axis=0, + ), + "energy_filtered": ak.flatten(filtered_features["energy"]).to_numpy(), + "energy_per_layer": np.mean( + sum_energy_per_layer(filtered_features["z"], filtered_features["energy"]), + axis=0, + ), + "pixel": np.arange(0, 21) + 0.5, + "hits": np.arange(0, 29) + 0.5, + } + ) + + # Plotting (two figures) + mpl.rcParams["xtick.labelsize"] = 15 + mpl.rcParams["ytick.labelsize"] = 15 + # mpl.rcParams['font.size'] = 28 + mpl.rcParams["font.size"] = 10 + mpl.rcParams["legend.frameon"] = False + mpl.rcParams["text.usetex"] = False + mpl.rcParams["font.family"] = "sans-serif" + + fig = plt.figure(figsize=(18, 12), facecolor="white") + + """Plots the distributions of energy, energy sum, number of hits, and z start layer.""" + gs = fig.add_gridspec( + 5, 3, wspace=0.3, hspace=0.1, height_ratios=[3, 0.8, 0.9, 3, 0.8] + ) # 3 rows for the different distributions + # print("Plotting distributions:max(features_list[z])", max(features_list["z"])) + + # Binning setup (adjust ranges and bins as needed for your data) + fontsize_labels = 18 + + energy_sum = 2000 + energy = 70 + n_hits = 1700 + + energy_bins = np.logspace(np.log10(0.01), np.log10(energy), 50) # Logarithmic bins for energy + energy_sum_bins = np.arange(0, energy_sum, 75) + voxel_bins = np.arange(0, n_hits, 50) # The number of hits + dist_e_bins = np.arange(0, 21, 1) # The distance + bins_cog = np.arange(8, 22, 0.5) + bins_z = np.arange(0, 31.5, 1) + + # Energy Distribution + ax0 = fig.add_subplot(gs[0, 0]) # vis cell energy x/y log + ax1 = fig.add_subplot(gs[0, 1]) # energy sum + ax2 = fig.add_subplot(gs[0, 2]) # number of hits + ax3 = fig.add_subplot(gs[3, 0]) # center of gravity Z + ax4 = fig.add_subplot(gs[3, 1]) # spatial distribution Z + ax5 = fig.add_subplot(gs[3, 2]) # energy over distance + + # looping through all input data to be plottet on the different distributions + for features, label, color in zip(features_list, labels, colors): + histtype = "stepfilled" if features is features_list[0] else "step" + edgecolor = "gray" if histtype == "stepfilled" else color + linestyle = ( + "--" + if len(features_list) > 2 + and ( + features is features_list[2] + or len(features_list) > 3 + and (features is features_list[3]) + ) + else "-" + ) + alpha = 0.95 + ax0.hist( + features["energy"], + bins=energy_bins, + linestyle=linestyle, + histtype=histtype, + edgecolor=edgecolor, + lw=2, + alpha=alpha, + label=label, + color=color, + ) + ax1.hist( + features["shower_energy"], + bins=energy_sum_bins, + histtype=histtype, + edgecolor=edgecolor, + linestyle=linestyle, + lw=2, + alpha=alpha, + label=label, + density=True, + color=color, + ) + ax2.hist( + features["voxel"], + bins=voxel_bins, + histtype=histtype, + edgecolor=edgecolor, + linestyle=linestyle, + lw=2, + alpha=alpha, + label=label, + density=True, + color=color, + ) + ax3.hist( + features["z_zero"], + bins=bins_cog, + histtype=histtype, + lw=2, + alpha=alpha, + linestyle=linestyle, + label=label, + edgecolor=edgecolor, + density=True, + color=color, + ) + ax4.hist( + features["hits"], + bins=bins_z, + histtype=histtype, + lw=2, + alpha=alpha, + label=label, + color=color, + linestyle=linestyle, + weights=features["energy_per_layer"], + ) + ax5.hist( + features["pixel"], + bins=dist_e_bins, + weights=features["distance"], + histtype=histtype, + edgecolor=edgecolor, + linestyle=linestyle, + lw=2, + alpha=alpha, + label=label, + color=color, + ) + # ax0.set_xlabel("Energy (MeV)") + ax0.set_ylabel("a.u.", fontsize=fontsize_labels) + ax0.set_xscale("log") + ax0.set_yscale("log") + ax0.set_xlim(left=0.01) + ax0.axvspan(0.01, 0.1, ymin=0, ymax=0.73, facecolor="lightgray", alpha=0.2, hatch="/") + ax0.tick_params(axis="x", labelbottom=False) + ymin, ymax = ax0.get_ylim() + new_ymax = ymax + 1620 * ymax + ax0.set_ylim(ymin, new_ymax) + # Create twin axis for ratio plot + + mask = [0.7, 1.3] + ax0_twin = fig.add_subplot(gs[1, 0], sharex=ax0) + ax0_twin.set_xlim(left=0.01) + plot_ratios(ax0_twin, features_list, energy_bins, "energy", labels, colors, mask=mask) + # Add horizontal line at y=1 + ax0_twin.axhline(y=1, color="gray", linestyle="--") + ax0_twin.axvspan(0.01, 0.1, facecolor="lightgray", alpha=0.5, hatch="/") + # Set y-axis limits + ax0_twin.set_ylim(mask) + ax0_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax0_twin.set_xlabel("visible cell energy [MeV]", fontsize=fontsize_labels) + ax0_twin.tick_params(axis="y", labelcolor="black") + + # Energy Sum Distribution + ax1.set_ylabel("normalized", fontsize=fontsize_labels) + ax1.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) + ax1.tick_params(axis="x", labelbottom=False) + ymin, ymax = ax1.get_ylim() + new_ymax = ymax + 0.45 * ymax + ax1.set_ylim(ymin, new_ymax) + # Create twin axis for ratio plot + ax1_twin = fig.add_subplot(gs[1, 1], sharex=ax1) + plot_ratios( + ax1_twin, features_list, energy_sum_bins, "shower_energy", labels, colors, mask=mask + ) + ax1_twin.axhline(y=1, color="gray", linestyle="--") + # Set y-axis limits + ax1_twin.set_ylim(mask) + ax1_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax1_twin.set_xlabel("energy sum [MeV]", fontsize=fontsize_labels) + ax1_twin.tick_params(axis="y", labelcolor="black") + + # Number of Hits (Voxel) Distribution + mask = [0.6, 1.4] + ax2.set_ylabel("normalized", fontsize=fontsize_labels) + ax2.tick_params(axis="x", labelbottom=False) + ax2.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax2.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) + ymin, ymax = ax2.get_ylim() + new_ymax = ymax + 0.44 * ymax + ax2.set_ylim(ymin, new_ymax) + + # Create twin axis for ratio plot + ax2_twin = fig.add_subplot(gs[1, 2], sharex=ax2) + plot_ratios(ax2_twin, features_list, voxel_bins, "voxel", labels, colors, mask) + + ax2_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + ax2_twin.set_ylim(mask) + ax2_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax2_twin.set_xlabel("number of hits", fontsize=fontsize_labels) + ax2_twin.tick_params(axis="y", labelcolor="black") + + # Center of Gravity Z Distribution + ax3.set_ylabel("normalized", fontsize=fontsize_labels) + ax3.tick_params(axis="x", labelbottom=False) + ax3.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) + ax3.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) + ymin, ymax = ax3.get_ylim() + new_ymax = ymax + 0.48 * ymax + ax3.set_ylim(ymin, new_ymax) + + # Create twin axis for ratio plot + ax3_twin = fig.add_subplot(gs[4, 0], sharex=ax3) + mask = (0.4, 1.6) + plot_ratios(ax3_twin, features_list, bins_cog, "z_zero", labels, colors, mask=mask) + + ax3_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + + ax3_twin.set_ylim(mask) + ax3_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax3_twin.set_xlabel("center of gravity Z [layer]", fontsize=fontsize_labels) + ax3_twin.tick_params(axis="y", labelcolor="black") + + # Z Distribution + ax4.set_ylabel("energy [MeV]", fontsize=fontsize_labels) + ax4.tick_params(axis="x", labelbottom=False) + ax4.set_yscale("log") + ax4.set_xlim(0, 30) + ymin, ymax = ax4.get_ylim() + new_ymax = ymax + 40 * ymax + ax4.set_ylim(ymin, new_ymax) + + # Create twin axis for ratio plot + ax4_twin = fig.add_subplot(gs[4, 1], sharex=ax4) + mask = [0.7, 1.3] + plot_ratios( + ax4_twin, features_list, bins_z, "hits", labels, colors, mask, weights="energy_per_layer" + ) + + ax4_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + + ax4_twin.set_ylim(mask) + ax4_twin.set_xlim(0, 30) + ax4_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax4_twin.set_xlabel("layer", fontsize=fontsize_labels) + ax4_twin.tick_params(axis="y", labelcolor="black") + + # Energy Distribution per Layer + ax5.set_ylabel("energy [MeV]", fontsize=fontsize_labels) + ax5.set_yscale("log") + ax5.set_xlim(0, 21) + ax5.tick_params(axis="x", labelbottom=False, labelsize=fontsize_labels) + ymin, ymax = ax5.get_ylim() + new_ymax = ymax + 40 * ymax + ax5.set_ylim(ymin, new_ymax) + + # Create twin axis for ratio plot + ax5_twin = fig.add_subplot(gs[4, 2], sharex=ax5) + mask = [0.7, 1.3] + plot_ratios( + ax5_twin, features_list, dist_e_bins, "pixel", labels, colors, mask, weights="distance" + ) + + ax5_twin.axhline(y=1, color="gray", linestyle="--") + + # Set y-axis limits + ax5_twin.set_ylim(mask) + ax5_twin.set_xlim(0, 21) + ax5_twin.set_ylabel("ratio", color="black", fontsize=fontsize_labels) + ax5_twin.set_xlabel("radius [pixels]", fontsize=fontsize_labels) + ax5_twin.tick_params(axis="y", labelcolor="black") + + # Add legend to the first subplot (energy) + legend_elements = [ + Line2D( + [0], + [0], + color=color, + lw=2, + label=label, + linestyle="--" + if len(features_list) > 2 + and ( + features is features_list[2] + or len(features_list) > 3 + and (features is features_list[3]) + ) + else "-", + ) + for color, label, features in zip(colors, labels, features_list) + ] + # Create the figure + ax5.legend(handles=legend_elements, loc="upper right", fontsize=fontsize_labels - 5, ncol=2) + ax2.legend(handles=legend_elements, loc="upper right", fontsize=fontsize_labels - 5, ncol=2) + ax3.legend(handles=legend_elements, loc="upper right", fontsize=fontsize_labels - 5, ncol=2) + ax0.legend(handles=legend_elements, loc="upper right", fontsize=fontsize_labels - 5, ncol=2) + ax4.legend(handles=legend_elements, loc="upper right", fontsize=fontsize_labels - 5, ncol=2) + ax1.legend(handles=legend_elements, loc="upper right", fontsize=fontsize_labels - 5, ncol=2) + + return fig diff --git a/gabbro/plotting/utils.py b/gabbro/plotting/utils.py new file mode 100644 index 0000000..115a73e --- /dev/null +++ b/gabbro/plotting/utils.py @@ -0,0 +1,560 @@ +from pathlib import Path + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +from cycler import cycler +from matplotlib.transforms import Bbox, ScaledTranslation + +rcParams = mpl.rcParams +DEFAULT_ALPHA = 0.95 +DEFAULT_LABELS = { + "part_pt": "Particle $p_{\\mathrm{T}}$ [GeV]", + "part_eta": "Particle $\\eta$", + "part_phi": "Particle $\\phi$", + "part_ptrel": "Particle $p_{\\mathrm{T}}^\\mathrm{rel}$", + "part_etarel": "Particle $\\eta^\\mathrm{rel}$", + "part_phirel": "Particle $\\phi^\\mathrm{rel}$", + "jet_pt": "Jet $p_{\\mathrm{T}}$ [GeV]", + "jet_eta": "Jet $\\eta$", + "jet_phi": "Jet $\\phi$", + "jet_mass": "Jet mass [GeV]", +} + +params_to_update = { + # --- axes --- + # https://matplotlib.org/stable/gallery/color/named_colors.html + "axes.prop_cycle": cycler( + "color", + [ + mpl.colors.ColorConverter().to_rgba(col, DEFAULT_ALPHA) + for col in [ + "steelblue", + "orange", + "forestgreen", + "purple", + "firebrick", + "lightseagreen", + "yellowgreen", + "hotpink", + "dimgrey", + "olive", + ] + ], + ), + # --- figure --- + "figure.figsize": (3.5, 2.5), + # "figure.dpi": 130, + # --- grid --- + "grid.color": "black", + "grid.alpha": 0.1, + "grid.linestyle": "-", + "grid.linewidth": 1, + # --- legend --- + "legend.fontsize": 10, + "legend.frameon": False, + "legend.numpoints": 1, + "legend.scatterpoints": 1, + # --- lines --- + "lines.linewidth": 1.5, + "lines.markeredgewidth": 0, + "lines.markersize": 7, + # "lines.solid_capstyle": "round", + # --- patches --- + "patch.facecolor": "4C72B0", + "patch.linewidth": 1.7, + # --- histogram --- + "hist.bins": 100, + # --- font --- + "font.family": "sans-serif", + "font.sans-serif": "Arial, Liberation Sans, DejaVu Sans, Bitstream Vera Sans, sans-serif", + # --- image --- + "image.cmap": "Greys", +} +params_to_update_dark = { + # --- axes --- + "axes.prop_cycle": cycler( + "color", + [ + mpl.colors.ColorConverter().to_rgba(col, DEFAULT_ALPHA) + for col in [ + "dodgerblue", + "red", + "mediumseagreen", + "darkorange", + "orchid", + "turquoise", + "#64B5CD", + ] + ], + ), + # --- figure --- + "figure.facecolor": "#1A1D22", + # --- grid --- + "grid.color": "lightgray", + # --- lines --- + "lines.solid_capstyle": "round", + # --- font --- + "font.family": "sans-serif", + "font.sans-serif": "Arial, Liberation Sans, DejaVu Sans, Bitstream Vera Sans, sans-serif", + # --- image --- + "image.cmap": "Greys", + # --- legend --- + "legend.frameon": False, + "legend.numpoints": 1, + "legend.scatterpoints": 1, + # --- xtick --- + "xtick.direction": "in", + "xtick.color": "white", + # --- ytick --- + "ytick.direction": "out", + "ytick.color": "white", + # --- axes.axisbelow --- + "axes.axisbelow": True, + "lines.color": "white", + "patch.edgecolor": "white", + "text.color": "white", + "axes.facecolor": "#1A1D22", + "axes.edgecolor": "lightgray", + "axes.labelcolor": "white", + "figure.edgecolor": "#1A1D22", + "savefig.facecolor": "#1A1D22", + "savefig.edgecolor": "#1A1D22", +} + + +def reset_mpl_style(): + """Reset matplotlib rcParams to default.""" + rcParams.update(mpl.rcParamsDefault) + + +def set_mpl_style(darkmode=False): + """Set matplotlib rcParams to custom configuration.""" + reset_mpl_style() + rcParams.update(params_to_update if not darkmode else params_to_update_dark) + + +def save(fig, saveas, transparent=True): + """Save a figure both as pdf and as png. + + Parameters + ---------- + fig : matplotlib.figure.Figure + The figure to save. + saveas : str + The path to save the figure to, expected to end in ".pdf". + transparent : bool, optional + Whether to save the figure with a transparent background, by default True + """ + save_kwargs = dict(transparent=transparent, dpi=300, bbox_inches="tight") + # create the directory if it does not exist + if not Path(saveas).parent.exists(): + print(f"Creating directory {Path(saveas).parent}") + Path(saveas).parent.mkdir(parents=True, exist_ok=True) + print(f"Saving figure to {saveas}") + fig.savefig(saveas, **save_kwargs) + fig.savefig(str(saveas).replace(".pdf", ".png"), **save_kwargs) + + +# # default seaborn aesthetic +# # darkgrid + deep palette + notebook context + +# # axes.axisbelow: True +# # axes.edgecolor: white +# # axes.facecolor: EAEAF2 +# # axes.grid: True +# # axes.labelcolor: .15 +# # axes.labelsize: 11 +# # axes.linewidth: 0 +# # axes.prop_cycle: cycler('color', ['4C72B0', '55A868', 'C44E52', '8172B2', 'CCB974', '64B5CD']) +# axes.prop_cycle: cycler('color', ['55A868', 'C44E52', '4C72B0', '8172B2', 'CCB974', '64B5CD']) +# # axes.prop_cycle: cycler('color', ['55A868', '9E4AC2','4C72B0' , 'C44E52', '8172B2', 'CCB974', '64B5CD']) +# # axes.titlesize: 12 + +# figure.facecolor: white + + +# # xtick.color: .15 +# # xtick.direction: out +# # xtick.labelsize: 10 +# # xtick.major.pad: 7 +# # xtick.major.size: 0 +# # xtick.major.width: 1 +# # xtick.minor.size: 0 +# # xtick.minor.width: .5 + +# # ytick.color: .15 +# # ytick.direction: out +# # ytick.labelsize: 10 +# # ytick.major.pad: 7 +# # ytick.major.size: 0 +# # ytick.major.width: 1 +# # ytick.minor.size: 0 +# # ytick.minor.width: .5 + +# # figure.facecolor: white +# # text.color: .15 +# # axes.labelcolor: .15 +# # legend.frameon: False +# # legend.numpoints: 1 +# # legend.scatterpoints: 1 +# # xtick.direction: in +# # ytick.direction: out +# # xtick.color: .15 +# # ytick.color: .15 +# # axes.axisbelow: True +# # image.cmap: Greys +# # font.family: sans-serif +# # font.sans-serif: Arial, Liberation Sans, DejaVu Sans, Bitstream Vera Sans, sans-serif +# # grid.linestyle: - +# # lines.solid_capstyle: round + +# # Seaborn dark parameters +# # axes.grid: False +# # axes.facecolor: EAEAF2 +# # axes.edgecolor: white +# # axes.linewidth: 0 +# # grid.color: white +# # xtick.major.size: 0 +# # ytick.major.size: 0 +# # xtick.minor.size: 0 +# # ytick.minor.size: 0 + + +def decorate_ax( + ax, + yscale=1.3, + text=None, + text_line_spacing=1.2, + text_font_size=12, + draw_legend=False, + indent=0.7, + top_distance=1.2, + hepstyle=False, + remove_first_ytick=False, +): + """Helper function to decorate the axes. + + Parameters + ---------- + ax : matplotlib.axes.Axes + Axes to decorate + yscale : float, optional + Factor by which the y-axis is scaled, by default 1.3 + text : str, optional + Text to add to the plot, by default None + text_line_spacing : float, optional + Spacing between lines of text, by default 1.2 + text_font_size : int, optional + Font size of the text, by default 12 + draw_legend : bool, optional + Draw the legend with `frameon=False`, by default False + indent : float, optional + Horizontal indent, by default 0.7 + top_distance : float, optional + Vertical indent, by default 1.2 + hepstyle : bool, optional + Use the atlasify function to make the plot look like an ATLAS plot, by default False + remove_first_ytick : bool, optional + Remove the first y-tick, by default False. + Can be useful to avoid overlap with the ratio plot ticks. + """ + PT = 1 / 72 # 1 point in inches + + # reset the y-axis limits (if they were changed before, it can happen + # that the y-axis is not scaled correctly. especially it happens that ymin + # becomes 0 even after setting logscale, which raises an error below as we + # divide by ymin for logscale) + if yscale != 1: + xmin, xmax = ax.get_xlim() + ax.relim() + ax.autoscale() + ax.set_xlim(xmin, xmax) + + # This weird order is necessary to allow for later + # saving in logscaled y-axis + if ax.get_yscale() == "log": + ymin, _ = ax.get_ylim() + ax.set_yscale("linear") + _, ymax = ax.get_ylim() + ax.set_yscale("log") + yscale = (ymax / ymin) ** (yscale - 0.99) + else: + ymin, ymax = ax.get_ylim() + + # scale the y-axis to avoid overlap with text + ax.set_ylim(top=yscale * (ymax - ymin) + ymin) + + if text is None: + pass + elif isinstance(text, str): + # translation from the left side of the axes (aka indent) + trans_indent = ScaledTranslation( + indent * text_line_spacing * PT * text_font_size, + 0, + ax.figure.dpi_scale_trans, + ) + # translation from the top of the axes + trans_top = ScaledTranslation( + 0, + -top_distance * text_line_spacing * PT * text_font_size, + ax.figure.dpi_scale_trans, + ) + + # add each line of the tag text to the plot + for line in text.split("\n"): + # fmt: off + ax.text(0, 1, line, transform=ax.transAxes + trans_top + trans_indent, fontsize=text_font_size) # noqa: E501 + trans_top += ScaledTranslation(0, -text_line_spacing * text_font_size * PT, ax.figure.dpi_scale_trans) # noqa: E501 + # fmt: on + else: + raise TypeError("`text` attribute of the plot has to be of type `str`.") + + if draw_legend: + ax.legend(frameon=False) + + if remove_first_ytick: + # remove the first y-tick label of the upper plot to avoid overlap with the ratio plot + ax.set_yticks(ax.get_yticks()[1:]) + + +def get_good_linestyles(names=None): + """Returns a list of good linestyles. + + Parameters + ---------- + names : list or str, optional + List or string of the name(s) of the linestyle(s) you want to retrieve, e.g. + "densely dotted" or ["solid", "dashdot", "densely dashed"], by default None + + Returns + ------- + list + List of good linestyles. Either the specified selection or the whole list in + the predefined order. + + Raises + ------ + ValueError + If `names` is not a str or list. + """ + linestyle_tuples = { + "solid": "solid", + "densely dashed": (0, (5, 1)), + "densely dotted": (0, (1, 1)), + "densely dashdotted": (0, (3, 1, 1, 1)), + "densely dashdotdotted": (0, (3, 1, 1, 1, 1, 1)), + "dotted": (0, (1, 1)), + "dashed": (0, (5, 5)), + "dashdot": "dashdot", + "loosely dashed": (0, (5, 10)), + "loosely dotted": (0, (1, 10)), + "loosely dashdotted": (0, (3, 10, 1, 10)), + "loosely dashdotdotted": (0, (3, 10, 1, 10, 1, 10)), + "dashdotted": (0, (3, 5, 1, 5)), + "dashdotdotted": (0, (3, 5, 1, 5, 1, 5)), + } + + default_order = [ + "solid", + "densely dotted", + "densely dashed", + "densely dashdotted", + "densely dashdotdotted", + "dotted", + "dashed", + "dashdot", + # "loosely dotted", + # "loosely dashed", + # "loosely dashdotted", + # "loosely dashdotdotted", + "dashdotted", + "dashdotdotted", + ] + if names is None: + names = default_order * 3 + elif isinstance(names, str): + return linestyle_tuples[names] + elif not isinstance(names, list): + raise ValueError("Invalid type of `names`, has to be a list of strings or a string.") + return [linestyle_tuples[name] for name in names] + + +def get_col(i): + """Get the i-th color from the default color cycle.""" + return rcParams["axes.prop_cycle"].by_key()["color"][i % len(rcParams["axes.prop_cycle"])] + + +def get_label(name): + """Get the label for the given name.""" + return DEFAULT_LABELS[name] + + +def get_ax_and_fig_2x5(ratio=True, figsize=(18, 6)): + """Create a 2x5 grid of axes for plotting histograms and ratios. The axes are returned as a + 2x10 grid, where the first row contains the histogram plots and the second row contains the + ratio plots. + + Parameters: + ---------- + ratio : bool, optional + Whether to include a ratio plot. Default is True. + figsize : tuple, optional + Size of the figure. + """ + + if ratio: + gridspec = dict(hspace=0.0, height_ratios=[1, 0.3, 0.3, 1, 0.3]) + fig, ax = plt.subplots(5, 5, figsize=figsize, gridspec_kw=gridspec) + # make third row invisible + for i in range(5): + ax[2, i].axis("off") + # remove the dummy axes in the middle + axes = np.concatenate([ax[:2, :], ax[3:, :]], axis=1) + else: + fig, ax = plt.subplots(2, 5, figsize=figsize) + axes = ax.flatten() + return fig, axes + + +def get_ax_and_fig_for_ratio_plot(figsize=(4, 3)): + """Returns fig, axes for a ratio plot. ax[0] is the histogram, ax[1] is the ratio. + + Parameters: + ---------- + figsize : tuple, optional + Size of the figure. + """ + + gridspec = dict(hspace=0.0, height_ratios=[1, 0.3]) + fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw=gridspec) + # make third second ax invisible + # ax[1].axis("off") + # remove the dummy axes in the middle + # axes = np.concatenate([ax[0], ax[-1]], axis=1) + return fig, ax + + +def save_two_subplots( + fig, + ax1, + ax2, + saveas: str, + expanded_x=1.05, + expanded_y=1.05, +): + """Save two subplots of a figure to a single file. + + Parameters + ---------- + fig : matplotlib.figure.Figure + The figure containing the subplots. + ax1 : matplotlib.axes.Axes + The first subplot. + ax2 : matplotlib.axes.Axes + The second subplot. + saveas : str + The path to save the figure to. + expanded_x : float, optional + Factor by which to expand the bounding box in the x-direction, by default 1.05 + expanded_y : float, optional + Factor by which to expand the bounding box in the y-direction, by default 1.05 + """ + bbox1 = ax1.get_tightbbox().transformed(fig.dpi_scale_trans.inverted()) + bbox2 = ax2.get_tightbbox().transformed(fig.dpi_scale_trans.inverted()) + bbox_total = Bbox.union([bbox1, bbox2]) + fig.savefig(saveas, bbox_inches=bbox_total.expanded(expanded_x, expanded_y)) + + +def save_subplot( + fig, + ax, + saveas: str, + expanded_x=1.05, + expanded_y=1.05, +): + """Save a single subplot of a figure to a file. + + Parameters + ---------- + fig : matplotlib.figure.Figure + The figure containing the subplot. + ax : matplotlib.axes.Axes + The subplot to save. + saveas : str + The path to save the figure to. + expanded_x : float, optional + Factor by which to expand the bounding box in the x-direction, by default 1.05 + expanded_y : float, optional + Factor by which to expand the bounding box in the y-direction, by default 1.05 + """ + bbox = ax.get_tightbbox().transformed(fig.dpi_scale_trans.inverted()) + fig.savefig(saveas, bbox_inches=bbox.expanded(expanded_x, expanded_y)) + + +def plot_ratios(ax_twin, features_list, bins, feature, labels, colors, mask, weights=None): + """Plots ratio plots for each feature compared to the first feature. + + Args: + ax_twin (matplotlib.axes.Axes): The twin axis on which to plot the ratios. + features_list (list): A list of dictionaries, each containing a feature's data. + bins (array-like): The bin edges for the histogram. + feature (str): Name of the feature to plot. + labels (list): A list of labels for each feature. + colors (list): A list of colors for each feature. + weights (str, optional): Key for weights in the feature dictionaries. Default is None. + + Returns: + None + """ + + for i in range(1, len(features_list)): # Compare feature 1 with feature 2 and 3 + if weights: + counts1, _ = np.histogram( + features_list[0][feature], bins=bins, weights=features_list[0][weights] + ) + counts2, _ = np.histogram( + features_list[i][feature], bins=bins, weights=features_list[i][weights] + ) + else: + counts1, _ = np.histogram(features_list[0][feature], bins=bins) + counts2, _ = np.histogram(features_list[i][feature], bins=bins) + + ratios = [] + for j in range(len(counts1)): + if counts2[j] == 0: + ratios.append(np.nan) + else: + ratios.append(counts2[j] / counts1[j]) + linestyle = "--" if i > 1 else "-" + ax_twin.step( + bins, # Use all energy bin edges + ratios + [ratios[-1]], # Duplicate last ratio for the step + where="post", # 'post' creates the step after the data point + color=colors[i], # Use color from the current feature + label=f"ratio to {labels[i]}", # Use label from the current feature + linestyle=linestyle, # Use linestyle from the current feature + ) + bin_centers = (bins[:-1] + bins[1:]) / 2 + # Add out-of-bounds markers at the edge of the plot + mask_below = np.array(ratios) < mask[0] + mask_above = np.array(ratios) > mask[1] + + ax_twin.plot( + bin_centers[mask_below], + np.full_like(bin_centers[mask_below], mask[0]), + marker="v", + color=colors[i], + markersize=6, + clip_on=False, + linestyle="None", + ) + ax_twin.plot( + bin_centers[mask_above], + np.full_like(bin_centers[mask_above], mask[1]), + marker="^", + color=colors[i], + markersize=6, + clip_on=False, + linestyle="None", + ) diff --git a/gabbro/schedulers/lr_scheduler.py b/gabbro/schedulers/lr_scheduler.py new file mode 100644 index 0000000..17e6658 --- /dev/null +++ b/gabbro/schedulers/lr_scheduler.py @@ -0,0 +1,95 @@ +import numpy as np +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +class CosineWarmupScheduler(_LRScheduler): + def __init__(self, optimizer, warmup, max_iters): + self.warmup = warmup + self.max_num_iters = max_iters + super().__init__(optimizer) + + def get_lr(self): + lr_factor = self.get_lr_factor(epoch=self.last_epoch) + return [base_lr * lr_factor for base_lr in self.base_lrs] + + def get_lr_factor(self, epoch): + lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) + if epoch <= self.warmup: + lr_factor *= epoch * 1.0 / self.warmup + return lr_factor + + +class OneCycleCooldown(_LRScheduler): + """LR scheduler that implements the one-cycle learning rate policy. + + Followed by a cooldown period where the learning rate is gradually reduced to a minimum value. + """ + + def __init__( + self, + optimizer, + warmup, + cooldown, + cooldown_final, + initial_lr, + max_lr, + final_lr=1e-6, + ): + """optimizer (Optimizer): Wrapped optimizer. + + warmup: number of epochs to warmup for. + cooldown: number of epochs to cooldown for. + initial_lr: initial learning rate. + max_lr: maximum learning rate. + final_lr: final learning rate. + """ + self.warmup = warmup + self.cooldown = cooldown + self.cooldown_final = cooldown_final + self.initial_lr = initial_lr + self.max_lr = max_lr + self.final_lr = final_lr + super().__init__(optimizer) + + def get_lr(self): + lr = self.get_lr_factor(epoch=self.last_epoch) + return [lr for base_lr in self.base_lrs] + + def get_lr_factor(self, epoch): + if epoch <= self.warmup: + lr = self.initial_lr + (self.max_lr - self.initial_lr) * epoch / self.warmup + elif epoch <= self.warmup + self.cooldown: + lr = ( + self.max_lr + - (self.max_lr - self.initial_lr) * (epoch - self.warmup) / self.cooldown + ) + elif epoch <= self.warmup + self.cooldown + self.cooldown_final: + lr = ( + self.initial_lr + - (self.initial_lr - self.final_lr) + * (epoch - self.warmup - self.cooldown) + / self.cooldown_final + ) + else: + lr = self.final_lr + return lr + + +class WarmupToConstant(_LRScheduler): + """Gradually warm-up learning rate in optimizer to a constant value.""" + + def __init__(self, optimizer: Optimizer, num_steps: int = 100) -> None: + """ + args: + optimizer (Optimizer): Wrapped optimizer. + num_steps: target learning rate is reached at num_steps. + """ + self.num_steps = num_steps + self.finished = False + super().__init__(optimizer) + + def get_lr(self) -> list[float]: + if self.last_epoch > self.num_steps: + return [base_lr for base_lr in self.base_lrs] + return [(base_lr / self.num_steps) * self.last_epoch for base_lr in self.base_lrs] diff --git a/gabbro/train.py b/gabbro/train.py new file mode 100644 index 0000000..e5fc18b --- /dev/null +++ b/gabbro/train.py @@ -0,0 +1,238 @@ +import hashlib +import os +import time +from pathlib import Path +from typing import List, Optional, Tuple + +import hydra +import lightning as L +import pyrootutils +import torch +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.loggers import Logger + +import gabbro.utils.git_utils as git_utils +from gabbro.models.vqvae import VQVAELightning +from gabbro.utils.bigram import get_bigram +from gabbro.utils.pylogger import get_pylogger +from gabbro.utils.utils import ( + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/pyrootutils +# ------------------------------------------------------------------------------------ # + + +log = get_pylogger(__name__) + + +def get_nodename_bigram(): + """Generate a unique run identifier based on the nodename and a random bigram. + Example: `max-wng029_QuickBear` + + The bigram is generated from the nodename and the current time, which means + that two runs starting at the same time on different nodes will have different + bigrams (if the nodename is not included, two runs starting at the same time + will have the same bigram). + + Returns: + str: Unique run identifier. + """ + nodename = os.uname().nodename + # cleanup + nodename = nodename.split(".")[0] + + nodename_with_time = f"{nodename}_{int(time.time())}" + + # get a hash of the nodename + hashed_nodename_with_time = hashlib.sha256(nodename_with_time.encode()).hexdigest() + + # bigram + bigram = get_bigram(seed=int(hashed_nodename_with_time, 16)) + + return "_".join([nodename, bigram]) + + +OmegaConf.register_new_resolver("nodename_bigram", get_nodename_bigram, use_cache=True) + + +@task_wrapper +def train(cfg: DictConfig) -> Tuple[dict, dict]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator which applies extra utilities + before and after the call. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + + # check if cuda available + if not torch.cuda.is_available(): + raise ValueError("CUDA is not available!") + else: + log.info("CUDA is available.") + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + # extract number of input features from datamodule + # input_dim = len(cfg.data.dataset_kwargs_common.feature_dict) + len( + # cfg.data.dataset_kwargs_common.get("feature_dict_jet", []) + # ) + # log.info(f"Input dim: {input_dim}") + # if "input_dim" in cfg.model: + # cfg.model.input_dim = input_dim + # elif "model_kwargs" in cfg.model: + # cfg.model.model_kwargs.input_dim = input_dim + # else: + # raise ValueError( + # "Could not find input_dim in model config. Please add it to model or model_kwargs." + # ) + + log.info(f"Git Status: {git_utils.get_git_status()}") + log.info(f"Git Hash: {git_utils.get_git_hash()}") + log.info(f"Last Commit Message: {git_utils.get_last_commit_message()}") + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: L.LightningModule = hydra.utils.instantiate(cfg.model) + + # TODO: this is a bit of a hack, but it works for now (used to load the + # VQVAE encoder for the classifier) + if cfg.get("load_weights_from", False): + log.info(f"Loading model weights from {cfg.load_weights_from}") + + load_cpt_path = Path(cfg.load_weights_from).parent.parent / "config.yaml" + log.info("Model config before loading weights:") + log.info(OmegaConf.to_yaml(cfg.model)) + cfg_ckpt = OmegaConf.load(load_cpt_path) + cfg.model.model_kwargs_loaded = cfg_ckpt.model.model_kwargs + + if isinstance(model, VQVAELightning): + model = VQVAELightning.load_from_checkpoint( + cfg.load_weights_from, + strict=cfg.get("load_weights_strict", False), + ) + else: + raise ValueError("Model not recognized!") + log.info("Model config after loading weights:") + log.info(OmegaConf.to_yaml(cfg.model)) + + if cfg.model.model_kwargs.get("class_head_kwargs", False): + model.model.class_head_kwargs = cfg.model.model_kwargs.class_head_kwargs + model.model.initialize_classification_head() + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + log.info("Instantiating callbacks...") + callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info(f"Model: \n{model}") + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: L.Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + "git": { + "git_hash": git_utils.get_git_hash(), + "git_status": git_utils.get_git_status(), + "git_last_commit_message": git_utils.get_last_commit_message(), + }, + "slurm": { + "job_id": os.environ.get("SLURM_JOB_ID", None), + "log_file": os.environ.get("SLURM_LOGFILE", None), + }, + "load_weights_from": cfg.get("load_weights_from", None), + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + # save config for reproducibility and debugging + cfg_backup_file = f'{cfg.trainer.get("default_root_dir")}/config.yaml' + with open(cfg_backup_file, "w") as f: + log.info(f"Saving config to {cfg_backup_file}") + OmegaConf.save(cfg, f) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + # set CUDA_LAUNCH_BLOCKING=1 to get more informative stack traces + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + torch.set_float32_matmul_precision("medium") + + experiment_name = Path(cfg.trainer.default_root_dir).name.split("_")[3] + cfg.logger.comet.experiment_name = experiment_name + cfg.logger.wandb.name = experiment_name + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() diff --git a/gabbro/utils/__init__.py b/gabbro/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gabbro/utils/arrays.py b/gabbro/utils/arrays.py new file mode 100644 index 0000000..4fa750a --- /dev/null +++ b/gabbro/utils/arrays.py @@ -0,0 +1,513 @@ +import awkward as ak +import numpy as np +import torch +import vector + +vector.register_awkward() + + +def ak_pad(x: ak.Array, maxlen: int, axis: int = 1, fill_value=0, return_mask=False): + """Function to pad an awkward array to a specified length. The array is padded along the + specified axis. + + Parameters + ---------- + x : awkward array + Array to pad. + maxlen : int + Length to pad to. + axis : int, optional + Axis along which to pad. Default is 1. + fill_value : float or int, optional + Value to use for padding. Default is 0. + return_mask : bool, optional + If True, also return a mask array indicating which values are padded. + Default is False. + If the input array has fields, the mask is created from the first field. + + Returns + ------- + awkward array + Padded array. + mask : awkward array + Mask array indicating which values are padded. Only returned if return_mask is True. + """ + padded_x = ak.fill_none(ak.pad_none(x, maxlen, axis=axis, clip=True), fill_value) + if return_mask: + if len(x.fields) >= 1: + mask = ak.ones_like(x[x.fields[0]], dtype="bool") + else: + mask = ak.ones_like(x, dtype="bool") + mask = ak.fill_none(ak.pad_none(mask, maxlen, axis=axis, clip=True), False) + return padded_x, mask + return padded_x + + +def np_to_ak(x: np.ndarray, names: list, mask: np.ndarray = None): + """Function to convert a numpy array and its mask to an awkward array. The features + corresponding to the names are assumed to correspond to the last axis of the array. + + Parameters + ---------- + x : np.ndarray + Array to convert. + names : list + List of field names (corresponding to the features in x along the last dimension). + mask : np.ndarray, optional + Mask array. Default is None. + """ + + if mask is None: + mask = np.ones_like(x[..., 0], dtype="bool") + + return ak.Array( + { + name: ak.values_astype( + ak.drop_none(ak.mask(ak.Array(x[..., i]), mask != 0)), + "float32", + ) + for i, name in enumerate(names) + } + ) + + +def np_to_akward(x: np.ndarray, pp_dict: dict): + """Function to convert a numpy array to an awkward array with specified labels. + + Parameters + ---------- + x : np.ndarray + Array to convert. + pp_dict : dict + Dictionary containing field names as keys. + """ + + return ak.Array( + { + name: ak.values_astype(ak.Array(x[..., i]), "float32") + for i, name in enumerate(pp_dict.keys()) + } + ) + + +def ak_to_np_stack(ak_array: ak.Array, names: list = None, axis: int = -1): + """Function to convert an awkward array to a numpy array by stacking the values of the + specified fields. This is much faster than ak.to_numpy(ak_array) for large arrays. + + Parameters + ---------- + ak_array : awkward array + Array to convert. + names : list, optional + List of field names to convert. Default is None. + axis : int, optional + Axis along which to stack the values. Default is -1. + """ + if names is None: + raise ValueError("names must be specified") + return ak.to_numpy( + np.stack( + [ak.to_numpy(ak.values_astype(ak_array[name], "float32")) for name in names], + axis=axis, + ) + ) + + +def np_PtEtaPhi_to_Momentum4D(arr, mask, log_pt=False): + """Convert numpy array with 4-momenta to ak array of Momentum4D objects. + NOTE: the input array is assumed to be in (pT, eta, phi) format, thus mass = 0. + + Expects an array of shape (batch_size, num_particles, 3) + where the last dimension is (pt, eta, phi) + + Returns an ak array of shape (batch_size, var, 4) of Momentum4D objects + + If log_pt is True, the corresponding variable is exponentiated + before being passed to Momentum4D + + Parameters + ---------- + arr : np.ndarray + Input array of shape (batch_size, num_particles, 3) + mask : np.ndarray + Mask array of shape (batch_size, num_particles) + log_pt : bool, optional + Whether to exponentiate pt, by default False + + Returns + ------- + ak.Array + Array of Momentum4D objects + """ + + p4 = ak.zip( + { + "pt": np.clip(arr[:, :, 0], 0, None) if not log_pt else np.exp(arr[:, :, 0]), + "eta": arr[:, :, 1], + "phi": arr[:, :, 2], + "mass": ak.zeros_like(arr[:, :, 0]), + }, + with_name="Momentum4D", + ) + # mask the array + ak_mask = ak.Array(mask) + return ak.drop_none(ak.mask(p4, ak_mask == 1)) + + +def ak_select_and_preprocess(ak_array: ak.Array, pp_dict=None, inverse=False): + """Function to select and pre-process fields from an awkward array. + + Parameters + ---------- + ak_array : awkward array + Array to convert. + pp_dict : dict, optional + Dictionary with pre-processing values for each field. Default is None. + The dictionary should have the following format: + { + "field_name_1": {"multiply_by": 1, "subtract_by": 0, "func": "np.log"}, + "field_name_2": {"multiply_by": 1, "subtract_by": 0, "func": None}, + ... + } + inverse : bool, optional + If True, the inverse of the pre-processing is applied. Default is False. + """ + if pp_dict is None: + pp_dict = {} + + # define initial mask as all True + first_feat = list(pp_dict.keys())[0] + selection_mask = ak.ones_like(ak_array[first_feat], dtype="bool") + + for name, params in pp_dict.items(): + if params is None: + pp_dict[name] = {"subtract_by": 0, "multiply_by": 1, "func": None} + else: + if "subtract_by" not in params: + pp_dict[name]["subtract_by"] = 0 + if "multiply_by" not in params: + pp_dict[name]["multiply_by"] = 1 + if "func" not in params: + pp_dict[name]["func"] = None + if "inv_func" not in params: + pp_dict[name]["inv_func"] = None + + if pp_dict[name]["func"] is not None: + if pp_dict[name]["inv_func"] is None: + raise ValueError( + "If a function is specified, an inverse function must also be specified." + ) + else: + if pp_dict[name]["inv_func"] is not None: + raise ValueError( + "If an inverse function is specified, a function must also be specified." + ) + # apply selection cuts + if pp_dict[name].get("larger_than") is not None: + selection_mask = selection_mask & (ak_array[name] > pp_dict[name]["larger_than"]) + if pp_dict[name].get("smaller_than") is not None: + selection_mask = selection_mask & (ak_array[name] < pp_dict[name]["smaller_than"]) + + if inverse: + return ak.Array( + { + name: ( + eval(params["inv_func"])( # nosec + getattr(ak_array, name) / params["multiply_by"] + params["subtract_by"] + ) + if params["inv_func"] + else getattr(ak_array, name) / params["multiply_by"] + params["subtract_by"] + ) + for name, params in pp_dict.items() + } + ) + return ak.Array( + { + name: ( + ( + eval(params["func"])(getattr(ak_array, name)[selection_mask]) # nosec + if params["func"] + else getattr(ak_array, name)[selection_mask] + ) + - params["subtract_by"] + ) + * params["multiply_by"] + for name, params in pp_dict.items() + } + ) + + +# define a padding function for the shower arrays +def ak_padding(x: ak.Array, maxlen: int, energy_threshold: float): + """Pads an Awkward Array and creates a mask based on energy values. + + Args: + x (ak.Array): The Awkward Array to pad. + maxlen (int): The maximum length to pad to. + energy_threshold (float): The threshold for considering energy as non-zero. + + Returns: + tuple: A tuple containing the padded array and the mask. + """ + + # Create mask based on energy threshold + mask = x["energy"] > energy_threshold + + # Pad both data and mask to maxlen + padded_x = ak.pad_none(x, maxlen, axis=1, clip=True) + padded_mask = ak.pad_none(mask, maxlen, axis=1, clip=True) + + # Fill None in data with zeros (or custom values if needed) + + return padded_x, padded_mask + + +# define a function to preprocess shower data for use in the model +def ak_preprocess(ak_array: ak.Array, pp_dict=None, inverse: bool = False): + """Function to select and pre-process fields from an awkward array. + + Parameters + ---------- + ak_array : awkward array + Array to convert. + pp_dict : dict, optional + Dictionary with pre-processing values for each field. Default is None. + The dictionary should have the following format: + { + "field_name_1": {"multiply_by": 1, "subtract_by": 0, "func": "np.log"}, + "field_name_2": {"multiply_by": 1, "subtract_by": 0, "func": None}, + ... + } + inverse : bool, optional + If True, the inverse of the pre-processing is applied. Default is False. + """ + if pp_dict is None: + pp_dict = {} + + # Get the input shape + # num_records = len(ak_array) + # num_values_per_field = len(ak_array["x"][0]) # Assuming uniform structure + # input_shape = (num_records, num_values_per_field) + # define initial mask as all True + first_feat = list(pp_dict.keys())[0] + selection_mask = ak.ones_like(ak_array[first_feat], dtype="bool") + + for name, params in pp_dict.items(): + if params is None: + pp_dict[name] = {"subtract_by": 0, "multiply_by": 1, "func": None} + # pylogger.info(f"if params is None: {pp_dict[name]}") + else: + # pylogger.info(f"if params is not None: {pp_dict[name]}") + + if "subtract_by" not in params: + pp_dict[name]["subtract_by"] = 0 + if "multiply_by" not in params: + pp_dict[name]["multiply_by"] = 1 + if "func" not in params: + pp_dict[name]["func"] = None + if "inv_func" not in params: + pp_dict[name]["inv_func"] = None + + if pp_dict[name]["func"] is not None: + if pp_dict[name]["inv_func"] is None: + raise ValueError( + "If a function is specified, an inverse function must also be specified." + ) + else: + if pp_dict[name]["inv_func"] is not None: + raise ValueError( + "If an inverse function is specified, a function must also be specified." + ) + # apply selection cuts + if pp_dict[name].get("larger_than") is not None: + selection_mask = selection_mask & (ak_array[name] > pp_dict[name]["larger_than"]) + if pp_dict[name].get("smaller_than") is not None: + selection_mask = selection_mask & (ak_array[name] < pp_dict[name]["smaller_than"]) + + if inverse: + result_array = ak.Array( + { + name: ( + eval(params["inv_func"])( # nosec + getattr(ak_array, name) / params["multiply_by"] + params["subtract_by"] + ) + if params["inv_func"] + else getattr(ak_array, name) / params["multiply_by"] + params["subtract_by"] + ) + for name, params in pp_dict.items() + } + ) + else: + result_array = ak.Array( + { + name: ( + ( + ( + eval(params["func"])(getattr(ak_array, name)) # nosec + if params["func"] + else getattr(ak_array, name) + ) + - params["subtract_by"] + ) + * params["multiply_by"] + ) + for name, params in pp_dict.items() + } + ) + + # numpy_array = ak.to_numpy(result_array) + # pylogger.info("numpy_array: ", numpy_array) + # pylogger.info("numpy_array.shape: ", numpy_array.shape) + # reshaped_numpy_array = numpy_array.reshape(input_shape) + # pylogger.info("reshaped_numpy_array: ", reshaped_numpy_array) + # return ak.from_numpy(reshaped_numpy_array) + return result_array + + +# define a function to sort ak.Array by pt +def sort_by_pt(constituents: ak.Array, ascending: bool = False): + """Sort ak.Array of jet constituents by the pt + Args: + constituents (ak.Array): constituents array that should be sorted by pt. + It should have a pt attribute. + ascending (bool, optional): If True, the first value in each sorted + group will be smallest; if False, the order is from largest to + smallest. Defaults to False. + Returns: + ak.Array: sorted constituents array + """ + if isinstance(constituents, ak.Array): + try: + temppt = constituents.pt + except AttributeError: + raise AttributeError( + "Trying to sort an ak.Array without a pt attribute. Please check the input." + ) + indices = ak.argsort(temppt, axis=1, ascending=ascending) + return constituents[indices] + + +def ak_smear(arr, sigma=0, seed=42): + """Helper function to smear an array of values by a given sigma. + + Parameters + ---------- + arr : awkward array + The array to smear + sigma : float, optional + The sigma of the smearing, by default 0 (i.e. no smearing) + seed : int, optional + Seed for the random number generator, by default 42 + """ + # Convert it to a 1D numpy array and perform smearing + numpy_arr = ak.to_numpy(arr.layout.content) + + if sigma != 0: + rng = np.random.default_rng(seed) + numpy_arr = rng.normal(numpy_arr, sigma) + + # Convert it back to awkward form + return ak.Array(ak.contents.ListOffsetArray(arr.layout.offsets, ak.Array(numpy_arr).layout)) + + +def ak_clip(arr, clip_min=None, clip_max=None): + """Helper function to clip the values of an array. + + Parameters + ---------- + arr : awkward array + The array to clip + clip_min : float, optional + Minimum value to clip to, by default None + clip_max : float, optional + Maximum value to clip to, by default None + """ + # Convert it to a 1D numpy array and perform clipping + numpy_arr = ak.to_numpy(arr.layout.content) + + if clip_min is not None: + numpy_arr = np.clip(numpy_arr, clip_min, None) + + if clip_max is not None: + numpy_arr = np.clip(numpy_arr, None, clip_max) + + # Convert it back to awkward form + return ak.Array(ak.contents.ListOffsetArray(arr.layout.offsets, ak.Array(numpy_arr).layout)) + + +def count_appearances(arr, mask, count_up_to: int = 10): + """ + Parameters + ---------- + arr : np.ndarray + Array of integers, shape (n_jets, n_constituents) + mask : np.ndarray + Mask array, shape (n_jets, n_constituents) + count_up_to : int, optional + The maximum number of appearances to check for, by default 10 + + Returns + ------- + np.ndarray + Array of shape (n_jets, n_tokens) containing the counts of each token. + I.e. if the maximum token number is 5, the array will have 5 columns + indicating how many times each token appears in each jet. + np.ndarray + Array of shape (n_jets, count_up_to) containing the number of tokens + that appear 0, 1, 2, 3, ... times in each jet. + np.ndarray + Array of shape (n_jets, count_up_to) containing the fraction of tokens + that appear 0, 1, 2, 3, ... times in each jet. + """ + # fill the masked values with one above the maximum value in the array + arr = np.where(mask != 0, arr, np.max(arr) + 1) + + # Count the occurrences of each integer in each row + counts = np.array([np.bincount(row) for row in arr]) + # remove the last column, which is the count of the maximum (fill) value + counts = counts[:, :-1] + + # calculate how many tokens appear 0, 1, 2, 3, ... times + n_token_appearances = [] + for i in range(count_up_to + 1): + n_token_appearances.append(np.sum(np.array(counts) == i, axis=1)) + + # calculate the percentages of tokens that appear 0, 1, 2, 3, ... times + n_tokens_total = np.sum(mask, axis=1) + frac_token_appearances = np.array( + [n * i / n_tokens_total for i, n in enumerate(n_token_appearances)] + ) + + return counts, np.array(n_token_appearances).T, frac_token_appearances.T + + +def fix_padded_logits(logits, mask, factor=1e6): + """Used to fix a tensor of logits if the sequences are padded after some token. The logits of + the padded values are all set to 0, except for the first value, which is set to `factor`. This + is useful when using the logits to calculate the loss. + + Parameters + ---------- + logits : torch.Tensor + Tensor of logits. Shape (batch_size, seq_len, n_tokens) + mask : torch.Tensor + Mask tensor. Shape (batch_size, seq_len) + factor : float, optional + Value to set the first token of the padded values to. Default is 1e6. + + Returns + ------- + torch.Tensor + Fixed logits. + """ + # fix the padded logits + logits = logits * mask.unsqueeze(dim=-1) + # set the logits of padded values to [1e6, -1e6, -1e6, ...] + logits = logits + torch.cat( + [ + (~mask).unsqueeze(-1) * factor, + torch.zeros_like(logits[:, :, 1:]), + ], + dim=-1, + ) + return logits diff --git a/gabbro/utils/bigram.py b/gabbro/utils/bigram.py new file mode 100644 index 0000000..aeb90a1 --- /dev/null +++ b/gabbro/utils/bigram.py @@ -0,0 +1,32 @@ +"""Tools to help with submitting jobs to the cluster.""" +import ssl + +import nltk +import numpy as np +from nltk.corpus import wordnet + + +def get_bigram(seed): + """Return a random bigram of the form _.""" + try: + _create_unverified_https_context = ssl._create_unverified_context + except AttributeError: + pass + else: + ssl._create_default_https_context = _create_unverified_https_context + + nltk.download("wordnet") # Download WordNet data + adjectives = [synset.lemmas()[0].name() for synset in wordnet.all_synsets(wordnet.ADJ)] + nouns = [synset.lemmas()[0].name() for synset in wordnet.all_synsets(wordnet.NOUN)] + + adjectives = [ + adj for adj in adjectives if "-" not in adj and "_" not in adj and adj[0].islower() + ] + nouns = [noun for noun in nouns if "-" not in noun and "_" not in noun and noun[0].islower()] + + rng = np.random.default_rng(seed) + i_adj, i_noun = rng.choice(len(adjectives)), rng.choice(len(nouns)) + + # Return the bigram with the words capitalized + + return adjectives[i_adj].capitalize() + nouns[i_noun].capitalize() diff --git a/gabbro/utils/git_utils.py b/gabbro/utils/git_utils.py new file mode 100644 index 0000000..3c58ffb --- /dev/null +++ b/gabbro/utils/git_utils.py @@ -0,0 +1,30 @@ +"""Useful functions for git operations. + +(e.g. getting git hash, last commit message, etc.) +""" + +import subprocess # nosec + + +def get_git_hash(): + return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8") # nosec + + +def get_git_status(): + cmd = "git diff -- . ':!*.ipynb' --color" + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) # nosec + stdout, stderr = process.communicate() + git_diff_output = stdout.decode("utf-8") + separator_start = f"\n{100*'='}\n{'=' * 10} start git diff {'=' * 10}\n" + separator_end = f"\n{'=' * 10} end git diff {'=' * 10}\n{100*'='}\n" + return separator_start + git_diff_output + separator_end + + +def get_last_commit_message(): + return ( + subprocess.check_output( # nosec + ["git", "log", "-1", "--pretty=%B"], + ) + .strip() + .decode("utf-8") + ) diff --git a/gabbro/utils/mapping.py b/gabbro/utils/mapping.py new file mode 100644 index 0000000..808db56 --- /dev/null +++ b/gabbro/utils/mapping.py @@ -0,0 +1,134 @@ +import awkward as ak +import numpy as np +import vector + +vector.register_awkward() + + +def merge_duplicates_numpy_old(rounded_showers): + """Merges duplicate voxels using NumPy for intermediate processing. + + Args: + rounded_showers: An Awkward Array of showers with voxel + coordinates. + + Returns: + An Awkward Array with duplicate voxels merged (energies summed). + """ + + merged_showers = [] + for shower in rounded_showers: + # Convert to NumPy arrays + x = np.round(shower.x.to_numpy()).astype(int) + y = np.round(shower.y.to_numpy()).astype(int) + z = np.round(shower.z.to_numpy()).astype(int) + energy = shower.energy.to_numpy() + + # Create a structured array for easier duplicate handling + voxel_ids = np.stack((x, y, z), axis=-1) + + # Find unique voxels and their indices + unique_voxels, inverse_indices = np.unique(voxel_ids, return_inverse=True, axis=0) + + # Sum energies of duplicate voxels + max_energy = np.zeros(len(unique_voxels)) + np.maximum.at(max_energy, inverse_indices, energy) + + # Construct updated shower with merged energies + merged_shower = ak.zip( + { + "x": unique_voxels[:, 0], + "y": unique_voxels[:, 1], + "z": unique_voxels[:, 2], + "energy": max_energy, + }, + with_name="data", + ) + + merged_showers.append(merged_shower) + + return ak.Array(merged_showers) + + +def merge_duplicates_numpy(rounded_showers): + """Merges duplicate voxels using NumPy for intermediate processing and takes the maximum + energy. + + Args: + rounded_showers: An Awkward Array of showers with voxel + coordinates. + + Returns: + An Awkward Array with duplicate voxels merged (maximum energy taken). + """ + + merged_showers = [] + for shower in rounded_showers: + # Convert to NumPy arrays + x = np.round(shower.x.to_numpy()).astype(int) + y = np.round(shower.y.to_numpy()).astype(int) + z = np.round(shower.z.to_numpy()).astype(int) + energy = shower.energy.to_numpy() + + # Create a structured array for easier duplicate handling + voxel_ids = np.stack((x, y, z), axis=-1) + + # Find unique voxels and their indices + unique_voxels, inverse_indices = np.unique(voxel_ids, return_inverse=True, axis=0) + + # Max energies of duplicate voxels + # max_energy = np.zeros(len(unique_voxels)) + # np.maximum.at(max_energy, inverse_indices, energy) + + # Shift duplicate voxels by z-layer if spot is free + for idx in range(len(voxel_ids)): + if np.sum(inverse_indices == idx) > 1: # Check for duplicates + # Find the unique voxel for this duplicate + unique_voxel = unique_voxels[inverse_indices[idx]] + + # Find a free spot in the z-direction and shift the duplicate + # Find the unique voxel for this duplicate + unique_voxel = unique_voxels[inverse_indices[idx]] + + # Get the energies of the duplicate voxels + duplicate_indices = np.where(inverse_indices == inverse_indices[idx])[0] + duplicate_energies = energy[duplicate_indices] + + # Find the index of the voxel with the highest energy + max_energy_index = duplicate_indices[np.argmax(duplicate_energies)] + + # Shift all other duplicate voxels + for duplicate_idx in duplicate_indices: + if duplicate_idx != max_energy_index: + shift = 1 + while True: + # Check for free spot in +z direction + if not np.any( + (voxel_ids == (unique_voxel + [0, 0, shift])).all(axis=1) + ): + voxel_ids[duplicate_idx] = unique_voxel + [0, 0, shift] + break + + # Check for free spot in -z direction + if not np.any( + (voxel_ids == (unique_voxel + [0, 0, -shift])).all(axis=1) + ): + voxel_ids[duplicate_idx] = unique_voxel + [0, 0, -shift] + break + + shift += 1 # Increment the shift for next iteration + + # Construct updated shower with maximum energies + merged_shower = ak.zip( + { + "x": voxel_ids[:, 0], + "y": voxel_ids[:, 1], + "z": voxel_ids[:, 2], + "energy": energy, + }, + with_name="data", + ) + + merged_showers.append(merged_shower) + + return ak.Array(merged_showers) diff --git a/gabbro/utils/optimizer/lookahead.py b/gabbro/utils/optimizer/lookahead.py new file mode 100644 index 0000000..f2865e0 --- /dev/null +++ b/gabbro/utils/optimizer/lookahead.py @@ -0,0 +1,117 @@ +"""Implements the Lookahead optimizer. + +Code taken from +https://github.com/hqucms/weaver-core/blob/main/weaver/utils/nn/optimizer/lookahead.py +""" + +from collections import defaultdict + +import torch +from torch.optim import Optimizer + + +# https://github.com/lonePatient/lookahead_pytorch/blob/1055128057408fe8533ffa30654551a317f07f0a/optimizer.py +class Lookahead(Optimizer): + """PyTorch implementation of the lookahead wrapper. + + Lookahead Optimizer: https://arxiv.org/abs/1907.08610 + """ + + def __init__(self, optimizer, alpha=0.5, k=6, pullback_momentum="none"): + """ + :param optimizer:inner optimizer + :param k (int): number of lookahead steps + :param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer. + :param pullback_momentum (str): change to inner optimizer momentum on interpolation update + """ + if not 0.0 <= alpha <= 1.0: + raise ValueError(f"Invalid slow update rate: {alpha}") + if not 1 <= k: + raise ValueError(f"Invalid lookahead steps: {k}") + self.optimizer = optimizer + self.alpha = alpha + self.k = k + self.step_counter = 0 + assert pullback_momentum in ["reset", "pullback", "none"] + self.pullback_momentum = pullback_momentum + self.defaults = optimizer.defaults + self.reset() + + def reset(self): + self.param_groups = self.optimizer.param_groups + self.state = defaultdict(dict) + + # Cache the current optimizer parameters + for group in self.optimizer.param_groups: + for p in group["params"]: + param_state = self.state[p] + param_state["cached_params"] = torch.zeros_like(p.data) + param_state["cached_params"].copy_(p.data) + + def __getstate__(self): + return { + "state": self.state, + "optimizer": self.optimizer, + "alpha": self.alpha, + "step_counter": self.step_counter, + "k": self.k, + "pullback_momentum": self.pullback_momentum, + } + + def zero_grad(self): + self.optimizer.zero_grad() + + def state_dict(self): + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) + self.reset() + + def _backup_and_load_cache(self): + """Useful for performing evaluation on the slow weights (which typically generalize + better)""" + for group in self.optimizer.param_groups: + for p in group["params"]: + param_state = self.state[p] + param_state["backup_params"] = torch.zeros_like(p.data) + param_state["backup_params"].copy_(p.data) + p.data.copy_(param_state["cached_params"]) + + def _clear_and_load_backup(self): + for group in self.optimizer.param_groups: + for p in group["params"]: + param_state = self.state[p] + p.data.copy_(param_state["backup_params"]) + del param_state["backup_params"] + + def step(self, closure=None): + """Performs a single Lookahead optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = self.optimizer.step(closure) + self.step_counter += 1 + + if self.step_counter >= self.k: + self.step_counter = 0 + # Lookahead and cache the current optimizer parameters + for group in self.optimizer.param_groups: + for p in group["params"]: + param_state = self.state[p] + p.data.mul_(self.alpha).add_( + param_state["cached_params"], alpha=1.0 - self.alpha + ) # crucial line + param_state["cached_params"].copy_(p.data) + if self.pullback_momentum == "pullback": + internal_momentum = self.optimizer.state[p]["momentum_buffer"] + self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_( + self.alpha + ).add_(param_state["cached_mom"], alpha=1.0 - self.alpha) + param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"] + elif self.pullback_momentum == "reset": + self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data) + + return loss diff --git a/gabbro/utils/optimizer/radam.py b/gabbro/utils/optimizer/radam.py new file mode 100644 index 0000000..0184bd7 --- /dev/null +++ b/gabbro/utils/optimizer/radam.py @@ -0,0 +1,289 @@ +"""Implements the RAdam optimizer. + +Code taken from https://github.com/hqucms/weaver-core/blob/main/weaver/utils/nn/optimizer/radam.py +""" + +import math + +import torch +from torch.optim.optimizer import Optimizer + + +# https://github.com/LiyuanLucasLiu/RAdam/blob/688cb1ec99944d52690c1034f6dcfe830b24d3fd/radam/radam.py +class RAdam(Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + degenerated_to_sgd=True, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if "betas" in param and ( + param["betas"][0] != betas[0] or param["betas"][1] != betas[1] + ): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)], + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state["step"]) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + + return loss + + +class PlainRAdam(Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + degenerated_to_sgd=True, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + self.degenerated_to_sgd = degenerated_to_sgd + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state["step"] += 1 + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + step_size = ( + group["lr"] + * math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) + / (1 - beta1 ** state["step"]) + ) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) + p.data.copy_(p_data_fp32) + elif self.degenerated_to_sgd: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + step_size = group["lr"] / (1 - beta1 ** state["step"]) + p_data_fp32.add_(exp_avg, alpha=-step_size) + p.data.copy_(p_data_fp32) + + return loss + + +class AdamW(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, warmup=warmup) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + denom = exp_avg_sq.sqrt().add_(group["eps"]) + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + if group["warmup"] > state["step"]: + scheduled_lr = 1e-8 + state["step"] * group["lr"] / group["warmup"] + else: + scheduled_lr = group["lr"] + + step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * scheduled_lr) + + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/gabbro/utils/optimizer/ranger.py b/gabbro/utils/optimizer/ranger.py new file mode 100644 index 0000000..02f9ab4 --- /dev/null +++ b/gabbro/utils/optimizer/ranger.py @@ -0,0 +1,20 @@ +"""Implementation of the Ranger optimizer. + +Code taken from https://github.com/hqucms/weaver-core/blob/main/weaver/utils/nn/optimizer/ranger.py +""" + +from .lookahead import Lookahead +from .radam import RAdam + + +def Ranger( + params, + lr=1e-3, # lr + betas=(0.95, 0.999), + eps=1e-5, + weight_decay=0, # RAdam options + alpha=0.5, + k=6, # LookAhead options +): + radam = RAdam(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + return Lookahead(radam, alpha, k) diff --git a/gabbro/utils/pylogger.py b/gabbro/utils/pylogger.py new file mode 100644 index 0000000..17258ce --- /dev/null +++ b/gabbro/utils/pylogger.py @@ -0,0 +1,35 @@ +import logging +import os + + +def get_pylogger(name=__name__, rank=None) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger. + + Parameters + ---------- + name : str, optional + Name of the logger. Default is __name__. + rank : int, optional + Rank of the current process. If not provided, it will be retrieved from + torch.distributed.get_rank(). + + Returns + ------- + logging.Logger + Logger object. + """ + if rank is None: + rank = "unknown" + rank_string = f"rank:{rank}" + + hostname = os.getenv("HOSTNAME", default="unknown-host") + + logger = logging.getLogger(f"{hostname}|{rank_string}|{name}") + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + # logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + # for level in logging_levels: + # setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/gabbro/utils/rich_utils.py b/gabbro/utils/rich_utils.py new file mode 100644 index 0000000..bca8e5d --- /dev/null +++ b/gabbro/utils/rich_utils.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning.utilities import rank_zero_only +from rich.prompt import Prompt + +from gabbro.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/gabbro/utils/submit_tools.py b/gabbro/utils/submit_tools.py new file mode 100644 index 0000000..2324958 --- /dev/null +++ b/gabbro/utils/submit_tools.py @@ -0,0 +1,165 @@ +"""Tools to help with submitting jobs to the cluster.""" + +import argparse +import itertools +import os +import re + + +def from_dict(dct): + """Return a function that looks up keys in dct.""" + + def lookup(match): + key = match.group(1) + return dct.get(key, f"<{key} not found>") + + return lookup + + +def convert_values_to_strings(dct): + """Convert all values in dct to strings.""" + return {k: str(v) for k, v in dct.items()} + + +def replace_placeholders(file_in, file_out, subs): + """Replace placeholders of the form @@@@ in file_in and write to file_out. + + Parameters + ---------- + file_in : str + Input file. + file_out : str + Output file. + subs : dict + Dictionary mapping placeholders to their replacements, i.e. `{"dummy": "foo"} + will replace @@dummy@@ with foo. + """ + with open(file_in) as f: + text = f.read() + with open(file_out, "w") as f: + f.write(re.sub("@@(.*?)@@", from_dict(subs), text)) + + +def calc_batches_per_node(config, mode="train"): + # convert parameters to floats + n_files_at_once = float(config[f"{mode}_n_files_at_once"]) + n_jets_per_file = float(config[f"{mode}_n_jets_per_file"]) + batch_size = float(config["batch_size"]) + num_gpus_per_node = float(config["num_gpus_per_node"]) + num_nodes = float(config["num_nodes"]) + + batches_per_node = ( + n_files_at_once * n_jets_per_file / batch_size / num_gpus_per_node / num_nodes + ) + config[f"limit_{mode}_batches"] = str(int(batches_per_node)) + + +def create_job_scripts_from_template_and_submit( + hparams_to_try, + hparams_defaults, + job_file_template="job_template.sh", +): + """Create job scripts from a template and submit them to the cluster. This function also + initialized as argument parser under the hood. I.e. the following command line arguments are + available if this function is. + + used in your script: + --dry_run: Don't actually submit the jobs. + --print_run_script: Print the run script of the individual jobs to the console. + --use_bash: Run the job script with bash instead of sbatch (for debugging on + interactive nodes). + + + Parameters + ---------- + hparams_to_try : dict + Dictionary mapping hyperparameters to lists of values to try. + Those parameters have to appear in the job_file_template with the + placeholders @@@@. + hparams_defaults : dict + Dictionary mapping hyperparameters to default values. + job_file_template : str + Path to the template file. + """ + + parser = get_job_script_parser() + args = parser.parse_args() + + for k, v in hparams_defaults.items(): + if k not in hparams_to_try: + hparams_to_try[k] = v + + combinations = list(itertools.product(*hparams_to_try.values())) + + for i, combination in enumerate(combinations): + subs = dict(zip(hparams_to_try.keys(), combination)) + subs = convert_values_to_strings(subs) + print(100 * "-") + print(f"Config {i+1}/{len(combinations)}:") + # ---- + # check if it was requested to calculate the limit_train_batches or limit_val_batches + limit_train_batches = subs.get("limit_train_batches") + limit_val_batches = subs.get("limit_val_batches") + if limit_train_batches is not None or limit_val_batches is not None: + if isinstance(limit_train_batches, str): + if limit_train_batches == "calculate": + print("Calculating limit_train_batches from other parameters.") + calc_batches_per_node(subs) + if isinstance(limit_val_batches, str): + if limit_val_batches == "calculate": + print("Calculating limit_val_batches from other parameters.") + calc_batches_per_node(subs, mode="val") + # ---- + # print key-value pairs formatted as a table + max_key_len = max(len(k) for k in subs.keys()) + for k, v in subs.items(): + print(f"{k:>{max_key_len}} : {v}") + print(100 * "-") + replace_placeholders(job_file_template, "run_tmp.sh", subs) + + # if "use_bash" is true, remove "srun " from the run script + if args.use_bash: + with open("run_tmp.sh") as f: + run_script = f.read() + run_script = run_script.replace("srun ", "") + with open("run_tmp.sh", "w") as f: + f.write(run_script) + + if args.print_run_script: + print("Run script:") + print("-----------") + with open("run_tmp.sh") as f: + print(f.read()) + if not args.dry_run: + if args.use_bash: + os.system("bash run_tmp.sh") # nosec + else: + os.system("sbatch run_tmp.sh") # nosec + + +def get_job_script_parser(): + """Return an argument parser for job scripts. + + Returns + ------- + argparse.ArgumentParser + Argument parser for job scripts with the following flags: + --dry_run: Don't actually submit the jobs. + --print_run_script: Print the run script of the individual jobs to the console. + --use_bash: Run the job script with bash instead of sbatch (for debugging on interactive nodes). + """ + parser = argparse.ArgumentParser() + parser.add_argument("--dry_run", action="store_true", help="Don't actually submit the jobs.") + parser.add_argument( + "--print_run_script", + action="store_true", + default=False, + help="Print the run script of the individual jobs to the console.", + ) + parser.add_argument( + "--use_bash", + action="store_true", + default=False, + help="Run the job script with bash instead of sbatch (for debugging on interactive nodes).", + ) + return parser diff --git a/gabbro/utils/utils.py b/gabbro/utils/utils.py new file mode 100644 index 0000000..b50eacb --- /dev/null +++ b/gabbro/utils/utils.py @@ -0,0 +1,474 @@ +import json +import warnings +from importlib.util import find_spec +from typing import Callable, List + +import awkward as ak +import hydra +import numpy as np +from omegaconf import DictConfig +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger +from pytorch_lightning.utilities import rank_zero_only + +from gabbro.utils import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def translate_bash_range(wildcard: str, verbose: bool = False): + """Translate bash range to list of strings with the corresponding numbers. + + Parameters + ---------- + wildcard : str + Wildcard string with bash range (or not). + verbose : bool, optional + If True, print debug messages. + + Returns + ------- + list + List of strings with the corresponding numbers. + """ + + # raise value error if two ranges are found + if wildcard.count("{") > 1: + raise ValueError( + f"Only one range is allowed in the wildcard. Provided the following wildcard: {wildcard}" + ) + + if "{" in wildcard and ".." in wildcard and "}" in wildcard: + log.info("Bash range found in wildcard --> translating to list of remaining wildcards.") + start = wildcard.find("{") + end = wildcard.find("}") + prefix = wildcard[:start] + suffix = wildcard[end + 1 :] + wildcard_range = wildcard[start + 1 : end] + start_number = int(wildcard_range.split("..")[0]) + end_number = int(wildcard_range.split("..")[1]) + if verbose: + log.info( + f"Prefix: {prefix}, Suffix: {suffix}, Start: {start_number}, End: {end_number}" + ) + return [f"{prefix}{i}{suffix}" for i in range(start_number, end_number + 1)] + else: + # print("No range found in wildcard") + return [wildcard] + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that wraps the task function in extra utilities. + + Makes multirun more resistant to failure. + + Utilities: + - Calling the `utils.extras()` before the task is started + - Calling the `utils.close_loggers()` after the task is finished or failed + - Logging the exception if occurs + - Logging the output dir + """ + + def wrap(cfg: DictConfig): + # execute the task + try: + # apply extra utilities + extras(cfg) + + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # when using hydra plugins like Optuna, you might want to disable raising exception + # to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # close loggers (even if exception occurs so multirun won't fail) + close_loggers() + + return metric_dict, object_dict + + return wrap + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig, ckpt_path: str = None) -> List[Callback]: + """Instantiates callbacks from config.""" + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + hparams["run_note"] = cfg.get("run_note") + hparams["git"] = object_dict.get("git") + hparams["slurm"] = object_dict.get("slurm") + hparams["load_weights_from"] = object_dict.get("load_weights_from") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def close_loggers() -> None: + """Makes sure all loggers closed properly (prevents logging failure during multirun).""" + + log.info("Closing loggers...") + + if find_spec("wandb"): # if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + +@rank_zero_only +def save_file(path: str, content: str) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content) + + +def get0Momentum(x: float, weights: float) -> float: + # calculate the 0-momentum + out = (x * weights).sum(-1) + return out / weights.sum(-1) + + +def get_diff_construct(original: float, reconstructed: float) -> float: + # calculate the difference + return reconstructed - original + + +def KL(data1: np.ndarray, data2: np.ndarray, bins: int) -> float: + """Calculates the KL divergence between two probability distributions. + + Args: + data1: The first dataset (samples). + data2: The second dataset (samples). + bins: The number of bins for creating histograms. + + Returns: + The KL divergence (a non-negative float). + """ + hist1, bin_edges1 = np.histogram(data1, bins, density=True) + hist2, bin_edges2 = np.histogram(data2, bins, density=True) + epsilon = 1e-8 + hist1 = np.maximum(hist1, epsilon) # Ensure no zero values + hist2 = np.maximum(hist2, epsilon) + + # Assuming you have your hist1 and hist2 arrays from before + + # Get bin widths + bin_widths1 = np.diff(bin_edges1) + bin_widths2 = np.diff(bin_edges2) + + # Calculate approximate probabilities + hist1 = hist1 * bin_widths1 + hist2 = hist2 * bin_widths2 + + # Handle cases where bins are zero in either histogram + nonzero_mask = (hist1 != 0) & (hist2 != 0) + + # Calculate KL divergence only for non-zero bins + kl_div = np.sum(hist1[nonzero_mask] * np.log(hist1[nonzero_mask] / hist2[nonzero_mask])) + return kl_div + + +def find_max_energy_z(energy: ak.Array, z: ak.Array) -> ak.Array: + """Finds the z-value corresponding to the maximum energy in each shower. + + Args: + energy: Awkward array of energy values for each shower. + z: Awkward array of z-values for each shower. + + Returns: + Awkward array of z-values corresponding to the maximum energy in each shower. + """ + z = ak.fill_none(z, 0) + energy = ak.fill_none(energy, 0) + max_energy_indices = ak.argmax(energy, axis=1) + max_energy_indices = ak.fill_none(max_energy_indices, 0) + + shower_indices = ak.from_numpy(np.arange(len(z))) + + max_energy_z_values = z[shower_indices, max_energy_indices] + + return ak.to_numpy(max_energy_z_values) + + +def get_COG_ak(x: ak.Array, weights: ak.Array) -> ak.Array: + """Calculates the 0-momentum for each individual shower (subarray) in an awkward array. + + Args: + x: Awkward array of coordinates. + weights: Awkward array of weights (e.g., energy). + + Returns: + Awkward array of 0-momentum values for each shower. + """ + # Element-wise multiplication for each shower + weighted_x = x * weights + + # Calculate the sum of weighted x and sum of weights for each shower + sum_weighted_x = ak.sum(weighted_x, axis=-1) + sum_weights = ak.sum(weights, axis=-1) + + # Divide the sums to get the 0-momentum for each shower + return sum_weighted_x / sum_weights + + +def find_radial_profile(x: ak.Array, y: ak.Array) -> ak.Array: + """finds the energy-weighted distances from the incident point in the x-y-plane. + + Args: + x: Awkward array of x-coordinates for each point in the shower. + y: Awkward array of y-coordinates for each point in the shower. + energy: Awkward array of energy values for each point in the shower. + + Returns: + Awkward array of radial distances from the center for each point in the shower. + """ + + # Calculate the middle (mean) of x and y coordinates for each shower + x_middle = 14.5 + y_middle = 14.5 + + # Calculate the radial distance from the center + radial_distance = ((x - x_middle) ** 2 + (y - y_middle) ** 2) ** 0.5 + # Fill None values with 0 to ensure each axis has the same length + radial_distance = ak.flatten(radial_distance, axis=1) + radial_distance = ak.to_numpy(radial_distance) + + return radial_distance + + +def sum_energy_per_radial_distance(x: ak.Array, y: ak.Array, energy: ak.Array) -> ak.Array: + """Sums up the energy per radial distance bin for each shower. + + Args: + radial_distance: Awkward array of radial distances for each point in the shower. + energy: Awkward array of energy values for each point in the shower. + Returns: + Awkward array of summed energy values per radial distance bin for each shower. + """ + x_middle = 14.5 + y_middle = 14.5 + + # Calculate the radial distance from the center + radial_distance = ((x - x_middle) ** 2 + (y - y_middle) ** 2) ** 0.5 + result = [] + radial_bins = np.arange(0, 22) + for radial_shower, energy_shower in zip(radial_distance, energy): + # Bin energy values according to radial distance bins + hist, _ = np.histogram(radial_shower, bins=radial_bins, weights=energy_shower) + result.append(hist) + # Convert the result to an awkward array for better handling + binned_energy = ak.Array(result) + return binned_energy + + +def sum_energy_per_layer(z: ak.Array, energy: ak.Array) -> ak.Array: + """Sums up the energy per layer (z-bin) for each shower. + + Args: + z: Awkward array of z-coordinates for each point in the shower. + energy: Awkward array of energy values for each point in the shower. + z_bins: Array of z-bin edges. + Returns: + Awkward array of summed energy values per z-bin for each shower. + """ + result = [] + z_bins = np.arange(0, 30) + for z_shower, energy_shower in zip(z, energy): + # Bin energy values according to Z-bins + hist, _ = np.histogram(z_shower, bins=z_bins, weights=energy_shower) + result.append(hist) + # Convert the result to an awkward array for better handling + binned_energy = ak.Array(result) + return binned_energy + + +def write_distances_to_json(kld, wasserstein, filepath, weights, n_data, feature): + """Writes KLD and Wasserstein distances to a JSON file, structured for plotting. + + Args: + kld: Kullback-Leibler divergence. + wasserstein: Wasserstein distance. + filepath: Path to the JSON file. + weights: "weights" or "no weights". + n_data: Number of training data (e.g., "100", "1000", "10000"). + feature: Feature name (e.g., "energy", "energy_sum", "max_z"). + """ + + # Load existing JSON data if the file exists + try: + with open(filepath) as f: + data = json.load(f) + except FileNotFoundError: + data = {} + + # Create a nested dictionary for the feature + if feature not in data: + data[feature] = {} + + # Add data for the current setting + if weights not in data[feature]: + data[feature][weights] = {} + + if n_data not in data[feature][weights]: + data[feature][weights][n_data] = {} + + if "kld" not in data[feature][weights][n_data]: + data[feature][weights][n_data] = {"kld": [], "wasserstein": []} + + # Add the new entry to the data list + data[feature][weights][n_data]["kld"].append(kld) + data[feature][weights][n_data]["wasserstein"].append(wasserstein) + + # Write the updated data back to the JSON file + with open(filepath, "w") as f: + json.dump(data, f, indent=4) + + +# Analyze the first 10 tokens of each shower and their commonality +def analyze_first_10_tokens(token_ids: ak.Array) -> np.ndarray: + """Analyzes the first 10 tokens of each shower and their commonality. + + Args: + token_ids: Awkward array of token IDs for each shower. + + Returns: + Dictionary with unique token sequences as keys and their counts as values. + """ + first_10_tokens = ak.to_numpy(ak.pad_none(token_ids[:, 1:11], 10, clip=True)) + unique, counts = np.unique(first_10_tokens.flatten(), return_counts=True) + counts = np.sort(counts)[::-1] + return counts diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..443fb67 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[tool.pytest.ini_options] +addopts = [ + "--color=yes", + "--durations=0", + "--strict-markers", + "--doctest-modules", +] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore::UserWarning", +] +log_cli = "True" +markers = [ + "slow: slow tests", +] +minversion = "6.0" +testpaths = "tests/" + +[tool.coverage.report] +exclude_lines = [ + "pragma: nocover", + "raise NotImplementedError", + "raise NotImplementedError()", + "if __name__ == .__main__.:", +] +[tool.black] +line-length = 99 +preview = "True" + +[tool.ruff] +include = ["*.ipynb"] +line-length = 99 diff --git a/scripts/tokenize_shower.py b/scripts/tokenize_shower.py new file mode 100644 index 0000000..1613e84 --- /dev/null +++ b/scripts/tokenize_shower.py @@ -0,0 +1,204 @@ +import argparse +import logging +import shutil +from pathlib import Path + +import vector +import yaml +from omegaconf import OmegaConf + +from gabbro.data.data_tokenization import reconstruct_shower_file, tokenize_shower_file + +# import gabbro.plotting as jplt + +vector.register_awkward() + +logger = logging.getLogger(__name__) + + +def copy_checkpoint(ckpt_path, directory): + """Copies a checkpoint file to a specified directory.""" + ckpt_path = Path(ckpt_path) + directory = Path(directory) + + if not ckpt_path.exists(): + raise FileNotFoundError(f"Checkpoint file not found at: {ckpt_path}") + + directory.mkdir(parents=True, exist_ok=True) # Create the directory if it doesn't exist + + new_ckpt_path = directory / "model.ckpt" # Maintain the original filename + + try: + shutil.copy2(ckpt_path, new_ckpt_path) # Use shutil.copy2 to preserve metadata + except Exception as e: + raise RuntimeError(f"Error copying checkpoint file: {e}") + + print(f"Checkpoint file copied to: {new_ckpt_path}") + + +def load_config(config_file_path): + """Loads the YAML configuration file and returns it as a dictionary.""" + config_path = Path(config_file_path) + try: + with config_path.open("r") as f: + config = yaml.safe_load(f) + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found at: {config_file_path}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML format in configuration file: {e}") + return config + + +def main(): + parser = argparse.ArgumentParser(description="Your script description") + parser.add_argument("config_file", help="Path to the YAML configuration file") + args = parser.parse_args() + + config = load_config(args.config_file) + + network = config["data"]["network"] + epoch = config["data"]["epoch"] + tokenize = config["data"]["tokenize"] + reconstruct = config["data"]["reconstruct"] + save_config = config["data"]["save_config"] + save_ckpt = config["data"]["save_ckpt"] + train = config["data"]["train"] + test = config["data"]["test"] + val = config["data"]["val"] + + print(f"networkfile: {network}") + print(f"epoch: {epoch}") + # network = "2024-06-06_10-04-25_max-cmsg010_OutermostOsteopetrosis" + # epoch = "epoch_013_loss_13236648851865075712.00000.ckpt" + filename_in_train = "/beegfs/desy/user/korcariw/CaloClouds/dataset/showers/photons_10_100GeV_float32_sorted_train.h5" + filename_in_test = "/beegfs/desy/user/korcariw/CaloClouds/dataset/showers/photons_10_100GeV_float32_sorted_test.h5" + filename_in_val = "/beegfs/desy/user/korcariw/CaloClouds/dataset/showers/photons_10_100GeV_float32_sorted_val.h5" + + ckpt_path = ( + f"/beegfs/desy/user/rosehenn/gabbro_output/TokTrain/runs/{network}/checkpoints/{epoch}" + ) + + directory = f"/beegfs/desy/user/rosehenn/gabbro/compare/{network}" + filename_out_test = ( + f"/beegfs/desy/user/rosehenn/gabbro/compare/{network}/tokenized_test.parquet" + ) + filename_out_train = ( + f"/beegfs/desy/user/rosehenn/gabbro/compare/{network}/tokenized_train.parquet" + ) + filename_out_val = f"/beegfs/desy/user/rosehenn/gabbro/compare/{network}/tokenized_val.parquet" + + filename_out_2_train = ( + f"/beegfs/desy/user/rosehenn/gabbro/compare/{network}/reconstructed_train.parquet" + ) + filename_out_2_test = ( + f"/beegfs/desy/user/rosehenn/gabbro/compare/{network}/reconstructed_test.parquet" + ) + filename_out_2_val = ( + f"/beegfs/desy/user/rosehenn/gabbro/compare/{network}/reconstructed_val.parquet" + ) + config_path = f"/beegfs/desy/user/rosehenn/gabbro_output/TokTrain/runs/{network}/config.yaml" + + if tokenize: + print("Tokenizing data...") + if train: + # this function will save the tokenized data to a parquet file in the desired location + tokens_int, p4s_original, data_showers = tokenize_shower_file( + filename_in=filename_in_train, + model_ckpt_path=ckpt_path, + filename_out=filename_out_train, + add_start_end_tokens=True, + energy_sorting=True, + n_load=760000, + ) + print("Tokenized training data saved to:", filename_out_train) + + if test: + tokens_int, p4s_original, data_showers = tokenize_shower_file( + filename_in=filename_in_test, + model_ckpt_path=ckpt_path, + filename_out=filename_out_test, + add_start_end_tokens=True, + energy_sorting=True, + n_load=760000, + ) + print("Tokenized testing data saved to:", filename_out_test) + + if val: + tokens_int, p4s_original, data_showers = tokenize_shower_file( + filename_in=filename_in_val, + model_ckpt_path=ckpt_path, + filename_out=filename_out_val, + add_start_end_tokens=True, + energy_sorting=True, + n_load=760000, + ) + print("Tokenized validation data saved to:", filename_out_val) + + if reconstruct: + print("Reconstructing data...") + if train: + data, p4data = reconstruct_shower_file( + filename_in=filename_out_train, + model_ckpt_path=ckpt_path, + config_path=config_path, + filename_out=filename_out_2_train, + start_token_included=True, + end_token_included=True, + shift_tokens_by_minus_one=True, + print_model=False, + device="cuda", + merge_duplicates=True, + ) + print("Reconstructed training data saved to:", filename_out_2_train) + if test: + data, p4data = reconstruct_shower_file( + filename_in=filename_out_test, + model_ckpt_path=ckpt_path, + config_path=config_path, + filename_out=filename_out_2_test, + start_token_included=True, + end_token_included=True, + shift_tokens_by_minus_one=True, + print_model=False, + device="cuda", + merge_duplicates=True, + ) + print("Reconstructed testing data saved to:", filename_out_2_test) + + if val: + data, p4data = reconstruct_shower_file( + filename_in=filename_out_val, + model_ckpt_path=ckpt_path, + config_path=config_path, + filename_out=filename_out_2_val, + start_token_included=True, + end_token_included=True, + shift_tokens_by_minus_one=True, + print_model=False, + device="cuda", + merge_duplicates=True, + ) + print("Reconstructed validation data saved to:", filename_out_2_val) + + if save_config: + output_dir = Path(directory) + output_dir.mkdir( + parents=True, exist_ok=True + ) # Create the output directory if it doesn't exist + # Extract the original filename + original_filename = Path(config_path).name + # Construct the new file path within the output directory + new_config_path = output_dir / original_filename + + with new_config_path.open("w") as f: + config_of_network = load_config(config_path) + OmegaConf.save(config_of_network, f) + + print(f"Modified configuration saved to: {new_config_path}") + + if save_ckpt: # Or a Path object + copy_checkpoint(ckpt_path, directory) + + +if __name__ == "__main__": + main() diff --git a/scripts/tokenize_shower.yaml b/scripts/tokenize_shower.yaml new file mode 100644 index 0000000..5018a61 --- /dev/null +++ b/scripts/tokenize_shower.yaml @@ -0,0 +1,13 @@ +data: + + network: "2024-09-21_16-54-39_max-wng062_CerousLocknut" + epoch: "epoch_231_loss_0.17179.ckpt" + + + tokenize: True # Will transform the data into tokens and save it to the compare directory + reconstruct: True # Will transform the tokens back into the original data and save it to the compare directory + save_config: True # Will save the configuration file to the compare directory + save_ckpt: True # Will save the checkpoint file to the compare directory + train: True # Will transform the train data + test: False # Will transform the test data + val: False # Will transform the validation data diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d61a0ef --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,377 @@ +"""Tests for functions in gabbro/utils.""" + +import unittest + +import awkward as ak +import numpy as np + +from gabbro.utils.arrays import ( + ak_clip, + ak_pad, + ak_select_and_preprocess, + ak_smear, + ak_to_np_stack, + count_appearances, + np_to_ak, + sort_by_pt, +) + + +class TestNpToAk(unittest.TestCase): + def setUp(self): + # np array of shape (2, 3, 2) (2 jets, 3 constituents, 2 features) + self.np_array = np.array( + [ + [[1, 2], [3, 3], [0, 0]], + # also want to mask the "4", but checking if this is corrected by the function + [[2, 2], [4, 0], [0, 0]], + ] + ) + self.np_mask = np.array( + [ + [True, True, False], + [True, False, False], + ] + ) + self.names = ["pt", "eta"] + self.ak_arrary_expected = ak.Array( + { + "pt": [[1, 3], [2]], + "eta": [[2, 3], [2]], + } + ) + self.ak_arrary_expected_without_mask = ak.Array( + { + "pt": [[1, 3, 0], [2, 4, 0]], + "eta": [[2, 3, 0], [2, 0, 0]], + } + ) + + def test_np_to_ak_with_mask(self): + result = np_to_ak(self.np_array, mask=self.np_mask, names=self.names) + for i, name in enumerate(self.names): + self.assertTrue(ak.all(result[name] == self.ak_arrary_expected[name])) + + def test_np_to_ak_without_mask(self): + result = np_to_ak(self.np_array, names=self.names) + for i, name in enumerate(self.names): + self.assertTrue(ak.all(result[name] == self.ak_arrary_expected_without_mask[name])) + + +class TestAkToNpStack(unittest.TestCase): + def setUp(self): + self.ak_array = ak.Array( + { + "pt": [[1, 2, 3], [2, 4]], + "eta": [[0, 0, 0], [2, 2]], + "phi": [[0, 0, 0], [3, 3]], + "E": [[1, 1, 1], [4, 4]], + } + ) + # use as the arget array a version where the pt and eta are swapped + # --> this check both that the order of the stacked fields is correct + # and that not all features have to be selected + self.np_array_padded_len5_eta_pt = np.array( + [ + [[0, 1], [0, 2], [0, 3], [0, 0], [0, 0]], + [[2, 2], [2, 4], [0, 0], [0, 0], [0, 0]], + ] + ) + + def test_ak_to_np_stack(self): + input_data = ak_pad(self.ak_array, maxlen=5, axis=1) + result = ak_to_np_stack(input_data, axis=2, names=["eta", "pt"]) + + try: + self.assertTrue(np.array_equal(result, self.np_array_padded_len5_eta_pt)) + except AssertionError: + print("Arrays are not equal:") + print("Expected:", self.np_array_padded_len5_eta_pt) + print("Actual:", result) + raise AssertionError + + +class TestSortByPt(unittest.TestCase): + def test_error_raise(self): + input_array_without_pt = ak.Array( + { + "eta": [[0, 0, 1], [2, 1]], + "phi": [[0, 0, 1], [2, 1]], + } + ) + + # check that AttributeError is raised + with self.assertRaises(AttributeError): + sort_by_pt(input_array_without_pt) + + def test_sorting_order(self): + """Test the function sort_by_pt()""" + + input_array = ak.Array( + { + "pt": [[2, 1, 3], [2, 4]], + "eta": [[0, 0, 1], [2, 1]], + } + ) + expected_sorted_array = ak.Array( + { + "pt": [[3, 2, 1], [4, 2]], + "eta": [[1, 0, 0], [1, 2]], + } + ) + + sorted_array = sort_by_pt(input_array) + + # compare array.pt and array.eta as lists + self.assertEqual(sorted_array.pt.tolist(), expected_sorted_array.pt.tolist()) + self.assertEqual(sorted_array.eta.tolist(), expected_sorted_array.eta.tolist()) + + +class TestAkSelectAndPreprocess(unittest.TestCase): + def setUp(self): + self.input_array = ak.Array( + { + "pt": [[2, 1], [2]], + "eta": [[0, 1], [1]], + "phi": [[0, 1], [2]], + } + ) + self.pp_dict = { + "pt": {"subtract_by": 1, "multiply_by": 3, "func": "np.log", "inv_func": "np.exp"}, + "eta": {"subtract_by": 0, "multiply_by": 2}, + "phi": None, + } + self.expected_output = ak.Array( + { + "pt": [[(np.log(2) - 1) * 3, (np.log(1) - 1) * 3], [(np.log(2) - 1) * 3]], + "eta": [[0 * 2, 1 * 2], [1 * 2]], + "phi": [[0, 1], [2]], + } + ) + + def test_ak_select_and_preprocess(self): + result = ak_select_and_preprocess(self.input_array, pp_dict=self.pp_dict) + for field in self.input_array.fields: + self.assertEqual(result[field].tolist(), self.expected_output[field].tolist()) + + def test_ak_select_and_preprocess_no_inv_func(self): + """Test error raise if `func` is defined, but `inv_func` isn't.""" + + pp_dict_wrong = { + "pt": {"subtract_by": 1, "multiply_by": 3, "func": "np.log"}, + } + with self.assertRaises(ValueError): + ak_select_and_preprocess(self.input_array, pp_dict=pp_dict_wrong) + + def test_ak_select_and_preprocess_no_func(self): + """Test error raise if `inv_func` is defined, but `func` isn't.""" + + pp_dict_wrong = { + "pt": {"subtract_by": 1, "multiply_by": 3, "inv_func": "np.exp"}, + } + with self.assertRaises(ValueError): + ak_select_and_preprocess(self.input_array, pp_dict=pp_dict_wrong) + + def test_inverse_sanity_check(self): + """Test that applying the preprocessing once and then applying the inverse results in the + same array again.""" + + result = ak_select_and_preprocess(self.input_array, pp_dict=self.pp_dict) + result = ak_select_and_preprocess(result, pp_dict=self.pp_dict, inverse=True) + + for field in self.pp_dict.keys(): + self.assertEqual(result[field].tolist(), self.input_array[field].tolist()) + + def test_single_selection_cut(self): + """Test that the function applies a single selection cut correctly.""" + + arr = ak.Array( + { + "pt": [[1, 2, 3], [4, 5]], + "eta": [[5, 6, 7], [8, 9]], + } + ) + pp_dict = {"pt": {"larger_than": 2}, "eta": None} + arr_selected_expected = ak.Array( + { + "pt": [[3], [4, 5]], + "eta": [[7], [8, 9]], + } + ) + arr_selected = ak_select_and_preprocess(arr, pp_dict) + for field in pp_dict.keys(): + self.assertEqual(arr_selected[field].tolist(), arr_selected_expected[field].tolist()) + + def test_multiple_selection_cuts(self): + """Test that the function applies multiple selection cuts correctly.""" + arr = ak.Array( + { + "pt": [[1, 2, 3], [4, 5]], + "eta": [[5, 6, 7], [8, 9]], + } + ) + pp_dict = {"pt": {"larger_than": 2}, "eta": {"smaller_than": 9}} + arr_selected_expected = ak.Array( + { + "pt": [[3], [4]], + "eta": [[7], [8]], + } + ) + arr_selected = ak_select_and_preprocess(arr, pp_dict) + for field in pp_dict.keys(): + self.assertEqual(arr_selected[field].tolist(), arr_selected_expected[field].tolist()) + + def test_selection_and_transform_combined(self): + """Test that the function applies selection cuts and transforms the input array.""" + arr = ak.Array( + { + "pt": [[1, 2, 3], [4, 5]], + "eta": [[5, 6, 7], [8, 9]], + } + ) + pp_dict = { + "pt": { + "larger_than": 2, + "subtract_by": 1, + "multiply_by": 3, + }, + "eta": {"smaller_than": 9}, + } + arr_selected_expected = ak.Array( + { + "pt": [[(3 - 1) * 3], [(4 - 1) * 3]], + "eta": [[7], [8]], + } + ) + arr_selected = ak_select_and_preprocess(arr, pp_dict) + for field in pp_dict.keys(): + self.assertEqual(arr_selected[field].tolist(), arr_selected_expected[field].tolist()) + + +class TestAkSmearAndClip(unittest.TestCase): + def setUp(self): + self.input_array = ak.Array( + { + "pt": [[2, 1], [2]], + } + ) + + def test_smear(self): + """Test that the function smears the input array.""" + result = ak_smear(self.input_array["pt"], sigma=0.05, seed=101) + expected_result = [ + [1.9604923750018493, 0.8982687259084063], + [2.030165087346238], + ] + self.assertEqual(result.tolist(), expected_result) + + def test_clipmin(self): + """Test that the function clips the input array to min value.""" + result = ak_clip(self.input_array["pt"], clip_min=1.5) + expected_result = [ + [2, 1.5], + [2], + ] + self.assertEqual(result.tolist(), expected_result) + + def test_clipmax(self): + """Test that the function clips the input array to max value.""" + result = ak_clip(self.input_array["pt"], clip_max=1.5) + expected_result = [ + [1.5, 1], + [1.5], + ] + self.assertEqual(result.tolist(), expected_result) + + def test_clipminmax(self): + """Test that the function clips the input array to min and max value.""" + result = ak_clip(self.input_array["pt"], clip_min=1.5, clip_max=1.8) + expected_result = [ + [1.8, 1.5], + [1.8], + ] + self.assertEqual(result.tolist(), expected_result) + + def test_smear_and_clip(self): + """Test that the function smears and clips the input array.""" + result = ak_clip( + ak_smear( + self.input_array["pt"], + sigma=0.05, + seed=101, + ), + clip_min=0.9, + clip_max=2.01, + ) + expected_result = [ + [1.9604923750018493, 0.9], + [2.01], + ] + self.assertEqual(result.tolist(), expected_result) + + +class TestTokenCounting(unittest.TestCase): + def __init__(self): + tokens_dev = np.array( + [ + [1, 2, 4, 4, 0], + [6, 6, 6, 6, 0], + ] + ) + mask_dev = np.array( + [ + [1, 1, 1, 1, 0], + [1, 1, 1, 1, 0], + ] + ) + + token_counts_expected = np.array( + # each row corresponds to how often the token appears in the respective jet + [ + [0, 1, 1, 0, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 4], + ] + ) + n_token_appearance_expected = np.array( + # each row corresponds to how many tokens are part of the group + # of appearing i-times. E.g. here + # in the first jet: + # - 4 tokens appear 0 times + # - 2 token appears 1 time, + # - 1 token appears 2 times. + # in the second jet: + # - 6 tokens appear 0 times + # - 1 tokens appear 4 times + [ + [4, 2, 1, 0, 0], + [6, 0, 0, 0, 1], + ] + ) + frac_token_appearance_expected = np.array( + # each row corresponds to the fraction of tokens being part of the + # group that appear i-times + # I.e. here: + # in the first jet: + # - first one is a filler + # - 50% of the tokens appear 1 time + # - 50% of the tokens appear 2 times + # in the second jet: + # - first one is a filler + # - 100% of the tokens appear 4 times + [ + [ + [0.0, 0.5, 0.5, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ] + ] + ) + + token_counts, n_token_appearance, frac_token_appearance = count_appearances( + tokens_dev, + mask=mask_dev, + count_up_to=4, + ) + + assert np.array_equal(token_counts, token_counts_expected) + assert np.array_equal(n_token_appearance, n_token_appearance_expected) + assert np.array_equal(frac_token_appearance, frac_token_appearance_expected) diff --git a/tests/test_utils_jet_types.py b/tests/test_utils_jet_types.py new file mode 100644 index 0000000..e728346 --- /dev/null +++ b/tests/test_utils_jet_types.py @@ -0,0 +1,53 @@ +"""Tests for functions in gabbro/utils/jet_types.""" + +import unittest + +from gabbro.utils.jet_types import ( + get_jet_type_from_file_prefix, + get_numerical_label_from_file_prefix, + get_tex_label_from_numerical_label, +) + + +class TestJetTypeHelperFunctions(unittest.TestCase): + def test_get_numerical_from_fileprefix(self): + self.assertEqual(get_numerical_label_from_file_prefix("ZJetsToNuNu_"), 0) + self.assertEqual(get_numerical_label_from_file_prefix("HToBB_"), 1) + self.assertEqual(get_numerical_label_from_file_prefix("HToCC_"), 2) + self.assertEqual(get_numerical_label_from_file_prefix("HToGG_"), 3) + self.assertEqual(get_numerical_label_from_file_prefix("HToWW4Q_"), 4) + self.assertEqual(get_numerical_label_from_file_prefix("HToWW2Q1L_"), 5) + self.assertEqual(get_numerical_label_from_file_prefix("ZToQQ_"), 6) + self.assertEqual(get_numerical_label_from_file_prefix("WToQQ_"), 7) + self.assertEqual(get_numerical_label_from_file_prefix("TTBar_"), 8) + self.assertEqual(get_numerical_label_from_file_prefix("TTBarLep_"), 9) + with self.assertRaises(ValueError): + get_numerical_label_from_file_prefix("invalid_prefix") + + def test_get_tex_label_from_numerical_label(self): + self.assertEqual(get_tex_label_from_numerical_label(0), "$q/g$") + self.assertEqual(get_tex_label_from_numerical_label(1), "$H\\rightarrow b\\bar{b}$") + self.assertEqual(get_tex_label_from_numerical_label(2), "$H\\rightarrow c\\bar{c}$") + self.assertEqual(get_tex_label_from_numerical_label(3), "$H\\rightarrow gg$") + self.assertEqual(get_tex_label_from_numerical_label(4), "$H\\rightarrow 4q$") + self.assertEqual(get_tex_label_from_numerical_label(5), "$H\\rightarrow \\ell\\nu qq'$") + self.assertEqual(get_tex_label_from_numerical_label(6), "$Z\\rightarrow q\\bar{q}$") + self.assertEqual(get_tex_label_from_numerical_label(7), "$W\\rightarrow qq'$") + self.assertEqual(get_tex_label_from_numerical_label(8), "$t\\rightarrow bqq'$") + self.assertEqual(get_tex_label_from_numerical_label(9), "$t\\rightarrow b\\ell\\nu$") + with self.assertRaises(ValueError): + get_tex_label_from_numerical_label(10) + + def test_get_jet_type_from_file_prefix(self): + self.assertEqual(get_jet_type_from_file_prefix("ZJetsToNuNu_"), "QCD") + self.assertEqual(get_jet_type_from_file_prefix("HToBB_"), "Hbb") + self.assertEqual(get_jet_type_from_file_prefix("HToCC_"), "Hcc") + self.assertEqual(get_jet_type_from_file_prefix("HToGG_"), "Hgg") + self.assertEqual(get_jet_type_from_file_prefix("HToWW4Q_"), "H4q") + self.assertEqual(get_jet_type_from_file_prefix("HToWW2Q1L_"), "Hqql") + self.assertEqual(get_jet_type_from_file_prefix("ZToQQ_"), "Zqq") + self.assertEqual(get_jet_type_from_file_prefix("WToQQ_"), "Wqq") + self.assertEqual(get_jet_type_from_file_prefix("TTBar_"), "Tbqq") + self.assertEqual(get_jet_type_from_file_prefix("TTBarLep_"), "Tbl") + with self.assertRaises(ValueError): + get_jet_type_from_file_prefix("invalid_prefix")