-
Notifications
You must be signed in to change notification settings - Fork 892
/
slim_model.py
82 lines (60 loc) · 2.64 KB
/
slim_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import json
import time
import jax
import numpy as np
import optax
from mesh_transformer import util
from mesh_transformer.checkpoint import read_ckpt, write_ckpt
from mesh_transformer.transformer_shard import CausalTransformer
from smart_open import open
from mesh_transformer.util import clip_by_global_norm, to_bf16, to_f16
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None, help="Config file location")
parser.add_argument("--ckpt-step", type=int, default=-1, help="Step number of the checkpoint to convert (if not specified, converts the most recent checkpoint)")
parser.add_argument("--f16", default=False, action="store_true", help="Convert to float16 (instead of bfloat16)")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
params = json.load(open(args.config))
convert_fn = to_f16 if args.f16 else to_bf16
cores_per_replica = params["cores_per_replica"]
assert cores_per_replica <= 8
bucket = params["bucket"]
model_dir = params["model_dir"]
params["optimizer"] = optax.chain(
optax.scale(1),
clip_by_global_norm(1),
optax.scale_by_adam(),
optax.additive_weight_decay(0),
optax.scale(-1),
optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
)
start = time.time()
print(f"jax devices: {jax.device_count()}")
print(f"jax runtime initialized in {time.time() - start:.06}s")
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)
with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
meta = json.load(f)
if args.ckpt_step > -1:
ckpt_step = args.ckpt_step
else:
ckpt_step = meta["checkpoints"][-1]
print(f"using checkpoint {ckpt_step}")
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
network = CausalTransformer(params)
start = time.time()
network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
print(f"network loaded in {time.time() - start:.06}s")
start = time.time()
del network.state["opt_state"]
network.state["params"] = convert_fn(network.state["params"])
print(f"network converted in {time.time() - start:.06}s")
suffix = "_slim_f16" if args.f16 else "_slim"
for i in range(cores_per_replica):
write_ckpt(network.state, f"gs://{bucket}/{model_dir}{suffix}/step_{ckpt_step}/", i)
print(f"written shard {i}")