Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
da4cdba
feature(pu): add unizero/muzero multitask pipeline and net plasticity…
Apr 25, 2025
a6eed25
fix(pu): fix some adaptation bug
Apr 25, 2025
67a0e9a
feature(pu): add unizero multitask balance pipeline for atari and dmc
Apr 29, 2025
f083096
fix(pu): fix some adaptation bug
Apr 29, 2025
37eb118
feature(pu): add vit encoder for unizero
Apr 29, 2025
f32d63e
polish(pu): polish moe layer in transformer
May 1, 2025
c0aa747
feature(pu): add eval norm mean/medium for atari
May 5, 2025
8b3cff6
fix(pu): fix atari norm mean/median, fix collect in balance pipeline
May 7, 2025
f2c158b
polish(pu): polish config
May 7, 2025
20b42f7
fix(pu): fix dmc multitask to be compatiable with timestep (which is …
May 7, 2025
39ee55e
polish(pu): polish config
May 13, 2025
e85c449
fix(pu): fix task_id bug in balance pipeline, and polish benchmark_na…
May 14, 2025
c16d564
fix(pu): fix benchmark_name option
May 14, 2025
474b81c
polish(pu): fix norm score computation, adapt config to aliyun
May 21, 2025
50e367e
polish(pu): polish unizero_mt balance pipeline use CurriculumControll…
May 23, 2025
9171c3e
tmp
May 30, 2025
bc5003a
Merge branch 'dev-multitask-balance-clean' of https://github.com/open…
May 30, 2025
158e4a0
tmp
Jun 1, 2025
d66b986
tmp
Jun 4, 2025
0d5ede0
test(pu): add vit moe test
Jun 5, 2025
ca6ddb6
polish(pu): add adapter_scales to tb
Jun 11, 2025
7dd6c04
feature(pu): add atari uz balance config
Jun 12, 2025
c8e7cb8
polish(pu): add stable_adaptor_scale
Jun 19, 2025
0313335
tmp
Jun 23, 2025
ef170fd
sync code
Jun 25, 2025
bbec353
polish(pu): use freeze_non_lora_parameters in transformer, not use Le…
zjowowen Jul 30, 2025
20648d5
feature(pu): add vit-encoder lora in balance pipeline
zjowowen Jul 30, 2025
db6032a
polish(pu): fix reanalyze index bug, fix global_solved bug, add apply…
Aug 5, 2025
f63b544
polish(pu): add collect/eval_num_simulations option
Aug 5, 2025
bbbe505
polish(pu): polish comments and style in entry of scalezero
puyuan1996 Sep 28, 2025
bf9f965
polish(pu): polish comments and style of ctree/tree_search/buffer/com…
puyuan1996 Sep 28, 2025
fb04c7a
polish(pu): polish comments and style of files in lzero.model
puyuan1996 Sep 28, 2025
06148e7
polish(pu): polish comments and style of files in lzero.model.unizero…
puyuan1996 Sep 28, 2025
471ae6a
polish(pu): polish comments and style of unizero_world_models
puyuan1996 Sep 28, 2025
07933a5
polish(pu): polish comments and style of files in policy/
puyuan1996 Sep 28, 2025
df3b644
polish(pu): polish comments and style of files in worker
puyuan1996 Sep 28, 2025
4f89dcc
polish(pu): polish comments and style of files in configs
puyuan1996 Sep 28, 2025
e7a8796
Merge remote-tracking branch 'origin/main' into dev-multitask-balance…
puyuan1996 Sep 28, 2025
ab746d1
fix(pu): fix some merge typo
tAnGjIa520 Sep 28, 2025
0476aca
fix(pu): fix ln norm_type, fix kv_cache rewrite bug, add value_priori…
tAnGjIa520 Sep 28, 2025
2c0a965
fix(pu): fix unizero_mt
tAnGjIa520 Sep 28, 2025
84e6094
polish(pu): add LN in head, polish init_weight, polish adamw
tAnGjIa520 Sep 29, 2025
05da638
fix(pu): fix configure_optimizer_unizero in unizero_mt
tAnGjIa520 Oct 2, 2025
06ad080
feature(pu): add encoder-clip, label smooth, analyze_latent_represent…
tAnGjIa520 Oct 9, 2025
9f69f5a
feature(pu): add encoder-clip, label smooth option in unizero_multit…
tAnGjIa520 Oct 9, 2025
af99278
fix(pu): fix tb log when gpu_num<task_num, fix total_loss += bug, polish
tAnGjIa520 Oct 9, 2025
bf91ca2
polish(pu):polish config
tAnGjIa520 Oct 9, 2025
b18f892
fix(pu): fix encoder-clip bug and num_channel/res bug
tAnGjIa520 Oct 11, 2025
bf3cd12
polish(pu): polish scale_factor in DPS
tAnGjIa520 Oct 12, 2025
b1efa60
tmp
tAnGjIa520 Oct 18, 2025
c2f9817
feature(pu): add some analysis metrics in tensorboard for unizero and…
tAnGjIa520 Oct 23, 2025
b081379
polish(pu): abstract a KVCacheManager for world model
tAnGjIa520 Oct 23, 2025
2eff68d
tmp
Oct 23, 2025
27075c1
polish(pu): polish unizero obs_loss to cos_sim loss
Oct 23, 2025
b4c3ba8
tmp
Oct 24, 2025
3788eb7
polish(pu): polish minotor-log and adapt to ale/xxx-v5 style game
Oct 25, 2025
6d7761a
feature(pu): add decode_loss for unizero atari
Oct 25, 2025
a7ed590
test(pu): test unizero-mt
Oct 25, 2025
be07791
fix(pu): fix Deep Copy Before Storag bug when Use KVCacheManager
Oct 28, 2025
74ff3d6
sync code
Oct 31, 2025
aefa082
feature(pu): add iter_policy_evaluation demo in grid-world
puyuan1996 Nov 4, 2025
a8be15e
Merge branch 'dev-multitask-balance-clean-kvcachemanager' of https://…
puyuan1996 Nov 4, 2025
08f3a29
polish(pu): polish atari uz config
Nov 5, 2025
16ca8d4
polish(pu): polish policy logits stability
Nov 13, 2025
5cff9eb
sync code
Nov 17, 2025
3c820ef
polish(pu): polish policy logits stability
Nov 17, 2025
b9b8d26
Merge branch 'dev-multitask-balance-clean-kvcachemanager' of https://…
Nov 17, 2025
bd67cdf
fix(pu): fix exp_name and task_id bug in dmc pipeline, fix some configs
Nov 20, 2025
39a9c8c
feature(pu): add head-clip manager
Dec 2, 2025
32d7f36
fix(pu): fix head-clip log
Dec 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1450,4 +1450,4 @@ events.*
!/assets/pooltool/**
lzero/mcts/ctree/ctree_alphazero/pybind11

zoo/jericho/envs/z-machine-games-master
zoo/jericho/envs/z-machine-games-master
161 changes: 161 additions & 0 deletions iter_policy_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import numpy as np
import matplotlib.pyplot as plt

def setup_matplotlib_for_chinese():
"""
为 matplotlib 设置中文字体,以解决中文显示为方块的问题。
此函数会尝试为 macOS, Windows, 和 Linux 设置合适的字体。
"""
try:
# 优先选择 macOS 的苹方字体
plt.rcParams['font.sans-serif'] = ['PingFang SC', 'Arial Unicode MS', 'SimHei']
except Exception:
# 如果失败,可能是其他系统,继续尝试
try:
# Windows 的黑体或微软雅黑
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
except Exception:
# Linux 的文泉驿正黑
try:
plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei']
except Exception:
print("警告:未能找到可用的中文字体。图表中的中文可能无法正常显示。")

# 解决负号'-'显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False
print("Matplotlib 中文字体设置完成。")


def calculate_policy_evaluation():
"""
使用迭代策略评估解决小格子世界问题。

返回:
V (np.ndarray): 最终收敛的价值函数网格。
iterations_history (list): 每次迭代的序号列表。
deltas_history (list): 每次迭代的 delta 值列表。
"""
# --- 1. 环境设置 ---
grid_size = 4
V = np.zeros((grid_size, grid_size))
gamma = 1.0
reward = -1.0
theta = 1e-6 # 收敛阈值

# --- 2. 用于可视化的数据收集 ---
iterations_history = []
deltas_history = []

iteration = 0
while True:
iteration += 1
delta = 0
V_old = V.copy()

for i in range(grid_size):
for j in range(grid_size):
if (i == 0 and j == 0) or (i == grid_size - 1 and j == grid_size - 1):
continue

# --- 3. 应用贝尔曼期望方程 ---
current_v = 0

# 动作: 上
next_i_up = i - 1 if i > 0 else i
current_v += 0.25 * (reward + gamma * V_old[next_i_up, j])

# 动作: 下
next_i_down = i + 1 if i < grid_size - 1 else i
current_v += 0.25 * (reward + gamma * V_old[next_i_down, j])

# 动作: 左
next_j_left = j - 1 if j > 0 else j
current_v += 0.25 * (reward + gamma * V_old[i, next_j_left])

# 动作: 右
next_j_right = j + 1 if j < grid_size - 1 else j
current_v += 0.25 * (reward + gamma * V_old[i, next_j_right])

V[i, j] = current_v
delta = max(delta, abs(V[i, j] - V_old[i, j]))

# --- 4. 记录历史并检查收敛 ---
iterations_history.append(iteration)
deltas_history.append(delta)

if iteration % 10 == 0 or delta < theta:
print(f"迭代 {iteration}: 最大变化量 (delta) = {delta:.8f}")

if delta < theta:
break

return V, iterations_history, deltas_history

def print_value_grid(V):
"""以可读的网格格式打印最终的价值函数。"""
print("\n--- 迭代策略评估已收敛! ---")
print("最终价值函数 V(s):")

grid_size = V.shape[0]
print("+------------------------------------------------+")
for i in range(grid_size):
row_str = "| "
for j in range(grid_size):
row_str += f"{V[i, j]:>8.3f} | "
print(row_str)
print("+------------------------------------------------+")

def save_delta_plot_as_image(iterations, deltas, filename="delta_convergence_curve.png"):
"""
使用 matplotlib 生成 delta 收敛曲线图,并将其保存为图像文件。

参数:
iterations (list): 迭代次数的列表。
deltas (list): delta 值的列表。
filename (str): 保存图像的文件名。
"""
print(f"\n正在生成图表并保存为 '{filename}'...")

# 1. 创建一个图形和一个坐标轴
plt.figure(figsize=(10, 6), dpi=100)

# 2. 绘制数据
plt.plot(iterations, deltas, label='Delta (价值函数最大变化量)', color='dodgerblue', linewidth=2)

# 3. 设置 Y 轴为对数刻度
plt.yscale('log')

# 4. 添加标题和轴标签
plt.title('迭代策略评估的收敛过程', fontsize=16)
plt.xlabel('迭代次数 (Iteration)', fontsize=12)
plt.ylabel('Delta (对数刻度)', fontsize=12)

# 5. 添加图例
plt.legend()

# 6. 添加网格以便于阅读
plt.grid(True, which="both", linestyle='--', linewidth=0.5)

# 7. 保存图像
try:
plt.savefig(filename, bbox_inches='tight')
print(f"图表已成功保存为 '{filename}'")
except Exception as e:
print(f"保存图表时出错: {e}")

# 8. 关闭图形以释放内存
plt.close()


if __name__ == '__main__':
# 在所有绘图操作之前,首先设置好字体环境
setup_matplotlib_for_chinese()

# 1. 执行计算
final_V, iterations_hist, deltas_hist = calculate_policy_evaluation()

# 2. 在控制台打印最终的价值网格
print_value_grid(final_V)

# 3. 生成并保存 delta 曲线图
save_delta_plot_as_image(iterations_hist, deltas_hist, "delta_convergence_curve.png")
6 changes: 5 additions & 1 deletion lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .utils import *
from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp
from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp
from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval
from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp
from .utils import *
Loading
Loading