Skip to content

Commit 86f27fe

Browse files
[BUG] Fix bugs in dist env (#1251)
* add necessary synchronization for data_iter destroying and creating bug when in dist * refine examples code * print CPU mem info by default, add psutil to requirements.txt and refine install_mesh.sh * fix train_LBFGS_epoch_func too * Update install_mesh.sh Co-authored-by: Copilot <[email protected]> * refine code as copilot * move uss&pss_msg to benchmark_flag branch * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * refine * update API doc --------- Co-authored-by: Copilot <[email protected]>
1 parent e583698 commit 86f27fe

File tree

8 files changed

+73
-28
lines changed

8 files changed

+73
-28
lines changed

docs/zh/api/utils/misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- Prettydefaultdict
1010
- RankZeroOnly
1111
- RankZeroFirst
12+
- Synchronized
1213
- Timer
1314
- all_gather
1415
- concat_dict_list

examples/amgnet/amgnet_airfoil.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def train(cfg: DictConfig):
8686
loss=ppsci.loss.FunctionalLoss(train_mse_func),
8787
name="Sup",
8888
)
89+
cfg.TRAIN.iters_per_epoch = len(sup_constraint.data_loader)
8990
# wrap constraints together
9091
constraint = {sup_constraint.name: sup_constraint}
9192

@@ -121,16 +122,9 @@ def train(cfg: DictConfig):
121122
solver = ppsci.solver.Solver(
122123
model,
123124
constraint,
124-
cfg.output_dir,
125-
optimizer,
126-
None,
127-
cfg.TRAIN.epochs,
128-
cfg.TRAIN.iters_per_epoch,
129-
save_freq=cfg.TRAIN.save_freq,
130-
eval_during_train=cfg.TRAIN.eval_during_train,
131-
eval_freq=cfg.TRAIN.eval_freq,
125+
optimizer=optimizer,
132126
validator=validator,
133-
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
127+
cfg=cfg,
134128
)
135129
# train model
136130
solver.train()

examples/biharmonic2d/biharmonic2d.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,9 @@ def inference(cfg: DictConfig):
413413
x_flatten = x_grad.reshape(-1, 1)
414414
y_flatten = y_grad.reshape(-1, 1)
415415

416-
with ppsci.misc.Timer("infer"):
417-
output_dict = predictor.predict(
418-
{"x": x_flatten, "y": y_flatten}, cfg.INFER.batch_size
419-
)
416+
output_dict = predictor.predict(
417+
{"x": x_flatten, "y": y_flatten}, cfg.INFER.batch_size
418+
)
420419

421420
# mapping data to cfg.INFER.output_keys
422421
output_dict = {

install_mesh.sh

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,22 @@ echo "==== Python Environment Verified ===="
88

99

1010
echo "==== Step 2: Checking CMake Environment ===="
11-
# Verify if cmake exists in current environment. If not, download a temporary version.
12-
if command -v cmake >/dev/null 2>&1; then
13-
echo "CMake found: $(cmake --version | head -n 1)"
14-
else
15-
echo "CMake not found. Installing temporary CMake 3.23.0 ..."
16-
wget -c https://paddle-org.bj.bcebos.com/paddlescience/cmake-3.23.0-linux-x86_64.tar.gz
17-
tar -zxf cmake-3.23.0-linux-x86_64.tar.gz --checkpoint=.100 --totals
18-
rm -f cmake-3.23.0-linux-x86_64.tar.gz
19-
export PATH=$PWD/cmake-3.23.0-linux-x86_64/bin:$PATH
20-
echo "Temporary CMake installed: $(cmake --version | head -n 1)"
21-
fi
11+
# Always install a temporary version of CMake 3.23.0, regardless of existing installations.
12+
echo "Installing temporary CMake 3.23.0 ..."
13+
wget -c https://paddle-org.bj.bcebos.com/paddlescience/cmake-3.23.0-linux-x86_64.tar.gz
14+
tar -zxf cmake-3.23.0-linux-x86_64.tar.gz --checkpoint=.100 --totals
15+
rm -f cmake-3.23.0-linux-x86_64.tar.gz
16+
export PATH=$PWD/cmake-3.23.0-linux-x86_64/bin:$PATH
17+
echo "Temporary CMake installed: $(cmake --version | head -n 1)"
2218
echo "CMake environment ready."
2319

24-
2520
echo "==== Step 3: Downloading PyMesh Package ===="
2621
# Download PyMesh package if not already present.
2722
wget -c https://paddle-org.bj.bcebos.com/paddlescience/PyMesh.tar.gz
2823
echo "Download completed. Extracting package..."
2924
tar -zxf PyMesh.tar.gz --checkpoint=.1000 --totals
3025
echo "PyMesh package extracted."
3126

32-
3327
echo "==== Step 4: Entering PyMesh Directory ===="
3428
cd PyMesh
3529
export PYMESH_PATH=$(pwd)

ppsci/solver/printer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Dict
1818
from typing import Optional
1919

20+
import psutil
2021
from paddle import device
2122

2223
from ppsci.utils import logger
@@ -80,13 +81,21 @@ def log_train_info(
8081
+ ", ".join(filter(None, [metric_msg, time_msg, ips_msg, eta_msg]))
8182
)
8283
if solver.benchmark_flag:
84+
# GPU memory
8385
max_mem_reserved_msg = (
8486
f"max_mem_reserved: {device.max_memory_reserved() // (1 << 20)} MB"
8587
)
8688
max_mem_allocated_msg = (
8789
f"max_mem_allocated: {device.max_memory_allocated() // (1 << 20)} MB"
8890
)
8991
log_str += f", {max_mem_reserved_msg}, {max_mem_allocated_msg}"
92+
93+
# CPU memory
94+
_process = psutil.Process()
95+
_mem_full = _process.memory_full_info()
96+
uss_msg = f"USS: {_mem_full.uss // (1 << 20)} MB" # Unique Set Size (USS) memory in MB
97+
pss_msg = f"PSS: {_mem_full.pss // (1 << 20)} MB" # Proportional Set Size (PSS) memory in MB
98+
log_str += f", {uss_msg}, {pss_msg}"
9099
logger.info(log_str)
91100

92101
# reset time information after printing

ppsci/solver/train.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import gc
1718
import sys
1819
import time
1920
from typing import TYPE_CHECKING
@@ -88,7 +89,11 @@ def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
8889
try:
8990
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
9091
except StopIteration:
91-
_constraint.data_iter = iter(_constraint.data_loader)
92+
with misc.Synchronized(solver.world_size > 1):
93+
if hasattr(_constraint, "data_iter"):
94+
del _constraint.data_iter
95+
gc.collect()
96+
_constraint.data_iter = iter(_constraint.data_loader)
9297
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
9398

9499
if solver.nvtx_flag: # only for nsight analysis
@@ -245,7 +250,11 @@ def train_LBFGS_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int
245250
try:
246251
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
247252
except StopIteration:
248-
_constraint.data_iter = iter(_constraint.data_loader)
253+
with misc.Synchronized(solver.world_size > 1):
254+
if hasattr(_constraint, "data_iter"):
255+
del _constraint.data_iter
256+
gc.collect()
257+
_constraint.data_iter = iter(_constraint.data_loader)
249258
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
250259
reader_cost += time.perf_counter() - reader_tic
251260

ppsci/utils/misc.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"Prettydefaultdict",
4444
"RankZeroOnly",
4545
"RankZeroFirst",
46+
"Synchronized",
4647
"Timer",
4748
"all_gather",
4849
"concat_dict_list",
@@ -226,6 +227,43 @@ def __exit__(self, type, value, traceback):
226227
dist.barrier() # Allow others to proceed
227228

228229

230+
class Synchronized(ContextDecorator):
231+
"""
232+
A context manager/decorator that ensures code blocks are synchronized across all processes.
233+
234+
It calls barrier before and after the code block execution to ensure synchronization
235+
among all processes.
236+
237+
Args:
238+
enabled (bool): Whether to enable synchronization. Defaults to True.
239+
240+
Examples:
241+
>>> import paddle
242+
>>> from paddle import distributed as dist
243+
>>> dist.init_parallel_env()
244+
>>> with Synchronized(dist.get_world_size() > 1):
245+
... x = paddle.randn(2, 2) * 2
246+
"""
247+
248+
def __init__(self, enabled: bool = True):
249+
if enabled and not dist.is_initialized():
250+
logger.warning(
251+
"Distributed environment is not initialized, but `Synchronized` is enabled. "
252+
"This may cause unexpected behavior. Please ensure distributed environment is properly initialized "
253+
"when using `Synchronized` with `enabled=True`.",
254+
)
255+
self.enabled = enabled
256+
257+
def __enter__(self):
258+
if self.enabled and dist.is_initialized():
259+
dist.barrier()
260+
return self
261+
262+
def __exit__(self, type, value, traceback):
263+
if self.enabled and dist.is_initialized():
264+
dist.barrier()
265+
266+
229267
class Timer(ContextDecorator):
230268
"""Count time cost for code block within context.
231269

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ imageio
66
matplotlib
77
meshio==5.3.4
88
numpy>=1.20.0,<2.0.0
9+
psutil
910
pydantic>=2.5.0
1011
pyevtk
1112
pyyaml

0 commit comments

Comments
 (0)