Skip to content

Commit

Permalink
Added explanatory Codebook
Browse files Browse the repository at this point in the history
  • Loading branch information
HenningRose committed Jan 13, 2025
1 parent 3e0305a commit 8bf0c7b
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 10 deletions.
100 changes: 93 additions & 7 deletions examples/example_notebooks/example_generate_showers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"import awkward as ak\n",
"import numpy as np\n",
"import vector\n",
"from omegaconf import OmegaConf\n",
"\n",
"sys.path.append(\"/beegfs/desy/user/rosehenn/gabbro\")"
"sys.path.append(\"/data/dust/user/rosehenn/gabbro\")"
]
},
{
Expand All @@ -36,8 +40,7 @@
"# this checkpoint is the checkpoint from a backbone training with the nex-token-prediction head\n",
"# make sure you have downloaded the checkpoint in advance\n",
"# if not, run the script `checkpoints/download_checkpoints.sh`\n",
"ckpt_path = \"/beegfs/desy/user/rosehenn/gabbro_output/full_resolution/runs/2024-11-21_13-49-55_max-wng060_TerminativeCirculation/checkpoints/epoch_032_loss_4.10881.ckpt\"\n",
"\n",
"ckpt_path = \"/data/dust/user/rosehenn/gabbro_output/full_resolution/runs/2024-11-21_13-49-55_max-wng060_TerminativeCirculation/checkpoints/epoch_032_loss_4.10881.ckpt\"\n",
"gen_model = BackboneNextTokenPredictionLightning.load_from_checkpoint(ckpt_path)\n",
"gen_model.eval()"
]
Expand All @@ -46,7 +49,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generating Jets"
"## Generating Showers"
]
},
{
Expand All @@ -55,13 +58,96 @@
"metadata": {},
"outputs": [],
"source": [
"# save_path = \"/beegfs/desy/user/birkjosc/testing/omnijet/generated_jets.parquet\"\n",
"generated_jets = gen_model.generate_n_jets_batched(\n",
" n_jets=2,\n",
"generated_showers = gen_model.generate_n_showers_batched(\n",
" n_showers=2,\n",
" batch_size=2,\n",
" # saveas=save_path, # use this option if you want to save the awkward array as a parquet file\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generated_showers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# --- Load the tokenizer model from checkpoint, and also get the feature_dict from the config ---\n",
"from gabbro.models.vqvae import VQVAELightning\n",
"\n",
"ckpt_path = \"/data/dust/user/rosehenn/gabbro_output/TokTrain/runs/2024-09-21_16-54-39_max-wng062_CerousLocknut/checkpoints/epoch_231_loss_0.17179.ckpt\"\n",
"\n",
"vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path)\n",
"vqvae_model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg = OmegaConf.load(Path(ckpt_path).parent.parent / \"config.yaml\")\n",
"pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)\n",
"print(\"\\npp_dict:\")\n",
"for item in pp_dict:\n",
" print(item, pp_dict[item])\n",
"\n",
"# get the cuts from the pp_dict (since this leads to particles being removed during\n",
"# preprocessing/tokenization), thus we also have to remove them from the original jets\n",
"# when we compare the tokenized+reconstructed particles to the original ones)\n",
"pp_dict_cuts = {\n",
" feat_name: {\n",
" criterion: pp_dict[feat_name].get(criterion)\n",
" for criterion in [\"larger_than\", \"smaller_than\"]\n",
" }\n",
" for feat_name in pp_dict\n",
"}\n",
"\n",
"print(\"\\npp_dict_cuts:\")\n",
"for item in pp_dict_cuts:\n",
" print(item, pp_dict_cuts[item])\n",
"\n",
"print(\"\\nModel:\")\n",
"print(vqvae_model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# reconstruct the generated tokens to physical features\n",
"\n",
"# note that if you want to reconstruct tokens from the generative model, you'll have\n",
"# to remove the start token from the tokenized array, and subtract 1 from the tokens\n",
"# (since we chose the convention to use 0 as the start token, so the tokens from the\n",
"# generative model are shifted by 1 compared to the ones from the VQ-VAE)\n",
"showers_reconstructed = vqvae_model.reconstruct_ak_tokens(\n",
" tokens_ak=generated_showers[:, 1:] - 1,\n",
" pp_dict=pp_dict,\n",
" batch_size=512,\n",
" pad_length=128,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"showers_reconstructed"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,23 @@
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import os\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"import awkward as ak\n",
"import numpy as np\n",
"import vector\n",
"from omegaconf import OmegaConf\n",
"\n",
"sys.path.append(\"/beegfs/desy/user/rosehenn/gabbro\")"
"sys.path.append(\"/data/dust/user/rosehenn/gabbro\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenization with the VQ-VAE"
]
},
{
Expand All @@ -24,8 +34,136 @@
"outputs": [],
"source": [
"# --- Load the tokenizer model from checkpoint, and also get the feature_dict from the config ---\n",
"from gabbro.models.vqvae import VQVAELightning\n",
"\n",
"ckpt_path = \"/data/dust/user/rosehenn/gabbro_output/TokTrain/runs/2024-09-21_16-54-39_max-wng062_CerousLocknut/checkpoints/epoch_231_loss_0.17179.ckpt\"\n",
"\n",
"vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path)\n",
"vqvae_model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg = OmegaConf.load(Path(ckpt_path).parent.parent / \"config.yaml\")\n",
"pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)\n",
"print(\"\\npp_dict:\")\n",
"for item in pp_dict:\n",
" print(item, pp_dict[item])\n",
"\n",
"# get the cuts from the pp_dict (since this leads to particles being removed during\n",
"# preprocessing/tokenization), thus we also have to remove them from the original jets\n",
"# when we compare the tokenized+reconstructed particles to the original ones)\n",
"pp_dict_cuts = {\n",
" feat_name: {\n",
" criterion: pp_dict[feat_name].get(criterion)\n",
" for criterion in [\"larger_than\", \"smaller_than\"]\n",
" }\n",
" for feat_name in pp_dict\n",
"}\n",
"\n",
"print(\"\\npp_dict_cuts:\")\n",
"for item in pp_dict_cuts:\n",
" print(item, pp_dict_cuts[item])\n",
"\n",
"print(\"\\nModel:\")\n",
"print(vqvae_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load shower file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from gabbro.data.loading import read_shower_file\n",
"\n",
"filename_in = \"/data/dust/user/rosehenn/gabbro/notebooks/array_real.parquet\"\n",
"showers = ak.from_parquet(filename_in)\n",
"showers = showers[:5000]\n",
"# part_features_ak = ak_select_and_preprocess(data_showers, pp_dict_cuts)[:, :128]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenize and reconstruct showers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# tokenization and reconstruction\n",
"\n",
"part_features_ak_tokenized = vqvae_model.tokenize_ak_array(\n",
" ak_arr=showers,\n",
" pp_dict=pp_dict,\n",
" batch_size=4,\n",
" pad_length=1700,\n",
")\n",
"# note that if you want to reconstruct tokens from the generative model, you'll have\n",
"# to remove the start token from the tokenized array, and subtract 1 from the tokens\n",
"# (since we chose the convention to use 0 as the start token, so the tokens from the\n",
"# generative model are shifted by 1 compared to the ones from the VQ-VAE)\n",
"part_features_ak_reco = vqvae_model.reconstruct_ak_tokens(\n",
" tokens_ak=part_features_ak_tokenized,\n",
" pp_dict=pp_dict,\n",
" batch_size=4,\n",
" pad_length=1700,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# inspect the tokenized and reconstructed Showers\n",
"print(\"First 5 tokenized Showers:\")\n",
"for i in range(5):\n",
" print(part_features_ak_tokenized[i])\n",
"\n",
"print(\"\\nFirst 5 reconstructed Showers:\")\n",
"for i in range(5):\n",
" print(part_features_ak_reco[i])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the reconstructed showers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from gabbro.plotting.feature_plotting import plot_paper_plots\n",
"\n",
"from gabbro.models.vqvae import VQVAELightning"
"fig = plot_paper_plots(\n",
" feature_sets=[showers[: len(part_features_ak_reco)], part_features_ak_reco],\n",
" labels=[\"Geant4\", \"Tokenized\"], # \"OmniJet-$\\\\alpha_C$\" \"BIB-AE\", \"L2L Flows\"\n",
" colors=[\"lightgrey\", \"#1a80bb\", \"#ea801c\", \"#4CAF50\", \"#1a80bb\"],\n",
")\n",
"fig.show()"
]
}
],
Expand Down

0 comments on commit 8bf0c7b

Please sign in to comment.