This repository contains training scripts for scene-level image editing models using GRPO (Group Relative Policy Optimization).
This codebase is build upon:
- Flow-GRPO, which is licensed under MIT license;
- Orient-Anything that is licensed under CC-BY-4.0 license,
- lang-segment-anything that is licensed underApache-2.0 license;
- Grounding-DINO that is licensed under Apache-2.0 license
- grpo.py: Main configuration file containing all GRPO training setups for different models and tasks
- base.py: Base configuration with default hyperparameters
- dpo.py: Direct Preference Optimization configurations
- sft.py: Supervised fine-tuning configurations
The reward system supports multiple scoring methods:
- Text-Image Alignment: CLIP score, PickScore, ImageReward
- Image Quality: Aesthetic score, JPEG compressibility
- Task-Specific: OCR accuracy, GenEval, object manipulation
- Edit Quality: Position accuracy, rotation accuracy, resize accuracy, LPIPS similarity
- Multi-Modal: Qwen-VL based verification
- Single-node training: Use
train_*.pydirectly with accelerate - Multi-node training: Use scripts in
scripts/multi_node/for distributed training - Demo scripts: Quick inference examples in
scripts/demo/ - Test scripts: Evaluation and testing utilities in
scripts/test/
- Python 3.8+
- PyTorch with CUDA support
- 16 GPUs (2 nodes × 8 GPUs per node)
- Required Python packages (install via
pip install -e .)
Before running training, update the paths in your configuration:
- Replace
enter_path_hereplaceholders in the codebase with your actual paths - Update
MASTER_ADDRinscripts/multi_node/qwenimagedit/main2.shto match your master node IP - Ensure all nodes can communicate via the specified master address and port
To run training on 16 GPUs across 2 nodes (8 GPUs per node):
sh scripts/multi_node/qwenimagedit/main2.sh 0sh scripts/multi_node/qwenimagedit/main2.sh 1The training script uses the following default settings:
- GPUs per node: 8
- Number of nodes: 2
- Total GPUs: 16
- Master port: 19001
- Config:
config/grpo.py:counting_qwenimage_edit_mini
To modify these settings, edit scripts/multi_node/qwenimagedit/main2.sh.
Check config/grpo.py for available training configurations:
counting_qwenimage_edit_mini- Mini configuration for testingcounting_qwenimage_edit_8gpu_2- 8 GPU setup- Various task-specific configs for rotation, resize, translation, manipulation
Each configuration specifies:
- Model architecture and checkpoint paths
- Batch sizes and gradient accumulation steps
- Sampling parameters (num_steps, guidance_scale)
- Reward function weights
- Training hyperparameters (learning rate, beta, etc.)
Training logs and checkpoints will be saved according to the save_dir specified in your configuration (typically in logs/ directory).
- Create a new scorer in
flow_grpo/(e.g.,my_scorer.py) - Implement the scorer class with a
run()or__call__()method - Add the reward function to
flow_grpo/rewards.pyin thescore_functionsdict - Configure the reward weight in your config file
- Open
config/grpo.py - Define a new function (e.g.,
def my_custom_config()) - Start from an existing config:
config = pickscore_sd3()orconfig = compressibility() - Modify parameters as needed
- Use it with:
--config config/grpo.py:my_custom_config
For different node counts, modify the shell scripts:
GPUS_PER_NODE: GPUs available per node (typically 8)NUM_MACHINES: Total number of nodesMASTER_ADDR: IP address of the master node (rank 0)MASTER_PORT: Communication port (default 19001)
- Connection issues: Verify that
MASTER_ADDRis correct and nodes can communicate - CUDA out of memory: Reduce batch size in the config file
- Path errors: Ensure all
enter_path_hereplaceholders are replaced with valid paths - Reward server errors: Check that reward server IPs (
your-api-server-ip,your-reward-server-ip) are correctly configured - Import errors: Run
pip install -e .to install the package in development mode - NCCL timeout: Increase timeout or check network connectivity between nodes
If you use this code in your research, please cite the relevant papers for the models and methods use
@misc{tan2026talk2movereinforcementlearningtextinstructed,
title={Talk2Move: Reinforcement Learning for Text-Instructed Object-Level Geometric Transformation in Scenes},
author={Jing Tan and Zhaoyang Zhang and Yantao Shen and Jiarui Cai and Shuo Yang and Jiajun Wu and Wei Xia and Zhuowen Tu and Stefano Soatto},
year={2026},
eprint={2601.02356},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2601.02356},
}
This codebase is built by Jing Tan during her internship at AWS Agentic AI.
For any question, feel free to contact her via
tj023@ie.cuhk.edu.hk