Skip to content

Adamdad/kat

Repository files navigation

Kolmogorov–Arnold Transformer:
A PyTorch Implementation

Tested PyTorch Versions License

ICLR 2025


Yes, I kan!

πŸŽ‰ This is a PyTorch/GPU implementation of the paper Kolmogorov–Arnold Transformer (KAT), which replace the MLP layers in transformer with KAN layers.

For more technical details, please refer to our ICLR'25 paper.

Kolmogorov–Arnold Transformer
πŸ“[Paper] </>[code] </>[Trition/CUDA kernel]
Xingyi Yang, Xinchao Wang
National University of Singapore
International Conference on Learning Representations (ICLR'25)

πŸ”‘ Key Insight:

Vanilla ViT + KAN struggle to scale effectively. We introduce the KAT model, which integrates GR-KANs into transformers for large-scale training scenarios like ImageNet, achieving significant performance improvements.


🎯 Our Solutions:

  1. Base Function: Replace B-spline to CUDA-implemented Rational.
  2. Group KAN: Share weights among groups of edges for efficiency.
  3. Initialization: Maintain activation magnitudes across layers.

βœ… Updates

  • Release the KAT paper, CUDA implementation and IN-1k training code.
  • πŸŽ‰πŸŽ‰πŸŽ‰πŸŽ‰ Triton Implementation, on 1D and 2D tasks. This is much easier to install than the CUDA version. Please See https://github.com/Adamdad/rational_kat_cu.
  • KAT Detection and segmentation code.
  • KAT on NLP tasks.

πŸ› οΈ Installation and Dataset

Please find our CUDA implementation in https://github.com/Adamdad/rational_kat_cu.git.

# install torch and other things
pip install timm==1.0.3
pip install wandb # I personally use wandb for results visualizations
git clone https://github.com/Adamdad/rational_kat_cu.git
cd rational_kat_cu
pip install -e .

πŸ“¦ Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this script

β”‚imagenet/
β”œβ”€β”€train/
β”‚  β”œβ”€β”€ n01440764
β”‚  β”‚   β”œβ”€β”€ n01440764_10026.JPEG
β”‚  β”‚   β”œβ”€β”€ n01440764_10027.JPEG
β”‚  β”‚   β”œβ”€β”€ ......
β”‚  β”œβ”€β”€ ......
β”œβ”€β”€val/
β”‚  β”œβ”€β”€ n01440764
β”‚  β”‚   β”œβ”€β”€ ILSVRC2012_val_00000293.JPEG
β”‚  β”‚   β”œβ”€β”€ ILSVRC2012_val_00002138.JPEG
β”‚  β”‚   β”œβ”€β”€ ......
β”‚  β”œβ”€β”€ ......

Usage

Refer to example.py for a detailed use case demonstrating how to use KAT with timm to classify an image.

πŸ“Š Model Checkpoints

Download pre-trained models or access training checkpoints:

🏷️ Model βš™οΈ Setup πŸ“¦ Param πŸ“ˆ Top1 πŸ”— Link
KAT-T From Scratch 5.7M 74.6 link/huggingface
KAT-T From ViT 5.7M 75.7 link/huggingface
KAT-S From Scratch 22.1M 81.2 link/huggingface
KAT-S From ViT 22.1M 82.0 link/huggingface
KAT-B From Scratch 86.6M 82.3 link/huggingface
KAT-B From ViT 86.6M 82.8 link/huggingface

πŸŽ“Model Training

All training scripts are under scripts/

bash scripts/train_kat_tiny_8x128.sh

If you want to change the hyper-parameters, can edit

#!/bin/bash
DATA_PATH=/local_home/dataset/imagenet/

bash ./dist_train.sh 8 $DATA_PATH \
    --model kat_tiny_swish_patch16_224 \ # Rationals are initialized to be swish functions 
    -b 128 \
    --opt adamw \
    --lr 1e-3 \
    --weight-decay 0.05 \
    --epochs 300 \
    --mixup 0.8 \
    --cutmix 1.0 \
    --sched cosine \
    --smoothing 0.1 \
    --drop-path 0.1 \
    --aa rand-m9-mstd0.5 \
    --remode pixel --reprob 0.25 \
    --amp \
    --crop-pct 0.875 \
    --mean 0.485 0.456 0.406 \
    --std 0.229 0.224 0.225 \
    --model-ema \
    --model-ema-decay 0.9999 \
    --output output/kat_tiny_swish_patch16_224 \
    --log-wandb

πŸ§ͺ Evaluation

To evaluate our kat_tiny_patch16_224 models, run:

DATA_PATH=/local_home/dataset/imagenet/
CHECKPOINT_PATH=kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth
python validate.py $DATA_PATH --model kat_tiny_patch16_224 \
    --checkpoint $CHECKPOINT_PATH -b 512

###################
Validating in float32. AMP not enabled.
Loaded state_dict from checkpoint 'kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth'
Model kat_tiny_patch16_224 created, param count: 5718328
Data processing configuration for current model + dataset:
        input_size: (3, 224, 224)
        interpolation: bicubic
        mean: (0.485, 0.456, 0.406)
        std: (0.229, 0.224, 0.225)
        crop_pct: 0.875
        crop_mode: center
Test: [   0/98]  Time: 3.453s (3.453s,  148.28/s)  Loss:  0.6989 (0.6989)  Acc@1:  84.375 ( 84.375)  Acc@5:  96.875 ( 96.875)
.......
Test: [  90/98]  Time: 0.212s (0.592s,  864.23/s)  Loss:  1.1640 (1.1143)  Acc@1:  71.875 ( 74.270)  Acc@5:  93.750 ( 92.220)
 * Acc@1 74.558 (25.442) Acc@5 92.390 (7.610)
--result
{
    "model": "kat_tiny_patch16_224",
    "top1": 74.558,
    "top1_err": 25.442,
    "top5": 92.39,
    "top5_err": 7.61,
    "param_count": 5.72,
    "img_size": 224,
    "crop_pct": 0.875,
    "interpolation": "bicubic"
}

πŸ™ Acknowledgments

We extend our gratitude to the authors of rational_activations for their contributions to CUDA rational function implementations that inspired parts of this work. We thank @yuweihao, @florinshen, @Huage001 and @yu-rp for valuable discussions.

πŸ“š Bibtex

If you use this repository, please cite:

@inproceedings{
  yang2025kolmogorovarnold,
  title={Kolmogorov-Arnold Transformer},
  author={Xingyi Yang, Xinchao Wang},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=BCeock53nt}
}