Skip to content

Commit 64dc922

Browse files
authored
Enable static type checking with Pyrefly (#2136)
Enables static type checking of torchtitan with [pyrefly](https://github.com/facebook/pyrefly). Type checking the code helps catch bugs earlier in the development cycle. * Adds pyrefly to CI, as part of the linting workflow. * Addresses ~100 type errors that can be fixed via local code changes and updates to type annotations, and silences the rest with `# pyrefly: ignore` suppression comments. Note that 325efd9 contains all of the non-comment changes.
1 parent 7a398ea commit 64dc922

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+516
-89
lines changed

.ci/docker/requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ expecttest==0.1.6
22
pytest==7.3.2
33
pytest-cov
44
pre-commit
5+
pyrefly==0.45.1
56
tomli-w >= 1.1.0
67
transformers

.ci/docker/requirements-flux.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
11
transformers>=4.51.1
2-
einops
32
sentencepiece
4-
pillow

.ci/docker/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ tyro
99
tokenizers >= 0.15.0
1010
safetensors
1111
psutil
12+
einops
13+
pillow

.github/workflows/lint.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ jobs:
2828
run: python -m pip install --upgrade pip
2929
- name: Install lint utilities
3030
run: |
31-
python -m pip install pre-commit
31+
python -m pip install -r requirements.txt -r requirements-dev.txt
32+
python -m pip install --force-reinstall --pre --index-url https://download.pytorch.org/whl/nightly/cu126 torch
3233
pre-commit install-hooks
3334
- name: Get changed files
3435
id: changed-files

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,11 @@ repos:
6161
types: [text]
6262
additional_dependencies:
6363
- tomli
64+
65+
- repo: https://github.com/facebook/pyrefly-pre-commit
66+
rev: 0.45.1
67+
hooks:
68+
- id: pyrefly-check
69+
name: Pyrefly (type checking)
70+
pass_filenames: false
71+
language: system

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ possible. Contributions should follow the [Contributing Guidelines](#contributin
44

55
### Setup
66
```
7-
pip install -r requirements-dev.txt
7+
pip install -r requirements.txt -r requirements-dev.txt
88
```
99

1010
### Pull Requests

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ dependencies = [
2525
"tyro",
2626
"tensorboard",
2727
"psutil",
28+
"einops",
29+
"pillow",
2830
]
2931
dynamic = ["version"]
3032

@@ -62,3 +64,7 @@ include = ["torchtitan*"]
6264
[tool.pytest.ini_options]
6365
addopts = ["--showlocals"] # show local variables in tracebacks
6466
testpaths = ["tests"]
67+
68+
[tool.pyrefly]
69+
project-excludes = ["torchtitan/experiments", "**/tests/**"]
70+
ignore-missing-imports = ["torchao.*", "torchft"] # optional dependencies

scripts/checkpoint_conversion/convert_from_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@
1616

1717
@torch.inference_mode()
1818
def convert_from_hf(input_dir, output_dir, model_name, model_flavor):
19-
if model_name == "flux":
20-
import torchtitan.experiments.flux # noqa: F401
2119
# initialize model to allocate memory for state dict
2220
train_spec = train_spec_module.get_train_spec(model_name)
2321
model_args = train_spec.model_args[model_flavor]
2422

2523
with torch.device("cpu"):
2624
model = train_spec.model_cls(model_args)
25+
# pyrefly: ignore [bad-argument-type]
2726
model = ModelWrapper(model)
2827

28+
# pyrefly: ignore [not-callable]
2929
sd_adapter = train_spec.state_dict_adapter(model_args, None)
3030
assert (
3131
sd_adapter is not None

scripts/checkpoint_conversion/convert_to_hf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ def convert_to_hf(
3030

3131
with torch.device("cpu"):
3232
model = train_spec.model_cls(model_args)
33+
# pyrefly: ignore [bad-argument-type]
3334
model = ModelWrapper(model)
3435

36+
# pyrefly: ignore [not-callable]
3537
sd_adapter = train_spec.state_dict_adapter(model_args, hf_assets_path)
3638
assert (
3739
sd_adapter is not None

scripts/checkpoint_conversion/numerical_tests_example.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def loss_fn(logits1, logits2):
2525
probs2 = F.softmax(logits2, dim=-1)
2626

2727
# Calculate KL Divergence
28-
kl_loss = F.kl_div(probs1, probs2, "mean")
28+
kl_loss = F.kl_div(probs1, probs2, reduction="mean")
2929
return kl_loss
3030

3131

@@ -75,10 +75,13 @@ def forward_tt(config_path, checkpoint_path, test_set):
7575

7676
# materalize model
7777
device = torch.device(device_type)
78+
# pyrefly: ignore [missing-attribute]
7879
model.to_empty(device=device)
7980
model.init_weights(buffer_device=device)
81+
# pyrefly: ignore [missing-attribute]
8082
model.eval()
8183

84+
# pyrefly: ignore [bad-argument-type]
8285
modelWrapper = ModelWrapper(model)
8386
state_dict = modelWrapper._get_state_dict()
8487

@@ -94,6 +97,7 @@ def forward_tt(config_path, checkpoint_path, test_set):
9497
input_ids = input_ids.unsqueeze(0)
9598

9699
# obtains the logits of only the last token in the predictions
100+
# pyrefly: ignore [not-callable]
97101
predictions = model(input_ids)[:, -1, :].unsqueeze(1)
98102
output_list.append(predictions)
99103

@@ -120,6 +124,7 @@ def forward_tt(config_path, checkpoint_path, test_set):
120124
config_manager = ConfigManager()
121125
config = config_manager.parse_args([f"--job.config_file={config_path}"])
122126
train_spec = get_train_spec(config.model.name)
127+
# pyrefly: ignore [not-callable]
123128
tokenizer = train_spec.build_tokenizer_fn(config)
124129

125130
# Build test set of randomly generated token ids
@@ -150,10 +155,11 @@ def forward_tt(config_path, checkpoint_path, test_set):
150155
avg_losses = {}
151156

152157
for test_name, (baseline_outputs, conversion_outputs) in test_configs.items():
153-
total_loss = 0
158+
total_loss: int | torch.Tensor = 0
154159
for baseline, outputs in zip(baseline_outputs, conversion_outputs):
155160
total_loss += loss_fn(baseline, outputs)
156161
avg_loss = total_loss / len(test_set)
162+
# pyrefly: ignore [missing-attribute]
157163
avg_losses[test_name] = avg_loss.item()
158164

159165
for test_name, avg_loss in avg_losses.items():

0 commit comments

Comments
 (0)