diff --git a/zoo/board_games/chinesechess/config/cchess_muzero_sp_mode_config.py b/zoo/board_games/chinesechess/config/cchess_muzero_sp_mode_config.py new file mode 100644 index 000000000..09fb81763 --- /dev/null +++ b/zoo/board_games/chinesechess/config/cchess_muzero_sp_mode_config.py @@ -0,0 +1,134 @@ +from easydict import EasyDict + +# ============================================================== +# 最常修改的配置参数 +# ============================================================== +# 多GPU配置 +use_multi_gpu = True # 开启多GPU训练 +gpu_num = 8 # 使用的GPU数量,根据实际情况修改 +batch_size = 128 + +collector_env_num = 4 +n_episode = 128 +evaluator_env_num = 10 +num_simulations = 50 # 增加到 400 以提升搜索质量,目前简单测试时,先设置为50 +update_per_collect = 50 +reanalyze_ratio = 0.0 # 利用MuZero重分析优势,提升样本利用率 +max_env_step = int(1e8) # 中国象棋需要更多训练步数 +# ============================================================== +# 配置参数结束 +# ============================================================== + +cchess_muzero_config = dict( + exp_name=f'data_muzero/cchess_self-play-mode_seed0', + env=dict( + battle_mode='self_play_mode', + channel_last=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=True, ), + # UCI引擎配置(可选,用于eval_mode评估) + # uci_engine_path='pikafish', # UCI引擎路径,如 'pikafish' 或 '/path/to/pikafish' + # engine_depth=10, # 引擎搜索深度,1-20,越大越强(5=弱,10=中,15=强,20=很强) + # render_mode='human', # 渲染模式: 'human'打印棋盘, 'svg'生成SVG + ), + policy=dict( + model=dict( + # 15层 * 4帧 + 1层颜色 = 57层 + # 14层(7己+7敌) * 4历史 + 1颜色 + observation_shape=(57, 10, 9), + action_space_size=90 * 90, # 8100 个可能的动作 + image_channel=57, # 匹配 observation_shape + num_res_blocks=9, # 增加到9个残差块,匹配中国象棋复杂度 + num_channels=128, # 增加通道数 + reward_support_range=(-2., 3., 1.), # 范围[-2,2]共5类,高效且安全 + value_support_range=(-2., 3., 1.), # 范围[-2,2]共5类,完全满足-1/0/1奖励 + ), + cuda=True, + multi_gpu=use_multi_gpu, # 开启多GPU数据并行 + env_type='board_games', + action_type='varied_action_space', + mcts_ctree=True, + game_segment_length=50, # 中国象棋平均步数较多 + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + piecewise_decay_lr_scheduler=False, + learning_rate=0.0003, # 从0.003降到0.0003,避免训练震荡 + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + num_unroll_steps=5, # MuZero展开步数 + td_steps=30, # TD学习步数,需要满足:game_segment_length > td_steps + num_unroll_steps + discount_factor=1, # 棋类游戏使用 1 + n_episode=n_episode, + eval_freq=int(200), + replay_buffer_size=int(2e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +cchess_muzero_config = EasyDict(cchess_muzero_config) +main_config = cchess_muzero_config + +cchess_muzero_create_config = dict( + env=dict( + type='cchess', + import_names=['zoo.board_games.chinesechess.envs.cchess_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +cchess_muzero_create_config = EasyDict(cchess_muzero_create_config) +create_config = cchess_muzero_create_config + +if __name__ == "__main__": + from ding.utils import DDPContext + from lzero.entry import train_muzero + from lzero.config.utils import lz_to_ddp_config + + # ============================================================== + # 兼容 Ding 日志聚合:在调用 learner 的 hook 之前,把 log_buffer + # 里的 numpy.ndarray 转成 Python 标量或 list,避免 + # "invalid type in reduce: "。 + # 只改 BaseLearner.call_hook,不动框架其他逻辑。 + # ============================================================== + import numpy as np + from ding.worker import BaseLearner + + def _sanitize_log_buffer_for_ndarray(data): + if isinstance(data, dict): + return {k: _sanitize_log_buffer_for_ndarray(v) for k, v in data.items()} + elif isinstance(data, list): + return [_sanitize_log_buffer_for_ndarray(v) for v in data] + elif isinstance(data, np.ndarray): + # 标量数组 -> 标量;向量/矩阵 -> Python list + if data.size == 1: + return data.item() + else: + return data.tolist() + else: + return data + + _orig_call_hook = BaseLearner.call_hook + + def _patched_call_hook(self, place: str): + # 只在 after_iter 前做一次清洗,其他 hook 保持原样 + if place == 'after_iter' and getattr(self, 'log_buffer', None) is not None: + try: + self.log_buffer = _sanitize_log_buffer_for_ndarray(self.log_buffer) + except Exception: + # 清洗失败时不影响训练流程 + pass + return _orig_call_hook(self, place) + + BaseLearner.call_hook = _patched_call_hook + + seed = 0 + with DDPContext(): + main_config = lz_to_ddp_config(main_config) + train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/board_games/chinesechess/envs/cchess/__init__.py b/zoo/board_games/chinesechess/envs/cchess/__init__.py new file mode 100644 index 000000000..1e835be16 --- /dev/null +++ b/zoo/board_games/chinesechess/envs/cchess/__init__.py @@ -0,0 +1,2597 @@ +from typing import Iterable, Union, SupportsInt, Iterator, Callable, List, Tuple, Dict, Optional +import copy +import dataclasses +import enum +import datetime +import warnings +import re + +Color = bool +COLORS = [BLACK, RED] = [False, True] +COLOR_NAMES = ["black", "red"] +COLOR_NAMES_CN = ["黑", "红"] + +PieceType = int +PIECE_TYPES = [PAWN, ROOK, KNIGHT, BISHOP, ADVISOR, KING, CANNON] = range(1, 8) +PIECE_SYMBOLS = [None, "p", "r", "n", "b", "a", "k", "c"] +PIECE_NAMES = [None, "pawn", "rook", "knight", "bishop", "advisor", "king", "cannon"] + +UNICODE_PIECE_SYMBOLS = { + "R": "俥", "r": "車", + "N": "傌", "n": "馬", + "B": "相", "b": "象", + "A": "仕", "a": "士", + "K": "帥", "k": "將", + "P": "兵", "p": "卒", + "C": "炮", "c": "砲" +} +UNICODE_TO_PIECE_SYMBOLS = dict(zip(UNICODE_PIECE_SYMBOLS.values(), UNICODE_PIECE_SYMBOLS.keys())) + +ARABIC_NUMBERS = '123456789' +CHINESE_NUMBERS = '九八七六五四三二一' + +COORDINATES_MODERN_TO_TRADITIONAL = [dict(zip(range(9), ARABIC_NUMBERS)), dict(zip(range(9), CHINESE_NUMBERS))] +COORDINATES_TRADITIONAL_TO_MODERN = [dict(zip(ARABIC_NUMBERS, range(9))), dict(zip(CHINESE_NUMBERS, range(9)))] + +PIECE_SYMBOL_TRANSLATOR = [str.maketrans("车马炮将", "車馬砲將"), str.maketrans("车马士帅", "俥傌仕帥")] + +ADVISOR_BISHOP_MOVES_TRADITIONAL_TO_MODERN = { + "仕六进五": "d0e1", "仕六退五": "d2e1", "仕四进五": "f0e1", "仕四退五": "f2e1", + "仕五退六": "e1d0", "仕五进六": "e1d2", "仕五退四": "e1f0", "仕五进四": "e1f2", + "士6进5": "f9e8", "士6退5": "f7e8", "士4进5": "d9e8", "士4退5": "d7e8", + "士5退6": "e8f9", "士5进6": "e8f7", "士5退4": "e8d9", "士5进4": "e8d7", + + "相三进五": "g0e2", "相三进一": "g0i2", "相三退五": "g4e2", "相三退一": "g4i2", + "相七进五": "c0e2", "相七进九": "c0a2", "相七退五": "c4e2", "相七退九": "c4a2", + "相五退三": "e2g0", "相一退三": "i2g0", "相五进三": "e2g4", "相一进三": "i2g4", + "相五退七": "e2c0", "相九退七": "a2c0", "相五进七": "e2c4", "相九进七": "a2c4", + + "象3进5": "c9e7", "象3进1": "c9a7", "象3退5": "c5e7", "象3退1": "c5a7", + "象7进5": "g9e7", "象7进9": "g9i7", "象7退5": "g5e7", "象7退9": "g5i7", + "象5退3": "e7c9", "象1退3": "a7c9", "象5进3": "e7c5", "象1进3": "a7c5", + "象5退7": "e7g9", "象9退7": "i7g9", "象5进7": "e7g5", "象9进7": "i7g5" +} + +ADVISOR_BISHOP_MOVES_MODERN_TO_TRADITIONAL = dict(zip(ADVISOR_BISHOP_MOVES_TRADITIONAL_TO_MODERN.values(), + ADVISOR_BISHOP_MOVES_TRADITIONAL_TO_MODERN.keys())) + +TRADITIONAL_VERTICAL_DIRECTION = [{True: "退", False: "进"}, {True: "进", False: "退"}] +TRADITIONAL_VERTICAL_POS = [{True: "后", False: "前"}, {True: "前", False: "后"}] + +VERTICAL_MOVE_CHINESE_TO_ARABIC = dict(zip(reversed(CHINESE_NUMBERS), ARABIC_NUMBERS)) +VERTICAL_MOVE_ARABIC_TO_CHINESE = dict( + zip(VERTICAL_MOVE_CHINESE_TO_ARABIC.values(), VERTICAL_MOVE_CHINESE_TO_ARABIC.keys())) + + +def piece_symbol(piece_type: PieceType) -> str: + return PIECE_SYMBOLS[piece_type] + + +def piece_name(piece_type: PieceType) -> str: + return PIECE_NAMES[piece_type] + + +STARTING_FEN = 'rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR w - - 0 1' +STARTING_BOARD_FEN = "rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR" + + +class Status(enum.IntFlag): + VALID = 0 + EMPTY = 1 << 0 + TOO_MANY_RED_PIECES = 1 << 1 + TOO_MANY_BLACK_PIECES = 1 << 2 + NO_RED_KING = 1 << 3 + NO_BLACK_KING = 1 << 4 + TOO_MANY_RED_KINGS = 1 << 5 + TOO_MANY_BLACK_KINGS = 1 << 6 + RED_KING_PLACE_WRONG = 1 << 7 + BLACK_KING_PLACE_WRONG = 1 << 8 + TOO_MANY_RED_PAWNS = 1 << 9 + TOO_MANY_BLACK_PAWNS = 1 << 10 + RED_PAWNS_PLACE_WRONG = 1 << 11 + BLACK_PAWNS_PLACE_WRONG = 1 << 12 + TOO_MANY_RED_ROOKS = 1 << 13 + TOO_MANY_BLACK_ROOKS = 1 << 14 + TOO_MANY_RED_KNIGHTS = 1 << 15 + TOO_MANY_BLACK_KNIGHTS = 1 << 16 + TOO_MANY_RED_BISHOPS = 1 << 17 + TOO_MANY_BLACK_BISHOPS = 1 << 18 + RED_BISHOPS_PLACE_WRONG = 1 << 19 + BLACK_BISHOPS_PLACE_WRONG = 1 << 20 + TOO_MANY_RED_ADVISORS = 1 << 21 + TOO_MANY_BLACK_ADVISORS = 1 << 22 + RED_ADVISORS_PLACE_WRONG = 1 << 23 + BLACK_ADVISORS_PLACE_WRONG = 1 << 24 + TOO_MANY_RED_CANNONS = 1 << 25 + TOO_MANY_BLACK_CANNONS = 1 << 26 + OPPOSITE_CHECK = 1 << 27 + KING_LINE_OF_SIGHT = 1 << 28 + + +STATUS_VALID = Status.VALID +STATUS_EMPTY = Status.EMPTY +STATUS_TOO_MANY_RED_PIECES = Status.TOO_MANY_RED_PIECES +STATUS_TOO_MANY_BLACK_PIECES = Status.TOO_MANY_BLACK_PIECES +STATUS_NO_RED_KING = Status.NO_RED_KING +STATUS_NO_BLACK_KING = Status.NO_BLACK_KING +STATUS_TOO_MANY_RED_KINGS = Status.TOO_MANY_RED_KINGS +STATUS_TOO_MANY_BLACK_KINGS = Status.TOO_MANY_BLACK_KINGS +STATUS_RED_KING_PLACE_WRONG = Status.RED_KING_PLACE_WRONG +STATUS_BLACK_KING_PLACE_WRONG = Status.BLACK_KING_PLACE_WRONG +STATUS_TOO_MANY_RED_PAWNS = Status.TOO_MANY_RED_PAWNS +STATUS_TOO_MANY_BLACK_PAWNS = Status.TOO_MANY_BLACK_PAWNS +STATUS_RED_PAWNS_PLACE_WRONG = Status.RED_PAWNS_PLACE_WRONG +STATUS_BLACK_PAWNS_PLACE_WRONG = Status.BLACK_PAWNS_PLACE_WRONG +STATUS_TOO_MANY_RED_ROOKS = Status.TOO_MANY_RED_ROOKS +STATUS_TOO_MANY_BLACK_ROOKS = Status.TOO_MANY_BLACK_ROOKS +STATUS_TOO_MANY_RED_KNIGHTS = Status.TOO_MANY_RED_KNIGHTS +STATUS_TOO_MANY_BLACK_KNIGHTS = Status.TOO_MANY_BLACK_KNIGHTS +STATUS_TOO_MANY_RED_BISHOPS = Status.TOO_MANY_RED_BISHOPS +STATUS_TOO_MANY_BLACK_BISHOPS = Status.TOO_MANY_BLACK_BISHOPS +STATUS_RED_BISHOPS_PLACE_WRONG = Status.RED_BISHOPS_PLACE_WRONG +STATUS_BLACK_BISHOPS_PLACE_WRONG = Status.BLACK_BISHOPS_PLACE_WRONG +STATUS_TOO_MANY_RED_ADVISORS = Status.TOO_MANY_RED_ADVISORS +STATUS_TOO_MANY_BLACK_ADVISORS = Status.TOO_MANY_BLACK_ADVISORS +STATUS_RED_ADVISORS_PLACE_WRONG = Status.RED_ADVISORS_PLACE_WRONG +STATUS_BLACK_ADVISORS_PLACE_WRONG = Status.BLACK_ADVISORS_PLACE_WRONG +STATUS_TOO_MANY_RED_CANNONS = Status.TOO_MANY_RED_CANNONS +STATUS_TOO_MANY_BLACK_CANNONS = Status.TOO_MANY_BLACK_CANNONS +STATUS_OPPOSITE_CHECK = Status.OPPOSITE_CHECK +STATUS_KING_LINE_OF_SIGHT = Status.KING_LINE_OF_SIGHT + + +class Termination(enum.Enum): + """Enum with reasons for a game to be over.""" + + CHECKMATE = enum.auto() + """See :func:`cchess.Board.is_checkmate()`.""" + STALEMATE = enum.auto() + """See :func:`cchess.Board.is_stalemate()`.""" + INSUFFICIENT_MATERIAL = enum.auto() + """See :func:`cchess.Board.is_insufficient_material()`.""" + FOURFOLD_REPETITION = enum.auto() + """See :func:`cchess.Board.is_fourfold_repetition()`.""" + SIXTY_MOVES = enum.auto() + """See :func:`cchess.Board.is_sixty_moves()`.""" + PERPETUAL_CHECK = enum.auto() + """See :func:`cchess.Board.is_perpetual_check()`.""" + + +@dataclasses.dataclass +class Outcome: + """ + Information about the outcome of an ended game, usually obtained from + :func:`cchess.Board.outcome()`. + """ + + termination: Termination + """The reason for the game to have ended.""" + + winner: Optional[Color] + """The winning color or ``None`` if drawn.""" + + def result(self) -> str: + """Returns ``1-0``, ``0-1`` or ``1/2-1/2``.""" + return "1/2-1/2" if self.winner is None else ("1-0" if self.winner else "0-1") + + +Square = int + +COLUMN_NAMES = ["a", "b", "c", "d", "e", "f", "g", "h", "i"] +ROW_NAMES = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] + +SQUARES = [ + A0, B0, C0, D0, E0, F0, G0, H0, I0, + A1, B1, C1, D1, E1, F1, G1, H1, I1, + A2, B2, C2, D2, E2, F2, G2, H2, I2, + A3, B3, C3, D3, E3, F3, G3, H3, I3, + A4, B4, C4, D4, E4, F4, G4, H4, I4, + A5, B5, C5, D5, E5, F5, G5, H5, I5, + A6, B6, C6, D6, E6, F6, G6, H6, I6, + A7, B7, C7, D7, E7, F7, G7, H7, I7, + A8, B8, C8, D8, E8, F8, G8, H8, I8, + A9, B9, C9, D9, E9, F9, G9, H9, I9 +] = range(90) + +SQUARE_NAMES = [c + r for r in ROW_NAMES for c in COLUMN_NAMES] + + +def parse_square(name: str): + """ + Gets the square index for the given square *name* + (e.g., ``a0`` returns ``0``). + + :raises: :exc:`ValueError` if the square name is invalid. + """ + return SQUARE_NAMES.index(name) + + +def square_name(square: Square): + """Gets the name of the square, like ``a3``.""" + return SQUARE_NAMES[square] + + +def square(column_index: int, row_index: int): + """Gets a square number by column and row index.""" + return row_index * 9 + column_index + + +def square_column(square: Square) -> int: + """Gets the column index of the square where ``0`` is the a-column.""" + return square % 9 + + +def square_row(square: Square) -> int: + """Gets the row index of the square where ``0`` is the first row.""" + return square // 9 + + +def square_distance(a: Square, b: Square) -> int: + """ + Gets the distance (i.e., the number of king steps) from square *a* to *b*. + """ + return max(abs(square_column(a) - square_column(b)), abs(square_row(a) - square_row(b))) + + +def square_mirror(square: Square) -> Square: + """Mirrors the square vertically.""" + return 81 - square + ((square % 9) << 1) + + +SQUARES_180 = [square_mirror(sq) for sq in SQUARES] + +# BitBoard + +BitBoard = int +BB_SQUARES = [ + BB_A0, BB_B0, BB_C0, BB_D0, BB_E0, BB_F0, BB_G0, BB_H0, BB_I0, + BB_A1, BB_B1, BB_C1, BB_D1, BB_E1, BB_F1, BB_G1, BB_H1, BB_I1, + BB_A2, BB_B2, BB_C2, BB_D2, BB_E2, BB_F2, BB_G2, BB_H2, BB_I2, + BB_A3, BB_B3, BB_C3, BB_D3, BB_E3, BB_F3, BB_G3, BB_H3, BB_I3, + BB_A4, BB_B4, BB_C4, BB_D4, BB_E4, BB_F4, BB_G4, BB_H4, BB_I4, + BB_A5, BB_B5, BB_C5, BB_D5, BB_E5, BB_F5, BB_G5, BB_H5, BB_I5, + BB_A6, BB_B6, BB_C6, BB_D6, BB_E6, BB_F6, BB_G6, BB_H6, BB_I6, + BB_A7, BB_B7, BB_C7, BB_D7, BB_E7, BB_F7, BB_G7, BB_H7, BB_I7, + BB_A8, BB_B8, BB_C8, BB_D8, BB_E8, BB_F8, BB_G8, BB_H8, BB_I8, + BB_A9, BB_B9, BB_C9, BB_D9, BB_E9, BB_F9, BB_G9, BB_H9, BB_I9 +] = [1 << sq for sq in SQUARES] + +BB_EMPTY = 0 +BB_ALL = 0x3ffffffffffffffffffffff + +BB_CORNERS = 0x20200000000000000000101 + +BB_ROWS = [ + BB_ROW_0, + BB_ROW_1, + BB_ROW_2, + BB_ROW_3, + BB_ROW_4, + BB_ROW_5, + BB_ROW_6, + BB_ROW_7, + BB_ROW_8, + BB_ROW_9 +] = [0x1ff << (9 * i) for i in range(10)] + +BB_COLUMNS = [ + BB_COLUMN_A, + BB_COLUMN_B, + BB_COLUMN_C, + BB_COLUMN_D, + BB_COLUMN_E, + BB_COLUMN_F, + BB_COLUMN_G, + BB_COLUMN_H, + BB_COLUMN_I +] = [0x201008040201008040201 << i for i in range(9)] + +BB_COLOR_SIDES = [0x3ffffffffffe00000000000, 0x1fffffffffff] + +BB_PALACES = [0x70381c0000000000000000, 0xe07038] +BB_ADVISOR_POS = [0x5010140000000000000000, 0xa02028] +BB_BISHOP_POS = [0x8800888008800000000000, 0x44004440044] +BB_PAWN_POS = [0x556abfffffffffff, 0x3fffffffffff55aa8000000] + +BB_START_PAWNS = 0x5540000aa8000000 +BB_START_ROOKS = BB_CORNERS +BB_START_KNIGHTS = 0x10400000000000000000082 +BB_START_BISHOPS = 0x8800000000000000000044 +BB_START_ADVISORS = 0x5000000000000000000028 +BB_START_KINGS = 0x2000000000000000000010 +BB_START_CANNONS = 0x410000000002080000 + +BB_START_OCCUPIED_RED = 0xaaa0801ff +BB_START_OCCUPIED_BLACK = 0x3fe00415540000000000000 +BB_START_OCCUPIED = 0x3fe00415540000aaa0801ff + + +def _sliding_attacks(square: Square, occupied: BitBoard, deltas: Iterable[int]): + attacks = BB_EMPTY + + for delta in deltas: + sq = square + + while True: + sq += delta + if not (0 <= sq < 90) or square_distance(sq, sq - delta) > 2: + break + + attacks |= BB_SQUARES[sq] + + if occupied & BB_SQUARES[sq]: + break + + return attacks + + +def _step_attacks(square: Square, deltas: Iterable[int], restriction: BitBoard = BB_ALL): + if not BB_SQUARES[square] & restriction: + return BB_EMPTY + return restriction & _sliding_attacks(square, BB_ALL, deltas) + + +KNIGHT_LEG_DELTAS = [1, 9, -1, -9] +KNIGHT_ATTACK_DELTAS = [-7, 11, 17, 19, -11, 7, -19, -17] + + +def _knight_attacks(square: Square, occupied: BitBoard): + attacks = BB_EMPTY + + for i, leg_delta in enumerate(KNIGHT_LEG_DELTAS): + leg_sq = square + leg_delta + if not (0 <= leg_sq < 90): + continue + if not occupied & BB_SQUARES[leg_sq]: + attack_deltas = KNIGHT_ATTACK_DELTAS[2 * i: 2 * i + 2] + for delta in attack_deltas: + sq = square + delta + if not (0 <= sq < 90) or square_distance(sq, square) > 2: + continue + attacks |= BB_SQUARES[sq] + + return attacks + + +KNIGHT_ATTACKER_LEG_DELTAS = [8, 10, -8, -10] +KNIGHT_ATTACKER_DELTAS = [7, 17, 19, 11, -7, -17, -19, -11] + + +def _knights_can_attack(square: Square, occupied: BitBoard): + attackers = BB_EMPTY + + for i, leg_delta in enumerate(KNIGHT_ATTACKER_LEG_DELTAS): + leg_sq = square + leg_delta + if not (0 <= leg_sq < 90): + continue + if not occupied & BB_SQUARES[leg_sq]: + attack_deltas = KNIGHT_ATTACKER_DELTAS[2 * i: 2 * i + 2] + for delta in attack_deltas: + sq = square + delta + if not (0 <= sq < 90) or square_distance(sq, square) > 2: + continue + attackers |= BB_SQUARES[sq] + + return attackers + + +BISHOP_EYE_DELTAS = [8, -8, 10, -10] +BISHOP_ATTACK_DELTAS = [16, -16, 20, -20] + + +def _bishop_attacks(square: Square, occupied: BitBoard, color: int): + attacks = BB_EMPTY + + for delta, leg_delta in zip(BISHOP_ATTACK_DELTAS, BISHOP_EYE_DELTAS): + eye_sq = square + leg_delta + if not (0 <= eye_sq < 90): + continue + if not occupied & BB_SQUARES[eye_sq]: + sq = square + delta + if not (0 <= sq < 90) or square_distance(sq, square) > 2: + continue + attacks |= BB_SQUARES[sq] + + return attacks & BB_BISHOP_POS[color] + + +BB_PAWN_ATTACKS = [[], []] +BB_PAWN_ATTACKS[BLACK] = [_step_attacks(sq, [-9, -1, 1], BB_PAWN_POS[BLACK]) for sq in range(45)] + \ + [_step_attacks(sq, [-9], BB_PAWN_POS[BLACK]) for sq in range(45, 90)] +BB_PAWN_ATTACKS[RED] = [_step_attacks(sq, [9], BB_PAWN_POS[RED]) for sq in range(45)] + \ + [_step_attacks(sq, [9, -1, 1], BB_PAWN_POS[RED]) for sq in range(45, 90)] + +BB_KING_ATTACKS = [[_step_attacks(sq, [9, -9, 1, -1], BB_PALACES[color]) for sq in SQUARES] for color in COLORS] +BB_ADVISOR_ATTACKS = [[_step_attacks(sq, [8, -8, 10, -10], BB_ADVISOR_POS[color]) for sq in SQUARES] for color in + COLORS] + +BB_PAWNS_CAN_ATTACK = [[], []] +BB_PAWNS_CAN_ATTACK[BLACK] = [_step_attacks(sq, [9, -1, 1], BB_PAWN_POS[BLACK]) for sq in range(45)] + \ + [_step_attacks(sq, [9], BB_PAWN_POS[BLACK]) for sq in range(45, 90)] +BB_PAWNS_CAN_ATTACK[RED] = [_step_attacks(sq, [-9], BB_PAWN_POS[RED]) for sq in range(45)] + \ + [_step_attacks(sq, [-9, -1, 1], BB_PAWN_POS[RED]) for sq in range(45, 90)] + + +def _edges(square: Square) -> BitBoard: + return (((BB_ROW_0 | BB_ROW_9) & ~BB_ROWS[square_row(square)]) | + ((BB_COLUMN_A | BB_COLUMN_I) & ~BB_COLUMNS[square_column(square)])) + + +def _carry_rippler(mask: BitBoard) -> Iterator[BitBoard]: + # Carry-Rippler trick to iterate subsets of mask. + subset = BB_EMPTY + while True: + yield subset + subset = (subset - mask) & mask + if not subset: + break + + +def _attack_table(deltas: List[int]) -> Tuple[List[BitBoard], List[Dict[BitBoard, BitBoard]]]: + mask_table = [] + attack_table = [] + + for square in SQUARES: + attacks = {} + + mask = _sliding_attacks(square, 0, deltas) & ~_edges(square) + for subset in _carry_rippler(mask): + attacks[subset] = _sliding_attacks(square, subset, deltas) + + attack_table.append(attacks) + mask_table.append(mask) + + return mask_table, attack_table + + +BB_COLUMN_MASKS, BB_COLUMN_ATTACKS = _attack_table([-9, 9]) # 车在某个位置时,该列上棋子各种分布对应的其能吃到的范围 +BB_ROW_MASKS, BB_ROW_ATTACKS = _attack_table([-1, 1]) # 车在某个位置时,该行上棋子各种分布对应的其能吃到的范围 + + +def _rook_attacks(square: Square, occupied: BitBoard): + return _sliding_attacks(square, occupied, [1, -1, 9, -9]) + + +def _cannon_attacks(square: Square, occupied: BitBoard): + attacks = BB_EMPTY + + for delta in [1, -1, 9, -9]: + sq = square + occupied_num = 0 + + while True: + sq += delta + if not (0 <= sq < 90) or square_distance(sq, sq - delta) > 2: + break + + if occupied & BB_SQUARES[sq]: + occupied_num += 1 + if occupied_num == 2: + attacks |= BB_SQUARES[sq] + break + + return attacks + + +def _cannon_slides(square: Square, occupied: BitBoard): + slides = BB_EMPTY + + for delta in [1, -1, 9, -9]: + sq = square + + while True: + sq += delta + if not (0 <= sq < 90) or square_distance(sq, sq - delta) > 2: + break + + if occupied & BB_SQUARES[sq]: + break + slides |= BB_SQUARES[sq] + + return slides + + +def msb(bb: BitBoard): + """Most Significant Byte""" + return bb.bit_length() - 1 + + +def lsb(bb: BitBoard): + """Least Significant Byte""" + return (bb & -bb).bit_length() - 1 + + +def _lines() -> List[List[BitBoard]]: + lines = [] + for a, bb_a in enumerate(BB_SQUARES): + rays_row = [] + for b, bb_b in enumerate(BB_SQUARES): + if BB_ROW_ATTACKS[a][0] & bb_b: + rays_row.append(BB_ROW_ATTACKS[a][0] | bb_a) + elif BB_COLUMN_ATTACKS[a][0] & bb_b: + rays_row.append(BB_COLUMN_ATTACKS[a][0] | bb_a) + else: + rays_row.append(BB_EMPTY) + lines.append(rays_row) + return lines + + +BB_LINES = _lines() + + +def line(a: Square, b: Square) -> BitBoard: + return BB_LINES[a][b] + + +def between(a: Square, b: Square): + bb = BB_LINES[a][b] & ((BB_ALL << a) ^ (BB_ALL << b)) + return bb & (bb - 1) + + +class Piece: + """A piece with type and color.""" + + def __init__(self, piece_type: PieceType, color: Color): + self.piece_type = piece_type + """The piece type.""" + + self.color = color + """The piece color.""" + + def symbol(self): + symbol = piece_symbol(self.piece_type) + return symbol.upper() if self.color else symbol + + def unicode_symbol(self, *, invert_color: bool = False): + symbol = self.symbol().swapcase() if invert_color else self.symbol() + return UNICODE_PIECE_SYMBOLS[symbol] + + def __hash__(self) -> int: + return self.piece_type + (-1 if self.color else 5) + + def __repr__(self) -> str: + return f"Piece.from_symbol({self.symbol()!r})" + + def __str__(self) -> str: + return self.symbol() + + def _repr_svg_(self) -> str: + import cchess.svg + return cchess.svg.piece(self, size=45) + + @classmethod + def from_symbol(cls, symbol: str): + return cls(PIECE_SYMBOLS.index(symbol.lower()), symbol.isupper()) + + @classmethod + def from_unicode(cls, unicode: str): + return cls.from_symbol(UNICODE_TO_PIECE_SYMBOLS[unicode]) + + +@dataclasses.dataclass +class Move: + def __init__(self, from_square: Square, to_square: Square): + assert from_square in SQUARES, f"from_square out of range: {from_square!r}" + assert to_square in SQUARES, f"to_square out of range: {to_square!r}" + self.from_square = from_square + self.to_square = to_square + + def uci(self) -> str: + """ + Gets a UCI string for the move. + + The UCI representation of a null move is ``0000``. + """ + if self: + return SQUARE_NAMES[self.from_square] + SQUARE_NAMES[self.to_square] + else: + return "0000" + + @classmethod + def from_uci(cls, uci: str): + if uci == "0000": + return cls.null() + elif len(uci) == 4: + from_square = SQUARE_NAMES.index(uci[0:2]) + to_square = SQUARE_NAMES.index(uci[2:4]) + return cls(from_square, to_square) + else: + raise ValueError(f"expected uci string to be of length 4: {uci!r}") + + def __repr__(self) -> str: + return f"Move.from_uci({self.uci()!r})" + + def __str__(self) -> str: + return self.uci() + + def xboard(self) -> str: + return self.uci() if self else "@@@@" + + def __bool__(self): + return bool(self.from_square or self.to_square) + + @classmethod + def null(cls): + return cls(0, 0) + + def __hash__(self): + return hash((self.from_square, self.to_square)) + + +class BaseBoard: + def __init__(self, board_fen: Optional[str] = STARTING_BOARD_FEN): + self.occupied_co = [BB_EMPTY, BB_EMPTY] + self._starting_board_fen = "" + if board_fen is None: + self._clear_board() + elif board_fen == STARTING_BOARD_FEN: + self._reset_board() + else: + self._set_board_fen(board_fen) + self._svg_css = None + self._axes_type = 0 + + def _clear_board(self): + self.pawns = BB_EMPTY + self.rooks = BB_EMPTY + self.knights = BB_EMPTY + self.bishops = BB_EMPTY + self.advisors = BB_EMPTY + self.kings = BB_EMPTY + self.cannons = BB_EMPTY + + self.occupied_co[RED] = BB_EMPTY + self.occupied_co[BLACK] = BB_EMPTY + self.occupied = BB_EMPTY + self._starting_board_fen = "" + + def clear_board(self): + self._clear_board() + + def _reset_board(self): + self.pawns = BB_START_PAWNS + self.rooks = BB_START_ROOKS + self.knights = BB_START_KNIGHTS + self.bishops = BB_START_BISHOPS + self.advisors = BB_START_ADVISORS + self.kings = BB_START_KINGS + self.cannons = BB_START_CANNONS + + self.occupied_co[RED] = BB_START_OCCUPIED_RED + self.occupied_co[BLACK] = BB_START_OCCUPIED_BLACK + self.occupied = BB_START_OCCUPIED + self._starting_board_fen = STARTING_BOARD_FEN + + def reset_board(self): + self._reset_board() + + def set_style(self, style: str): + self._svg_css = style + + def set_axes_type(self, type_: int): + assert type_ in [0, 1] + self._axes_type = type_ + + def _repr_svg_(self): + import cchess.svg + return cchess.svg.board(board=self, size=600, axes_type=self._axes_type, style=self._svg_css) + + def _set_board_fen(self, fen: str): + # Compatibility with set_fen(). + fen = fen.strip() + if " " in fen: + raise ValueError(f"expected position part of fen, got multiple parts: {fen!r}") + + # Ensure the FEN is valid. + rows = fen.split("/") + if len(rows) != 10: + raise ValueError(f"expected 10 rows in position part of fen: {fen!r}") + + # Validate each row. + for row in rows: + field_sum = 0 + previous_was_digit = False + + for c in row: + if c in ARABIC_NUMBERS: + if previous_was_digit: + raise ValueError(f"two subsequent digits in position part of fen: {fen!r}") + field_sum += int(c) + previous_was_digit = True + elif c.lower() in PIECE_SYMBOLS: + field_sum += 1 + previous_was_digit = False + else: + raise ValueError(f"invalid character in position part of fen: {fen!r}") + + if field_sum != 9: + raise ValueError(f"expected 9 columns per row in position part of fen: {fen!r}") + + # Clear the board. + self._clear_board() + + # Put pieces on the board. + square_index = 0 + for c in fen: + if c in ARABIC_NUMBERS: + square_index += int(c) + elif c.lower() in PIECE_SYMBOLS: + piece = Piece.from_symbol(c) + self._set_piece_at(SQUARES_180[square_index], piece.piece_type, piece.color) + square_index += 1 + self._starting_board_fen = fen + + def set_board_fen(self, fen: str): + """ + Parses *fen* and sets up the board, where *fen* is the board part of + a FEN. + + :raises: :exc:`ValueError` if syntactically invalid. + """ + self._set_board_fen(fen) + + def piece_map(self, *, mask: BitBoard = BB_ALL): + """ + Gets a dictionary of :class:`pieces ` by square index. + """ + result = {} + for square in scan_reversed(self.occupied & mask): + result[square] = self.piece_at(square) + return result + + def _set_piece_map(self, pieces: Dict[Square, Piece]) -> None: + self._clear_board() + for square, piece in pieces.items(): + self._set_piece_at(square, piece.piece_type, piece.color) + + def set_piece_map(self, pieces: Dict[Square, Piece]) -> None: + """ + Sets up the board from a dictionary of :class:`pieces ` + by square index. + """ + self._set_piece_map(pieces) + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.board_fen()!r})" + + def __str__(self) -> str: + builder = [] + + for square in SQUARES_180: + piece = self.piece_at(square) + + if piece: + builder.append(piece.symbol()) + else: + builder.append(".") + + if BB_SQUARES[square] & BB_COLUMN_I: + if square != I0: + builder.append("\n") + else: + builder.append(" ") + + return "".join(builder) + + def unicode(self, *, invert_color: bool = False, axes: bool = True, axes_type: int = 0) -> str: + """ + Returns a string representation of the board with Unicode pieces. + Useful for pretty-printing to a terminal. + + :param invert_color: Invert color of the Unicode pieces. + :param axes: Show a coordinate axes margin. + :param axes_type: Coordinate axes type, 0 for modern and 1 for traditional. + """ + builder = [] + assert axes_type in [0, 1], f"axes_type must value 0 or 1, got {axes_type}" + if axes: + if axes_type == 0: + builder.append(' abcdefghi\n') + else: + builder.append('123456789\n') + + for row_index in range(9, -1, -1): + if axes and axes_type == 0: + builder.append(ROW_NAMES[row_index]) + builder.append(' ') + + for col_index in range(9): + square_index = square(col_index, row_index) + + piece = self.piece_at(square_index) + + if piece: + builder.append(piece.unicode_symbol(invert_color=invert_color)) + else: + builder.append(".") + + if axes or row_index > 0: + builder.append("\n") + + if axes: + if axes_type == 0: + builder.append(' abcdefghi') + else: + builder.append('九八七六五四三二一') + + return "".join(builder) + + def pieces_mask(self, piece_type: PieceType, color: Color) -> BitBoard: + if piece_type == PAWN: + bb = self.pawns + elif piece_type == ROOK: + bb = self.rooks + elif piece_type == KNIGHT: + bb = self.knights + elif piece_type == BISHOP: + bb = self.bishops + elif piece_type == ADVISOR: + bb = self.advisors + elif piece_type == KING: + bb = self.kings + elif piece_type == CANNON: + bb = self.cannons + else: + assert False, f"expected PieceType, got {piece_type!r}" + + return bb & self.occupied_co[color] + + def piece_at(self, square: Square): + """Gets the :class:`piece ` at the given square.""" + piece_type = self.piece_type_at(square) + if piece_type: + mask = BB_SQUARES[square] + color = bool(self.occupied_co[RED] & mask) + return Piece(piece_type, color) + else: + return None + + def piece_type_at(self, square: Square): + mask = BB_SQUARES[square] + + if not self.occupied & mask: + return None + elif self.pawns & mask: + return PAWN + elif self.rooks & mask: + return ROOK + elif self.knights & mask: + return KNIGHT + elif self.bishops & mask: + return BISHOP + elif self.advisors & mask: + return ADVISOR + elif self.kings & mask: + return KING + else: + return CANNON + + def color_at(self, square: Square): + """Gets the color of the piece at the given square.""" + mask = BB_SQUARES[square] + if self.occupied_co[RED] & mask: + return RED + elif self.occupied_co[BLACK] & mask: + return BLACK + else: + return None + + def king(self, color: Color): + """ + Finds the king square of the given side. Returns ``None`` if there + is no king of that color. + """ + king_mask = self.occupied_co[color] & self.kings + return msb(king_mask) if king_mask else None + + def attacks_mask(self, square: Square) -> BitBoard: + bb_square = BB_SQUARES[square] + + if bb_square & self.pawns: + color = bool(bb_square & self.occupied_co[RED]) + return BB_PAWN_ATTACKS[color][square] + elif bb_square & self.rooks: + return _rook_attacks(square, self.occupied) + elif bb_square & self.knights: + return _knight_attacks(square, self.occupied) + elif bb_square & self.bishops: + color = bool(bb_square & self.occupied_co[RED]) + return _bishop_attacks(square, self.occupied, color) + elif bb_square & self.advisors: + color = bool(bb_square & self.occupied_co[RED]) + return BB_ADVISOR_ATTACKS[color][square] + elif bb_square & self.kings: + color = bool(bb_square & self.occupied_co[RED]) + return BB_KING_ATTACKS[color][square] + elif bb_square & self.cannons: + return _cannon_attacks(square, self.occupied) + return 0 + + def attacks(self, square: Square): + """ + Gets the set of attacked squares from the given square. + + There will be no attacks if the square is empty. Pinned pieces are + still attacking other squares. + + Returns a :class:`set of squares `. + """ + return SquareSet(self.attacks_mask(square)) + + def _attackers_mask(self, color: Color, square: Square, occupied: BitBoard) -> BitBoard: + row_pieces = BB_ROW_MASKS[square] & occupied + column_pieces = BB_COLUMN_MASKS[square] & occupied + + attackers = ( + (BB_ROW_ATTACKS[square][row_pieces] & self.rooks) | + (BB_COLUMN_ATTACKS[square][column_pieces] & self.rooks) | + _knights_can_attack(square, occupied) & self.knights | + _bishop_attacks(square, occupied, color) & self.bishops | + BB_ADVISOR_ATTACKS[color][square] & self.advisors | + BB_KING_ATTACKS[color][square] & self.kings | + BB_PAWNS_CAN_ATTACK[color][square] & self.pawns | + _cannon_attacks(square, occupied) & self.cannons + ) + + return attackers & self.occupied_co[color] + + def attackers_mask(self, color: Color, square: Square) -> BitBoard: + return self._attackers_mask(color, square, self.occupied) + + def is_attacked_by(self, color: Color, square: Square) -> bool: + """ + Checks if the given side attacks the given square. + """ + return bool(self.attackers_mask(color, square)) + + def attackers(self, color: Color, square: Square): + """ + Gets the set of attackers of the given color for the given square. + + Returns a :class:`set of squares `. + """ + return SquareSet(self.attackers_mask(color, square)) + + def _remove_piece_at(self, square: Square): + piece_type = self.piece_type_at(square) + mask = BB_SQUARES[square] + + if piece_type == PAWN: + self.pawns ^= mask + elif piece_type == ROOK: + self.rooks ^= mask + elif piece_type == KNIGHT: + self.knights ^= mask + elif piece_type == BISHOP: + self.bishops ^= mask + elif piece_type == ADVISOR: + self.advisors ^= mask + elif piece_type == KING: + self.kings ^= mask + elif piece_type == CANNON: + self.cannons ^= mask + else: + return None + + self.occupied ^= mask + self.occupied_co[RED] &= ~mask + self.occupied_co[BLACK] &= ~mask + + return piece_type + + def remove_piece_at(self, square: Square): + color = bool(self.occupied_co[RED] & BB_SQUARES[square]) + piece_type = self._remove_piece_at(square) + return Piece(piece_type, color) if piece_type else None + + def _set_piece_at(self, square: Square, piece_type: PieceType, color: Color): + self._remove_piece_at(square) + + mask = BB_SQUARES[square] + + if piece_type == PAWN: + self.pawns |= mask + elif piece_type == ROOK: + self.rooks |= mask + elif piece_type == KNIGHT: + self.knights |= mask + elif piece_type == BISHOP: + self.bishops |= mask + elif piece_type == ADVISOR: + self.advisors |= mask + elif piece_type == KING: + self.kings |= mask + elif piece_type == CANNON: + self.cannons |= mask + else: + return + + self.occupied |= mask + self.occupied_co[color] |= mask + + def set_piece_at(self, square: Square, piece): + if piece is None: + self._remove_piece_at(square) + else: + self._set_piece_at(square, piece.piece_type, piece.color) + + def board_fen(self) -> str: + """ + Gets the board FEN (e.g., + ``rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR``). + """ + builder = [] + empty = 0 + + for square in SQUARES_180: + piece = self.piece_at(square) + + if not piece: + empty += 1 + else: + if empty: + builder.append(str(empty)) + empty = 0 + builder.append(piece.symbol()) + + if BB_SQUARES[square] & BB_COLUMN_I: + if empty: + builder.append(str(empty)) + empty = 0 + + if square != I0: + builder.append("/") + + return "".join(builder) + + def __eq__(self, board: object) -> bool: + if isinstance(board, BaseBoard): + return ( + self.occupied == board.occupied and + self.occupied_co[RED] == board.occupied_co[RED] and + self.pawns == board.pawns and + self.rooks == board.rooks and + self.knights == board.knights and + self.bishops == board.bishops and + self.advisors == board.advisors and + self.kings == board.kings and + self.cannons == board.cannons) + else: + return NotImplemented + + def copy(self): + """Creates a copy of the board.""" + board = type(self)(None) + + board.pawns = self.pawns + board.knights = self.knights + board.rooks = self.rooks + board.bishops = self.bishops + board.advisors = self.advisors + board.kings = self.kings + board.cannons = self.cannons + + board.occupied_co[RED] = self.occupied_co[RED] + board.occupied_co[BLACK] = self.occupied_co[BLACK] + board.occupied = self.occupied + + return board + + def __copy__(self): + return self.copy() + + def __deepcopy__(self, memo: Dict[int, object]): + board = self.copy() + memo[id(self)] = board + return board + + @classmethod + def empty(cls): + """ + Creates a new empty board. Also see + :func:`~cchess.BaseBoard.clear_board()`. + """ + return cls(None) + + +class _BoardState: + + def __init__(self, board) -> None: + self.pawns = board.pawns + self.rooks = board.rooks + self.knights = board.knights + self.bishops = board.bishops + self.advisors = board.advisors + self.kings = board.kings + self.cannons = board.cannons + + self.occupied_r = board.occupied_co[RED] + self.occupied_b = board.occupied_co[BLACK] + self.occupied = board.occupied + + self.turn = board.turn + self.halfmove_clock = board.halfmove_clock + self.fullmove_number = board.fullmove_number + + def __eq__(self, other) -> bool: + return all([self.turn == other.turn, + self.pawns == other.pawns, + self.rooks == other.rooks, + self.knights == other.knights, + self.bishops == other.bishops, + self.advisors == other.advisors, + self.kings == other.kings, + self.cannons == other.cannons, + self.occupied_r == other.occupied_r, + self.occupied_b == other.occupied_b]) + + def restore(self, board) -> None: + board.pawns = self.pawns + board.rooks = self.rooks + board.knights = self.knights + board.bishops = self.bishops + board.advisors = self.advisors + board.kings = self.kings + board.cannons = self.cannons + + board.occupied_co[RED] = self.occupied_r + board.occupied_co[BLACK] = self.occupied_b + board.occupied = self.occupied + + board.turn = self.turn + board.halfmove_clock = self.halfmove_clock + board.fullmove_number = self.fullmove_number + + +class Board(BaseBoard): + starting_fen = STARTING_FEN + turn: Color + """The side to move (``cchess.RED`` or ``cchess.BLACK``).""" + fullmove_number: int + """ + Counts move pairs. Starts at `1` and is incremented after every move + of the black side. + """ + halfmove_clock: int + """The number of half-moves since the last capture.""" + move_stack: List[Move] + """ + The move stack. Use :func:`Board.push() `, + :func:`Board.pop() `, + :func:`Board.peek() ` and + :func:`Board.clear_stack() ` for + manipulation. + """ + + def __init__(self, fen: Optional[str] = STARTING_FEN): + super(Board, self).__init__(None) + self.move_stack = [] + self._stack = [] + self._starting_fen = "" + + if fen is None: + self.clear() + elif fen == type(self).starting_fen: + self.reset() + else: + self.set_fen(fen) + + def __repr__(self): + return f"{type(self).__name__}({self.fen()!r})" + + def _repr_svg_(self) -> str: + import cchess.svg + return cchess.svg.board(board=self, + size=450, + axes_type=self._axes_type, + lastmove=self.peek() if self.move_stack else None, + checkers=self.checkers() if self.is_check() else None, + style=self._svg_css) + + def fen(self) -> str: + """ + Gets a FEN representation of the position. + + A FEN string (e.g., + ``rnbakabnr/9/1c5c1/p1p1p1p1p/9/9/P1P1P1P1P/1C5C1/9/RNBAKABNR w - - 0 1``) consists + of the board part :func:`~cchess.Board.board_fen()`, the + :data:`~cchess.Board.turn`, + the :data:`~cchess.Board.halfmove_clock` + and the :data:`~cchess.Board.fullmove_number`. + """ + return " ".join([ + self.epd(), + str(self.halfmove_clock), + str(self.fullmove_number) + ]) + + def epd(self) -> str: + """ + Gets an EPD representation of the current position. + """ + epd = [self.board_fen(), + "w" if self.turn == RED else "b", + '-', "-"] + + return " ".join(epd) + + def set_fen(self, fen: str) -> None: + """ + Parses a FEN and sets the position from it. + + :raises: :exc:`ValueError` if syntactically invalid. Use + :func:`~cchess.Board.is_valid()` to detect invalid positions. + """ + parts = fen.split() + + # Board part. + try: + board_part = parts.pop(0) + except IndexError: + raise ValueError("empty fen") + + # Turn. + try: + turn_part = parts.pop(0) + except IndexError: + turn = RED + else: + if turn_part == "w": + turn = RED + elif turn_part == "b": + turn = BLACK + else: + raise ValueError(f"expected 'w' or 'b' for turn part of fen: {fen!r}") + + try: + parts.pop(0) + except IndexError: + pass + try: + parts.pop(0) + except IndexError: + pass + + # Check that the half-move part is valid. + try: + halfmove_part = parts.pop(0) + except IndexError: + halfmove_clock = 0 + else: + try: + halfmove_clock = int(halfmove_part) + except ValueError: + raise ValueError(f"invalid half-move clock in fen: {fen!r}") + + if halfmove_clock < 0: + raise ValueError(f"half-move clock cannot be negative: {fen!r}") + + # Check that the full-move number part is valid. + # 0 is allowed for compatibility, but later replaced with 1. + try: + fullmove_part = parts.pop(0) + except IndexError: + fullmove_number = 1 + else: + try: + fullmove_number = int(fullmove_part) + except ValueError: + raise ValueError(f"invalid fullmove number in fen: {fen!r}") + + if fullmove_number < 0: + raise ValueError(f"fullmove number cannot be negative: {fen!r}") + + fullmove_number = max(fullmove_number, 1) + + # All parts should be consumed now. + if parts: + raise ValueError(f"fen string has more parts than expected: {fen!r}") + + # Validate the board part and set it. + self._set_board_fen(board_part) + + # Apply. + self.turn = turn + self.halfmove_clock = halfmove_clock + self.fullmove_number = fullmove_number + self.clear_stack() + self._starting_fen = self.fen() + + @property + def legal_moves(self): + """ + A dynamic list of legal moves. + """ + return LegalMoveGenerator(self) + + @property + def pseudo_legal_moves(self): + """ + A dynamic list of pseudo-legal moves, much like the legal move list. + """ + return PseudoLegalMoveGenerator(self) + + def clear(self): + """ + Clears the board. + + Resets move stack and move counters. The side to move is red. + + In order to be in a valid :func:`~cchess.Board.status()`, at least kings + need to be put on the board. + """ + self.turn = RED + self.halfmove_clock = 0 + self.fullmove_number = 1 + + self.clear_board() + self._starting_fen = "" + + def clear_board(self): + super().clear_board() + self.clear_stack() + + def clear_stack(self): + """Clears the move stack.""" + self.move_stack.clear() + self._stack.clear() + + def reset(self) -> None: + """Restores the starting position.""" + self.turn = RED + self.halfmove_clock = 0 + self.fullmove_number = 1 + + self.reset_board() + self._starting_fen = type(self).starting_fen + + def reset_board(self) -> None: + """ + Resets only pieces to the starting position. Use + :func:`~cchess.Board.reset()` to fully restore the starting position + (including turn, castling rights, etc.). + """ + super().reset_board() + self.clear_stack() + + def root(self): + """Returns a copy of the root position.""" + if self._stack: + board = type(self)(None) + self._stack[0].restore(board) + return board + else: + return self.copy(stack=False) + + def copy(self, *, stack: Union[bool, int] = True): + """ + Creates a copy of the board. + + Defaults to copying the entire move stack. Alternatively, *stack* can + be ``False``, or an integer to copy a limited number of moves. + """ + board = super().copy() + + board.turn = self.turn + board.fullmove_number = self.fullmove_number + board.halfmove_clock = self.halfmove_clock + + if stack: + stack = len(self.move_stack) if stack is True else stack + board.move_stack = [copy.copy(move) for move in self.move_stack[-stack:]] + board._stack = self._stack[-stack:] + + return board + + @classmethod + def empty(cls): + """Creates a new empty board. Also see :func:`~cchess.Board.clear()`.""" + return cls(None) + + def ply(self) -> int: + return 2 * (self.fullmove_number - 1) + (self.turn == BLACK) + + def remove_piece_at(self, square: Square, clear_stack=True): + piece = super().remove_piece_at(square) + if clear_stack: + self.clear_stack() + return piece + + def set_piece_at(self, square: Square, piece: Optional[Piece], clear_stack=True): + super().set_piece_at(square, piece) + if clear_stack: + self.clear_stack() + + def checkers_mask(self) -> BitBoard: + king = self.king(self.turn) + return BB_EMPTY if king is None else self.attackers_mask(not self.turn, king) + + def checkers(self): + """ + Gets the pieces currently giving check. + + Returns a :class:`set of squares `. + """ + return SquareSet(self.checkers_mask()) + + def is_check(self) -> bool: + """Tests if the current side to move is in check.""" + return bool(self.checkers_mask()) + + def is_king_line_of_sight(self) -> bool: + red_king, black_king = self.king(RED), self.king(BLACK) + if red_king is None or black_king is None: + return False + between_kings = between(red_king, black_king) + if not between_kings: + return False + return not bool(between(red_king, black_king) & self.occupied) + + def gives_check(self, move: Move) -> bool: + """ + Probes if the given move would put the opponent in check. The move + must be at least pseudo-legal. + """ + self.push(move) + try: + return self.is_check() + finally: + self.pop() + + def _is_safe(self, move: Move) -> bool: + try: + self.push(move) + return not (bool(self.attackers_mask(self.turn, self.king(not self.turn))) | self.is_king_line_of_sight()) + finally: + self.pop() + + def generate_pseudo_legal_moves(self, from_mask: BitBoard = BB_ALL, to_mask: BitBoard = BB_ALL) -> Iterator[Move]: + our_pieces = self.occupied_co[self.turn] + + for from_square in scan_reversed(our_pieces & from_mask): + moves = self.attacks_mask(from_square) & ~our_pieces & to_mask + for to_square in scan_reversed(moves): + yield Move(from_square, to_square) # pieces attack + if BB_SQUARES[from_square] & self.cannons: + slides = _cannon_slides(from_square, self.occupied) & to_mask + for to_square in scan_reversed(slides): + yield Move(from_square, to_square) # cannons slide + + def generate_pseudo_legal_captures(self, from_mask: BitBoard = BB_ALL, to_mask: BitBoard = BB_ALL) -> Iterator[ + Move]: + return self.generate_pseudo_legal_moves(from_mask, to_mask & self.occupied_co[not self.turn]) + + def generate_legal_moves(self, from_mask: BitBoard = BB_ALL, to_mask: BitBoard = BB_ALL) -> Iterator[Move]: + + king = self.king(self.turn) + oppo_king = self.king(not self.turn) + if king: + for move in self.generate_pseudo_legal_moves(from_mask, to_mask): + if move.to_square == oppo_king: + yield move + elif self._is_safe(move): + yield move + else: + yield from self.generate_pseudo_legal_moves(from_mask, to_mask) + + def generate_legal_captures(self, from_mask: BitBoard = BB_ALL, to_mask: BitBoard = BB_ALL) -> Iterator[Move]: + return self.generate_legal_moves(from_mask, to_mask & self.occupied_co[not self.turn]) + + def is_into_check(self, move): + king = self.king(self.turn) + if king is None: + return False + return not self._is_safe(move) + + def was_into_check(self) -> bool: + king = self.king(not self.turn) + return king is not None and self.is_attacked_by(self.turn, king) + + def is_pseudo_legal(self, move: Move) -> bool: + # Null moves are not pseudo-legal. + if not move: + return False + + # Source square must not be vacant. + piece = self.piece_type_at(move.from_square) + if not piece: + return False + + # Get square masks. + from_mask = BB_SQUARES[move.from_square] + to_mask = BB_SQUARES[move.to_square] + + # Check turn. + if not self.occupied_co[self.turn] & from_mask: + return False + + # Destination square can not be occupied by own piece. + if self.occupied_co[self.turn] & to_mask: + return False + + # Cannon + if piece == CANNON: + slides = _cannon_slides(move.from_square, self.occupied) + if to_mask & slides: + return True + + # Handle all other pieces. + return bool(self.attacks_mask(move.from_square) & to_mask) + + def is_legal(self, move: Move) -> bool: + if self.is_pseudo_legal(move): + if move.to_square == self.king(not self.turn): + return True + return not self.is_into_check(move) + return False + + def _board_state(self): + return _BoardState(self) + + def is_zeroing(self, move: Move) -> bool: + """Checks if the given pseudo-legal move is a capture.""" + to_square = BB_SQUARES[move.to_square] + return bool(to_square & self.occupied_co[not self.turn]) + + def push(self, move: Move) -> None: + """ + Updates the position with the given *move* and puts it onto the + move stack. + + Null moves just increment the move counters, switch turns. + + .. warning:: + Moves are not checked for legality. It is the caller's + responsibility to ensure that the move is at least pseudo-legal or + a null move. + """ + # Push move and remember board state. + board_state = self._board_state() + self.move_stack.append(move) + self._stack.append(board_state) + + # Increment move counters. + self.halfmove_clock += 1 + if self.turn == BLACK: + self.fullmove_number += 1 + + # Zero the half-move clock. + if self.is_zeroing(move): + self.halfmove_clock = 0 + + piece = self.remove_piece_at(move.from_square, clear_stack=False) + assert piece is not None, f"push() expects move to be pseudo-legal, but got {move} in {self.board_fen()}" + self.set_piece_at(move.to_square, piece, clear_stack=False) + + # Swap turn. + self.turn = not self.turn + + def pop(self) -> Move: + """ + Restores the previous position and returns the last move from the stack. + + :raises: :exc:`IndexError` if the move stack is empty. + """ + move = self.move_stack.pop() + self._stack.pop().restore(self) + return move + + def peek(self) -> Move: + """ + Gets the last move from the move stack. + + :raises: :exc:`IndexError` if the move stack is empty. + """ + return self.move_stack[-1] + + def push_notation(self, notation: str): + try: + move = self.parse_notation(notation) + if self.is_legal(move): + self.push(move) + return move + else: + raise ValueError(f"illegal notation: {notation!r} in {self.fen()!r}") + except (AssertionError, ValueError): + raise ValueError(f"illegal notation: {notation!r} in {self.fen()!r}") + + def push_uci(self, uci: str): + move = Move.from_uci(uci) + if not self.is_legal(move): + raise ValueError(f"illegal uci: {uci!r} in {self.fen()!r}") + self.push(move) + return move + + def find_move(self, from_square: Square, to_square: Square) -> Move: + """ + Finds a matching legal move for an origin square and a target square. + + :raises: :exc:`ValueError` if no matching legal move is found. + """ + + move = Move(from_square=from_square, to_square=to_square) + if not self.is_legal(move): + raise ValueError( + f"no matching legal move for {move.uci()} ({SQUARE_NAMES[from_square]} -> {SQUARE_NAMES[to_square]}) in {self.fen()}") + + return move + + def is_checkmate(self) -> bool: + """Checks if the current position is a checkmate.""" + if not self.is_check(): + return False + + return not any(self.generate_legal_moves()) + + def is_stalemate(self) -> bool: + """Checks if the current position is a stalemate.""" + if self.is_check(): + return False + + return not any(self.generate_legal_moves()) + + def is_insufficient_material(self) -> bool: + """Checks if neither side has sufficient winning material. + For simplicity, it returns True if and only if neither side has pieces that can cross the river. + """ + if self.pawns == self.rooks == self.knights == self.cannons == BB_EMPTY: + return True + return False + + def is_halfmoves(self, n: int) -> bool: + return self.halfmove_clock >= n and any(self.generate_legal_moves()) + + def is_forty_moves(self) -> bool: + return self.is_halfmoves(80) + + def is_fifty_moves(self) -> bool: + return self.is_halfmoves(100) + + def is_sixty_moves(self) -> bool: + return self.is_halfmoves(120) + + def _transposition_key(self): + return (self.pawns, self.rooks, self.knights, self.bishops, + self.advisors, self.kings, self.cannons, + self.occupied_co[RED], self.occupied_co[BLACK], + self.turn) + + def is_irreversible(self, move: Move) -> bool: + return self.is_zeroing(move) + + def is_repetition(self, count: int = 3) -> bool: + """ + Checks if the current position has repeated 3 (or a given number of) + times. + + Note that checking this can be slow: In the worst case, the entire + game has to be replayed because there is no incremental transposition + table. + """ + # Fast check, based on occupancy only. + maybe_repetitions = 1 + for state in reversed(self._stack): + if state.occupied == self.occupied: + maybe_repetitions += 1 + if maybe_repetitions >= count: + break + if maybe_repetitions < count: + return False + + # Check full replay. + transposition_key = self._transposition_key() + switchyard = [] + + try: + while True: + if count <= 1: + return True + + if len(self.move_stack) < count - 1: + break + + move = self.pop() + switchyard.append(move) + + if self.is_irreversible(move): + break + + if self._transposition_key() == transposition_key: + count -= 1 + finally: + while switchyard: + self.push(switchyard.pop()) + + return False + + def is_perpetual_check(self) -> bool: + if not self.is_check(): + return False + if len(self._stack) <= 6: + return False + state = self._transposition_key() + oppo_is_perpetual_check = True + check_num = 1 + switchyard = [] + is_repetition = False + try: + move = self.pop() + switchyard.append(move) + if self.is_irreversible(move): + return False + while True: + if oppo_is_perpetual_check and not self.is_check(): + oppo_is_perpetual_check = False + switchyard.append(self.pop()) + if not self.is_check(): + return False + check_num += 1 + if not is_repetition and self._transposition_key() == state: + is_repetition = True + move = self.pop() + switchyard.append(move) + if self.is_irreversible(move): + return False + if check_num >= 4 and is_repetition and not oppo_is_perpetual_check: + return True + except IndexError: + return False + finally: + while switchyard: + self.push(switchyard.pop()) + + def is_sixfold_repetition(self) -> bool: + return self.is_repetition(6) + + def is_fivefold_repetition(self) -> bool: + return self.is_repetition(5) + + def is_fourfold_repetition(self) -> bool: + return self.is_repetition(4) + + def is_threefold_repetition(self) -> bool: + return self.is_repetition(3) + + def is_capture(self, move: Move) -> bool: + touched = BB_SQUARES[move.from_square] ^ BB_SQUARES[move.to_square] + return bool(touched & self.occupied_co[not self.turn]) + + def outcome(self) -> Optional[Outcome]: + """ + Checks if the game is over due to + :func:`checkmate `, + :func:`insufficient_material `, + :func:`stalemate `, + :func:`perpetual_check `, + the :func:`sixty-move rule `, + :func:`sixfold repetition `, + Returns the :class:`cchess.Outcome` if the game has ended, otherwise + ``None``. + + Alternatively, use :func:`~cchess.Board.is_game_over()` if you are not + interested in who won the game and why. + """ + + # Normal game end. + if self.is_checkmate(): + return Outcome(Termination.CHECKMATE, not self.turn) + if self.is_insufficient_material(): + return Outcome(Termination.INSUFFICIENT_MATERIAL, None) + if not any(self.generate_legal_moves()): + return Outcome(Termination.STALEMATE, not self.turn) + if self.is_perpetual_check(): # 单方长将 + return Outcome(Termination.PERPETUAL_CHECK, self.turn) + + # Automatic draws. + if self.is_fourfold_repetition(): + return Outcome(Termination.FOURFOLD_REPETITION, None) + if self.is_sixty_moves(): + return Outcome(Termination.SIXTY_MOVES, None) + + return None + + def result(self) -> str: + outcome = self.outcome() + return outcome.result() if outcome else "*" + + def is_game_over(self): + return self.outcome() is not None + + def status(self) -> Status: + """ + Gets a bitmask of possible problems with the position. + + :data:`~cchess.STATUS_VALID` if all basic validity requirements are met. + This does not imply that the position is actually reachable with a + series of legal moves from the starting position. + """ + errors = STATUS_VALID + + # There must be at least one piece. + if not self.occupied: + errors |= STATUS_EMPTY + + # There can not be more than 16 pieces of any color. + if popcount(self.occupied_co[RED]) > 16: + errors |= STATUS_TOO_MANY_RED_PIECES + if popcount(self.occupied_co[BLACK]) > 16: + errors |= STATUS_TOO_MANY_BLACK_PIECES + + # There must be exactly one king of each color. + if not self.occupied_co[RED] & self.kings: + errors |= STATUS_NO_RED_KING + if not self.occupied_co[BLACK] & self.kings: + errors |= STATUS_NO_BLACK_KING + # There can not be more than 1 king of any color. + if popcount(self.occupied_co[RED] & self.kings) > 1: + errors |= STATUS_TOO_MANY_RED_KINGS + if popcount(self.occupied_co[BLACK] & self.kings) > 1: + errors |= STATUS_TOO_MANY_BLACK_KINGS + + # Kings are placed on wrong places. + if self.kings & self.occupied_co[RED] & ~BB_PALACES[RED]: + errors |= STATUS_RED_KING_PLACE_WRONG + if self.kings & self.occupied_co[BLACK] & ~BB_PALACES[BLACK]: + errors |= STATUS_BLACK_KING_PLACE_WRONG + + # There can not be more than 5 pawns of any color. + if popcount(self.occupied_co[RED] & self.pawns) > 5: + errors |= STATUS_TOO_MANY_RED_PAWNS + if popcount(self.occupied_co[BLACK] & self.pawns) > 5: + errors |= STATUS_TOO_MANY_BLACK_PAWNS + + # Pawns are placed on wrong places. + if self.pawns & self.occupied_co[RED] & ~BB_PAWN_POS[RED]: + errors |= STATUS_RED_PAWNS_PLACE_WRONG + if self.pawns & self.occupied_co[BLACK] & ~BB_PAWN_POS[BLACK]: + errors |= STATUS_BLACK_PAWNS_PLACE_WRONG + + # There can not be more than 2 rooks of any color. + if popcount(self.occupied_co[RED] & self.rooks) > 2: + errors |= STATUS_TOO_MANY_RED_ROOKS + if popcount(self.occupied_co[BLACK] & self.rooks) > 2: + errors |= STATUS_TOO_MANY_BLACK_ROOKS + + # There can not be more than 2 knights of any color. + if popcount(self.occupied_co[RED] & self.knights) > 2: + errors |= STATUS_TOO_MANY_RED_KNIGHTS + if popcount(self.occupied_co[BLACK] & self.knights) > 2: + errors |= STATUS_TOO_MANY_BLACK_KNIGHTS + + # There can not be more than 2 bishops of any color. + if popcount(self.occupied_co[RED] & self.bishops) > 2: + errors |= STATUS_TOO_MANY_RED_BISHOPS + if popcount(self.occupied_co[BLACK] & self.bishops) > 2: + errors |= STATUS_TOO_MANY_BLACK_BISHOPS + + # Bishops are placed on wrong places. + if self.bishops & self.occupied_co[RED] & ~BB_BISHOP_POS[RED]: + errors |= STATUS_RED_BISHOPS_PLACE_WRONG + if self.bishops & self.occupied_co[BLACK] & ~BB_BISHOP_POS[BLACK]: + errors |= STATUS_BLACK_BISHOPS_PLACE_WRONG + + # There can not be more than 2 advisors of any color. + if popcount(self.occupied_co[RED] & self.advisors) > 2: + errors |= STATUS_TOO_MANY_RED_ADVISORS + if popcount(self.occupied_co[BLACK] & self.advisors) > 2: + errors |= STATUS_TOO_MANY_BLACK_ADVISORS + + # Advisors are placed on wrong places. + if self.advisors & self.occupied_co[RED] & ~BB_ADVISOR_POS[RED]: + errors |= STATUS_RED_ADVISORS_PLACE_WRONG + if self.advisors & self.occupied_co[BLACK] & ~BB_ADVISOR_POS[BLACK]: + errors |= STATUS_BLACK_ADVISORS_PLACE_WRONG + + # There can not be more than 2 cannons of any color. + if popcount(self.occupied_co[RED] & self.cannons) > 2: + errors |= STATUS_TOO_MANY_RED_CANNONS + if popcount(self.occupied_co[BLACK] & self.cannons) > 2: + errors |= STATUS_TOO_MANY_BLACK_CANNONS + + # Side to move giving check. + if self.was_into_check(): + errors |= STATUS_OPPOSITE_CHECK + + if self.is_king_line_of_sight(): + errors |= STATUS_KING_LINE_OF_SIGHT + + return errors + + def is_valid(self) -> bool: + """ + Checks some basic validity requirements. + + See :func:`~cchess.Board.status()` for details. + """ + return self.status() == STATUS_VALID + + def parse_uci(self, uci: str) -> Move: + move = Move.from_uci(uci) + + if not move: + return move + + if not self.is_legal(move): + raise ValueError(f"illegal uci: {uci!r} in {self.fen()!r}") + + return move + + def parse_notation(self, notation: str) -> Move: + assert len(notation) == 4, "记号的长度不为4" + notation = notation.translate(PIECE_SYMBOL_TRANSLATOR[self.turn]) + if notation in ADVISOR_BISHOP_MOVES_TRADITIONAL_TO_MODERN: + move = Move.from_uci(ADVISOR_BISHOP_MOVES_TRADITIONAL_TO_MODERN[notation]) + piece = self.piece_type_at(move.from_square) + if piece in [BISHOP, ADVISOR]: + return move + raise ValueError("未找到仕(士)或相(象)") + piece_notation = notation[:2] + direction_move_notation = notation[2:] + if piece_notation[0] in UNICODE_PIECE_SYMBOLS.values(): + piece = Piece.from_unicode(piece_notation[0]) + piece_type = piece.piece_type + color = piece.color + from_column_notation = piece_notation[1] + assert from_column_notation in COORDINATES_MODERN_TO_TRADITIONAL[ + color].values(), f"起始列记号错误: {from_column_notation!r}" + column_index = COORDINATES_TRADITIONAL_TO_MODERN[color][from_column_notation] + from_square = get_unique_piece_square(self, piece_type, color, piece_notation[0], column_index) + elif piece_notation[0] in ['前', '后']: + pawn_col = None + if piece_notation[1] in ['俥', '傌', '炮', '兵', + '車', '馬', '砲', '卒']: + piece = Piece.from_unicode(piece_notation[1]) + piece_type = piece.piece_type + color = piece.color + elif piece_notation[1] in CHINESE_NUMBERS: + piece_type = PAWN + color = RED + pawn_col = CHINESE_NUMBERS.index(piece_notation[1]) + elif piece_notation[1] in ARABIC_NUMBERS: + piece_type = PAWN + color = BLACK + pawn_col = ARABIC_NUMBERS.index(piece_notation[1]) + else: + raise ValueError(f"棋子种类记号错误: {piece_notation[1]!r}") + if piece_type != PAWN: + rank = ['前', '后'].index(piece_notation[0]) + from_square = get_double_piece_square(self, piece_type, color, piece_notation[1], rank) + else: + from_square = get_multiply_pawn_square(self, color, piece_notation[0], pawn_column=pawn_col) + elif piece_notation[0] in ['中', '二', '三', '四', '五']: + pawn_col = None + if piece_notation[1] in ['兵', '卒']: + color = piece_notation[1] == '兵' + elif piece_notation[1] in CHINESE_NUMBERS: + color = RED + pawn_col = CHINESE_NUMBERS.index(piece_notation[1]) + elif piece_notation[1] in ARABIC_NUMBERS: + color = BLACK + pawn_col = ARABIC_NUMBERS.index(piece_notation[1]) + else: + raise ValueError(f"棋子种类记号错误: {piece_notation[1]!r}") + piece_type = PAWN + from_square = get_multiply_pawn_square(self, color, piece_notation[0], pawn_column=pawn_col) + else: + raise ValueError(f'记号首字符错误: {piece_notation[0]!r}') + direction = direction_move_notation[0] + if direction == '平': + assert piece_type in [ROOK, CANNON, PAWN, KING], "只有俥(車)、炮(砲)、兵(卒)、帥(將)可以使用移动方向“平”" + to_column_notation = direction_move_notation[1] + from_row = square_row(from_square) + from_column = square_column(from_square) + assert to_column_notation in COORDINATES_MODERN_TO_TRADITIONAL[ + color].values(), f"到达列记号错误: {to_column_notation!r}" + to_column = COORDINATES_TRADITIONAL_TO_MODERN[color][to_column_notation] + assert from_column != to_column, "使用“平”时,不能移动到同一列上。" + return Move(from_square, square(to_column, from_row)) + elif direction in ['进', '退']: + move = direction_move_notation[1] + if piece_type in [ROOK, CANNON, PAWN, KING]: + if color: + assert move in CHINESE_NUMBERS, f"前进、后退步数错误: {move!r}" + move = VERTICAL_MOVE_CHINESE_TO_ARABIC[move] + else: + assert move in ARABIC_NUMBERS, f"前进、后退步数错误: {move!r}" + if color ^ (direction == '退'): + to_square = from_square + 9 * int(move) + else: + to_square = from_square - 9 * int(move) + return Move(from_square, to_square) + assert piece_type == KNIGHT # 只需要额外处理马的情况 + assert move in COORDINATES_MODERN_TO_TRADITIONAL[color].values(), f"到达列记号错误: {move!r}" + to_column = COORDINATES_TRADITIONAL_TO_MODERN[color][move] + to_squares = _knight_attacks(from_square, BB_EMPTY) + for to_square in scan_forward(to_squares & BB_COLUMNS[to_column]): + if color ^ (direction == '退'): + if to_square > from_square: + return Move(from_square, to_square) + else: + if to_square < from_square: + return Move(from_square, to_square) + else: + raise ValueError(f"{piece_notation[0]!r}的到达位置错误!") + else: + raise ValueError(f'方向记号错误: {direction!r}') + + def move_to_notation(self, move: Move): + from_square, to_square = move.from_square, move.to_square + piece = self.piece_at(from_square) + if not piece: + return "" + if from_square == to_square: + return "" + piece_type = piece.piece_type + if piece_type in [BISHOP, ADVISOR]: + uci = move.uci() + assert uci in ADVISOR_BISHOP_MOVES_MODERN_TO_TRADITIONAL, "仕(士)、相(象)着法错误" + return ADVISOR_BISHOP_MOVES_MODERN_TO_TRADITIONAL[uci] + from_column = square_column(from_square) + from_row = square_row(from_square) + to_column = square_column(to_square) + to_row = square_row(to_square) + symbol = piece.unicode_symbol() + color = piece.color + if piece_type == KING: + column_notation = COORDINATES_MODERN_TO_TRADITIONAL[color][from_column] + piece_notation = symbol + column_notation + if from_row == to_row: + direction_notation = '平' + move_notation = COORDINATES_MODERN_TO_TRADITIONAL[color][to_column] + else: + direction_notation = TRADITIONAL_VERTICAL_DIRECTION[color][to_row > from_row] + move_notation = str(abs(to_row - from_row)) + if color: + move_notation = VERTICAL_MOVE_ARABIC_TO_CHINESE[move_notation] + elif piece_type in [ROOK, CANNON]: + bb_pieces = self.rooks if piece_type == ROOK else self.cannons + same = bb_pieces & self.occupied_co[color] & BB_COLUMNS[from_column] & ~BB_SQUARES[from_square] + if same == 0: + column_notation = COORDINATES_MODERN_TO_TRADITIONAL[color][from_column] + piece_notation = symbol + column_notation + else: + same_square = msb(same) + same_row = square_row(same_square) + piece_notation = TRADITIONAL_VERTICAL_POS[color][from_row > same_row] + symbol + if from_row == to_row: + direction_notation = '平' + move_notation = COORDINATES_MODERN_TO_TRADITIONAL[color][to_column] + else: + direction_notation = TRADITIONAL_VERTICAL_DIRECTION[color][to_row > from_row] + move_notation = str(abs(to_row - from_row)) + if color: + move_notation = VERTICAL_MOVE_ARABIC_TO_CHINESE[move_notation] + elif piece_type == KNIGHT: + if piece_type == KNIGHT: + bb_pieces = self.knights + elif piece_type == BISHOP: + bb_pieces = self.bishops + else: + bb_pieces = self.advisors + same = bb_pieces & self.occupied_co[color] & BB_COLUMNS[from_column] & ~BB_SQUARES[from_square] + if same == 0: + column_notation = COORDINATES_MODERN_TO_TRADITIONAL[color][from_column] + piece_notation = symbol + column_notation + else: + same_square = msb(same) + same_row = square_row(same_square) + piece_notation = TRADITIONAL_VERTICAL_POS[color][from_row > same_row] + symbol + direction_notation = TRADITIONAL_VERTICAL_DIRECTION[color][to_row > from_row] + move_notation = COORDINATES_MODERN_TO_TRADITIONAL[color][to_column] + else: + pawns = self.pawns & self.occupied_co[color] + same = pawns & BB_COLUMNS[from_column] & ~BB_SQUARES[from_square] + if color: + front_count = len(list(filter(lambda s: s > from_square, scan_forward(same)))) + else: + front_count = len(list(filter(lambda s: s < from_square, scan_forward(same)))) + count = popcount(same) + if count == 0: + column_notation = COORDINATES_MODERN_TO_TRADITIONAL[color][from_column] + piece_notation = symbol + column_notation + elif count == 1: + other_columns_gt_one = any([popcount(BB_COLUMNS[col] & pawns) >= 2 + for col in range(9) if col != from_column]) + if not other_columns_gt_one: + piece_notation = ['前', '后'][front_count] + symbol + else: + piece_notation = ['前', '后'][front_count] + COORDINATES_MODERN_TO_TRADITIONAL[color][from_column] + elif count == 2: + other_columns_gt_one = any([popcount(BB_COLUMNS[col] & pawns) >= 2 + for col in range(9) if col != from_column]) + if not other_columns_gt_one: + piece_notation = ['前', '中', '后'][front_count] + symbol + else: + piece_notation = ['前', '中', '后'][front_count] + COORDINATES_MODERN_TO_TRADITIONAL[color][ + from_column] + elif count == 3: + piece_notation = ['前', '二', '三', '四'][front_count] + symbol + else: + piece_notation = ['前', '二', '三', '四', '五'][front_count] + symbol + if from_row == to_row: + direction_notation = '平' + move_notation = COORDINATES_MODERN_TO_TRADITIONAL[color][to_column] + else: + direction_notation = TRADITIONAL_VERTICAL_DIRECTION[color][to_row > from_row] + move_notation = str(abs(to_row - from_row)) + if color: + move_notation = VERTICAL_MOVE_ARABIC_TO_CHINESE[move_notation] + return "".join([piece_notation, direction_notation, move_notation]) + + def to_pgn(self, *, red="", black="", format="Chinese", generator="Python-Chinese-Chess"): + if format not in ['Chinese', 'ICCS']: + warnings.warn(f"Unsupported Format: {format!r}, Use default 'Chinese'.") + format = 'Chinese' + board = Board() + pgn = ["""[Game "Chinese Chess"]""", f"""[Round: "{self.fullmove_number}"]""", + f"""[PlyCount "{self.ply()}"]""", + f"""[Date "{datetime.datetime.today().strftime("%Y-%m-%d")}"]""", + f"""[Red "{red}"]""", + f"""[Black "{black}"]""", + f"""[Generator "{generator}"]""", + f"""[Format "{format}"]"""] + outcome = self.outcome() + result = outcome.result() if outcome else "" + pgn.extend([f"""[Result "{result}"]""", f"""[FEN "{self._starting_fen}"]"""]) + notations = "" + turn = board.turn + stack = copy.copy(self._stack) + stack.append(self._board_state()) + for i, (move, state) in enumerate(zip(self.move_stack, stack)): + state.restore(board) + if board.turn == turn: + notations += f"{i // 2 + 1}." + if format == 'Chinese': + notations += board.move_to_notation(move) + elif format == 'ICCS': + iccs_move = move.uci().upper() + notations += iccs_move[:2] + '-' + iccs_move[2:] + if board.turn == turn: + notations += " " + else: + notations += "\n" + i += 1 + pgn.append(notations[:-1]) + if result: + if outcome.winner is not None: + pgn.append(result + " {%s胜}" % COLOR_NAMES_CN[outcome.winner]) + else: + pgn.append(result + " {和棋}") + return "\n".join(pgn) + + @classmethod + def from_pgn(cls, pgn_file: str, *, + to_gif=False, gif_file=None, duration=2, + to_html=False, html_file=None): + try: + with open(pgn_file, 'r') as f: + data = f.read() + except UnicodeDecodeError: + with open(pgn_file, 'r', encoding='gbk') as f: + data = f.read() + fen = re.search("\\[FEN \"(.+)\"\\]", data) + if fen: + end = fen.end() + fen = fen.groups()[0] + else: + warnings.warn("No FEN string found! Use default starting fen.") + fen = STARTING_FEN + end = - 1 + format = re.search("\\[Format \"(.+)\"\\]", data) + if format: + format = format.groups()[0] + if format not in ['Chinese', 'ICCS']: + warnings.warn(f"Unsupported Format: {format!r}, Use default 'Chinese'.") + format = 'Chinese' + else: + format = 'Chinese' + board = cls(fen=fen) + move_lines = data[end + 1:] + move_lines = re.sub("{(?:.|\n)*?}", "", move_lines) + if format == 'Chinese': + move_lines = move_lines.translate(str.maketrans("123456789", "123456789")) + notations = re.findall("(?:(?:[兵卒车俥車马馬傌炮砲仕士象相帅帥将將][1-9一二三四五六七八九])|" + "(?:[前后][车俥車马馬傌炮砲])|" + "(?:[前中后一二三四五][兵卒1-9一二三四五六七八九]))" + "[进退平][1-9一二三四五六七八九]", move_lines) + if not notations: + raise ValueError("Find no legal notations!") + for notation in notations: + board.push_notation(notation) + elif format == 'ICCS': + moves = re.findall("[a-i]\\d-[a-i]\\d", move_lines.lower()) + for move in moves: + board.push_uci(move.replace('-', '')) + filename = pgn_file[:pgn_file.rfind('.')] + if to_gif: + import cchess.svg + gif_file = gif_file or f'{filename}.gif' + cchess.svg.to_gif(board, filename=gif_file, axes_type=1, duration=duration) + print(f"GIF generated: {gif_file!r}") + if to_html: + import cchess.svg + title = re.search("\\[Event \"(.+)\"\\]", data) + if title: + title = title.groups()[0] + html_file = html_file or f'{filename}.html' + cchess.svg.to_html(board, filename=html_file, title=title) + print(f"HTML generated: {html_file!r}") + return board + + +def get_unique_piece_square(board: Board, piece_type, color, piece_unicode, column_index): + pieces = [None, board.pawns, board.rooks, board.knights, + None, None, board.kings, board.cannons][piece_type] + pieces = board.occupied_co[color] & pieces & BB_COLUMNS[column_index] + assert popcount(pieces) == 1, f"该列上对应棋子{piece_unicode!r}的数量有误" + return msb(pieces) + + +def get_double_piece_square(board: Board, piece_type, color, piece_unicode, rank): + pieces = [None, None, board.rooks, board.knights, + None, None, None, board.cannons][piece_type] + pieces = board.occupied_co[color] & pieces + for column in BB_COLUMNS: + column_pieces = pieces & column + if popcount(column_pieces) == 2: + break + else: + raise ValueError(f"未找到存在两个{piece_unicode!r}的合适列") + pieces = list(SquareSet(pieces)) + if color: + return pieces[1 - rank] + return pieces[rank] + + +def get_multiply_pawn_square(board: Board, color, rank_notation, pawn_column=None): + pawns = board.pawns & board.occupied_co[color] + pawn_nums = [popcount(col & pawns) for col in BB_COLUMNS] + multi_pawns_col_number = len(list(filter(lambda x: x >= 2, pawn_nums))) + if multi_pawns_col_number == 0: + raise ValueError("未找到存在多个兵(卒)的列") + if multi_pawns_col_number > 1 and pawn_column is None: + # 可能是新的兵(卒)记法 + count = ['一', '二', '三', '四', '五'].index(rank_notation) + for i, num in enumerate(reversed(pawn_nums) if color else pawn_nums): + if num >= 2: + if count >= num: + count -= num + else: + pawn_column = 8 - i if color else i + rank_notation = (['前', '后'] if num == 2 else ['前', '中', '后'])[count] + break + else: + raise ValueError("旧记法:记号存在歧义(未指明兵(卒)所在列) 或 新记法:记号中兵(卒)的数量超出实际兵(卒)的数量") + if multi_pawns_col_number == 1 and pawn_column is not None: + raise ValueError("记号不规范(无需指明列号)") + if rank_notation == '前': + if pawn_column is not None: + i = pawn_column + else: + for i, num in enumerate(pawn_nums): + if num >= 2: + break + pawns = list(SquareSet(pawns & BB_COLUMNS[i])) + if color: + return pawns[-1] + return pawns[0] + elif rank_notation == '后': # 有一列存在两个或三个兵 + if pawn_column is not None: + if pawn_nums[pawn_column] not in [2, 3]: + raise ValueError("该列上的兵(卒)数量不为2或3") + i = pawn_column + else: + for i, num in enumerate(pawn_nums): + if num in [2, 3]: + break + else: + raise ValueError("未找到存在2或3个兵(卒)的列") + pawns = list(SquareSet(pawns & BB_COLUMNS[i])) + if color: + return pawns[0] + return pawns[-1] + elif rank_notation == '中': # 有一列存在三个兵 + if pawn_column is not None: + if pawn_nums[pawn_column] != 3: + raise ValueError("该列上的兵(卒)数量不为3") + i = pawn_column + else: + for i, num in enumerate(pawn_nums): + if num == 3: + break + else: + raise ValueError("未找到兵(卒)数量为3的列") + pawns = list(SquareSet(pawns & BB_COLUMNS[i])) + return pawns[1] + elif rank_notation in ['二', '三', '四']: # 有一列兵数量不小于4 + for i, num in enumerate(pawn_nums): + if num >= 4: + break + else: + raise ValueError("未找到兵(卒)数量为4或5的列") + pawns = list(SquareSet(pawns & BB_COLUMNS[i])) + index = ['二', '三', '四'].index(rank_notation) + if color: + return pawns[-2 - index] + return pawns[1 + index] + elif rank_notation == '五': # 有一列存在五个兵 + for i, num in enumerate(pawn_nums): + if num == 5: + break + else: + raise ValueError("未找到兵(卒)数量为5的列") + pawns = list(SquareSet(pawns & BB_COLUMNS[i])) + if color: + return pawns[0] + return pawns[-1] + + +IntoSquareSet = Union[SupportsInt, Iterable[Square]] + + +def scan_forward(bb: BitBoard) -> Iterator[Square]: + while bb: + r = bb & -bb + yield r.bit_length() - 1 + bb ^= r + + +def scan_reversed(bb: BitBoard) -> Iterator[Square]: + while bb: + r = bb.bit_length() - 1 + yield r + bb ^= BB_SQUARES[r] + + +def popcount(x: BitBoard) -> int: + """ + 计算 BitBoard 中 1 的个数 + Python 3.10+ 原生 bit_count() 比 bin().count('1') 快 10+ 倍 + """ + return x.bit_count() + + +class LegalMoveGenerator: + + def __init__(self, board: Board) -> None: + self.board = board + + def __bool__(self) -> bool: + return any(self.board.generate_legal_moves()) + + def count(self) -> int: + # List conversion is faster than iterating. + return len(list(self)) + + def __iter__(self) -> Iterator[Move]: + return self.board.generate_legal_moves() + + def __contains__(self, move: Move) -> bool: + return self.board.is_legal(move) + + def __repr__(self) -> str: + sans = ", ".join(move.uci() for move in self) + return f"" + + +class PseudoLegalMoveGenerator: + + def __init__(self, board: Board) -> None: + self.board = board + + def __bool__(self) -> bool: + return any(self.board.generate_pseudo_legal_moves()) + + def count(self) -> int: + # List conversion is faster than iterating. + return len(list(self)) + + def __iter__(self) -> Iterator[Move]: + return self.board.generate_pseudo_legal_moves() + + def __contains__(self, move: Move) -> bool: + return self.board.is_pseudo_legal(move) + + def __repr__(self) -> str: + builder = [] + + for move in self: + builder.append(move.uci()) + + sans = ", ".join(builder) + return f"" + + +class SquareSet: + + def __init__(self, squares: IntoSquareSet = BB_EMPTY): + try: + self.mask = int(squares) & BB_ALL # type: ignore + return + except TypeError: + self.mask = 0 + for square in squares: + self.add(square) + + def __contains__(self, square: Square) -> bool: + return bool(BB_SQUARES[square] & self.mask) + + def __iter__(self) -> Iterator[Square]: + return scan_forward(self.mask) + + def __reversed__(self) -> Iterator[Square]: + return scan_reversed(self.mask) + + def __len__(self) -> int: + return popcount(self.mask) + + def __repr__(self) -> str: + return f"SquareSet({self.mask:#x})" + + def __sub__(self, other: IntoSquareSet): + r = SquareSet(other) + r.mask = self.mask & ~r.mask + return r + + def __isub__(self, other: IntoSquareSet): + self.mask &= ~SquareSet(other).mask + return self + + def __or__(self, other: IntoSquareSet): + r = SquareSet(other) + r.mask |= self.mask + return r + + def __ior__(self, other: IntoSquareSet): + self.mask |= SquareSet(other).mask + return self + + def __and__(self, other: IntoSquareSet): + r = SquareSet(other) + r.mask &= self.mask + return r + + def __iand__(self, other: IntoSquareSet): + self.mask &= SquareSet(other).mask + return self + + def __xor__(self, other: IntoSquareSet): + r = SquareSet(other) + r.mask ^= self.mask + return r + + def __ixor__(self, other: IntoSquareSet): + self.mask ^= SquareSet(other).mask + return self + + def __invert__(self): + return SquareSet(~self.mask & BB_ALL) + + def __lshift__(self, shift: int): + return SquareSet((self.mask << shift) & BB_ALL) + + def __rshift__(self, shift: int): + return SquareSet(self.mask >> shift) + + def __ilshift__(self, shift: int): + self.mask = (self.mask << shift) & BB_ALL + return self + + def __irshift__(self, shift: int): + self.mask >>= shift + return self + + def __int__(self) -> int: + return self.mask + + def __index__(self) -> int: + return self.mask + + def __eq__(self, other: IntoSquareSet) -> bool: + try: + return self.mask == SquareSet(other).mask + except (TypeError, ValueError): + return NotImplemented + + def __str__(self) -> str: + builder = [] + + for square in SQUARES_180: + mask = BB_SQUARES[square] + builder.append("1" if self.mask & mask else ".") + + if not mask & BB_COLUMN_I: + builder.append(" ") + elif square != I0: + builder.append("\n") + + return "".join(builder) + + def add(self, square: Square): + """Adds a square to the set.""" + self.mask |= BB_SQUARES[square] + + def discard(self, square: Square): + """Discards a square from the set.""" + self.mask &= ~BB_SQUARES[square] + + def isdisjoint(self, other: IntoSquareSet) -> bool: + """Tests if the square sets are disjoint.""" + return not bool(self & other) + + def issubset(self, other: IntoSquareSet) -> bool: + """Tests if this square set is a subset of another.""" + return not bool(self & ~SquareSet(other)) + + def issuperset(self, other: IntoSquareSet) -> bool: + """Tests if this square set is a superset of another.""" + return not bool(~self & other) + + def union(self, other: IntoSquareSet): + return self | other + + def intersection(self, other: IntoSquareSet): + return self & other + + def difference(self, other: IntoSquareSet): + return self - other + + def symmetric_difference(self, other: IntoSquareSet): + return self ^ other + + def update(self, *others: IntoSquareSet): + for other in others: + self |= other + + def intersection_update(self, *others: IntoSquareSet): + for other in others: + self &= other + + def difference_update(self, other: IntoSquareSet): + self -= other + + def symmetric_difference_update(self, other: IntoSquareSet): + self ^= other + + def copy(self): + return SquareSet(self.mask) + + def remove(self, square: Square) -> None: + """ + Removes a square from the set. + + :raises: :exc:`KeyError` if the given *square* was not in the set. + """ + mask = BB_SQUARES[square] + if self.mask & mask: + self.mask ^= mask + else: + raise KeyError(square) + + def pop(self) -> Square: + """ + Removes and returns a square from the set. + + :raises: :exc:`KeyError` if the set is empty. + """ + if not self.mask: + raise KeyError("pop from empty SquareSet") + + square = lsb(self.mask) + self.mask &= (self.mask - 1) + return square + + def clear(self): + """Removes all elements from this set.""" + self.mask = BB_EMPTY + + def tolist(self) -> List[bool]: + """Converts the set to a list of 90 bools.""" + result = [False] * 90 + for square in self: + result[square] = True + return result + + @classmethod + def from_square(cls, square: Square): + return cls(BB_SQUARES[square]) diff --git a/zoo/board_games/chinesechess/envs/cchess/engine.py b/zoo/board_games/chinesechess/envs/cchess/engine.py new file mode 100644 index 000000000..86af564b5 --- /dev/null +++ b/zoo/board_games/chinesechess/envs/cchess/engine.py @@ -0,0 +1,3131 @@ +from __future__ import annotations + +import abc +import asyncio +import collections +import concurrent.futures +import contextlib +import copy +import dataclasses +import enum +import logging +import math +import shlex +import subprocess +import sys +import threading +import time +import typing +import re + +import cchess + +from cchess import Color +from types import TracebackType +from typing import Any, Callable, Coroutine, Deque, Dict, Generator, Generic, Iterable, Iterator, List, Literal, Mapping, MutableMapping, Optional, Tuple, Type, TypedDict, TypeVar, Union + +if typing.TYPE_CHECKING: + from typing_extensions import override +else: + F = typing.TypeVar("F", bound=Callable[..., Any]) + def override(fn: F, /) -> F: + return fn + +if typing.TYPE_CHECKING: + from typing_extensions import Self + +WdlModel = Literal["sf", "sf16.1", "sf16", "sf15.1", "sf15", "sf14", "sf12", "licchess"] + + +T = TypeVar("T") +ProtocolT = TypeVar("ProtocolT", bound="Protocol") + +ConfigValue = Union[str, int, bool, None] +ConfigMapping = Mapping[str, ConfigValue] + + +LOGGER = logging.getLogger(__name__) + + +MANAGED_OPTIONS = ["uci_cchess960", "uci_variant", "multipv", "ponder"] + + +# No longer needed, but alias kept around for compatibility. +EventLoopPolicy = asyncio.DefaultEventLoopPolicy + + +def run_in_background(coroutine: Callable[[concurrent.futures.Future[T]], Coroutine[Any, Any, None]], *, name: Optional[str] = None, debug: Optional[bool] = None) -> T: + """ + Runs ``coroutine(future)`` in a new event loop on a background thread. + + Blocks on *future* and returns the result as soon as it is resolved. + The coroutine and all remaining tasks continue running in the background + until complete. + """ + assert asyncio.iscoroutinefunction(coroutine) + + future: concurrent.futures.Future[T] = concurrent.futures.Future() + + def background() -> None: + try: + asyncio.run(coroutine(future), debug=debug) + future.cancel() + except Exception as exc: + future.set_exception(exc) + + threading.Thread(target=background, name=name).start() + return future.result() + + +class EngineError(RuntimeError): + """Runtime error caused by a misbehaving engine or incorrect usage.""" + + +class EngineTerminatedError(EngineError): + """The engine process exited unexpectedly.""" + + +class AnalysisComplete(Exception): + """ + Raised when analysis is complete, all information has been consumed, but + further information was requested. + """ + + +@dataclasses.dataclass(frozen=True) +class Option: + """Information about an available engine option.""" + + name: str + """The name of the option.""" + + type: str + """ + The type of the option. + + +--------+-----+------+------------------------------------------------+ + | type | UCI | CECP | value | + +========+=====+======+================================================+ + | check | X | X | ``True`` or ``False`` | + +--------+-----+------+------------------------------------------------+ + | spin | X | X | integer, between *min* and *max* | + +--------+-----+------+------------------------------------------------+ + | combo | X | X | string, one of *var* | + +--------+-----+------+------------------------------------------------+ + | button | X | X | ``None`` | + +--------+-----+------+------------------------------------------------+ + | reset | | X | ``None`` | + +--------+-----+------+------------------------------------------------+ + | save | | X | ``None`` | + +--------+-----+------+------------------------------------------------+ + | string | X | X | string without line breaks | + +--------+-----+------+------------------------------------------------+ + | file | | X | string, interpreted as the path to a file | + +--------+-----+------+------------------------------------------------+ + | path | | X | string, interpreted as the path to a directory | + +--------+-----+------+------------------------------------------------+ + """ + + default: ConfigValue + """The default value of the option.""" + + min: Optional[int] + """The minimum integer value of a *spin* option.""" + + max: Optional[int] + """The maximum integer value of a *spin* option.""" + + var: Optional[List[str]] + """A list of allowed string values for a *combo* option.""" + + def parse(self, value: ConfigValue) -> ConfigValue: + if self.type == "check": + return value and value != "false" + elif self.type == "spin": + try: + value = int(value) # type: ignore + except ValueError: + raise EngineError(f"expected integer for spin option {self.name!r}, got: {value!r}") + if self.min is not None and value < self.min: + raise EngineError(f"expected value for option {self.name!r} to be at least {self.min}, got: {value}") + if self.max is not None and self.max < value: + raise EngineError(f"expected value for option {self.name!r} to be at most {self.max}, got: {value}") + return value + elif self.type == "combo": + value = str(value) + if value not in (self.var or []): + raise EngineError("invalid value for combo option {!r}, got: {} (available: {})".format(self.name, value, ", ".join(self.var) if self.var else "-")) + return value + elif self.type in ["button", "reset", "save"]: + return None + elif self.type in ["string", "file", "path"]: + value = str(value) + if "\n" in value or "\r" in value: + raise EngineError(f"invalid line-break in string option {self.name!r}: {value!r}") + return value + else: + raise EngineError(f"unknown option type: {self.type!r}") + + def is_managed(self) -> bool: + """ + Some options are managed automatically: ``UCI_cchess960``, + ``UCI_Variant``, ``MultiPV``, ``Ponder``. + """ + return self.name.lower() in MANAGED_OPTIONS + + +@dataclasses.dataclass +class Limit: + """Search-termination condition.""" + + time: Optional[float] = None + """Search exactly *time* seconds.""" + + depth: Optional[int] = None + """Search *depth* ply only.""" + + nodes: Optional[int] = None + """Search only a limited number of *nodes*.""" + + mate: Optional[int] = None + """Search for a mate in *mate* moves.""" + + red_clock: Optional[float] = None + """Time in seconds remaining for Red.""" + + black_clock: Optional[float] = None + """Time in seconds remaining for Black.""" + + red_inc: Optional[float] = None + """Fisher increment for Red, in seconds.""" + + black_inc: Optional[float] = None + """Fisher increment for Black, in seconds.""" + + remaining_moves: Optional[int] = None + """ + Number of moves to the next time control. If this is not set, but + *red_clock* and *black_clock* are, then it is sudden death. + """ + + clock_id: object = None + """ + An identifier to use with XBoard engines to signal that the time + control has changed. When this field changes, Xboard engines are + sent level or st commands as appropriate. Otherwise, only time + and otim commands are sent to update the engine's clock. + """ + + def __repr__(self) -> str: + # Like default __repr__, but without None values. + return "{}({})".format( + type(self).__name__, + ", ".join("{}={!r}".format(attr, getattr(self, attr)) + for attr in ["time", "depth", "nodes", "mate", "red_clock", "black_clock", "red_inc", "black_inc", "remaining_moves"] + if getattr(self, attr) is not None)) + + +class InfoDict(TypedDict, total=False): + """ + Dictionary of aggregated information sent by the engine. + + Commonly used keys are: ``score`` (a :class:`~cchess.engine.PovScore`), + ``pv`` (a list of :class:`~cchess.Move` objects), ``depth``, + ``seldepth``, ``time`` (in seconds), ``nodes``, ``nps``, ``multipv`` + (``1`` for the mainline). + + Others: ``tbhits``, ``currmove``, ``currmovenumber``, ``hashfull``, + ``cpuload``, ``refutation``, ``currline``, ``ebf`` (effective branching factor), + ``wdl`` (a :class:`~cchess.engine.PovWdl`), and ``string``. + """ + score: PovScore + pv: List[cchess.Move] + depth: int + seldepth: int + time: float + nodes: int + nps: int + tbhits: int + multipv: int + currmove: cchess.Move + currmovenumber: int + hashfull: int + cpuload: int + refutation: Dict[cchess.Move, List[cchess.Move]] + currline: Dict[int, List[cchess.Move]] + ebf: float + wdl: PovWdl + string: str + + +class PlayResult: + """Returned by :func:`cchess.engine.Protocol.play()`.""" + + move: Optional[cchess.Move] + """The best move according to the engine, or ``None``.""" + + ponder: Optional[cchess.Move] + """The response that the engine expects after *move*, or ``None``.""" + + info: InfoDict + """ + A dictionary of extra :class:`information ` + sent by the engine, if selected with the *info* argument of + :func:`~cchess.engine.Protocol.play()`. + """ + + draw_offered: bool + """Whether the engine offered a draw before moving.""" + + resigned: bool + """Whether the engine resigned.""" + + def __init__(self, + move: Optional[cchess.Move], + ponder: Optional[cchess.Move], + info: Optional[InfoDict] = None, + *, + draw_offered: bool = False, + resigned: bool = False) -> None: + self.move = move + self.ponder = ponder + self.info = info or {} + self.draw_offered = draw_offered + self.resigned = resigned + + def __repr__(self) -> str: + return "<{} at {:#x} (move={}, ponder={}, info={}, draw_offered={}, resigned={})>".format( + type(self).__name__, id(self), self.move, self.ponder, self.info, + self.draw_offered, self.resigned) + + +class Info(enum.IntFlag): + """Used to filter information sent by the cchess engine.""" + NONE = 0 + BASIC = 1 + SCORE = 2 + PV = 4 + REFUTATION = 8 + CURRLINE = 16 + ALL = BASIC | SCORE | PV | REFUTATION | CURRLINE + +INFO_NONE = Info.NONE +INFO_BASIC = Info.BASIC +INFO_SCORE = Info.SCORE +INFO_PV = Info.PV +INFO_REFUTATION = Info.REFUTATION +INFO_CURRLINE = Info.CURRLINE +INFO_ALL = Info.ALL + + +@dataclasses.dataclass +class Opponent: + """Used to store information about an engine's opponent.""" + + name: Optional[str] + """The name of the opponent.""" + + title: Optional[str] + """The opponent's title--for example, GM, IM, or BOT.""" + + rating: Optional[int] + """The opponent's ELO rating.""" + + is_engine: Optional[bool] + """Whether the opponent is a cchess engine/computer program.""" + + +class PovScore: + """A relative :class:`~cchess.engine.Score` and the point of view.""" + + relative: Score + """The relative :class:`~cchess.engine.Score`.""" + + turn: Color + """The point of view (``cchess.RED`` or ``cchess.BLACK``).""" + + def __init__(self, relative: Score, turn: Color) -> None: + self.relative = relative + self.turn = turn + + def red(self) -> Score: + """Gets the score from Red's point of view.""" + return self.pov(cchess.RED) + + def black(self) -> Score: + """Gets the score from Black's point of view.""" + return self.pov(cchess.BLACK) + + def pov(self, color: Color) -> Score: + """Gets the score from the point of view of the given *color*.""" + return self.relative if self.turn == color else -self.relative + + def is_mate(self) -> bool: + """Tests if this is a mate score.""" + return self.relative.is_mate() + + def wdl(self, *, model: WdlModel = "sf", ply: int = 30) -> PovWdl: + """See :func:`~cchess.engine.Score.wdl()`.""" + return PovWdl(self.relative.wdl(model=model, ply=ply), self.turn) + + def __repr__(self) -> str: + return "PovScore({!r}, {})".format(self.relative, "RED" if self.turn else "BLACK") + + def __eq__(self, other: object) -> bool: + if isinstance(other, PovScore): + return self.red() == other.red() + else: + return NotImplemented + + +class Score(abc.ABC): + """ + Evaluation of a position. + + The score can be :class:`~cchess.engine.Cp` (centi-pawns), + :class:`~cchess.engine.Mate` or :py:data:`~cchess.engine.MateGiven`. + A positive value indicates an advantage. + + There is a total order defined on centi-pawn and mate scores. + + >>> from cchess.engine import Cp, Mate, MateGiven + >>> + >>> Mate(-0) < Mate(-1) < Cp(-50) < Cp(200) < Mate(4) < Mate(1) < MateGiven + True + + Scores can be negated to change the point of view: + + >>> -Cp(20) + Cp(-20) + + >>> -Mate(-4) + Mate(+4) + + >>> -Mate(0) + MateGiven + """ + + @typing.overload + @abc.abstractmethod + def score(self, *, mate_score: int) -> int: ... + @typing.overload + @abc.abstractmethod + def score(self, *, mate_score: Optional[int] = None) -> Optional[int]: ... + @abc.abstractmethod + def score(self, *, mate_score: Optional[int] = None) -> Optional[int]: + """ + Returns the centi-pawn score as an integer or ``None``. + + You can optionally pass a large value to convert mate scores to + centi-pawn scores. + + >>> Cp(-300).score() + -300 + >>> Mate(5).score() is None + True + >>> Mate(5).score(mate_score=100000) + 99995 + """ + + @abc.abstractmethod + def mate(self) -> Optional[int]: + """ + Returns the number of plies to mate, negative if we are getting + mated, or ``None``. + + .. warning:: + This conflates ``Mate(0)`` (we lost) and ``MateGiven`` + (we won) to ``0``. + """ + + def is_mate(self) -> bool: + """Tests if this is a mate score.""" + return self.mate() is not None + + @abc.abstractmethod + def wdl(self, *, model: WdlModel = "sf", ply: int = 30) -> Wdl: + """ + Returns statistics for the expected outcome of this game, based on + a *model*, given that this score is reached at *ply*. + + Scores have a total order, but it makes little sense to compute + the difference between two scores. For example, going from + ``Cp(-100)`` to ``Cp(+100)`` is much more significant than going + from ``Cp(+300)`` to ``Cp(+500)``. It is better to compute differences + of the expectation values for the outcome of the game (based on winning + chances and drawing chances). + + >>> Cp(100).wdl().expectation() - Cp(-100).wdl().expectation() # doctest: +ELLIPSIS + 0.379... + + >>> Cp(500).wdl().expectation() - Cp(300).wdl().expectation() # doctest: +ELLIPSIS + 0.015... + + :param model: + * ``sf``, the WDL model used by the latest Stockfish + (currently ``sf16``). + * ``sf16``, the WDL model used by Stockfish 16. + * ``sf15.1``, the WDL model used by Stockfish 15.1. + * ``sf15``, the WDL model used by Stockfish 15. + * ``sf14``, the WDL model used by Stockfish 14. + * ``sf12``, the WDL model used by Stockfish 12. + * ``licchess``, the win rate model used by Licchess. + Does not use *ply*, and does not consider drawing chances. + :param ply: The number of half-moves played since the starting + position. Models may scale scores slightly differently based on + this. Defaults to middle game. + """ + + @abc.abstractmethod + def __neg__(self) -> Score: + ... + + @abc.abstractmethod + def __pos__(self) -> Score: + ... + + @abc.abstractmethod + def __abs__(self) -> Score: + ... + + def _score_tuple(self) -> Tuple[bool, bool, bool, int, Optional[int]]: + mate = self.mate() + return ( + isinstance(self, MateGivenType), + mate is not None and mate > 0, + mate is None, + -(mate or 0), + self.score(), + ) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Score): + return self._score_tuple() == other._score_tuple() + else: + return NotImplemented + + def __lt__(self, other: object) -> bool: + if isinstance(other, Score): + return self._score_tuple() < other._score_tuple() + else: + return NotImplemented + + def __le__(self, other: object) -> bool: + if isinstance(other, Score): + return self._score_tuple() <= other._score_tuple() + else: + return NotImplemented + + def __gt__(self, other: object) -> bool: + if isinstance(other, Score): + return self._score_tuple() > other._score_tuple() + else: + return NotImplemented + + def __ge__(self, other: object) -> bool: + if isinstance(other, Score): + return self._score_tuple() >= other._score_tuple() + else: + return NotImplemented + +def _sf16_1_wins(cp: int, *, ply: int) -> int: + # https://github.com/official-stockfish/Stockfish/blob/sf_16.1/src/uci.cpp#L48 + NormalizeToPawnValue = 356 + # https://github.com/official-stockfish/Stockfish/blob/sf_16.1/src/uci.cpp#L383-L384 + m = min(120, max(8, ply / 2 + 1)) / 32 + a = (((-1.06249702 * m + 7.42016937) * m + 0.89425629) * m) + 348.60356174 + b = (((-5.33122190 * m + 39.57831533) * m + -90.84473771) * m) + 123.40620748 + x = min(4000, max(cp * NormalizeToPawnValue / 100, -4000)) + return int(0.5 + 1000 / (1 + math.exp((a - x) / b))) + +def _sf16_wins(cp: int, *, ply: int) -> int: + # https://github.com/official-stockfish/Stockfish/blob/sf_16/src/uci.h#L38 + NormalizeToPawnValue = 328 + # https://github.com/official-stockfish/Stockfish/blob/sf_16/src/uci.cpp#L200-L224 + m = min(240, max(ply, 0)) / 64 + a = (((0.38036525 * m + -2.82015070) * m + 23.17882135) * m) + 307.36768407 + b = (((-2.29434733 * m + 13.27689788) * m + -14.26828904) * m) + 63.45318330 + x = min(4000, max(cp * NormalizeToPawnValue / 100, -4000)) + return int(0.5 + 1000 / (1 + math.exp((a - x) / b))) + +def _sf15_1_wins(cp: int, *, ply: int) -> int: + # https://github.com/official-stockfish/Stockfish/blob/sf_15.1/src/uci.h#L38 + NormalizeToPawnValue = 361 + # https://github.com/official-stockfish/Stockfish/blob/sf_15.1/src/uci.cpp#L200-L224 + m = min(240, max(ply, 0)) / 64 + a = (((-0.58270499 * m + 2.68512549) * m + 15.24638015) * m) + 344.49745382 + b = (((-2.65734562 * m + 15.96509799) * m + -20.69040836) * m) + 73.61029937 + x = min(4000, max(cp * NormalizeToPawnValue / 100, -4000)) + return int(0.5 + 1000 / (1 + math.exp((a - x) / b))) + +def _sf15_wins(cp: int, *, ply: int) -> int: + # https://github.com/official-stockfish/Stockfish/blob/sf_15/src/uci.cpp#L200-L220 + m = min(240, max(ply, 0)) / 64 + a = (((-1.17202460e-1 * m + 5.94729104e-1) * m + 1.12065546e+1) * m) + 1.22606222e+2 + b = (((-1.79066759 * m + 11.30759193) * m + -17.43677612) * m) + 36.47147479 + x = min(2000, max(cp, -2000)) + return int(0.5 + 1000 / (1 + math.exp((a - x) / b))) + +def _sf14_wins(cp: int, *, ply: int) -> int: + # https://github.com/official-stockfish/Stockfish/blob/sf_14/src/uci.cpp#L200-L220 + m = min(240, max(ply, 0)) / 64 + a = (((-3.68389304 * m + 30.07065921) * m + -60.52878723) * m) + 149.53378557 + b = (((-2.01818570 * m + 15.85685038) * m + -29.83452023) * m) + 47.59078827 + x = min(2000, max(cp, -2000)) + return int(0.5 + 1000 / (1 + math.exp((a - x) / b))) + +def _sf12_wins(cp: int, *, ply: int) -> int: + # https://github.com/official-stockfish/Stockfish/blob/sf_12/src/uci.cpp#L198-L218 + m = min(240, max(ply, 0)) / 64 + a = (((-8.24404295 * m + 64.23892342) * m + -95.73056462) * m) + 153.86478679 + b = (((-3.37154371 * m + 28.44489198) * m + -56.67657741) * m) + 72.05858751 + x = min(1000, max(cp, -1000)) + return int(0.5 + 1000 / (1 + math.exp((a - x) / b))) + +def _licchess_raw_wins(cp: int) -> int: + # https://github.com/licchess-org/lila/pull/11148 + # https://github.com/licchess-org/lila/blob/2242b0a08faa06e7be5508d338ede7bb09049777/modules/analyse/src/main/WinPercent.scala#L26-L30 + return round(1000 / (1 + math.exp(-0.00368208 * cp))) + + +class Cp(Score): + """Centi-pawn score.""" + + def __init__(self, cp: int) -> None: + self.cp = cp + + def mate(self) -> None: + return None + + def score(self, *, mate_score: Optional[int] = None) -> int: + return self.cp + + def wdl(self, *, model: WdlModel = "sf", ply: int = 30) -> Wdl: + if model == "licchess": + wins = _licchess_raw_wins(max(-1000, min(self.cp, 1000))) + losses = 1000 - wins + elif model == "sf12": + wins = _sf12_wins(self.cp, ply=ply) + losses = _sf12_wins(-self.cp, ply=ply) + elif model == "sf14": + wins = _sf14_wins(self.cp, ply=ply) + losses = _sf14_wins(-self.cp, ply=ply) + elif model == "sf15": + wins = _sf15_wins(self.cp, ply=ply) + losses = _sf15_wins(-self.cp, ply=ply) + elif model == "sf15.1": + wins = _sf15_1_wins(self.cp, ply=ply) + losses = _sf15_1_wins(-self.cp, ply=ply) + elif model == "sf16": + wins = _sf16_wins(self.cp, ply=ply) + losses = _sf16_wins(-self.cp, ply=ply) + else: + wins = _sf16_1_wins(self.cp, ply=ply) + losses = _sf16_1_wins(-self.cp, ply=ply) + draws = 1000 - wins - losses + return Wdl(wins, draws, losses) + + def __str__(self) -> str: + return f"+{self.cp:d}" if self.cp > 0 else str(self.cp) + + def __repr__(self) -> str: + return f"Cp({self})" + + def __neg__(self) -> Cp: + return Cp(-self.cp) + + def __pos__(self) -> Cp: + return Cp(self.cp) + + def __abs__(self) -> Cp: + return Cp(abs(self.cp)) + + +class Mate(Score): + """Mate score.""" + + def __init__(self, moves: int) -> None: + self.moves = moves + + def mate(self) -> int: + return self.moves + + @typing.overload + def score(self, *, mate_score: int) -> int: ... + @typing.overload + def score(self, *, mate_score: Optional[int] = None) -> Optional[int]: ... + def score(self, *, mate_score: Optional[int] = None) -> Optional[int]: + if mate_score is None: + return None + elif self.moves > 0: + return mate_score - self.moves + else: + return -mate_score - self.moves + + def wdl(self, *, model: WdlModel = "sf", ply: int = 30) -> Wdl: + if model == "licchess": + cp = (21 - min(10, abs(self.moves))) * 100 + wins = _licchess_raw_wins(cp) + return Wdl(wins, 0, 1000 - wins) if self.moves > 0 else Wdl(1000 - wins, 0, wins) + else: + return Wdl(1000, 0, 0) if self.moves > 0 else Wdl(0, 0, 1000) + + def __str__(self) -> str: + return f"#+{self.moves}" if self.moves > 0 else f"#-{abs(self.moves)}" + + def __repr__(self) -> str: + return "Mate({})".format(str(self).lstrip("#")) + + def __neg__(self) -> Union[MateGivenType, Mate]: + return MateGiven if not self.moves else Mate(-self.moves) + + def __pos__(self) -> Mate: + return Mate(self.moves) + + def __abs__(self) -> Union[MateGivenType, Mate]: + return MateGiven if not self.moves else Mate(abs(self.moves)) + + +class MateGivenType(Score): + """Winning mate score, equivalent to ``-Mate(0)``.""" + + def mate(self) -> int: + return 0 + + @typing.overload + def score(self, *, mate_score: int) -> int: ... + @typing.overload + def score(self, *, mate_score: Optional[int] = None) -> Optional[int]: ... + def score(self, *, mate_score: Optional[int] = None) -> Optional[int]: + return mate_score + + def wdl(self, *, model: WdlModel = "sf", ply: int = 30) -> Wdl: + return Wdl(1000, 0, 0) + + def __neg__(self) -> Mate: + return Mate(0) + + def __pos__(self) -> MateGivenType: + return self + + def __abs__(self) -> MateGivenType: + return self + + def __repr__(self) -> str: + return "MateGiven" + + def __str__(self) -> str: + return "#+0" + +MateGiven = MateGivenType() + + +class PovWdl: + """ + Relative :class:`win/draw/loss statistics ` and the point + of view. + + .. deprecated:: 1.2 + Behaves like a tuple + ``(wdl.relative.wins, wdl.relative.draws, wdl.relative.losses)`` + for backwards compatibility. But it is recommended to use the provided + fields and methods instead. + """ + + relative: Wdl + """The relative :class:`~cchess.engine.Wdl`.""" + + turn: Color + """The point of view (``cchess.RED`` or ``cchess.BLACK``).""" + + def __init__(self, relative: Wdl, turn: Color) -> None: + self.relative = relative + self.turn = turn + + def red(self) -> Wdl: + """Gets the :class:`~cchess.engine.Wdl` from Red's point of view.""" + return self.pov(cchess.RED) + + def black(self) -> Wdl: + """Gets the :class:`~cchess.engine.Wdl` from Black's point of view.""" + return self.pov(cchess.BLACK) + + def pov(self, color: Color) -> Wdl: + """ + Gets the :class:`~cchess.engine.Wdl` from the point of view of the given + *color*. + """ + return self.relative if self.turn == color else -self.relative + + def __bool__(self) -> bool: + return bool(self.relative) + + def __repr__(self) -> str: + return "PovWdl({!r}, {})".format(self.relative, "RED" if self.turn else "BLACK") + + # Unfortunately in python-cchess v1.1.0, info["wdl"] was a simple tuple + # of the relative permille values, so we have to support __iter__, + # __len__, __getitem__, and equality comparisons with other tuples. + # Never mind the ordering, because that's not a sensible operation, anyway. + + def __iter__(self) -> Iterator[int]: + yield self.relative.wins + yield self.relative.draws + yield self.relative.losses + + def __len__(self) -> int: + return 3 + + def __getitem__(self, idx: int) -> int: + return (self.relative.wins, self.relative.draws, self.relative.losses)[idx] + + def __eq__(self, other: object) -> bool: + if isinstance(other, PovWdl): + return self.red() == other.red() + elif isinstance(other, tuple): + return (self.relative.wins, self.relative.draws, self.relative.losses) == other + else: + return NotImplemented + + +@dataclasses.dataclass +class Wdl: + """Win/draw/loss statistics.""" + + wins: int + """The number of wins.""" + + draws: int + """The number of draws.""" + + losses: int + """The number of losses.""" + + def total(self) -> int: + """ + Returns the total number of games. Usually, ``wdl`` reported by engines + is scaled to 1000 games. + """ + return self.wins + self.draws + self.losses + + def winning_chance(self) -> float: + """Returns the relative frequency of wins.""" + return self.wins / self.total() + + def drawing_chance(self) -> float: + """Returns the relative frequency of draws.""" + return self.draws / self.total() + + def losing_chance(self) -> float: + """Returns the relative frequency of losses.""" + return self.losses / self.total() + + def expectation(self) -> float: + """ + Returns the expectation value, where a win is valued 1, a draw is + valued 0.5, and a loss is valued 0. + """ + return (self.wins + 0.5 * self.draws) / self.total() + + def __bool__(self) -> bool: + return bool(self.total()) + + def __iter__(self) -> Iterator[int]: + yield self.wins + yield self.draws + yield self.losses + + def __reversed__(self) -> Iterator[int]: + yield self.losses + yield self.draws + yield self.wins + + def __pos__(self) -> Wdl: + return self + + def __neg__(self) -> Wdl: + return Wdl(self.losses, self.draws, self.wins) + + +class MockTransport(asyncio.SubprocessTransport, asyncio.WriteTransport): + def __init__(self, protocol: Protocol) -> None: + super().__init__() + self.protocol = protocol + self.expectations: Deque[Tuple[str, List[str]]] = collections.deque() + self.expected_pings = 0 + self.stdin_buffer = bytearray() + self.protocol.connection_made(self) + + def expect(self, expectation: str, responses: List[str] = []) -> None: + self.expectations.append((expectation, responses)) + + def expect_ping(self) -> None: + self.expected_pings += 1 + + def assert_done(self) -> None: + assert not self.expectations, f"pending expectations: {self.expectations}" + + def get_pipe_transport(self, fd: int) -> Optional[asyncio.BaseTransport]: + assert fd == 0, f"expected 0 for stdin, got {fd}" + return self + + def write(self, data: bytes) -> None: + self.stdin_buffer.extend(data) + while b"\n" in self.stdin_buffer: + line_bytes, self.stdin_buffer = self.stdin_buffer.split(b"\n", 1) + line = line_bytes.decode("utf-8") + + if line.startswith("ping ") and self.expected_pings: + self.expected_pings -= 1 + self.protocol.pipe_data_received(1, (line.replace("ping ", "pong ") + "\n").encode("utf-8")) + else: + assert self.expectations, f"unexpected: {line!r}" + expectation, responses = self.expectations.popleft() + assert expectation == line, f"expected {expectation}, got: {line}" + if responses: + self.protocol.loop.call_soon(self.protocol.pipe_data_received, 1, "\n".join(responses + [""]).encode("utf-8")) + + def get_pid(self) -> int: + return id(self) + + def get_returncode(self) -> Optional[int]: + return None if self.expectations else 0 + + +class Protocol(asyncio.SubprocessProtocol, metaclass=abc.ABCMeta): + """Protocol for communicating with a cchess engine process.""" + + options: MutableMapping[str, Option] + """Dictionary of available options.""" + + id: Dict[str, str] + """ + Dictionary of information about the engine. Common keys are ``name`` + and ``author``. + """ + + returncode: asyncio.Future[int] + """Future: Exit code of the process.""" + + def __init__(self) -> None: + self.loop = asyncio.get_running_loop() + self.transport: Optional[asyncio.SubprocessTransport] = None + + self.buffer = { + 1: bytearray(), # stdout + 2: bytearray(), # stderr + } + + self.command: Optional[BaseCommand[Any]] = None + self.next_command: Optional[BaseCommand[Any]] = None + + self.initialized = False + self.returncode: asyncio.Future[int] = asyncio.Future() + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + # SubprocessTransport expected, but not checked to allow duck typing. + self.transport = transport # type: ignore + LOGGER.debug("%s: Connection made", self) + + def connection_lost(self, exc: Optional[Exception]) -> None: + assert self.transport is not None + code = self.transport.get_returncode() + assert code is not None, "connect lost, but got no returncode" + LOGGER.debug("%s: Connection lost (exit code: %d, error: %s)", self, code, exc) + + # Terminate commands. + command, self.command = self.command, None + next_command, self.next_command = self.next_command, None + if command: + command._engine_terminated(code) + if next_command: + next_command._engine_terminated(code) + + self.returncode.set_result(code) + + def process_exited(self) -> None: + LOGGER.debug("%s: Process exited", self) + + def send_line(self, line: str) -> None: + LOGGER.debug("%s: << %s", self, line) + assert self.transport is not None, "cannot send line before connection is made" + stdin = self.transport.get_pipe_transport(0) + # WriteTransport expected, but not checked to allow duck typing. + stdin.write((line + "\n").encode("utf-8")) # type: ignore + + def pipe_data_received(self, fd: int, data: Union[bytes, str]) -> None: + self.buffer[fd].extend(data) # type: ignore + while b"\n" in self.buffer[fd]: + line_bytes, self.buffer[fd] = self.buffer[fd].split(b"\n", 1) + if line_bytes.endswith(b"\r"): + line_bytes = line_bytes[:-1] + try: + line = line_bytes.decode("utf-8") + except UnicodeDecodeError as err: + LOGGER.warning("%s: >> %r (%s)", self, bytes(line_bytes), err) + else: + if fd == 1: + self._line_received(line) + else: + self.error_line_received(line) + + def error_line_received(self, line: str) -> None: + LOGGER.warning("%s: stderr >> %s", self, line) + + def _line_received(self, line: str) -> None: + LOGGER.debug("%s: >> %s", self, line) + + self.line_received(line) + + if self.command: + self.command._line_received(line) + + def line_received(self, line: str) -> None: + pass + + async def communicate(self, command_factory: Callable[[Self], BaseCommand[T]]) -> T: + command = command_factory(self) + + if self.returncode.done(): + raise EngineTerminatedError(f"engine process dead (exit code: {self.returncode.result()})") + + assert command.state == CommandState.NEW, command.state + + if self.next_command is not None: + self.next_command.result.cancel() + self.next_command.finished.cancel() + self.next_command.set_finished() + + self.next_command = command + + def previous_command_finished() -> None: + self.command, self.next_command = self.next_command, None + if self.command is not None: + cmd = self.command + + def cancel_if_cancelled(result: asyncio.Future[T]) -> None: + if result.cancelled(): + cmd._cancel() + + cmd.result.add_done_callback(cancel_if_cancelled) + cmd._start() + cmd.add_finished_callback(previous_command_finished) + + if self.command is None: + previous_command_finished() + elif not self.command.result.done(): + self.command.result.cancel() + elif not self.command.result.cancelled(): + self.command._cancel() + + return await command.result + + def __repr__(self) -> str: + pid = self.transport.get_pid() if self.transport is not None else "?" + return f"<{type(self).__name__} (pid={pid})>" + + @abc.abstractmethod + async def initialize(self) -> None: + """Initializes the engine.""" + + @abc.abstractmethod + async def ping(self) -> None: + """ + Pings the engine and waits for a response. Used to ensure the engine + is still alive and idle. + """ + + @abc.abstractmethod + async def configure(self, options: ConfigMapping) -> None: + """ + Configures global engine options. + + :param options: A dictionary of engine options where the keys are + names of :data:`~cchess.engine.Protocol.options`. Do not set options + that are managed automatically + (:func:`cchess.engine.Option.is_managed()`). + """ + + @abc.abstractmethod + async def send_opponent_information(self, *, opponent: Optional[Opponent] = None, engine_rating: Optional[int] = None) -> None: + """ + Sends the engine information about its opponent. The information will + be sent after a new game is announced and before the first move. This + method should be called before the first move of a game--i.e., the + first call to :func:`cchess.engine.Protocol.play()`. + + :param opponent: Optional. An instance of :class:`cchess.engine.Opponent` that has the opponent's information. + :param engine_rating: Optional. This engine's own rating. Only used by XBoard engines. + """ + + @abc.abstractmethod + async def play(self, board: cchess.Board, limit: Limit, *, game: object = None, info: Info = INFO_NONE, ponder: bool = False, draw_offered: bool = False, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}, opponent: Optional[Opponent] = None) -> PlayResult: + """ + Plays a position. + + :param board: The position. The entire move stack will be sent to the + engine. + :param limit: An instance of :class:`cchess.engine.Limit` that + determines when to stop thinking. + :param game: Optional. An arbitrary object that identifies the game. + Will automatically inform the engine if the object is not equal + to the previous game (e.g., ``ucinewgame``, ``new``). + :param info: Selects which additional information to retrieve from the + engine. ``INFO_NONE``, ``INFO_BASIC`` (basic information that is + trivial to obtain), ``INFO_SCORE``, ``INFO_PV``, + ``INFO_REFUTATION``, ``INFO_CURRLINE``, ``INFO_ALL`` or any + bitwise combination. Some overhead is associated with parsing + extra information. + :param ponder: Whether the engine should keep analysing in the + background even after the result has been returned. + :param draw_offered: Whether the engine's opponent has offered a draw. + Ignored by UCI engines. + :param root_moves: Optional. Consider only root moves from this list. + :param options: Optional. A dictionary of engine options for the + analysis. The previous configuration will be restored after the + analysis is complete. You can permanently apply a configuration + with :func:`~cchess.engine.Protocol.configure()`. + :param opponent: Optional. Information about a new opponent. Information + about the original opponent will be restored once the move is + complete. New opponent information can be made permanent with + :func:`~cchess.engine.Protocol.send_opponent_information()`. + """ + + @typing.overload + async def analyse(self, board: cchess.Board, limit: Limit, *, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> InfoDict: ... + @typing.overload + async def analyse(self, board: cchess.Board, limit: Limit, *, multipv: int, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> List[InfoDict]: ... + @typing.overload + async def analyse(self, board: cchess.Board, limit: Limit, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> Union[List[InfoDict], InfoDict]: ... + async def analyse(self, board: cchess.Board, limit: Limit, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> Union[List[InfoDict], InfoDict]: + """ + Analyses a position and returns a dictionary of + :class:`information `. + + :param board: The position to analyse. The entire move stack will be + sent to the engine. + :param limit: An instance of :class:`cchess.engine.Limit` that + determines when to stop the analysis. + :param multipv: Optional. Analyse multiple root moves. Will return + a list of at most *multipv* dictionaries rather than just a single + info dictionary. + :param game: Optional. An arbitrary object that identifies the game. + Will automatically inform the engine if the object is not equal + to the previous game (e.g., ``ucinewgame``, ``new``). + :param info: Selects which information to retrieve from the + engine. ``INFO_NONE``, ``INFO_BASIC`` (basic information that is + trivial to obtain), ``INFO_SCORE``, ``INFO_PV``, + ``INFO_REFUTATION``, ``INFO_CURRLINE``, ``INFO_ALL`` or any + bitwise combination. Some overhead is associated with parsing + extra information. + :param root_moves: Optional. Limit analysis to a list of root moves. + :param options: Optional. A dictionary of engine options for the + analysis. The previous configuration will be restored after the + analysis is complete. You can permanently apply a configuration + with :func:`~cchess.engine.Protocol.configure()`. + """ + analysis = await self.analysis(board, limit, multipv=multipv, game=game, info=info, root_moves=root_moves, options=options) + + with analysis: + await analysis.wait() + + return analysis.info if multipv is None else analysis.multipv + + @abc.abstractmethod + async def analysis(self, board: cchess.Board, limit: Optional[Limit] = None, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> AnalysisResult: + """ + Starts analysing a position. + + :param board: The position to analyse. The entire move stack will be + sent to the engine. + :param limit: Optional. An instance of :class:`cchess.engine.Limit` + that determines when to stop the analysis. Analysis is infinite + by default. + :param multipv: Optional. Analyse multiple root moves. + :param game: Optional. An arbitrary object that identifies the game. + Will automatically inform the engine if the object is not equal + to the previous game (e.g., ``ucinewgame``, ``new``). + :param info: Selects which information to retrieve from the + engine. ``INFO_NONE``, ``INFO_BASIC`` (basic information that is + trivial to obtain), ``INFO_SCORE``, ``INFO_PV``, + ``INFO_REFUTATION``, ``INFO_CURRLINE``, ``INFO_ALL`` or any + bitwise combination. Some overhead is associated with parsing + extra information. + :param root_moves: Optional. Limit analysis to a list of root moves. + :param options: Optional. A dictionary of engine options for the + analysis. The previous configuration will be restored after the + analysis is complete. You can permanently apply a configuration + with :func:`~cchess.engine.Protocol.configure()`. + + Returns :class:`~cchess.engine.AnalysisResult`, a handle that allows + asynchronously iterating over the information sent by the engine + and stopping the analysis at any time. + """ + + @abc.abstractmethod + async def send_game_result(self, board: cchess.Board, winner: Optional[Color] = None, game_ending: Optional[str] = None, game_complete: bool = True) -> None: + """ + Sends the engine the result of the game. + + XBoard engines receive the final moves and a line of the form + ``result {}``. The ```` field is one of ``1-0``, + ``0-1``, ``1/2-1/2``, or ``*`` to indicate red won, black won, draw, + or adjournment, respectively. The ```` field is a description + of the specific reason for the end of the game: "Red mates", + "Time forfeiture", "Stalemate", etc. + + UCI engines do not expect end-of-game information and so are not + sent anything. + + :param board: The final state of the board. + :param winner: Optional. Specify the winner of the game. This is useful + if the result of the game is not evident from the board--e.g., time + forfeiture or draw by agreement. If not ``None``, this parameter + overrides any winner derivable from the board. + :param game_ending: Optional. Text describing the reason for the game + ending. Similarly to the winner parameter, this overrides any game + result derivable from the board. + :param game_complete: Optional. Whether the game reached completion. + """ + + @abc.abstractmethod + async def quit(self) -> None: + """Asks the engine to shut down.""" + + @classmethod + async def popen(cls: Type[ProtocolT], command: Union[str, List[str]], *, setpgrp: bool = False, **popen_args: Any) -> Tuple[asyncio.SubprocessTransport, ProtocolT]: + if not isinstance(command, list): + command = [command] + + if setpgrp: + try: + # Windows. + popen_args["creationflags"] = popen_args.get("creationflags", 0) | subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore + except AttributeError: + # Unix. + if sys.version_info >= (3, 11): + popen_args["process_group"] = 0 + else: + # Before Python 3.11 + popen_args["start_new_session"] = True + + return await asyncio.get_running_loop().subprocess_exec(cls, *command, **popen_args) + + +class CommandState(enum.Enum): + NEW = enum.auto() + ACTIVE = enum.auto() + CANCELLING = enum.auto() + DONE = enum.auto() + + +class BaseCommand(Generic[T]): + def __init__(self, engine: Protocol) -> None: + self._engine = engine + + self.state = CommandState.NEW + + self.result: asyncio.Future[T] = asyncio.Future() + self.finished: asyncio.Future[None] = asyncio.Future() + + self._finished_callbacks: List[Callable[[], None]] = [] + + def add_finished_callback(self, callback: Callable[[], None]) -> None: + self._finished_callbacks.append(callback) + self._dispatch_finished() + + def _dispatch_finished(self) -> None: + if self.finished.done(): + while self._finished_callbacks: + self._finished_callbacks.pop()() + + def _engine_terminated(self, code: int) -> None: + hint = ", binary not compatible with cpu?" if code in [-4, 0xc000001d] else "" + exc = EngineTerminatedError(f"engine process died unexpectedly (exit code: {code}{hint})") + if self.state == CommandState.ACTIVE: + self.engine_terminated(exc) + elif self.state == CommandState.CANCELLING: + self.finished.set_result(None) + self._dispatch_finished() + elif self.state == CommandState.NEW: + self._handle_exception(exc) + + def _handle_exception(self, exc: Exception) -> None: + if not self.result.done(): + self.result.set_exception(exc) + else: + self._engine.loop.call_exception_handler({ # XXX + "message": f"{type(self).__name__} failed after returning preliminary result ({self.result!r})", + "exception": exc, + "protocol": self._engine, + "transport": self._engine.transport, + }) + + if not self.finished.done(): + self.finished.set_result(None) + self._dispatch_finished() + + def set_finished(self) -> None: + assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING], self.state + if not self.result.done(): + self.result.set_exception(EngineError(f"engine command finished before returning result: {self!r}")) + self.state = CommandState.DONE + self.finished.set_result(None) + self._dispatch_finished() + + def _cancel(self) -> None: + if self.state != CommandState.CANCELLING and self.state != CommandState.DONE: + assert self.state == CommandState.ACTIVE, self.state + self.state = CommandState.CANCELLING + self.cancel() + + def _start(self) -> None: + assert self.state == CommandState.NEW, self.state + self.state = CommandState.ACTIVE + try: + self.check_initialized() + self.start() + except EngineError as err: + self._handle_exception(err) + + def _line_received(self, line: str) -> None: + assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING], self.state + try: + self.line_received(line) + except EngineError as err: + self._handle_exception(err) + + def cancel(self) -> None: + pass + + def check_initialized(self) -> None: + if not self._engine.initialized: + raise EngineError("tried to run command, but engine is not initialized") + + def start(self) -> None: + raise NotImplementedError + + def line_received(self, line: str) -> None: + pass + + def engine_terminated(self, exc: Exception) -> None: + self._handle_exception(exc) + + def __repr__(self) -> str: + return "<{} at {:#x} (state={}, result={}, finished={}>".format(type(self).__name__, id(self), self.state, self.result, self.finished) + + +class UciProtocol(Protocol): + """ + An implementation of the + `Universal cchess Interface `_ + protocol. + """ + + def __init__(self) -> None: + super().__init__() + self.options: UciOptionMap[Option] = UciOptionMap() + self.config: UciOptionMap[ConfigValue] = UciOptionMap() + self.target_config: UciOptionMap[ConfigValue] = UciOptionMap() + self.id = {} + self.board = cchess.Board() + self.game: object = None + self.first_game = True + self.may_ponderhit: Optional[cchess.Board] = None + self.ponderhit = False + + async def initialize(self) -> None: + class UciInitializeCommand(BaseCommand[None]): + def __init__(self, engine: UciProtocol): + super().__init__(engine) + self.engine = engine + + @override + def check_initialized(self) -> None: + if self.engine.initialized: + raise EngineError("engine already initialized") + + @override + def start(self) -> None: + self.engine.send_line("uci") + + @override + def line_received(self, line: str) -> None: + token, remaining = _next_token(line) + if line.strip() == "uciok" and not self.result.done(): + self.engine.initialized = True + self.result.set_result(None) + self.set_finished() + elif token == "option": + self._option(remaining) + elif token == "id": + self._id(remaining) + + def _option(self, arg: str) -> None: + current_parameter = None + option_parts: dict[str, str] = {k: "" for k in ["name", "type", "default", "min", "max"]} + var = [] + + parameters = list(option_parts.keys()) + ['var'] + inner_regex = '|'.join([fr"\b{parameter}\b" for parameter in parameters]) + option_regex = fr"\s*({inner_regex})\s*" + for token in re.split(option_regex, arg.strip()): + if token == "var" or (token in option_parts and not option_parts[token]): + current_parameter = token + elif current_parameter == "var": + var.append(token) + elif current_parameter: + option_parts[current_parameter] = token + + def parse_min_max_value(option_parts: dict[str, str], which: Literal["min", "max"]) -> Optional[int]: + try: + number = option_parts[which] + return int(number) if number else None + except ValueError: + LOGGER.exception(f"Exception parsing option {which}") + return None + + name = option_parts["name"] + type = option_parts["type"] + default = option_parts["default"] + min = parse_min_max_value(option_parts, "min") + max = parse_min_max_value(option_parts, "max") + + without_default = Option(name, type, None, min, max, var) + option = Option(without_default.name, without_default.type, without_default.parse(default), min, max, var) + self.engine.options[option.name] = option + + if option.default is not None: + self.engine.config[option.name] = option.default + if option.default is not None and not option.is_managed() and option.name.lower() != "uci_analysemode": + self.engine.target_config[option.name] = option.default + + def _id(self, arg: str) -> None: + key, value = _next_token(arg) + self.engine.id[key] = value.strip() + + return await self.communicate(UciInitializeCommand) + + def _isready(self) -> None: + self.send_line("isready") + + def _opponent_info(self) -> None: + opponent_info = self.config.get("UCI_Opponent") or self.target_config.get("UCI_Opponent") + if opponent_info: + self.send_line(f"setoption name UCI_Opponent value {opponent_info}") + + def _ucinewgame(self) -> None: + self.send_line("ucinewgame") + self._opponent_info() + self.first_game = False + self.ponderhit = False + + def debug(self, on: bool = True) -> None: + """ + Switches debug mode of the engine on or off. This does not interrupt + other ongoing operations. + """ + if on: + self.send_line("debug on") + else: + self.send_line("debug off") + + async def ping(self) -> None: + class UciPingCommand(BaseCommand[None]): + def __init__(self, engine: UciProtocol) -> None: + super().__init__(engine) + self.engine = engine + + def start(self) -> None: + self.engine._isready() + + @override + def line_received(self, line: str) -> None: + if line.strip() == "readyok": + self.result.set_result(None) + self.set_finished() + else: + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) + + return await self.communicate(UciPingCommand) + + def _changed_options(self, options: ConfigMapping) -> bool: + return any(value is None or value != self.config.get(name) for name, value in _chain_config(options, self.target_config)) + + def _setoption(self, name: str, value: ConfigValue) -> None: + try: + value = self.options[name].parse(value) + except KeyError: + raise EngineError("engine does not support option {} (available options: {})".format(name, ", ".join(self.options))) + + if value is None or value != self.config.get(name): + builder = ["setoption name", name] + if value is False: + builder.append("value false") + elif value is True: + builder.append("value true") + elif value is not None: + builder.append("value") + builder.append(str(value)) + + if name != "UCI_Opponent": # sent after ucinewgame + self.send_line(" ".join(builder)) + self.config[name] = value + + def _configure(self, options: ConfigMapping) -> None: + for name, value in _chain_config(options, self.target_config): + if name.lower() in MANAGED_OPTIONS: + raise EngineError("cannot set {} which is automatically managed".format(name)) + self._setoption(name, value) + + async def configure(self, options: ConfigMapping) -> None: + class UciConfigureCommand(BaseCommand[None]): + def __init__(self, engine: UciProtocol): + super().__init__(engine) + self.engine = engine + + def start(self) -> None: + self.engine._configure(options) + self.engine.target_config.update({name: value for name, value in options.items() if value is not None}) + self.result.set_result(None) + self.set_finished() + + return await self.communicate(UciConfigureCommand) + + def _opponent_configuration(self, *, opponent: Optional[Opponent] = None) -> ConfigMapping: + if opponent and opponent.name and "UCI_Opponent" in self.options: + rating = opponent.rating or "none" + title = opponent.title or "none" + player_type = "computer" if opponent.is_engine else "human" + return {"UCI_Opponent": f"{title} {rating} {player_type} {opponent.name}"} + else: + return {} + + async def send_opponent_information(self, *, opponent: Optional[Opponent] = None, engine_rating: Optional[int] = None) -> None: + return await self.configure(self._opponent_configuration(opponent=opponent)) + + def _position(self, board: cchess.Board) -> None: + # Send starting position. + builder = ["position"] + safe_history = all(board.move_stack) + root = board.root() if safe_history else board + fen = root.fen() + if fen == cchess.STARTING_FEN: + builder.append("startpos") + else: + builder.append("fen") + builder.append(fen) + + # Send moves. + if not safe_history: + LOGGER.warning("Not transmitting history with null moves to UCI engine") + elif board.move_stack: + builder.append("moves") + builder.extend(move.uci() for move in board.move_stack) + + self.send_line(" ".join(builder)) + self.board = board.copy(stack=False) + + def _go(self, limit: Limit, *, root_moves: Optional[Iterable[cchess.Move]] = None, ponder: bool = False, infinite: bool = False) -> None: + builder = ["go"] + if ponder: + builder.append("ponder") + if limit.red_clock is not None: + builder.append("wtime") + builder.append(str(max(1, round(limit.red_clock * 1000)))) + if limit.black_clock is not None: + builder.append("btime") + builder.append(str(max(1, round(limit.black_clock * 1000)))) + if limit.red_inc is not None: + builder.append("winc") + builder.append(str(round(limit.red_inc * 1000))) + if limit.black_inc is not None: + builder.append("binc") + builder.append(str(round(limit.black_inc * 1000))) + if limit.remaining_moves is not None and int(limit.remaining_moves) > 0: + builder.append("movestogo") + builder.append(str(int(limit.remaining_moves))) + if limit.depth is not None: + builder.append("depth") + builder.append(str(max(1, int(limit.depth)))) + if limit.nodes is not None: + builder.append("nodes") + builder.append(str(max(1, int(limit.nodes)))) + if limit.mate is not None: + builder.append("mate") + builder.append(str(max(1, int(limit.mate)))) + if limit.time is not None: + builder.append("movetime") + builder.append(str(max(1, round(limit.time * 1000)))) + if infinite: + builder.append("infinite") + if root_moves is not None: + builder.append("searchmoves") + if root_moves: + builder.extend(move.uci() for move in root_moves) + else: + # Work around searchmoves followed by nothing. + builder.append("0000") + self.send_line(" ".join(builder)) + + async def play(self, board: cchess.Board, limit: Limit, *, game: object = None, info: Info = INFO_NONE, ponder: bool = False, draw_offered: bool = False, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}, opponent: Optional[Opponent] = None) -> PlayResult: + new_options: Dict[str, ConfigValue] = {} + for name, value in options.items(): + new_options[name] = value + new_options.update(self._opponent_configuration(opponent=opponent)) + + engine = self + + class UciPlayCommand(BaseCommand[PlayResult]): + def __init__(self, engine: UciProtocol): + super().__init__(engine) + self.engine = engine + + # May ponderhit only in the same game and with unchanged target + # options. The managed options UCI_AnalyseMode, Ponder, and + # MultiPV never change between pondering play commands. + engine.may_ponderhit = board if ponder and not engine.first_game and game == engine.game and not engine._changed_options(new_options) else None + + @override + def start(self) -> None: + self.info: InfoDict = {} + self.pondering: Optional[cchess.Board] = None + self.sent_isready = False + self.start_time = time.perf_counter() + + if self.engine.ponderhit: + self.engine.ponderhit = False + self.engine.send_line("ponderhit") + return + + if "UCI_AnalyseMode" in self.engine.options and "UCI_AnalyseMode" not in self.engine.target_config and all(name.lower() != "uci_analysemode" for name in new_options): + self.engine._setoption("UCI_AnalyseMode", False) + if "Ponder" in self.engine.options: + self.engine._setoption("Ponder", ponder) + if "MultiPV" in self.engine.options: + self.engine._setoption("MultiPV", self.engine.options["MultiPV"].default) + + new_opponent = new_options.get("UCI_Opponent") or self.engine.target_config.get("UCI_Opponent") + opponent_changed = new_opponent != self.engine.config.get("UCI_Opponent") + self.engine._configure(new_options) + + if self.engine.first_game or self.engine.game != game or opponent_changed: + self.engine.game = game + self.engine._ucinewgame() + self.sent_isready = True + self.engine._isready() + else: + self._readyok() + + @override + def line_received(self, line: str) -> None: + token, remaining = _next_token(line) + if token == "info": + self._info(remaining) + elif token == "bestmove": + self._bestmove(remaining) + elif line.strip() == "readyok" and self.sent_isready: + self._readyok() + else: + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) + + def _readyok(self) -> None: + self.sent_isready = False + engine._position(board) + engine._go(limit, root_moves=root_moves) + + def _info(self, arg: str) -> None: + if not self.pondering: + self.info.update(_parse_uci_info(arg, self.engine.board, info)) + + def _bestmove(self, arg: str) -> None: + if self.pondering: + self.pondering = None + elif not self.result.cancelled(): + best = _parse_uci_bestmove(self.engine.board, arg) + self.result.set_result(PlayResult(best.move, best.ponder, self.info)) + + if ponder and best.move and best.ponder: + self.pondering = board.copy() + self.pondering.push(best.move) + self.pondering.push(best.ponder) + self.engine._position(self.pondering) + + # Adjust clocks for pondering. + time_used = time.perf_counter() - self.start_time + ponder_limit = copy.copy(limit) + if ponder_limit.red_clock is not None: + ponder_limit.red_clock += (ponder_limit.red_inc or 0.0) + if self.pondering.turn == cchess.RED: + ponder_limit.red_clock -= time_used + if ponder_limit.black_clock is not None: + ponder_limit.black_clock += (ponder_limit.black_inc or 0.0) + if self.pondering.turn == cchess.BLACK: + ponder_limit.black_clock -= time_used + if ponder_limit.remaining_moves: + ponder_limit.remaining_moves -= 1 + + self.engine._go(ponder_limit, ponder=True) + + if not self.pondering: + self.end() + + def end(self) -> None: + engine.may_ponderhit = None + self.set_finished() + + @override + def cancel(self) -> None: + if self.engine.may_ponderhit and self.pondering and self.engine.may_ponderhit.move_stack == self.pondering.move_stack and self.engine.may_ponderhit == self.pondering: + self.engine.ponderhit = True + self.end() + else: + self.engine.send_line("stop") + + @override + def engine_terminated(self, exc: Exception) -> None: + # Allow terminating engine while pondering. + if not self.result.done(): + super().engine_terminated(exc) + + return await self.communicate(UciPlayCommand) + + async def analysis(self, board: cchess.Board, limit: Optional[Limit] = None, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> AnalysisResult: + class UciAnalysisCommand(BaseCommand[AnalysisResult]): + def __init__(self, engine: UciProtocol): + super().__init__(engine) + self.engine = engine + + def start(self) -> None: + self.analysis = AnalysisResult(stop=lambda: self.cancel()) + self.sent_isready = False + + if "Ponder" in self.engine.options: + self.engine._setoption("Ponder", False) + if "UCI_AnalyseMode" in self.engine.options and "UCI_AnalyseMode" not in self.engine.target_config and all(name.lower() != "uci_analysemode" for name in options): + self.engine._setoption("UCI_AnalyseMode", True) + if "MultiPV" in self.engine.options or (multipv and multipv > 1): + self.engine._setoption("MultiPV", 1 if multipv is None else multipv) + + self.engine._configure(options) + + if self.engine.first_game or self.engine.game != game: + self.engine.game = game + self.engine._ucinewgame() + self.sent_isready = True + self.engine._isready() + else: + self._readyok() + + @override + def line_received(self, line: str) -> None: + token, remaining = _next_token(line) + if token == "info": + self._info(remaining) + elif token == "bestmove": + self._bestmove(remaining) + elif line.strip() == "readyok" and self.sent_isready: + self._readyok() + else: + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) + + def _readyok(self) -> None: + self.sent_isready = False + self.engine._position(board) + + if limit: + self.engine._go(limit, root_moves=root_moves) + else: + self.engine._go(Limit(), root_moves=root_moves, infinite=True) + + self.result.set_result(self.analysis) + + def _info(self, arg: str) -> None: + self.analysis.post(_parse_uci_info(arg, self.engine.board, info)) + + def _bestmove(self, arg: str) -> None: + if not self.result.done(): + raise EngineError("was not searching, but engine sent bestmove") + best = _parse_uci_bestmove(self.engine.board, arg) + self.set_finished() + self.analysis.set_finished(best) + + @override + def cancel(self) -> None: + self.engine.send_line("stop") + + @override + def engine_terminated(self, exc: Exception) -> None: + LOGGER.debug("%s: Closing analysis because engine has been terminated (error: %s)", self.engine, exc) + self.analysis.set_exception(exc) + + return await self.communicate(UciAnalysisCommand) + + async def send_game_result(self, board: cchess.Board, winner: Optional[Color] = None, game_ending: Optional[str] = None, game_complete: bool = True) -> None: + pass + + async def quit(self) -> None: + self.send_line("quit") + await asyncio.shield(self.returncode) + + +UCI_REGEX = re.compile(r"^[a-h][1-8][a-h][1-8][pnbrqk]?|[PNBRQK]@[a-h][1-8]|0000\Z") + +def _create_variation_line(root_board: cchess.Board, line: str) -> tuple[list[cchess.Move], str]: + board = root_board.copy(stack=False) + currline: list[cchess.Move] = [] + while True: + next_move, remaining_line_after_move = _next_token(line) + if UCI_REGEX.match(next_move): + currline.append(board.push_uci(next_move)) + line = remaining_line_after_move + else: + return currline, line + + +def _parse_uci_info(arg: str, root_board: cchess.Board, selector: Info = INFO_ALL) -> InfoDict: + info: InfoDict = {} + if not selector: + return info + + remaining_line = arg + while remaining_line: + parameter, remaining_line = _next_token(remaining_line) + + if parameter == "string": + info["string"] = remaining_line + break + elif parameter in ["depth", "seldepth", "nodes", "multipv", "currmovenumber", "hashfull", "nps", "tbhits", "cpuload"]: + try: + number, remaining_line = _next_token(remaining_line) + info[parameter] = int(number) # type: ignore + except (ValueError, IndexError): + LOGGER.error("Exception parsing %s from info: %r", parameter, arg) + elif parameter == "time": + try: + time_ms, remaining_line = _next_token(remaining_line) + info["time"] = int(time_ms) / 1000.0 + except (ValueError, IndexError): + LOGGER.error("Exception parsing %s from info: %r", parameter, arg) + elif parameter == "ebf": + try: + number, remaining_line = _next_token(remaining_line) + info["ebf"] = float(number) + except (ValueError, IndexError): + LOGGER.error("Exception parsing %s from info: %r", parameter, arg) + elif parameter == "score" and selector & INFO_SCORE: + try: + kind, remaining_line = _next_token(remaining_line) + value, remaining_line = _next_token(remaining_line) + token, remaining_after_token = _next_token(remaining_line) + if token in ["lowerbound", "upperbound"]: + info[token] = True # type: ignore + remaining_line = remaining_after_token + if kind == "cp": + info["score"] = PovScore(Cp(int(value)), root_board.turn) + elif kind == "mate": + info["score"] = PovScore(Mate(int(value)), root_board.turn) + else: + LOGGER.error("Unknown score kind %r in info (expected cp or mate): %r", kind, arg) + except (ValueError, IndexError): + LOGGER.error("Exception parsing score from info: %r", arg) + elif parameter == "currmove": + try: + current_move, remaining_line = _next_token(remaining_line) + info["currmove"] = cchess.Move.from_uci(current_move) + except (ValueError, IndexError): + LOGGER.error("Exception parsing currmove from info: %r", arg) + elif parameter == "currline" and selector & INFO_CURRLINE: + try: + if "currline" not in info: + info["currline"] = {} + + cpunr_text, remaining_line = _next_token(remaining_line) + cpunr = int(cpunr_text) + currline, remaining_line = _create_variation_line(root_board, remaining_line) + info["currline"][cpunr] = currline + except (ValueError, IndexError): + LOGGER.error("Exception parsing currline from info: %r, position at root: %s", arg, root_board.fen()) + elif parameter == "refutation" and selector & INFO_REFUTATION: + try: + if "refutation" not in info: + info["refutation"] = {} + + board = root_board.copy(stack=False) + refuted_text, remaining_line = _next_token(remaining_line) + refuted = board.push_uci(refuted_text) + + refuted_by, remaining_line = _create_variation_line(board, remaining_line) + info["refutation"][refuted] = refuted_by + except (ValueError, IndexError): + LOGGER.error("Exception parsing refutation from info: %r, position at root: %s", arg, root_board.fen()) + elif parameter == "pv" and selector & INFO_PV: + try: + pv, remaining_line = _create_variation_line(root_board, remaining_line) + info["pv"] = pv + except (ValueError, IndexError): + LOGGER.error("Exception parsing pv from info: %r, position at root: %s", arg, root_board.fen()) + elif parameter == "wdl": + try: + wins, remaining_line = _next_token(remaining_line) + draws, remaining_line = _next_token(remaining_line) + losses, remaining_line = _next_token(remaining_line) + info["wdl"] = PovWdl(Wdl(int(wins), int(draws), int(losses)), root_board.turn) + except (ValueError, IndexError): + LOGGER.error("Exception parsing wdl from info: %r", arg) + + return info + +def _parse_uci_bestmove(board: cchess.Board, args: str) -> BestMove: + tokens = args.split() + + move = None + ponder = None + + if tokens and tokens[0] not in ["(none)", "NULL"]: + try: + # AnMon 5.75 uses uppercase letters to denote promotion types. + move = board.push_uci(tokens[0].lower()) + except ValueError as err: + raise EngineError(err) + + try: + # Houdini 1.5 sends NULL instead of skipping the token. + if len(tokens) >= 3 and tokens[1] == "ponder" and tokens[2] not in ["(none)", "NULL"]: + ponder = board.parse_uci(tokens[2].lower()) + except ValueError: + LOGGER.exception("Engine sent invalid ponder move") + finally: + board.pop() + + return BestMove(move, ponder) + + +def _chain_config(a: ConfigMapping, b: ConfigMapping) -> Iterator[Tuple[str, ConfigValue]]: + for name, value in a.items(): + yield name, value + for name, value in b.items(): + if name not in a: + yield name, value + + +class UciOptionMap(MutableMapping[str, T]): + """Dictionary with case-insensitive keys.""" + + def __init__(self, data: Optional[Iterable[Tuple[str, T]]] = None, **kwargs: T) -> None: + self._store: Dict[str, Tuple[str, T]] = {} + if data is None: + data = {} + self.update(data, **kwargs) + + def __setitem__(self, key: str, value: T) -> None: + self._store[key.lower()] = (key, value) + + def __getitem__(self, key: str) -> T: + return self._store[key.lower()][1] + + def __delitem__(self, key: str) -> None: + del self._store[key.lower()] + + def __iter__(self) -> Iterator[str]: + return (casedkey for casedkey, _ in self._store.values()) + + def __len__(self) -> int: + return len(self._store) + + def __eq__(self, other: object) -> bool: + try: + for key, value in self.items(): + if key not in other or other[key] != value: # type: ignore + return False + + for key, value in other.items(): # type: ignore + if key not in self or self[key] != value: + return False + + return True + except (TypeError, AttributeError): + return NotImplemented + + def copy(self) -> UciOptionMap[T]: + return type(self)(self._store.values()) + + def __copy__(self) -> UciOptionMap[T]: + return self.copy() + + def __repr__(self) -> str: + return f"{type(self).__name__}({dict(self.items())!r})" + + +XBOARD_ERROR_REGEX = re.compile(r"^\s*(Error|Illegal move)(\s*\([^()]+\))?\s*:") + + +class XBoardProtocol(Protocol): + """ + An implementation of the + `XBoard protocol `__ (CECP). + """ + + def __init__(self) -> None: + super().__init__() + self.features: Dict[str, Union[int, str]] = {} + self.id = {} + self.options = { + "random": Option("random", "check", False, None, None, None), + "computer": Option("computer", "check", False, None, None, None), + "name": Option("name", "string", "", None, None, None), + "engine_rating": Option("engine_rating", "spin", 0, None, None, None), + "opponent_rating": Option("opponent_rating", "spin", 0, None, None, None) + } + self.config: Dict[str, ConfigValue] = {} + self.target_config: Dict[str, ConfigValue] = {} + self.board = cchess.Board() + self.game: object = None + self.clock_id: object = None + self.first_game = True + + async def initialize(self) -> None: + class XBoardInitializeCommand(BaseCommand[None]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @override + def check_initialized(self) -> None: + if self.engine.initialized: + raise EngineError("engine already initialized") + + @override + def start(self) -> None: + self.engine.send_line("xboard") + self.engine.send_line("protover 2") + self.timeout_handle = self.engine.loop.call_later(2.0, lambda: self.timeout()) + + def timeout(self) -> None: + LOGGER.error("%s: Timeout during initialization", self.engine) + self.end() + + @override + def line_received(self, line: str) -> None: + token, remaining = _next_token(line) + if token.startswith("#"): + pass + elif token == "feature": + self._feature(remaining) + elif XBOARD_ERROR_REGEX.match(line): + raise EngineError(line) + + def _feature(self, arg: str) -> None: + for feature in shlex.split(arg): + key, value = feature.split("=", 1) + if key == "option": + option = _parse_xboard_option(value) + if option.name not in ["random", "computer", "cores", "memory"]: + self.engine.options[option.name] = option + else: + try: + self.engine.features[key] = int(value) + except ValueError: + self.engine.features[key] = value + + if "done" in self.engine.features: + self.timeout_handle.cancel() + if self.engine.features.get("done"): + self.end() + + def end(self) -> None: + if not self.engine.features.get("ping", 0): + self.result.set_exception(EngineError("xboard engine did not declare required feature: ping")) + self.set_finished() + return + if not self.engine.features.get("setboard", 0): + self.result.set_exception(EngineError("xboard engine did not declare required feature: setboard")) + self.set_finished() + return + + if not self.engine.features.get("reuse", 1): + LOGGER.warning("%s: Rejecting feature reuse=0", self.engine) + self.engine.send_line("rejected reuse") + if not self.engine.features.get("sigterm", 1): + LOGGER.warning("%s: Rejecting feature sigterm=0", self.engine) + self.engine.send_line("rejected sigterm") + if self.engine.features.get("san", 0): + LOGGER.warning("%s: Rejecting feature san=1", self.engine) + self.engine.send_line("rejected san") + + if "myname" in self.engine.features: + self.engine.id["name"] = str(self.engine.features["myname"]) + + if self.engine.features.get("memory", 0): + self.engine.options["memory"] = Option("memory", "spin", 16, 1, None, None) + self.engine.send_line("accepted memory") + if self.engine.features.get("smp", 0): + self.engine.options["cores"] = Option("cores", "spin", 1, 1, None, None) + self.engine.send_line("accepted smp") + if self.engine.features.get("egt"): + for egt in str(self.engine.features["egt"]).split(","): + name = f"egtpath {egt}" + self.engine.options[name] = Option(name, "path", None, None, None, None) + self.engine.send_line("accepted egt") + + for option in self.engine.options.values(): + if option.default is not None: + self.engine.config[option.name] = option.default + if option.default is not None and not option.is_managed(): + self.engine.target_config[option.name] = option.default + + self.engine.initialized = True + self.result.set_result(None) + self.set_finished() + + return await self.communicate(XBoardInitializeCommand) + + def _ping(self, n: int) -> None: + self.send_line(f"ping {n}") + + def _variant(self, variant: Optional[str]) -> None: + variants = str(self.features.get("variants", "")).split(",") + if not variant or variant not in variants: + raise EngineError("unsupported xboard variant: {} (available: {})".format(variant, ", ".join(variants))) + + self.send_line(f"variant {variant}") + + def _new(self, board: cchess.Board, game: object, options: ConfigMapping, opponent: Optional[Opponent] = None) -> None: + self._configure(options) + self._configure(self._opponent_configuration(opponent=opponent)) + + # Set up starting position. + root = board.root() + new_options = any(param in options for param in ("random", "computer")) + new_game = self.first_game or self.game != game or new_options or opponent or root != self.board.root() + self.game = game + self.first_game = False + if new_game: + self.board = root + self.send_line("new") + + variant = type(board).xboard_variant + if variant == "normal" and board.cchess960: + self._variant("fischerandom") + elif variant != "normal": + self._variant(variant) + + if self.config.get("random"): + self.send_line("random") + + opponent_name = self.config.get("name") + if opponent_name and self.features.get("name", True): + self.send_line(f"name {opponent_name}") + + opponent_rating = self.config.get("opponent_rating") + engine_rating = self.config.get("engine_rating") + if engine_rating or opponent_rating: + self.send_line(f"rating {engine_rating or 0} {opponent_rating or 0}") + + if self.config.get("computer"): + self.send_line("computer") + + self.send_line("force") + + fen = root.fen(shredder=board.cchess960, en_passant="fen") + if variant != "normal" or fen != cchess.STARTING_FEN or board.cchess960: + self.send_line(f"setboard {fen}") + else: + self.send_line("force") + + # Undo moves until common position. + common_stack_len = 0 + if not new_game: + for left, right in zip(self.board.move_stack, board.move_stack): + if left == right: + common_stack_len += 1 + else: + break + + while len(self.board.move_stack) > common_stack_len + 1: + self.send_line("remove") + self.board.pop() + self.board.pop() + + while len(self.board.move_stack) > common_stack_len: + self.send_line("undo") + self.board.pop() + + # Play moves from board stack. + for move in board.move_stack[common_stack_len:]: + if not move: + LOGGER.warning("Null move (in %s) may not be supported by all XBoard engines", self.board.fen()) + prefix = "usermove " if self.features.get("usermove", 0) else "" + self.send_line(prefix + self.board.xboard(move)) + self.board.push(move) + + async def ping(self) -> None: + class XBoardPingCommand(BaseCommand[None]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @override + def start(self) -> None: + n = id(self) & 0xffff + self.pong = f"pong {n}" + self.engine._ping(n) + + @override + def line_received(self, line: str) -> None: + if line == self.pong: + self.result.set_result(None) + self.set_finished() + elif not line.startswith("#"): + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) + elif XBOARD_ERROR_REGEX.match(line): + raise EngineError(line) + + return await self.communicate(XBoardPingCommand) + + async def play(self, board: cchess.Board, limit: Limit, *, game: object = None, info: Info = INFO_NONE, ponder: bool = False, draw_offered: bool = False, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}, opponent: Optional[Opponent] = None) -> PlayResult: + if root_moves is not None: + raise EngineError("play with root_moves, but xboard supports 'include' only in analysis mode") + + class XBoardPlayCommand(BaseCommand[PlayResult]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @override + def start(self) -> None: + self.play_result = PlayResult(None, None) + self.stopped = False + self.pong_after_move: Optional[str] = None + self.pong_after_ponder: Optional[str] = None + + # Set game, position and configure. + self.engine._new(board, game, options, opponent) + + # Limit or time control. + clock = limit.red_clock if board.turn else limit.black_clock + increment = limit.red_inc if board.turn else limit.black_inc + if limit.clock_id is None or limit.clock_id != self.engine.clock_id: + self._send_time_control(clock, increment) + self.engine.clock_id = limit.clock_id + if limit.nodes is not None: + if limit.time is not None or limit.red_clock is not None or limit.black_clock is not None or increment is not None: + raise EngineError("xboard does not support mixing node limits with time limits") + + if "nps" not in self.engine.features: + LOGGER.warning("%s: Engine did not explicitly declare support for node limits (feature nps=?)") + elif not self.engine.features["nps"]: + raise EngineError("xboard engine does not support node limits (feature nps=0)") + + self.engine.send_line("nps 1") + self.engine.send_line(f"st {max(1, int(limit.nodes))}") + if limit.depth is not None: + self.engine.send_line(f"sd {max(1, int(limit.depth))}") + if limit.red_clock is not None: + self.engine.send_line("{} {}".format("time" if board.turn else "otim", max(1, round(limit.red_clock * 100)))) + if limit.black_clock is not None: + self.engine.send_line("{} {}".format("otim" if board.turn else "time", max(1, round(limit.black_clock * 100)))) + + if draw_offered and self.engine.features.get("draw", 1): + self.engine.send_line("draw") + + # Start thinking. + self.engine.send_line("post" if info else "nopost") + self.engine.send_line("hard" if ponder else "easy") + self.engine.send_line("go") + + @override + def line_received(self, line: str) -> None: + token, remaining = _next_token(line) + if token == "move": + self._move(remaining.strip()) + elif token == "Hint:": + self._hint(remaining.strip()) + elif token == "pong": + pong_line = f"{token} {remaining.strip()}" + if pong_line == self.pong_after_move: + if not self.result.done(): + self.result.set_result(self.play_result) + if not ponder: + self.set_finished() + elif pong_line == self.pong_after_ponder: + if not self.result.done(): + self.result.set_result(self.play_result) + self.set_finished() + elif f"{token} {remaining.strip()}" == "offer draw": + if not self.result.done(): + self.play_result.draw_offered = True + self._ping_after_move() + elif line.strip() == "resign": + if not self.result.done(): + self.play_result.resigned = True + self._ping_after_move() + elif token in ["1-0", "0-1", "1/2-1/2"]: + if "resign" in line and not self.result.done(): + self.play_result.resigned = True + self._ping_after_move() + elif token.startswith("#"): + pass + elif XBOARD_ERROR_REGEX.match(line): + self.engine.first_game = True # Board state might no longer be in sync + raise EngineError(line) + elif len(line.split()) >= 4 and line.lstrip()[0].isdigit(): + self._post(line) + else: + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) + + def _send_time_control(self, clock: Optional[float], increment: Optional[float]) -> None: + if limit.remaining_moves or clock is not None or increment is not None: + base_mins, base_secs = divmod(int(clock or 0), 60) + self.engine.send_line(f"level {limit.remaining_moves or 0} {base_mins}:{base_secs:02d} {increment or 0}") + if limit.time is not None: + self.engine.send_line(f"st {max(0.01, limit.time)}") + + def _post(self, line: str) -> None: + if not self.result.done(): + self.play_result.info = _parse_xboard_post(line, self.engine.board, info) + + def _move(self, arg: str) -> None: + if not self.result.done() and self.play_result.move is None: + try: + self.play_result.move = self.engine.board.push_xboard(arg) + except ValueError as err: + self.result.set_exception(EngineError(err)) + else: + self._ping_after_move() + else: + try: + self.engine.board.push_xboard(arg) + except ValueError: + LOGGER.exception("Exception playing unexpected move") + + def _hint(self, arg: str) -> None: + if not self.result.done() and self.play_result.move is not None and self.play_result.ponder is None: + try: + self.play_result.ponder = self.engine.board.parse_xboard(arg) + except ValueError: + LOGGER.exception("Exception parsing hint") + else: + LOGGER.warning("Unexpected hint: %r", arg) + + def _ping_after_move(self) -> None: + if self.pong_after_move is None: + n = id(self) & 0xffff + self.pong_after_move = f"pong {n}" + self.engine._ping(n) + + @override + def cancel(self) -> None: + if self.stopped: + return + self.stopped = True + + if self.result.cancelled(): + self.engine.send_line("?") + + if ponder: + self.engine.send_line("easy") + + n = (id(self) + 1) & 0xffff + self.pong_after_ponder = f"pong {n}" + self.engine._ping(n) + + @override + def engine_terminated(self, exc: Exception) -> None: + # Allow terminating engine while pondering. + if not self.result.done(): + super().engine_terminated(exc) + + return await self.communicate(XBoardPlayCommand) + + async def analysis(self, board: cchess.Board, limit: Optional[Limit] = None, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> AnalysisResult: + if multipv is not None: + raise EngineError("xboard engine does not support multipv") + + if limit is not None and (limit.red_clock is not None or limit.black_clock is not None): + raise EngineError("xboard analysis does not support clock limits") + + class XBoardAnalysisCommand(BaseCommand[AnalysisResult]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @override + def start(self) -> None: + self.stopped = False + self.best_move: Optional[cchess.Move] = None + self.analysis = AnalysisResult(stop=lambda: self.cancel()) + self.final_pong: Optional[str] = None + + self.engine._new(board, game, options) + + if root_moves is not None: + if not self.engine.features.get("exclude", 0): + raise EngineError("xboard engine does not support root_moves (feature exclude=0)") + + self.engine.send_line("exclude all") + for move in root_moves: + self.engine.send_line(f"include {self.engine.board.xboard(move)}") + + self.engine.send_line("post") + self.engine.send_line("analyze") + + self.result.set_result(self.analysis) + + if limit is not None and limit.time is not None: + self.time_limit_handle: Optional[asyncio.Handle] = self.engine.loop.call_later(limit.time, lambda: self.cancel()) + else: + self.time_limit_handle = None + + @override + def line_received(self, line: str) -> None: + token, remaining = _next_token(line) + if token.startswith("#"): + pass + elif len(line.split()) >= 4 and line.lstrip()[0].isdigit(): + self._post(line) + elif f"{token} {remaining.strip()}" == self.final_pong: + self.end() + elif XBOARD_ERROR_REGEX.match(line): + self.engine.first_game = True # Board state might no longer be in sync + raise EngineError(line) + else: + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) + + def _post(self, line: str) -> None: + post_info = _parse_xboard_post(line, self.engine.board, info) + self.analysis.post(post_info) + + pv = post_info.get("pv") + if pv: + self.best_move = pv[0] + + if limit is not None: + if limit.time is not None and post_info.get("time", 0) >= limit.time: + self.cancel() + elif limit.nodes is not None and post_info.get("nodes", 0) >= limit.nodes: + self.cancel() + elif limit.depth is not None and post_info.get("depth", 0) >= limit.depth: + self.cancel() + elif limit.mate is not None and "score" in post_info: + if post_info["score"].relative >= Mate(limit.mate): + self.cancel() + + def end(self) -> None: + if self.time_limit_handle: + self.time_limit_handle.cancel() + + self.set_finished() + self.analysis.set_finished(BestMove(self.best_move, None)) + + @override + def cancel(self) -> None: + if self.stopped: + return + self.stopped = True + + self.engine.send_line(".") + self.engine.send_line("exit") + + n = id(self) & 0xffff + self.final_pong = f"pong {n}" + self.engine._ping(n) + + @override + def engine_terminated(self, exc: Exception) -> None: + LOGGER.debug("%s: Closing analysis because engine has been terminated (error: %s)", self.engine, exc) + + if self.time_limit_handle: + self.time_limit_handle.cancel() + + self.analysis.set_exception(exc) + + return await self.communicate(XBoardAnalysisCommand) + + def _setoption(self, name: str, value: ConfigValue) -> None: + if value is not None and value == self.config.get(name): + return + + try: + option = self.options[name] + except KeyError: + raise EngineError(f"unsupported xboard option or command: {name}") + + self.config[name] = value = option.parse(value) + + if name in ["random", "computer", "name", "engine_rating", "opponent_rating"]: + # Applied in _new. + pass + elif name in ["memory", "cores"] or name.startswith("egtpath "): + self.send_line(f"{name} {value}") + elif value is None: + self.send_line(f"option {name}") + elif value is True: + self.send_line(f"option {name}=1") + elif value is False: + self.send_line(f"option {name}=0") + else: + self.send_line(f"option {name}={value}") + + def _configure(self, options: ConfigMapping) -> None: + for name, value in _chain_config(options, self.target_config): + if name.lower() in MANAGED_OPTIONS: + raise EngineError(f"cannot set {name} which is automatically managed") + self._setoption(name, value) + + async def configure(self, options: ConfigMapping) -> None: + class XBoardConfigureCommand(BaseCommand[None]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @override + def start(self) -> None: + self.engine._configure(options) + self.engine.target_config.update({name: value for name, value in options.items() if value is not None}) + self.result.set_result(None) + self.set_finished() + + return await self.communicate(XBoardConfigureCommand) + + def _opponent_configuration(self, *, opponent: Optional[Opponent] = None, engine_rating: Optional[int] = None) -> ConfigMapping: + if opponent is None: + return {} + + opponent_info: Dict[str, Union[int, bool, str]] = {"engine_rating": engine_rating or self.target_config.get("engine_rating") or 0, + "opponent_rating": opponent.rating or 0, + "computer": opponent.is_engine or False} + + if opponent.name and self.features.get("name", True): + opponent_info["name"] = f"{opponent.title or ''} {opponent.name}".strip() + + return opponent_info + + async def send_opponent_information(self, *, opponent: Optional[Opponent] = None, engine_rating: Optional[int] = None) -> None: + return await self.configure(self._opponent_configuration(opponent=opponent, engine_rating=engine_rating)) + + async def send_game_result(self, board: cchess.Board, winner: Optional[Color] = None, game_ending: Optional[str] = None, game_complete: bool = True) -> None: + class XBoardGameResultCommand(BaseCommand[None]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @override + def start(self) -> None: + if game_ending and any(c in game_ending for c in "{}\n\r"): + raise EngineError(f"invalid line break or curly braces in game ending message: {game_ending!r}") + + self.engine._new(board, self.engine.game, {}) # Send final moves to engine. + + outcome = board.outcome(claim_draw=True) + + if not game_complete: + result = "*" + ending = game_ending or "" + elif winner is not None or game_ending: + result = "1-0" if winner == cchess.RED else "0-1" if winner == cchess.BLACK else "1/2-1/2" + ending = game_ending or "" + elif outcome is not None and outcome.winner is not None: + result = outcome.result() + winning_color = "Red" if outcome.winner == cchess.RED else "Black" + is_checkmate = outcome.termination == cchess.Termination.CHECKMATE + ending = f"{winning_color} {'mates' if is_checkmate else 'variant win'}" + elif outcome is not None: + result = outcome.result() + ending = outcome.termination.name.capitalize().replace("_", " ") + else: + result = "*" + ending = "" + + ending_text = f"{{{ending}}}" if ending else "" + self.engine.send_line(f"result {result} {ending_text}".strip()) + self.result.set_result(None) + self.set_finished() + + return await self.communicate(XBoardGameResultCommand) + + async def quit(self) -> None: + self.send_line("quit") + await asyncio.shield(self.returncode) + + +def _parse_xboard_option(feature: str) -> Option: + params = feature.split() + + name = params[0] + type = params[1][1:] + default: Optional[ConfigValue] = None + min = None + max = None + var = None + + if type == "combo": + var = [] + choices = params[2:] + for choice in choices: + if choice == "///": + continue + elif choice[0] == "*": + default = choice[1:] + var.append(choice[1:]) + else: + var.append(choice) + elif type == "check": + default = int(params[2]) + elif type in ["string", "file", "path"]: + if len(params) > 2: + default = params[2] + else: + default = "" + elif type == "spin": + default = int(params[2]) + min = int(params[3]) + max = int(params[4]) + + return Option(name, type, default, min, max, var) + + +def _parse_xboard_post(line: str, root_board: cchess.Board, selector: Info = INFO_ALL) -> InfoDict: + # Format: depth score time nodes [seldepth [nps [tbhits]]] pv + info: InfoDict = {} + + # Split leading integer tokens from pv. + pv_tokens = line.split() + integer_tokens = [] + while pv_tokens: + token = pv_tokens.pop(0) + try: + integer_tokens.append(int(token)) + except ValueError: + pv_tokens.insert(0, token) + break + + if len(integer_tokens) < 4: + return info + + # Required integer tokens. + info["depth"] = integer_tokens.pop(0) + cp = integer_tokens.pop(0) + info["time"] = int(integer_tokens.pop(0)) / 100 + info["nodes"] = int(integer_tokens.pop(0)) + + # Score. + if cp <= -100000: + score: Score = Mate(cp + 100000) + elif cp == 100000: + score = MateGiven + elif cp >= 100000: + score = Mate(cp - 100000) + else: + score = Cp(cp) + info["score"] = PovScore(score, root_board.turn) + + # Optional integer tokens. + if integer_tokens: + info["seldepth"] = integer_tokens.pop(0) + if integer_tokens: + info["nps"] = integer_tokens.pop(0) + + while len(integer_tokens) > 1: + # Reserved for future extensions. + integer_tokens.pop(0) + + if integer_tokens: + info["tbhits"] = integer_tokens.pop(0) + + # Principal variation. + pv = [] + board = root_board.copy(stack=False) + for token in pv_tokens: + if token.rstrip(".").isdigit(): + continue + + try: + pv.append(board.push_xboard(token)) + except ValueError: + break + + if not (selector & INFO_PV): + break + info["pv"] = pv + + return info + + +def _next_token(line: str) -> tuple[str, str]: + """ + Get the next token in a whitespace-delimited line of text. + + The result is returned as a 2-part tuple of strings. + + If the input line is empty or all whitespace, then the result is two + empty strings. + + If the input line is not empty and not completely whitespace, then + the first element of the returned tuple is a single word with + leading and trailing whitespace removed. The second element is the + unchanged rest of the line. + """ + parts = line.split(maxsplit=1) + return parts[0] if parts else "", parts[1] if len(parts) == 2 else "" + + +class BestMove: + """Returned by :func:`cchess.engine.AnalysisResult.wait()`.""" + + move: Optional[cchess.Move] + """The best move according to the engine, or ``None``.""" + + ponder: Optional[cchess.Move] + """The response that the engine expects after *move*, or ``None``.""" + + def __init__(self, move: Optional[cchess.Move], ponder: Optional[cchess.Move]): + self.move = move + self.ponder = ponder + + def __repr__(self) -> str: + return "<{} at {:#x} (move={}, ponder={}>".format( + type(self).__name__, id(self), self.move, self.ponder) + + +class AnalysisResult: + """ + Handle to ongoing engine analysis. + Returned by :func:`cchess.engine.Protocol.analysis()`. + + Can be used to asynchronously iterate over information sent by the engine. + + Automatically stops the analysis when used as a context manager. + """ + + multipv: List[InfoDict] + """ + A list of dictionaries with aggregated information sent by the engine. + One item for each root move. + """ + + def __init__(self, stop: Optional[Callable[[], None]] = None): + self._stop = stop + self._queue: asyncio.Queue[InfoDict] = asyncio.Queue() + self._posted_kork = False + self._seen_kork = False + self._finished: asyncio.Future[BestMove] = asyncio.Future() + self.multipv = [{}] + + def post(self, info: InfoDict) -> None: + # Empty dictionary reserved for kork. + if not info: + return + + multipv = info.get("multipv", 1) + while len(self.multipv) < multipv: + self.multipv.append({}) + self.multipv[multipv - 1].update(info) + + self._queue.put_nowait(info) + + def _kork(self) -> None: + if not self._posted_kork: + self._posted_kork = True + self._queue.put_nowait({}) + + def set_finished(self, best: BestMove) -> None: + if not self._finished.done(): + self._finished.set_result(best) + self._kork() + + def set_exception(self, exc: Exception) -> None: + self._finished.set_exception(exc) + self._kork() + + @property + def info(self) -> InfoDict: + """ + A dictionary of aggregated information sent by the engine. This is + actually an alias for ``multipv[0]``. + """ + return self.multipv[0] + + def stop(self) -> None: + """Stops the analysis as soon as possible.""" + if self._stop and not self._posted_kork: + self._stop() + self._stop = None + + async def wait(self) -> BestMove: + """Waits until the analysis is finished.""" + return await self._finished + + async def get(self) -> InfoDict: + """ + Waits for the next dictionary of information from the engine and + returns it. + + It might be more convenient to use ``async for info in analysis: ...``. + + :raises: :exc:`cchess.engine.AnalysisComplete` if the analysis is + complete (or has been stopped) and all information has been + consumed. Use :func:`~cchess.engine.AnalysisResult.next()` if you + prefer to get ``None`` instead of an exception. + """ + if self._seen_kork: + raise AnalysisComplete() + + info = await self._queue.get() + if not info: + # Empty dictionary marks end. + self._seen_kork = True + await self._finished + raise AnalysisComplete() + + return info + + def would_block(self) -> bool: + """ + Checks if calling :func:`~cchess.engine.AnalysisResult.get()`, + calling :func:`~cchess.engine.AnalysisResult.next()`, + or advancing the iterator one step would require waiting for the + engine. + + These functions would return immediately if information is + pending (queue is not + :func:`empty `) or if the search + is finished. + """ + return not self._seen_kork and self._queue.empty() + + def empty(self) -> bool: + """ + Checks if all current information has been consumed. + + If the queue is empty, but the analysis is still ongoing, then further + information can become available in the future. + """ + return self._seen_kork or self._queue.qsize() <= self._posted_kork + + async def next(self) -> Optional[InfoDict]: + try: + return await self.get() + except AnalysisComplete: + return None + + def __aiter__(self) -> AnalysisResult: + return self + + async def __anext__(self) -> InfoDict: + try: + return await self.get() + except AnalysisComplete: + raise StopAsyncIteration + + def __enter__(self) -> AnalysisResult: + return self + + def __exit__(self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: + self.stop() + + +async def popen_uci(command: Union[str, List[str]], *, setpgrp: bool = False, **popen_args: Any) -> Tuple[asyncio.SubprocessTransport, UciProtocol]: + """ + Spawns and initializes a UCI engine. + + :param command: Path of the engine executable, or a list including the + path and arguments. + :param setpgrp: Open the engine process in a new process group. This will + stop signals (such as keyboard interrupts) from propagating from the + parent process. Defaults to ``False``. + :param popen_args: Additional arguments for + `popen `_. + Do not set ``stdin``, ``stdout``, ``bufsize`` or + ``universal_newlines``. + + Returns a subprocess transport and engine protocol pair. + """ + transport, protocol = await UciProtocol.popen(command, setpgrp=setpgrp, **popen_args) + try: + await protocol.initialize() + except: + transport.close() + raise + return transport, protocol + + +async def popen_xboard(command: Union[str, List[str]], *, setpgrp: bool = False, **popen_args: Any) -> Tuple[asyncio.SubprocessTransport, XBoardProtocol]: + """ + Spawns and initializes an XBoard engine. + + :param command: Path of the engine executable, or a list including the + path and arguments. + :param setpgrp: Open the engine process in a new process group. This will + stop signals (such as keyboard interrupts) from propagating from the + parent process. Defaults to ``False``. + :param popen_args: Additional arguments for + `popen `_. + Do not set ``stdin``, ``stdout``, ``bufsize`` or + ``universal_newlines``. + + Returns a subprocess transport and engine protocol pair. + """ + transport, protocol = await XBoardProtocol.popen(command, setpgrp=setpgrp, **popen_args) + try: + await protocol.initialize() + except: + transport.close() + raise + return transport, protocol + + +async def _async(sync: Callable[[], T]) -> T: + return sync() + + +class SimpleEngine: + """ + Synchronous wrapper around a transport and engine protocol pair. Provides + the same methods and attributes as :class:`cchess.engine.Protocol` + with blocking functions instead of coroutines. + + You may not concurrently modify objects passed to any of the methods. Other + than that, :class:`~cchess.engine.SimpleEngine` is thread-safe. When sending + a new command to the engine, any previous running command will be cancelled + as soon as possible. + + Methods will raise :class:`asyncio.TimeoutError` if an operation takes + *timeout* seconds longer than expected (unless *timeout* is ``None``). + + Automatically closes the transport when used as a context manager. + """ + + def __init__(self, transport: asyncio.SubprocessTransport, protocol: Protocol, *, timeout: Optional[float] = 10.0) -> None: + self.transport = transport + self.protocol = protocol + self.timeout = timeout + + self._shutdown_lock = threading.Lock() + self._shutdown = False + self.shutdown_event = asyncio.Event() + + self.returncode: concurrent.futures.Future[int] = concurrent.futures.Future() + + def _timeout_for(self, limit: Optional[Limit]) -> Optional[float]: + if self.timeout is None or limit is None or limit.time is None: + return None + return self.timeout + limit.time + + @contextlib.contextmanager + def _not_shut_down(self) -> Generator[None, None, None]: + with self._shutdown_lock: + if self._shutdown: + raise EngineTerminatedError("engine event loop dead") + yield + + @property + def options(self) -> MutableMapping[str, Option]: + with self._not_shut_down(): + coro = _async(lambda: copy.copy(self.protocol.options)) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + @property + def id(self) -> Mapping[str, str]: + with self._not_shut_down(): + coro = _async(lambda: self.protocol.id.copy()) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + def communicate(self, command_factory: Callable[[Protocol], BaseCommand[T]]) -> T: + with self._not_shut_down(): + coro = self.protocol.communicate(command_factory) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + def configure(self, options: ConfigMapping) -> None: + with self._not_shut_down(): + coro = asyncio.wait_for(self.protocol.configure(options), self.timeout) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + def send_opponent_information(self, *, opponent: Optional[Opponent] = None, engine_rating: Optional[int] = None) -> None: + with self._not_shut_down(): + coro = asyncio.wait_for( + self.protocol.send_opponent_information(opponent=opponent, engine_rating=engine_rating), + self.timeout) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + def ping(self) -> None: + with self._not_shut_down(): + coro = asyncio.wait_for(self.protocol.ping(), self.timeout) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + def play(self, board: cchess.Board, limit: Limit, *, game: object = None, info: Info = INFO_NONE, ponder: bool = False, draw_offered: bool = False, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}, opponent: Optional[Opponent] = None) -> PlayResult: + with self._not_shut_down(): + coro = asyncio.wait_for( + self.protocol.play(board, limit, game=game, info=info, ponder=ponder, draw_offered=draw_offered, root_moves=root_moves, options=options, opponent=opponent), + self._timeout_for(limit)) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + @typing.overload + def analyse(self, board: cchess.Board, limit: Limit, *, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> InfoDict: ... + @typing.overload + def analyse(self, board: cchess.Board, limit: Limit, *, multipv: int, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> List[InfoDict]: ... + @typing.overload + def analyse(self, board: cchess.Board, limit: Limit, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> Union[InfoDict, List[InfoDict]]: ... + def analyse(self, board: cchess.Board, limit: Limit, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> Union[InfoDict, List[InfoDict]]: + with self._not_shut_down(): + coro = asyncio.wait_for( + self.protocol.analyse(board, limit, multipv=multipv, game=game, info=info, root_moves=root_moves, options=options), + self._timeout_for(limit)) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + def analysis(self, board: cchess.Board, limit: Optional[Limit] = None, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[cchess.Move]] = None, options: ConfigMapping = {}) -> SimpleAnalysisResult: + with self._not_shut_down(): + coro = asyncio.wait_for( + self.protocol.analysis(board, limit, multipv=multipv, game=game, info=info, root_moves=root_moves, options=options), + self.timeout) # Timeout until analysis is *started* + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return SimpleAnalysisResult(self, future.result()) + + def send_game_result(self, board: cchess.Board, winner: Optional[Color] = None, game_ending: Optional[str] = None, game_complete: bool = True) -> None: + with self._not_shut_down(): + coro = asyncio.wait_for(self.protocol.send_game_result(board, winner, game_ending, game_complete), self.timeout) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + def quit(self) -> None: + with self._not_shut_down(): + coro = asyncio.wait_for(self.protocol.quit(), self.timeout) + future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) + return future.result() + + def close(self) -> None: + """ + Closes the transport and the background event loop as soon as possible. + """ + def _shutdown() -> None: + self.transport.close() + self.shutdown_event.set() + + with self._shutdown_lock: + if not self._shutdown: + self._shutdown = True + self.protocol.loop.call_soon_threadsafe(_shutdown) + + @classmethod + def popen(cls, Protocol: Type[Protocol], command: Union[str, List[str]], *, timeout: Optional[float] = 10.0, debug: Optional[bool] = None, setpgrp: bool = False, **popen_args: Any) -> SimpleEngine: + async def background(future: concurrent.futures.Future[SimpleEngine]) -> None: + transport, protocol = await Protocol.popen(command, setpgrp=setpgrp, **popen_args) + threading.current_thread().name = f"{cls.__name__} (pid={transport.get_pid()})" + simple_engine = cls(transport, protocol, timeout=timeout) + try: + await asyncio.wait_for(protocol.initialize(), timeout) + future.set_result(simple_engine) + returncode = await protocol.returncode + simple_engine.returncode.set_result(returncode) + finally: + simple_engine.close() + await simple_engine.shutdown_event.wait() + + return run_in_background(background, name=f"{cls.__name__} (command={command!r})", debug=debug) + + @classmethod + def popen_uci(cls, command: Union[str, List[str]], *, timeout: Optional[float] = 10.0, debug: Optional[bool] = None, setpgrp: bool = False, **popen_args: Any) -> SimpleEngine: + """ + Spawns and initializes a UCI engine. + Returns a :class:`~cchess.engine.SimpleEngine` instance. + """ + return cls.popen(UciProtocol, command, timeout=timeout, debug=debug, setpgrp=setpgrp, **popen_args) + + @classmethod + def popen_xboard(cls, command: Union[str, List[str]], *, timeout: Optional[float] = 10.0, debug: Optional[bool] = None, setpgrp: bool = False, **popen_args: Any) -> SimpleEngine: + """ + Spawns and initializes an XBoard engine. + Returns a :class:`~cchess.engine.SimpleEngine` instance. + """ + return cls.popen(XBoardProtocol, command, timeout=timeout, debug=debug, setpgrp=setpgrp, **popen_args) + + def __enter__(self) -> SimpleEngine: + return self + + def __exit__(self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: + self.close() + + def __repr__(self) -> str: + pid = self.transport.get_pid() # This happens to be thread-safe + return f"<{type(self).__name__} (pid={pid})>" + + +class SimpleAnalysisResult: + """ + Synchronous wrapper around :class:`~cchess.engine.AnalysisResult`. Returned + by :func:`cchess.engine.SimpleEngine.analysis()`. + """ + + def __init__(self, simple_engine: SimpleEngine, inner: AnalysisResult) -> None: + self.simple_engine = simple_engine + self.inner = inner + + @property + def info(self) -> InfoDict: + with self.simple_engine._not_shut_down(): + coro = _async(lambda: self.inner.info.copy()) + future = asyncio.run_coroutine_threadsafe(coro, self.simple_engine.protocol.loop) + return future.result() + + @property + def multipv(self) -> List[InfoDict]: + with self.simple_engine._not_shut_down(): + coro = _async(lambda: [info.copy() for info in self.inner.multipv]) + future = asyncio.run_coroutine_threadsafe(coro, self.simple_engine.protocol.loop) + return future.result() + + def stop(self) -> None: + with self.simple_engine._not_shut_down(): + self.simple_engine.protocol.loop.call_soon_threadsafe(self.inner.stop) + + def wait(self) -> BestMove: + with self.simple_engine._not_shut_down(): + future = asyncio.run_coroutine_threadsafe(self.inner.wait(), self.simple_engine.protocol.loop) + return future.result() + + def would_block(self) -> bool: + with self.simple_engine._not_shut_down(): + future = asyncio.run_coroutine_threadsafe(_async(self.inner.would_block), self.simple_engine.protocol.loop) + return future.result() + + def empty(self) -> bool: + with self.simple_engine._not_shut_down(): + future = asyncio.run_coroutine_threadsafe(_async(self.inner.empty), self.simple_engine.protocol.loop) + return future.result() + + def get(self) -> InfoDict: + with self.simple_engine._not_shut_down(): + future = asyncio.run_coroutine_threadsafe(self.inner.get(), self.simple_engine.protocol.loop) + return future.result() + + def next(self) -> Optional[InfoDict]: + with self.simple_engine._not_shut_down(): + future = asyncio.run_coroutine_threadsafe(self.inner.next(), self.simple_engine.protocol.loop) + return future.result() + + def __iter__(self) -> Iterator[InfoDict]: + with self.simple_engine._not_shut_down(): + self.simple_engine.protocol.loop.call_soon_threadsafe(self.inner.__aiter__) + return self + + def __next__(self) -> InfoDict: + try: + with self.simple_engine._not_shut_down(): + future = asyncio.run_coroutine_threadsafe(self.inner.__anext__(), self.simple_engine.protocol.loop) + return future.result() + except StopAsyncIteration: + raise StopIteration + + def __enter__(self) -> SimpleAnalysisResult: + return self + + def __exit__(self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: + self.stop() diff --git a/zoo/board_games/chinesechess/envs/cchess/svg.py b/zoo/board_games/chinesechess/envs/cchess/svg.py new file mode 100644 index 000000000..96db3e904 --- /dev/null +++ b/zoo/board_games/chinesechess/envs/cchess/svg.py @@ -0,0 +1,405 @@ +import xml.etree.ElementTree as ET +import os +import copy +import json +import cchess + +from typing import Optional, Tuple, Dict, Union + +SQUARE_SIZE = 100 +MARGIN = 20 + +PIECES = { + "P": """""",# noqa: E501 + "R": """""",# noqa: E501 + "N": """""", # noqa: E501 + "B": """""", # noqa: E501 + "A": """""", # noqa: E501 + "K": """""", # noqa: E501 + "C": """""", # noqa: E501 + "p": """""", # noqa: E501 + "r": """""", # noqa: E501 + "n": """""", # noqa: E501 + "b": """""", # noqa: E501 + "a": """""", # noqa: E501 + "k": """""", # noqa: E501 + "c": """""" # noqa: E501 +} + +COORDS = { + "0": """""", # noqa: E501 + "1": """""", # noqa: E501 + "2": """""", # noqa: E501 + "3": """""", # noqa: E501 + "4": """""", # noqa: E501 + "5": """""", # noqa: E501 + "6": """""", # noqa: E501 + "7": """""", # noqa: E501 + "8": """""", # noqa: E501 + "9": """""", # noqa: E501 + "a": """""", # noqa: E501 + "b": """""", # noqa: E501 + "c": """""", # noqa: E501 + "d": """""", # noqa: E501 + "e": """""", # noqa: E501 + "f": """""", # noqa: E501 + "g": """""", # noqa: E501 + "h": """""", # noqa: E501 + "i": """""", # noqa: E501 +} + +TRADITIONAL_COORDINATES = [""" + + + + + + + + + + + + + + + + + + + + + +""", # noqa: E501 +""" + + + + + + + + + + + + + + + + + + + + + +"""] # noqa: E501 + +XX = """""" # noqa: E501 + +COORDS_DELTA_Y = { + "a": 0, + "b": -16, + "c": 0, + "d": -16, + "e": 0, + "f": -16, + "g": -10, + "h": -16, + "i": -16 +} + + +def _coord(text, dx, dy): + t = ET.Element("g", { + "transform": f"translate({dx}, {dy})", + "fill-rule": "evenodd" + }) + t.append(ET.fromstring(COORDS[text])) + return t + + +RIVER = """""" # noqa: E501 + + +def _svg(viewbox: tuple, size: Optional[Tuple[int, int]]) -> ET.Element: + x, y, w, h = viewbox + svg = ET.Element("svg", { + "xmlns": "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + "version": "1.2", + "baseProfile": "tiny", + "viewBox": f"{x:d} {y:d} {w:d} {h:d}", + }) + + if size is not None: + svg.set("width", str(size[0])) + svg.set("height", str(size[1])) + + return svg + + +class SvgWrapper(str): + def _repr_svg_(self): + return self + + +def _attrs(attrs: Dict[str, Union[str, int, float, None]]) -> Dict[str, str]: + return {k: str(v) for k, v in attrs.items() if v is not None} + + +def piece(piece: cchess.Piece, size: Optional[int] = None) -> str: + """ + Renders the given :class:`cchess.Piece` as an SVG image. + """ + svg = _svg((0, 0, SQUARE_SIZE, SQUARE_SIZE), (size, size)) + svg.append(ET.fromstring(PIECES[piece.symbol()])) + return SvgWrapper(ET.tostring(svg).decode("utf-8")) + + +def board(board: cchess.BaseBoard, *, + size: Optional[int] = 450, + orientation: cchess.Color = cchess.RED, + coordinates: bool = True, + axes_type: int = 0, + lastmove: Optional[cchess.Move] = None, + checkers: Optional[cchess.IntoSquareSet] = None, + squares: Optional[cchess.IntoSquareSet] = None, + style: Optional[str] = None) -> str: + assert axes_type in [0, 1], f"axes_type must value 0 or 1, got {axes_type}" + # Board + svg = _svg((-600, -600, 1200, 1200), (size, size)) + + if style: + ET.SubElement(svg, "style").text = style + + # def + defs = ET.SubElement(svg, "defs") + mark4 = ET.SubElement(defs, "g", {"id": "mark4"}) + mino32 = ET.SubElement(defs, "g", {"id": "mino32"}) + for piece_color in cchess.COLORS: + for piece_type in cchess.PIECE_TYPES: + if board.pieces_mask(piece_type, piece_color): + defs.append(ET.fromstring(PIECES[cchess.Piece(piece_type, piece_color).symbol()])) + if lastmove: + defs.append( + ET.fromstring("""""")) # noqa: E501 + if squares: + defs.append(ET.fromstring(XX)) + ET.SubElement(svg, "rect", {"id": "board", "width": "1200", "height": "1200", + "x": "-600", "y": "-600", + "stroke-width": "15", "stroke": "#cd853f", + "fill": "#eb5"}) + boarder = ET.SubElement(svg, "g", _attrs({"fill": "none", "stroke-width": 4, "stroke": "#000"})) + ET.SubElement(boarder, "rect", _attrs({"width": 816, "height": 916, "x": -408, "y": -458})) + ET.SubElement(boarder, "rect", _attrs({"width": 800, "height": 900, "x": -400, "y": -450})) + mark2R = ET.SubElement(mark4, "g", {"id": "mark2R"}) + mark2R.append(ET.fromstring("""""")) + ET.SubElement(mark2R, "use", {"xlink:href": "#mark", "transform": "rotate(-90,0,0)"}) + mark2L = ET.SubElement(mark4, "g", {"id": "mark2L"}) + ET.SubElement(mark2L, "use", {"xlink:href": "#mark2R", "transform": "rotate(180)"}) + mino16 = ET.SubElement(mino32, "g", {"id": "mino16"}) + ET.SubElement(mino32, "use", {"xlink:href": "#mino16", "transform": "scale(-1,1)"}) + line4 = ET.SubElement(mino16, "g", {"id": "line4"}) + line4.append(ET.fromstring("""""")) + ET.SubElement(line4, "use", _attrs({"xlink:href": "#line", "x": 100})) + ET.SubElement(line4, "use", _attrs({"xlink:href": "#line", "x": 200})) + ET.SubElement(line4, "use", _attrs({"xlink:href": "#line", "x": 300})) + ET.SubElement(mino16, "use", {"xlink:href": "#line4", "transform": "rotate(90,200,200)"}) + ET.SubElement(mino16, "use", _attrs({"xlink:href": "#mark2R", "x": 0, "y": 100})) + ET.SubElement(mino16, "use", _attrs({"xlink:href": "#mark4", "x": 200, "y": 100})) + ET.SubElement(mino16, "use", _attrs({"xlink:href": "#mark4", "x": 300, "y": 200})) + ET.SubElement(mino16, "use", _attrs({"xlink:href": "#mark2L", "x": 400, "y": 100})) + mino16.append(ET.fromstring("""""")) + ET.SubElement(mino16, "use", {"xlink:href": "#diagonal", "transform": "rotate(90,0,300)"}) + board_obj = ET.SubElement(defs, "g", {"id": "board"}) + halfboard = ET.SubElement(board_obj, "g", {"id": "halfboard"}) + ET.SubElement(halfboard, "use", _attrs({"xlink:href": "#mino32", "y": 50})) + ET.SubElement(board_obj, "use", {"xlink:href": "#halfboard", "transform": "rotate(180)"}) + ET.SubElement(svg, "use", {"xlink:href": "#board"}) + + # River + svg.append(ET.fromstring(RIVER)) + + if coordinates: + if axes_type == 0: + # Column Coordinate + col_coord = ET.SubElement(svg, "g") + upper_coord = ET.SubElement(col_coord, "g", {"transform": "translate(-428,-560)", "id": "upper-coord"}) + for index, col in enumerate(cchess.COLUMN_NAMES): + x = index if orientation else 8 - index + y = COORDS_DELTA_Y[col] + col_letter = _coord(col, x * 102, y) + upper_coord.append(col_letter) + ET.SubElement(col_coord, "use", {"xlink:href": "#upper-coord", "transform": "translate(0, 1080)"}) + + # Row Coordinate + row_coord = ET.SubElement(svg, "g") + left_coord = ET.SubElement(row_coord, "g", {"transform": "translate(-500,428)", "id": "left-coord"}) + for index, row in enumerate(cchess.ROW_NAMES): + y = index if orientation else 9 - index + row_number = _coord(row, 0, - 102 * y) + left_coord.append(row_number) + ET.SubElement(row_coord, "use", {"xlink:href": "#left-coord", "transform": "translate(970,0)"}) + else: + svg.append(ET.fromstring(TRADITIONAL_COORDINATES[orientation])) + # Pieces + for square, bb in enumerate(cchess.BB_SQUARES): + col_index = cchess.square_column(square) + row_index = cchess.square_row(square) + x = (col_index if orientation else 8 - col_index) * SQUARE_SIZE - 450 + y = (9 - row_index if orientation else row_index) * SQUARE_SIZE - 500 + piece = board.piece_at(square) + if piece: + href = f"#{cchess.COLOR_NAMES[piece.color]}-{cchess.PIECE_NAMES[piece.piece_type]}" + ET.SubElement(svg, "use", { + "href": href, + "xlink:href": href, + "transform": f"translate({x:d}, {y:d})", + }) + if squares is not None and square in squares: + ET.SubElement(svg, "use", _attrs({ + "href": "#xx", + "xlink:href": "#xx", + "x": x + 5, + "y": y + 5 + })) + + # Lastmove + if lastmove is not None: + col_index = cchess.square_column(lastmove.from_square) + row_index = cchess.square_row(lastmove.from_square) + x = (col_index if orientation else 8 - col_index) * SQUARE_SIZE - 450 + y = (9 - row_index if orientation else row_index) * SQUARE_SIZE - 500 + corners = ET.SubElement(svg, "g", {"transform": f"translate({x},{y})"}) + top_corners = ET.SubElement(corners, "g", {"id": "from-top-corners"}) + ET.SubElement(top_corners, "use", {"xlink:href": "#corner"}) + ET.SubElement(top_corners, "use", {"xlink:href": "#corner", + "transform": "translate(100,0) scale(-1,1)"}) + ET.SubElement(corners, "use", {"xlink:href": "#from-top-corners", + "transform": "translate(0, 100) scale(1,-1)"}) + col_index = cchess.square_column(lastmove.to_square) + row_index = cchess.square_row(lastmove.to_square) + x = (col_index if orientation else 8 - col_index) * SQUARE_SIZE - 450 + y = (9 - row_index if orientation else row_index) * SQUARE_SIZE - 500 + corners = ET.SubElement(svg, "g", {"transform": f"translate({x},{y})"}) + top_corners = ET.SubElement(corners, "g", {"id": "to-top-corners"}) + ET.SubElement(top_corners, "use", {"xlink:href": "#corner"}) + ET.SubElement(top_corners, "use", {"xlink:href": "#corner", + "transform": "translate(100,0) scale(-1,1)"}) + ET.SubElement(corners, "use", {"xlink:href": "#to-top-corners", + "transform": "translate(0, 100) scale(1,-1)"}) + # Check + if checkers is not None: + for square in checkers: + piece = board.piece_at(square) + if not piece: + continue + color = piece.color + col_index = cchess.square_column(square) + row_index = cchess.square_row(square) + + x = (col_index if orientation else 8 - col_index) * SQUARE_SIZE - 400 + y = (9 - row_index if orientation else row_index) * SQUARE_SIZE - 450 + ET.SubElement(svg, "circle", _attrs({ + "cx": x, + "cy": y, + "r": 45, + "stroke": cchess.COLOR_NAMES[color], + "stroke-width": 10, + "fill": "none" + })) + return SvgWrapper(ET.tostring(svg).decode("utf-8")) + + +def to_gif(board: cchess.Board, filename, *, + size: Optional[int] = 450, + duration: int = 3, + orientation: cchess.Color = cchess.RED, + coordinates: bool = True, + axes_type: int = 0, + lastmove: Optional[bool] = True, + checkers: Optional[bool] = True, + style: Optional[str] = None): + try: + import copy + import numpy as np + from PIL import Image + import io + import cairosvg + import imageio + import tqdm + except ImportError: + return + if not board.move_stack: + return + stack = copy.copy(getattr(board, "_stack")) + stack.append(getattr(board, "_board_state")()) + new_board = cchess.Board() + stack[0].restore(new_board) + gif_images = [] + svg = cchess.svg.board(new_board, size=size, + orientation=orientation, + coordinates=coordinates, + axes_type=axes_type, + lastmove=None, + style=style) + png_bytes = cairosvg.svg2png(svg) + png_array = np.array(Image.open(io.BytesIO(png_bytes))) + gif_images.append(png_array) + for i, move in tqdm.tqdm(enumerate(board.move_stack), total=len(board.move_stack)): + stack[i + 1].restore(new_board) + svg = cchess.svg.board(new_board, size=size, + orientation=orientation, + coordinates=coordinates, + axes_type=axes_type, + lastmove=move if lastmove else None, + checkers=new_board.checkers() if checkers else None, + style=style) + png_bytes = cairosvg.svg2png(svg) + png_array = np.array(Image.open(io.BytesIO(png_bytes))) + gif_images.append(png_array) + imageio.mimsave(filename, gif_images, duration=duration) + + +def _get_pieces(state): + pieces_dict = { + "pawn": state.pawns, + "rook": state.rooks, + "knight": state.knights, + "bishop": state.bishops, + "advisor": state.advisors, + "king": state.kings, + "cannon": state.cannons + } + pieces = {"red": {}, "black": {}} + for key, value in pieces_dict.items(): + red_pieces = list(cchess.scan_forward(value & state.occupied_r)) + black_pieces = list(cchess.scan_forward(value & state.occupied_b)) + pieces["red"][key] = red_pieces + pieces["black"][key] = black_pieces + return pieces + + +def to_html(board: cchess.Board, filename, title=None): + title = title or "Chinese Chess Board" + states = [] + notations = [] + stack = copy.copy(getattr(board, "_stack")) + stack.append(getattr(board, "_board_state")()) + new_board = cchess.Board() + stack[0].restore(new_board) + pieces = _get_pieces(stack[0]) + states.append({"pieces": pieces, "lastmove": None}) + for i, move in enumerate(board.move_stack): + notation = new_board.move_to_notation(move) + notations.append(notation) + stack[i + 1].restore(new_board) + pieces = _get_pieces(stack[i + 1]) + states.append({"pieces": pieces, "lastmove": [move.from_square, move.to_square]}) + with open(os.path.join(os.path.dirname(__file__), "resources", "board.html"), "r") as f: + html = f.read() + html += (f"") + with open(filename, "w") as f: + f.write(html) diff --git a/zoo/board_games/chinesechess/envs/cchess_env.py b/zoo/board_games/chinesechess/envs/cchess_env.py new file mode 100644 index 000000000..e3d55ba19 --- /dev/null +++ b/zoo/board_games/chinesechess/envs/cchess_env.py @@ -0,0 +1,870 @@ +""" +cchess库来自:https://github.com/windshadow233/python-chinese-chess/tree/main/cchess +修改了cchess库一处代码,提升位运算速度 +2346行:def popcount(x: BitBoard) -> int: + \"\"\" + 计算 BitBoard 中 1 的个数 + Python 3.10+ 原生 bit_count() 比 bin().count('1') 快 10+ 倍 + \"\"\" + return x.bit_count() + + +pikafish引擎可以自行去下载: +https://github.com/official-pikafish/Pikafish +https://www.pikafish.com/ + +Overview: + 中国象棋环境,封装 cchess 库以适配 LightZero 的 BaseEnv 接口 + 中国象棋是一个双人对弈游戏,棋盘为 9x10(9列10行) + +Mode: + - ``self_play_mode``: 自对弈模式,用于 AlphaZero/MuZero 数据生成 + - ``play_with_bot_mode``: 与内置 bot 对战模式 + - ``eval_mode``: 评估模式 + +Observation Space: + 字典结构,包含以下键: + - ``observation``: shape (N, 10, 9), float32. + - N = 14 * stack_obs_num + 1 = 14 * 4 + 1 = 57 + - 前 56 个通道为 4 帧历史观测堆叠,每一帧包含 14 个特征平面 (7种棋子 x 2种颜色) + - 最后一个通道为当前玩家颜色平面 (全1表示红方/先手,全0表示黑方/后手) + - 采用 Canonical View (规范视角):始终以当前玩家视角观察棋盘 (自己棋子在下方/前7层) + - ``action_mask``: shape (8100,), int8. 合法动作掩码,1表示合法,0表示非法 + - ``board``: shape (10, 9), int8. 棋盘可视化表示,用于调试或渲染 + - ``to_play``: shape (1,), int32. 当前该谁走 (-1: 结束/未知, 0: 黑方, 1: 红方) + +Action Space: + - Discrete(8100). 动作是移动的索引 (from_square * 90 + to_square) + - 棋盘有 90 个位置 (0-89),动作空间涵盖所有可能的起点-终点组合 (90 * 90 = 8100) + - 实际合法动作远小于 8100 (通常几十到一百多) + +Reward Space: + - Box(-1, 1, (1,), float32). + - +1: 当前玩家获胜 (Checkmate) + - -1: 当前玩家失败 (被Checkmate或长将违规) + - 0: 平局 (长闲循环、自然限招、无子可动等) 或 游戏未结束 +""" + +import copy +import os +from typing import List, Any, Tuple, Optional +from collections import deque + +import numpy as np +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.utils import ENV_REGISTRY +from ditk import logging +from easydict import EasyDict +from gymnasium import spaces + +from . import cchess + + +def move_to_action(move: cchess.Move) -> int: + """将 Move 对象转换为动作索引""" + return move.from_square * 90 + move.to_square + + +def action_to_move(action: int) -> cchess.Move: + """将动作索引转换为 Move 对象""" + from_square = action // 90 + to_square = action % 90 + return cchess.Move(from_square, to_square) + + +@ENV_REGISTRY.register('cchess') +class ChineseChessEnv(BaseEnv): + config = dict( + env_id="ChineseChess", + battle_mode='self_play_mode', + battle_mode_in_simulation_env='self_play_mode', + render_mode=None, # 'human', 'svg', 'rgb_array' + replay_path=None, + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + uci_engine_path=None, # UCI引擎路径,如 'pikafish' 或 '/path/to/pikafish' + engine_depth=5, # 引擎搜索深度,通常1-20,深度越大越强 + channel_last=False, + scale=False, + stop_value=2, + max_episode_steps=500, # 最大回合数限制,防止无限回合 + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: dict = None) -> None: + self.cfg = cfg + self.channel_last = cfg.channel_last + self.scale = cfg.scale + + self.render_mode = cfg.render_mode + self.replay_path = cfg.replay_path + + self.battle_mode = cfg.battle_mode + assert self.battle_mode in ['self_play_mode', 'play_with_bot_mode', 'eval_mode'] + self.battle_mode_in_simulation_env = 'self_play_mode' + + self.agent_vs_human = cfg.agent_vs_human + self.prob_random_agent = cfg.prob_random_agent + self.prob_expert_agent = cfg.prob_expert_agent + + # UCI引擎配置 + self.uci_engine_path = cfg.get('uci_engine_path', None) + self.engine_depth = cfg.get('engine_depth', 5) + self.engine = None + + # 初始化UCI引擎(如果配置了) + if self.uci_engine_path: + try: + from .cchess import engine + self.engine = engine.SimpleEngine.popen_uci(self.uci_engine_path) + logging.info(f"UCI引擎加载成功: {self.uci_engine_path}") + except Exception as e: + logging.warning(f"UCI引擎加载失败: {e},将使用随机策略") + self.engine = None + + # 最大步数限制 + self.max_episode_steps = cfg.max_episode_steps + self.current_step = 0 + + # 渲染相关 + self.frames = [] # 用于保存渲染图像帧 + + # 初始化棋盘 + self.board = cchess.Board() + + self.players = [1, 2] # 1: 红方(RED), 2: 黑方(BLACK) + self._current_player = 1 + self._env = self + + # 历史观测堆叠 + self.stack_obs_num = 4 + self.obs_buffer = deque(maxlen=self.stack_obs_num) + + # 预计算:Board 棋子遍历所需的查找表 + self._piece_types = [cchess.PAWN, cchess.ROOK, cchess.KNIGHT, cchess.CANNON, + cchess.ADVISOR, cchess.BISHOP, cchess.KING] + self._colors = [cchess.RED, cchess.BLACK] + + # 预计算:BitBoard位索引到(row, col)的映射 + self._square_to_coord = np.array([(s // 9, s % 9) for s in range(90)], dtype=np.int32) + + def _mirror_action(self, action: int) -> int: + """ + 将动作在镜像坐标系统中转换(用于黑方视角转换) + + 当黑方观测被旋转180度时,动作空间也需要相应转换。 + 使用 cchess.square_mirror() 对起点和终点坐标进行镜像。 + + Args: + action: 原始动作索引 (from_square * 90 + to_square) + + Returns: + 镜像后的动作索引 + """ + from_square = action // 90 + to_square = action % 90 + from_square_mirror = cchess.square_mirror(from_square) + to_square_mirror = cchess.square_mirror(to_square) + return from_square_mirror * 90 + to_square_mirror + + def _get_raw_planes(self) -> np.ndarray: + """ + 获取当前棋盘的原始平面表示(固定语义:前7层红方,后7层黑方) + 不包含颜色通道,不进行视角转换 + + 优化: + 使用 lookup table 替代 python scan_forward 循环中的重复除法/取模计算 + 虽然 scan_forward 本身在 Python 中循环,但减少了内部计算 + """ + state = np.zeros((14, 10, 9), dtype=np.float32) + + # 红方棋子 (前7层) + for i, piece_type in enumerate(self._piece_types): + mask = self.board.pieces_mask(piece_type, cchess.RED) + if mask: + # cchess.scan_forward 是 generator,我们手动解开以稍微加速 + # 或者更简单的:获取所有 set bits + # 由于 cchess 库限制,这里还是使用 scan_forward,但后续坐标计算查表 + for square in cchess.scan_forward(mask): + r, c = self._square_to_coord[square] + state[i, r, c] = 1 + + # 黑方棋子 (后7层) + for i, piece_type in enumerate(self._piece_types): + mask = self.board.pieces_mask(piece_type, cchess.BLACK) + if mask: + for square in cchess.scan_forward(mask): + r, c = self._square_to_coord[square] + state[i + 7, r, c] = 1 + + return state + + def _update_obs_buffer(self): + """更新观测缓存""" + planes = self._get_raw_planes() + self.obs_buffer.append(planes) + + def _player_step(self, action: int, flag: str, is_canonical_action: bool = True) -> BaseEnvTimestep: + """ + 执行一步棋 + + Args: + action: 动作索引 + flag: 标识字符串,用于日志记录 + is_canonical_action: 动作是否来自规范视角(Canonical View) + - True: 动作来自策略网络(规范视角),黑方时需要镜像转换 + - False: 动作来自真实棋盘(如bot),不需要转换 + """ + # 关键修复:只有规范视角的动作在黑方时才需要转换 + if self._current_player == 2 and is_canonical_action: # 黑方且是规范视角动作 + action_real = self._mirror_action(action) + else: # 红方 或 非规范视角动作(如bot) + action_real = action + + legal_actions = self.legal_actions + + if action_real not in legal_actions: + logging.warning( + f"非法动作: {action} (real: {action_real}), 合法动作有 {len(legal_actions)} 个。" + f"标志: {flag}, 玩家: {self._current_player}. 随机选择一个合法动作。" + ) + action_real = self.random_action(canonical=False) # 回退时使用真实坐标 + + # 保存执行动作的玩家(用于奖励计算) + acting_player = self._current_player + + move = action_to_move(action_real) # 使用真实坐标 + self.board.push(move) + + # 增加步数计数 + self.current_step += 1 + + # board.push() 已经自动切换了 turn,需要同步更新 _current_player + self._current_player = 1 if self.board.turn else 2 + + # 更新观测历史 + self._update_obs_buffer() + + # 检查游戏是否结束 + done = self.board.is_game_over() + outcome = self.board.outcome() + + # 检查是否达到最大步数 + if self.current_step >= self.max_episode_steps: + done = True + outcome = None # 达到最大步数视为平局 + + # 默认 reward 为 0.0 (游戏未结束或和棋) + reward_scalar = 0.0 + + if done: + # [DEBUG] 详细打印游戏结束原因,排查全和棋问题 + termination_reason = outcome.termination if outcome else "MaxSteps/Unknown" + + if outcome and outcome.winner is not None: + # 有明确的胜者,奖励从执行动作的玩家视角计算 + winner_info = "RED" if outcome.winner == cchess.RED else "BLACK" + if outcome.winner == cchess.RED: + # 红方胜 + reward_scalar = 1.0 if acting_player == 1 else -1.0 + else: + # 黑方胜 + reward_scalar = -1.0 if acting_player == 1 else 1.0 + logging.info(f"[ENV_DEBUG] Game Won! Winner: {winner_info}, ActingPlayer: {acting_player}, Reward: {reward_scalar}, Reason: {termination_reason}, Steps: {self.current_step}") + else: + # [优化策略] 统一判和逻辑:循环局面、最大步数、自然限招等均视为和棋 (0.0) + logging.info(f"[ENV_DEBUG] Game Ended. Reason: {termination_reason}, Game DRAW, Steps: {self.current_step}") + + # 对外接口仍然使用 shape (1,) 的 ndarray + reward = np.array([reward_scalar], dtype=np.float32) + + info = {} + obs = self.observe() + + return BaseEnvTimestep(obs, reward, done, info) + + def step(self, action: int) -> BaseEnvTimestep: + """ + 环境的 step 函数 + """ + if self.battle_mode == 'self_play_mode': + if self.prob_random_agent > 0: + if np.random.rand() < self.prob_random_agent: + action = self.random_action(canonical=True) # 规范视角的随机动作 + elif self.prob_expert_agent > 0: + if np.random.rand() < self.prob_expert_agent: + action = self.random_action(canonical=True) # TODO: 可以接入更强的 bot + + flag = "agent" + # 自我对弈模式:动作来自策略网络(规范视角),需要转换 + timestep = self._player_step(action, flag, is_canonical_action=True) + + if timestep.done: + # 【修复】在自我对弈中,使用规范视角(canonical view) + # reward 已经是从执行动作的玩家(当前玩家)视角,直接使用 + # 不需要转换为 player 1 视角,因为观察也是规范视角 + reward_scalar = float(timestep.reward[0]) + timestep.info['eval_episode_return'] = reward_scalar + + return timestep + + elif self.battle_mode == 'play_with_bot_mode': + # 玩家1的回合 (agent) + flag = "bot_agent" + timestep_player1 = self._player_step(action, flag, is_canonical_action=True) + + if timestep_player1.done: + # player 1 执行后游戏结束,reward 已经是 player 1 视角 + timestep_player1.info['eval_episode_return'] = float(timestep_player1.reward[0]) + timestep_player1.obs['to_play'] = np.array([-1], dtype=np.int32) + return timestep_player1 + + # 玩家2(bot)的回合 - bot的动作来自真实棋盘,不需要转换 + bot_action = self.bot_action() # 使用UCI引擎或随机策略 + flag = "bot_bot" + timestep_player2 = self._player_step(bot_action, flag, is_canonical_action=False) + + # player 2 执行后游戏结束,reward 是 player 2 视角,需要转换为 player 1 视角 + reward_scalar = float(timestep_player2.reward[0]) + timestep_player2.info['eval_episode_return'] = -reward_scalar + timestep_player2 = timestep_player2._replace(reward=-timestep_player2.reward) + # [修正] 在 eval_mode 下,返回给 agent 的 observation 应该是轮到 agent (Player 1) 走 + # 所以 to_play 应该是 1 (RED),而不是 -1 + timestep_player2.obs['to_play'] = np.array([1], dtype=np.int32) + + return timestep_player2 + + elif self.battle_mode == 'eval_mode': + # 玩家1的回合 (agent) + flag = "eval_agent" + timestep_player1 = self._player_step(action, flag, is_canonical_action=True) + + if timestep_player1.done: + # player 1 执行后游戏结束,reward 已经是 player 1 视角 + timestep_player1.info['eval_episode_return'] = float(timestep_player1.reward[0]) + timestep_player1.obs['to_play'] = np.array([-1], dtype=np.int32) + return timestep_player1 + + # 玩家2的回合 (bot 或 human) - bot/human的动作来自真实棋盘,不需要转换 + if self.agent_vs_human: + bot_action = self.human_to_action() + else: + bot_action = self.bot_action() # 使用UCI引擎或随机策略 + + flag = "eval_bot" + timestep_player2 = self._player_step(bot_action, flag, is_canonical_action=False) + + # player 2 执行后游戏结束,reward 是 player 2 视角,需要转换为 player 1 视角 + reward_scalar = float(timestep_player2.reward[0]) + timestep_player2.info['eval_episode_return'] = -reward_scalar + timestep_player2 = timestep_player2._replace(reward=-timestep_player2.reward) + # [修正] 在 eval_mode 下,返回给 agent 的 observation 应该是轮到 agent (Player 1) 走 + # 所以 to_play 应该是 1 (RED),而不是 -1 + timestep_player2.obs['to_play'] = np.array([1], dtype=np.int32) + + return timestep_player2 + + def reset(self, start_player_index: int = 0, init_state: Optional[str] = None) -> dict: + """ + 重置环境 + """ + if init_state is None: + self.board = cchess.Board() + else: + self.board = cchess.Board(fen=init_state) + + self.players = [1, 2] + self.start_player_index = start_player_index + + # 重置步数计数器 + self.current_step = 0 + + # 清空渲染帧 + self.frames = [] + + # 确保 _current_player 与 board.turn 保持一致 + # board.turn: RED=True, BLACK=False + self._current_player = 1 if self.board.turn else 2 + + # 重置历史观测 + self.obs_buffer.clear() + # 填充初始帧 (使用全0或初始状态重复) + init_planes = self._get_raw_planes() + for _ in range(self.stack_obs_num): + self.obs_buffer.append(init_planes) + + # 设置动作空间和观察空间 + self._action_space = spaces.Discrete(90 * 90) # 8100 个可能的动作 + self._reward_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32) + + # 计算Observation Shape: (14 * stack + 1, 10, 9) + obs_channels = 14 * self.stack_obs_num + 1 + self._observation_space = spaces.Dict( + { + "observation": spaces.Box(low=0, high=1, shape=(obs_channels, 10, 9), dtype=np.float32), + "action_mask": spaces.Box(low=0, high=1, shape=(90 * 90,), dtype=np.int8), + "board": spaces.Box(low=0, high=7, shape=(10, 9), dtype=np.int8), + "current_player_index": spaces.Box(low=0, high=1, shape=(1,), dtype=np.int32), # 0 或 1 + "to_play": spaces.Box(low=-1, high=2, shape=(1,), dtype=np.int32), # -1, 1, 或 2 + } + ) + + obs = self.observe() + return obs + + def current_state(self) -> Tuple[np.ndarray, np.ndarray]: + """ + 获取当前堆叠和转换后的状态 + """ + # 1. 转换视角 (Canonical View) + # 如果是黑方,需要将红方/黑方通道互换,并旋转棋盘 + stacked_obs = [] + for planes in self.obs_buffer: + if self._current_player == 2: # 黑方 + # 原始: [0-6: 红, 7-13: 黑] + # 目标: [0-6: 黑, 7-13: 红] (视角转换: 己方在前) + red_planes = planes[:7] + black_planes = planes[7:] + + # 交换并旋转 180 度 + # np.rot90(x, 2, axes=(1, 2)) 等价于旋转180度 + new_own = np.rot90(black_planes, 2, axes=(1, 2)) + new_opp = np.rot90(red_planes, 2, axes=(1, 2)) + + transformed_planes = np.concatenate([new_own, new_opp], axis=0) + stacked_obs.append(transformed_planes) + else: # 红方 + # 原始即为目标: [0-6: 红(己), 7-13: 黑(敌)] + stacked_obs.append(planes) + + # 2. 堆叠历史帧 + # shape: (14 * stack, 10, 9) + state = np.concatenate(stacked_obs, axis=0) + + # 3. 添加颜色/ToPlay通道 (1层) + # 在 Canonical View 下,通常网络总是视为"执红先手"视角 + # 但添加一个 feature map 全 1 (current) 或其他标记也是常见的 + # 这里保持原逻辑,如果是 player 1 (Red) 则全1,否则全0? + # 不,既然已经旋转了视角,颜色通道应该表示 "当前是谁的回合" 还是 "我是谁"? + # AlphaZero中,颜色通道是 constant 1 (if P1) or 0 (if P2). + # 但如果视角统一了,这个通道可以帮助区分先后手优势。 + color_plane = np.zeros((1, 10, 9), dtype=np.float32) + if self._current_player == 1: + color_plane[:] = 1.0 + + state = np.concatenate([state, color_plane], axis=0) + + if self.scale: + scale_state = state / 2 # 简单缩放,实际上binary plane不需要 + else: + scale_state = state + + if self.channel_last: + return np.transpose(state, [1, 2, 0]), np.transpose(scale_state, [1, 2, 0]) + else: + return state, scale_state + + def observe(self) -> dict: + """ + 返回观察 + + 关键修复:对于黑方,需要将action_mask也进行镜像转换, + 使其与旋转后的观测空间保持一致。 + """ + legal_actions_list = self.legal_actions + + action_mask = np.zeros(90 * 90, dtype=np.int8) + + # 关键修复:如果是黑方,action_mask 需要镜像 + if self._current_player == 2: # 黑方 + for action in legal_actions_list: + action_mirror = self._mirror_action(action) + action_mask[action_mirror] = 1 + else: # 红方 + for action in legal_actions_list: + action_mask[action] = 1 + + # 获取棋盘的可视化表示 + board_visual = np.zeros((10, 9), dtype=np.int8) + for square in range(90): + piece = self.board.piece_at(square) + if piece: + row = cchess.square_row(square) + col = cchess.square_column(square) + # 棋子类型编码:1-7 + board_visual[row, col] = piece.piece_type + + if self.battle_mode in ['play_with_bot_mode', 'eval_mode']: + return { + "observation": self.current_state()[1], + "action_mask": action_mask, + "board": board_visual, + "current_player_index": np.array([self.players.index(self._current_player)], dtype=np.int32), + "to_play": np.array([-1], dtype=np.int32) + } + else: # self_play_mode + return { + "observation": self.current_state()[1], + "action_mask": action_mask, + "board": board_visual, + "current_player_index": np.array([self.players.index(self._current_player)], dtype=np.int32), + "to_play": np.array([self._current_player], dtype=np.int32) + } + + @property + def legal_actions(self) -> List[int]: + """ + 返回所有合法动作的索引列表 + """ + legal_moves = list(self.board.legal_moves) + return [move_to_action(move) for move in legal_moves] + + def get_done_winner(self) -> Tuple[bool, int]: + """ + 检查游戏是否结束并返回胜者 + Returns: + - done: 游戏是否结束 + - winner: 胜者,1 表示红方,2 表示黑方,-1 表示和棋或游戏未结束 + """ + # 检查是否达到最大步数 + if self.current_step >= self.max_episode_steps: + return True, -1 # 达到最大步数,视为平局 + + done = self.board.is_game_over() + if not done: + return False, -1 + + outcome = self.board.outcome() + if outcome is None: + return done, -1 + + if outcome.winner is None: + return True, -1 # 和棋 + elif outcome.winner == cchess.RED: + return True, 1 # 红方胜 + else: + return True, 2 # 黑方胜 + + def get_done_reward(self) -> Tuple[bool, Optional[int]]: + """ + 检查游戏是否结束并从玩家1的视角返回奖励 + """ + done, winner = self.get_done_winner() + if not done: + return False, None + + if winner == 1: + reward = 1 + elif winner == 2: + reward = -1 + else: + reward = 0 + + return done, reward + + def random_action(self, canonical: bool = False) -> int: + """ + 随机选择一个合法动作 + + Args: + canonical: 是否返回规范视角的动作 + - False: 返回真实坐标(默认,用于bot等) + - True: 返回规范视角坐标(用于self_play_mode中的随机agent) + + Returns: + 动作索引(真实坐标或规范视角坐标) + """ + legal_actions_list = self.legal_actions # 真实坐标 + action_real = np.random.choice(legal_actions_list) + + # 如果需要规范视角且当前是黑方,转换为镜像坐标 + if canonical and self._current_player == 2: + return self._mirror_action(action_real) + else: + return action_real + + def bot_action(self) -> int: + """ + 使用UCI引擎或随机策略选择动作 + """ + if self.engine is not None: + try: + from .cchess import engine as engine_module + # 使用引擎计算最佳走法,按深度限制 + limit = engine_module.Limit(depth=self.engine_depth) + result = self.engine.play(self.board, limit) + return move_to_action(result.move) + except Exception as e: + logging.warning(f"引擎调用失败: {e},使用随机策略") + return self.random_action() + else: + return self.random_action() + + def human_to_action(self) -> int: + """ + 从人类输入获取动作 + """ + print(self.board.unicode(axes=True, axes_type=0)) + while True: + try: + uci = input(f"请输入走法(UCI格式,如 h2e2): ") + move = cchess.Move.from_uci(uci) + action = move_to_action(move) + if action in self.legal_actions: + return action + else: + print("非法走法,请重新输入") + except KeyboardInterrupt: + print("退出") + import sys + sys.exit(0) + except Exception as e: + print(f"输入错误: {e},请重新输入") + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def __repr__(self) -> str: + return "LightZero ChineseChess Env" + + @property + def current_player(self) -> int: + return self._current_player + + @property + def current_player_index(self) -> int: + return 0 if self._current_player == 1 else 1 + + @property + def next_player(self) -> int: + return self.players[0] if self._current_player == self.players[1] else self.players[1] + + @property + def observation_space(self) -> spaces.Space: + return self._observation_space + + @property + def action_space(self) -> spaces.Space: + return self._action_space + + @property + def reward_space(self) -> spaces.Space: + return self._reward_space + + def copy(self) -> 'ChineseChessEnv': + """ + 高效复制环境 + 替代 copy.deepcopy(self),只复制必要的动态状态 + """ + cls = self.__class__ + new_env = cls.__new__(cls) + + # 复制不可变配置 + new_env.cfg = self.cfg + new_env.channel_last = self.channel_last + new_env.scale = self.scale + new_env.render_mode = self.render_mode + new_env.replay_path = self.replay_path + new_env.battle_mode = self.battle_mode + new_env.battle_mode_in_simulation_env = self.battle_mode_in_simulation_env + new_env.agent_vs_human = self.agent_vs_human + new_env.prob_random_agent = self.prob_random_agent + new_env.prob_expert_agent = self.prob_expert_agent + new_env.uci_engine_path = self.uci_engine_path + new_env.engine_depth = self.engine_depth + new_env.max_episode_steps = self.max_episode_steps + new_env.players = self.players + new_env.start_player_index = self.start_player_index + + # 预计算表 + new_env._piece_types = self._piece_types + new_env._colors = self._colors + new_env._square_to_coord = self._square_to_coord + + # 复制动态状态 (需要拷贝) + new_env.current_step = self.current_step + new_env._current_player = self._current_player + new_env.frames = [] # frames 一般不需要在 simulate 中复制 + new_env.engine = None # simulator 不需要 engine + new_env._env = new_env + + # 关键:Board 的 copy,cchess.Board.copy() 已经是浅拷贝优化过的 + new_env.board = self.board.copy() + + # 关键:obs_buffer 的 copy + # deque 本身浅拷贝即可,里面的 numpy array 是新的 + new_env.stack_obs_num = self.stack_obs_num + new_env.obs_buffer = copy.copy(self.obs_buffer) + + # 空间定义 + new_env._action_space = self._action_space + new_env._reward_space = self._reward_space + new_env._observation_space = self._observation_space + + return new_env + + def simulate_action(self, action: int) -> Any: + """ + 模拟执行动作并返回新的模拟环境(用于 AlphaZero/MuZero 的 MCTS) + + Args: + action: 动作索引。如果当前是黑方,这个动作是基于镜像坐标系统的,需要转回真实坐标 + """ + # 关键修复:如果是黑方,将镜像动作转换回真实坐标 + if self._current_player == 2: # 黑方 + action_real = self._mirror_action(action) + else: # 红方 + action_real = action + + if action_real not in self.legal_actions: + raise ValueError(f"动作 {action} (real: {action_real}) 不合法,当前玩家: {self._current_player}") + + # 创建新环境 (使用高效拷贝) + new_env = self.copy() + + move = action_to_move(action_real) # 使用真实坐标 + new_env.board.push(move) + + # 增加步数计数 + new_env.current_step += 1 + + # board.push() 已经自动切换了 turn,需要同步更新 _current_player + # board.turn: RED=True(1), BLACK=False(0) + new_env._current_player = 1 if new_env.board.turn else 2 + + # 关键:同步更新历史观测 + new_env._update_obs_buffer() + + return new_env + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + cfg.battle_mode = 'eval_mode' + return [cfg for _ in range(evaluator_env_num)] + + def render(self, mode: str = None) -> None: + """ + 渲染棋盘 + + 根据LightZero官方文档:https://opendilab.github.io/LightZero/tutorials/envs/customize_envs.html + + Args: + mode: 渲染模式 + - 'state_realtime_mode': 实时打印棋盘状态(文本) + - 'image_realtime_mode': 实时显示SVG图像(暂不支持窗口显示) + - 'image_savefile_mode': 保存SVG到frames,游戏结束后可转为文件 + - 'human': 等同于'state_realtime_mode' + - 'svg': 返回SVG字符串(棋类游戏特有) + """ + mode = mode or self.render_mode + + if mode is None: + return None + + # LightZero标准模式:state_realtime_mode + if mode in ['state_realtime_mode', 'human']: + # 实时打印Unicode棋盘到控制台 + print("\n" + "=" * 50) + print(f"步数: {self.current_step} | 当前玩家: {'红方' if self._current_player == 1 else '黑方'}") + print(self.board.unicode(axes=True, axes_type=1)) + print("=" * 50) + return None + + # LightZero标准模式:image_savefile_mode + elif mode == 'image_savefile_mode': + # 保存SVG到frames列表,游戏结束后可用save_render_output转为文件 + try: + from .cchess import svg + last_move = self.board.peek() if self.board.move_stack else None + svg_str = svg.board( + self.board, + lastmove=last_move, + size=400 + ) + self.frames.append(svg_str) + except Exception as e: + logging.warning(f"SVG渲染失败: {e}") + return None + + # LightZero标准模式:image_realtime_mode + elif mode == 'image_realtime_mode': + # 实时显示图像(对于SVG,暂不支持窗口显示) + logging.warning("image_realtime_mode暂不支持实时窗口显示,请使用image_savefile_mode") + return None + + # 棋类游戏特有:直接返回SVG字符串 + elif mode == 'svg': + try: + from .cchess import svg + last_move = self.board.peek() if self.board.move_stack else None + svg_str = svg.board( + self.board, + lastmove=last_move, + size=400 + ) + return svg_str + except Exception as e: + logging.warning(f"SVG渲染失败: {e}") + return None + + # 其他模式 + else: + logging.warning(f"不支持的渲染模式: {mode}") + return None + + def save_render_output(self, replay_path: str = None, format: str = 'svg') -> None: + """ + 保存渲染输出到文件 + + Args: + replay_path: 保存路径,如果为None则使用self.replay_path + format: 保存格式,目前支持'svg' + """ + if not self.frames: + logging.warning("没有可保存的渲染帧") + return + + save_path = replay_path or self.replay_path + if save_path is None: + save_path = './replay_output' + + os.makedirs(save_path, exist_ok=True) + + if format == 'svg': + for i, svg_str in enumerate(self.frames): + file_path = os.path.join(save_path, f'step_{i:04d}.svg') + with open(file_path, 'w', encoding='utf-8') as f: + f.write(svg_str) + logging.info(f"已保存 {len(self.frames)} 个SVG文件到 {save_path}") + else: + logging.warning(f"不支持的保存格式: {format}") + + # 清空frames + self.frames = [] + + def close(self) -> None: + """关闭环境,释放资源""" + if self.engine is not None: + try: + self.engine.quit() + logging.info("UCI引擎已关闭") + except Exception as e: + logging.warning(f"关闭引擎时出错: {e}") + finally: + self.engine = None diff --git a/zoo/board_games/chinesechess/envs/test_action_mirror.py b/zoo/board_games/chinesechess/envs/test_action_mirror.py new file mode 100644 index 000000000..ab5311b2c --- /dev/null +++ b/zoo/board_games/chinesechess/envs/test_action_mirror.py @@ -0,0 +1,179 @@ +""" +测试中国象棋环境的动作镜像转换功能 + +验证黑方的观测旋转和动作镜像转换是否正确对应 +""" + +import numpy as np +import sys +sys.path.append('.') + +from zoo.board_games.chinesechess.envs.cchess_env import ChineseChessEnv +from easydict import EasyDict + + +def test_action_mirror(): + """测试动作镜像转换的正确性""" + + # 创建环境 + cfg = EasyDict({ + 'battle_mode': 'self_play_mode', + 'battle_mode_in_simulation_env': 'self_play_mode', + 'render_mode': None, + 'replay_path': None, + 'agent_vs_human': False, + 'prob_random_agent': 0, + 'prob_expert_agent': 0, + 'uci_engine_path': None, + 'engine_depth': 5, + 'channel_last': False, + 'scale': False, + 'stop_value': 2, + 'max_episode_steps': 500, + }) + + env = ChineseChessEnv(cfg) + obs = env.reset() + + print("=" * 80) + print("测试中国象棋环境的动作镜像转换功能") + print("=" * 80) + + # 测试1:初始状态(红方) + print("\n【测试1】初始状态 - 红方回合") + print(f"当前玩家: {env.current_player} (1=红方, 2=黑方)") + print(f"观测形状: {obs['observation'].shape}") + print(f"合法动作数: {obs['action_mask'].sum()}") + + legal_actions = env.legal_actions + print(f"前5个合法动作: {legal_actions[:5]}") + + # 验证action_mask和legal_actions一致 + action_mask_indices = np.where(obs['action_mask'] == 1)[0] + print(f"action_mask中的前5个合法动作: {action_mask_indices[:5]}") + assert len(action_mask_indices) == len(legal_actions), "红方: action_mask数量与legal_actions不一致!" + print("✓ 红方: action_mask 与 legal_actions 一致") + + # 测试2:执行一步后切换到黑方 + print("\n【测试2】执行一步后 - 黑方回合") + action = legal_actions[0] + print(f"红方执行动作: {action}") + + timestep = env.step(action) + obs = timestep.obs + + print(f"当前玩家: {env.current_player} (1=红方, 2=黑方)") + print(f"观测形状: {obs['observation'].shape}") + print(f"合法动作数: {obs['action_mask'].sum()}") + + # 获取黑方的合法动作(真实坐标) + legal_actions_black_real = env.legal_actions + print(f"黑方合法动作(真实坐标)前5个: {legal_actions_black_real[:5]}") + + # 获取黑方的action_mask(镜像坐标) + action_mask_indices_black = np.where(obs['action_mask'] == 1)[0] + print(f"黑方action_mask(镜像坐标)前5个: {action_mask_indices_black[:5]}") + + # 验证:将action_mask中的镜像动作转回真实坐标,应该等于legal_actions + action_mask_to_real = [] + for mirror_action in action_mask_indices_black: + real_action = env._mirror_action(mirror_action) + action_mask_to_real.append(real_action) + + action_mask_to_real_sorted = sorted(action_mask_to_real) + legal_actions_sorted = sorted(legal_actions_black_real) + + print(f"\n验证镜像转换:") + print(f" action_mask转回真实坐标后: {action_mask_to_real_sorted[:5]}...") + print(f" legal_actions(真实坐标): {legal_actions_sorted[:5]}...") + + assert action_mask_to_real_sorted == legal_actions_sorted, "黑方: action_mask镜像转换后与legal_actions不一致!" + print("✓ 黑方: action_mask 镜像转换正确") + + # 测试3:验证镜像函数的对称性 + print("\n【测试3】验证镜像函数的对称性") + test_actions = [0, 45, 89, 100, 500, 1000, 8099] + for test_action in test_actions: + mirror_once = env._mirror_action(test_action) + mirror_twice = env._mirror_action(mirror_once) + print(f"动作 {test_action:4d} -> 镜像 {mirror_once:4d} -> 再镜像 {mirror_twice:4d}") + assert mirror_twice == test_action, f"镜像函数不对称!{test_action} != {mirror_twice}" + print("✓ 镜像函数对称性验证通过") + + # 测试4:执行黑方动作并切换回红方 + print("\n【测试4】执行黑方动作后 - 切换回红方") + black_mirror_action = action_mask_indices_black[0] + print(f"黑方执行动作(镜像坐标): {black_mirror_action}") + + # 应该自动转换为真实坐标执行 + timestep = env.step(black_mirror_action) + obs = timestep.obs + + print(f"当前玩家: {env.current_player} (1=红方, 2=黑方)") + print(f"合法动作数: {obs['action_mask'].sum()}") + + # 验证切换回红方后,action_mask又回到真实坐标 + legal_actions_red = env.legal_actions + action_mask_indices_red = np.where(obs['action_mask'] == 1)[0] + + assert sorted(action_mask_indices_red) == sorted(legal_actions_red), "切换回红方后: action_mask与legal_actions不一致!" + print("✓ 切换回红方: action_mask 恢复为真实坐标") + + # 测试5:模拟MCTS场景 + print("\n【测试5】模拟MCTS场景") + env2 = env.copy() + print(f"复制环境后,当前玩家: {env2.current_player}") + + # 使用simulate_action + action_to_simulate = action_mask_indices_red[0] + print(f"模拟红方动作: {action_to_simulate}") + + try: + new_env = env2.simulate_action(action_to_simulate) + print(f"模拟成功!新环境当前玩家: {new_env.current_player}") + print("✓ simulate_action 工作正常") + except Exception as e: + print(f"✗ simulate_action 失败: {e}") + raise + + # 测试6:Bot模式测试(最重要的修复) + print("\n【测试6】Bot模式测试(eval_mode/play_with_bot_mode)") + cfg_eval = EasyDict(cfg) + cfg_eval.battle_mode = 'eval_mode' + cfg_eval.agent_vs_human = False + + env_eval = ChineseChessEnv(cfg_eval) + obs_eval = env_eval.reset() + + print(f"初始玩家: {env_eval.current_player}") + print(f"合法动作数: {obs_eval['action_mask'].sum()}") + + # Agent (Player 1, 红方) 执行一个动作 + legal_actions_red = np.where(obs_eval['action_mask'] == 1)[0] + agent_action = legal_actions_red[0] + print(f"Agent执行动作: {agent_action}") + + try: + timestep = env_eval.step(agent_action) + print(f"执行成功!当前玩家: {env_eval.current_player}") + + if not timestep.done: + print(f"Bot (Player 2, 黑方) 将执行动作...") + # 注意:step内部会调用bot_action()并自动处理 + # 这里我们已经执行了一步,下一次step会由bot执行 + print("✓ Bot模式测试通过(无非法动作警告)") + else: + print("游戏已结束") + except Exception as e: + print(f"✗ Bot模式测试失败: {e}") + raise + + print("\n" + "=" * 80) + print("所有测试通过!动作镜像转换功能正常工作。") + print("关键修复:Bot的动作(真实坐标)不会被错误地转换。") + print("=" * 80) + + +if __name__ == "__main__": + test_action_mirror() + diff --git a/zoo/board_games/chinesechess/envs/test_cchess_env.py b/zoo/board_games/chinesechess/envs/test_cchess_env.py new file mode 100644 index 000000000..821db8069 --- /dev/null +++ b/zoo/board_games/chinesechess/envs/test_cchess_env.py @@ -0,0 +1,434 @@ +import pytest +from easydict import EasyDict +from zoo.board_games.chinesechess.envs.cchess_env import ChineseChessEnv + + +@pytest.mark.envtest +class TestChineseChessEnv: + + def test_self_play_mode(self): + """测试自对弈模式""" + cfg = EasyDict( + battle_mode='self_play_mode', + channel_last=False, + scale=False, + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + render_mode=None, + replay_path=None, + uci_engine_path=None, + engine_depth=5, + max_episode_steps=200, + ) + env = ChineseChessEnv(cfg) + env.reset() + print('=' * 50) + print('自对弈模式测试') + print('=' * 50) + env.render(mode='human') + + step_count = 0 + while True: + # player 1 (红方) + action = env.random_action() + print(f'第 {step_count + 1} 步 - 红方走棋') + obs, reward, done, info = env.step(action) + env.render(mode='human') + step_count += 1 + + if done: + if reward > 0: + print('红方获胜!') + elif reward < 0: + print('黑方获胜!') + else: + print('和棋!') + break + + # player 2 (黑方) + action = env.random_action() + print(f'第 {step_count + 1} 步 - 黑方走棋') + obs, reward, done, info = env.step(action) + env.render(mode='human') + step_count += 1 + + if done: + if reward > 0: + print('黑方获胜!') + elif reward < 0: + print('红方获胜!') + else: + print('和棋!') + break + + print(f'游戏结束,共 {step_count} 步') + env.close() + + def test_play_with_bot_mode(self): + """测试人机对战模式 (Agent vs Bot)""" + cfg = EasyDict( + battle_mode='play_with_bot_mode', + channel_last=False, + scale=False, + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + render_mode=None, + replay_path=None, + uci_engine_path=None, + engine_depth=5, + max_episode_steps=200, + ) + env = ChineseChessEnv(cfg) + env.reset() + print('=' * 50) + print('人机对战模式测试 (Agent vs Random Bot)') + print('=' * 50) + env.render(mode='human') + + step_count = 0 + while True: + # Agent (红方) 走棋 + action = env.random_action() + print(f'第 {step_count + 1} 步 - Agent (红方) 走棋') + obs, reward, done, info = env.step(action) + # 在 play_with_bot_mode 下,step 会自动执行 bot 的回合 + env.render(mode='human') + step_count += 2 # Agent + Bot 各走一步 + + if done: + eval_return = info.get('eval_episode_return', reward) + if eval_return > 0: + print('Agent (红方) 获胜!') + elif eval_return < 0: + print('Bot (黑方) 获胜!') + else: + print('和棋!') + break + + print(f'游戏结束,共约 {step_count} 步') + env.close() + + def test_eval_mode(self): + """测试评估模式""" + cfg = EasyDict( + battle_mode='eval_mode', + channel_last=False, + scale=False, + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + render_mode=None, + replay_path=None, + uci_engine_path=None, + engine_depth=5, + max_episode_steps=200, + ) + env = ChineseChessEnv(cfg) + env.reset() + print('=' * 50) + print('评估模式测试') + print('=' * 50) + env.render(mode='human') + + step_count = 0 + while True: + # Agent (红方) 走棋 + action = env.random_action() + print(f'第 {step_count + 1} 步 - Agent (红方) 走棋') + obs, reward, done, info = env.step(action) + env.render(mode='human') + step_count += 2 + + if done: + eval_return = info.get('eval_episode_return', reward) + if eval_return > 0: + print('Agent (红方) 获胜!') + elif eval_return < 0: + print('Bot (黑方) 获胜!') + else: + print('和棋!') + break + + print(f'游戏结束,共约 {step_count} 步') + env.close() + + def test_observation_space(self): + """测试观测空间""" + cfg = EasyDict( + battle_mode='self_play_mode', + channel_last=False, + scale=False, + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + render_mode=None, + replay_path=None, + uci_engine_path=None, + engine_depth=5, + max_episode_steps=200, + ) + env = ChineseChessEnv(cfg) + obs = env.reset() + + print('=' * 50) + print('观测空间测试') + print('=' * 50) + print(f"observation shape: {obs['observation'].shape}") + print(f"action_mask shape: {obs['action_mask'].shape}") + print(f"action_mask sum (合法动作数): {obs['action_mask'].sum()}") + print(f"board shape: {obs['board'].shape}") + print(f"to_play: {obs['to_play']}") + print(f"current_player_index: {obs['current_player_index']}") + + assert obs['observation'].shape == (57, 10, 9), f"Expected (57, 10, 9), got {obs['observation'].shape}" + assert obs['action_mask'].shape == (8100,), f"Expected (8100,), got {obs['action_mask'].shape}" + assert obs['board'].shape == (10, 9), f"Expected (10, 9), got {obs['board'].shape}" + + print('观测空间测试通过!') + env.close() + + def test_legal_actions(self): + """测试合法动作""" + cfg = EasyDict( + battle_mode='self_play_mode', + channel_last=False, + scale=False, + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + render_mode=None, + replay_path=None, + uci_engine_path=None, + engine_depth=5, + max_episode_steps=200, + ) + env = ChineseChessEnv(cfg) + env.reset() + + print('=' * 50) + print('合法动作测试') + print('=' * 50) + + legal_actions = env.legal_actions + print(f'初始局面合法动作数: {len(legal_actions)}') + print(f'前10个合法动作索引: {legal_actions[:10]}') + + # 中国象棋初始局面,红方有44个合法走法 + assert len(legal_actions) > 0, "合法动作数不应为0" + print('合法动作测试通过!') + env.close() + + def test_simulate_action(self): + """测试 MCTS 模拟动作""" + cfg = EasyDict( + battle_mode='self_play_mode', + channel_last=False, + scale=False, + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + render_mode=None, + replay_path=None, + uci_engine_path=None, + engine_depth=5, + max_episode_steps=200, + ) + env = ChineseChessEnv(cfg) + env.reset() + + print('=' * 50) + print('MCTS 模拟动作测试') + print('=' * 50) + + # 获取原始状态 + original_step = env.current_step + original_player = env.current_player + + # 执行模拟 + action = env.random_action() + simulated_env = env.simulate_action(action) + + # 验证原环境未被修改 + assert env.current_step == original_step, "原环境步数被修改" + assert env.current_player == original_player, "原环境玩家被修改" + + # 验证模拟环境已更新 + assert simulated_env.current_step == original_step + 1, "模拟环境步数未更新" + assert simulated_env.current_player != original_player, "模拟环境玩家未切换" + + print(f'原环境步数: {env.current_step}, 模拟环境步数: {simulated_env.current_step}') + print(f'原环境玩家: {env.current_player}, 模拟环境玩家: {simulated_env.current_player}') + print('MCTS 模拟动作测试通过!') + env.close() + + +def play_human_vs_bot(engine_path: str = None): + """ + 人类 vs Bot 对战 + 人类执红先手,Bot 执黑后手 + + Args: + engine_path: UCI 引擎路径,如 pikafish。为 None 则 Bot 使用随机策略。 + """ + cfg = EasyDict( + battle_mode='eval_mode', + channel_last=False, + scale=False, + agent_vs_human=False, # False: Bot 是黑方; True: 人类是黑方 + prob_random_agent=0, + prob_expert_agent=0, + render_mode=None, + replay_path=None, + uci_engine_path=engine_path, # 设置为 pikafish 路径以使用更强的 Bot + engine_depth=10, + max_episode_steps=500, + ) + env = ChineseChessEnv(cfg) + env.reset() + + bot_name = "UCI引擎" if engine_path else "随机Bot" + print('=' * 60) + print(f'人类 vs {bot_name} 对战') + print('你执红方 (先手),Bot 执黑方 (后手)') + print('走法格式: UCI 格式,如 h2e2 (炮二平五)') + print('棋盘坐标: 列 a-i (左到右), 行 0-9 (下到上)') + print('=' * 60) + env.render(mode='human') + + step_count = 0 + while True: + # 人类输入红方走法 + action = env.human_to_action() + print(f'\n第 {step_count + 1} 步 - 你 (红方) 走棋') + # step() 内部会自动调用 bot_action() 让 Bot (黑方) 走棋 + obs, reward, done, info = env.step(action) + step_count += 2 + print(f'第 {step_count} 步 - Bot (黑方) 走棋') + env.render(mode='human') + + if done: + eval_return = info.get('eval_episode_return', reward) + if eval_return > 0: + print('\n恭喜!你 (红方) 获胜!') + elif eval_return < 0: + print(f'\n{bot_name} (黑方) 获胜!') + else: + print('\n和棋!') + break + + print(f'\n游戏结束,共 {step_count} 步') + env.close() + + +def play_bot_vs_bot(): + """ + Bot vs Bot 对战 (观战模式) + """ + cfg = EasyDict( + battle_mode='self_play_mode', + channel_last=False, + scale=False, + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + render_mode=None, + replay_path=None, + uci_engine_path=None, + engine_depth=10, + max_episode_steps=500, + ) + env = ChineseChessEnv(cfg) + env.reset() + + print('=' * 60) + print('Bot vs Bot 对战 (观战模式)') + print('=' * 60) + env.render(mode='human') + + step_count = 0 + while True: + # 红方 Bot + action = env.random_action() + obs, reward, done, info = env.step(action) + step_count += 1 + print(f'\n第 {step_count} 步 - 红方') + env.render(mode='human') + + if done: + if reward > 0: + print('\n红方获胜!') + elif reward < 0: + print('\n黑方获胜!') + else: + print('\n和棋!') + break + + # 黑方 Bot + action = env.random_action() + obs, reward, done, info = env.step(action) + step_count += 1 + print(f'\n第 {step_count} 步 - 黑方') + env.render(mode='human') + + if done: + if reward > 0: + print('\n黑方获胜!') + elif reward < 0: + print('\n红方获胜!') + else: + print('\n和棋!') + break + + print(f'\n游戏结束,共 {step_count} 步') + env.close() + + +if __name__ == '__main__': + import sys + + print('\n' + '=' * 60) + print('中国象棋环境测试') + print('=' * 60) + print('1. 运行自动化测试') + print('2. 人类 vs 随机Bot 对战') + print('3. 人类 vs UCI引擎 对战 (需要输入引擎路径)') + print('4. Bot vs Bot 观战') + print('=' * 60) + + choice = input('请选择 (1/2/3/4): ').strip() + + if choice == '1': + test = TestChineseChessEnv() + print('\n开始自动化测试...\n') + test.test_observation_space() + print() + test.test_legal_actions() + print() + test.test_simulate_action() + print() + test.test_self_play_mode() + print() + test.test_play_with_bot_mode() + print() + test.test_eval_mode() + print('\n所有测试完成!') + + elif choice == '2': + play_human_vs_bot(engine_path=None) + + elif choice == '3': + engine_path = input('请输入 UCI 引擎路径 (如 pikafish 或完整路径): ').strip() + if not engine_path: + print('引擎路径不能为空,退出') + sys.exit(1) + play_human_vs_bot(engine_path=engine_path) + + elif choice == '4': + play_bot_vs_bot() + + else: + print('无效选择,退出') + sys.exit(1) + diff --git a/zoo/board_games/chinesechess/eval/cchess_muzero_eval.py b/zoo/board_games/chinesechess/eval/cchess_muzero_eval.py new file mode 100644 index 000000000..7ebf3a4f3 --- /dev/null +++ b/zoo/board_games/chinesechess/eval/cchess_muzero_eval.py @@ -0,0 +1,66 @@ +from zoo.board_games.chinesechess.config.cchess_muzero_sp_mode_config import main_config, create_config +from lzero.entry import eval_muzero +import numpy as np + +if __name__ == '__main__': + """ + 中国象棋 MuZero 模型的评估入口 + + 变量说明: + - model_path: 预训练模型路径,应指向 ckpt 文件 + - returns_mean_seeds: 每个种子的平均回报列表 + - returns_seeds: 每个种子的回报列表 + - seeds: 环境种子列表 + - num_episodes_each_seed: 每个种子运行的局数 + - total_test_episodes: 总测试局数 + + python -m LightZero.zoo.board_games.chinesechess.eval.cchess_muzero_eval + + """ + # model_path = './ckpt/ckpt_best.pth.tar' + model_path = r'./data_muzero/cchess_self-play-mode_seed0/ckpt/iteration_0.pth.tar' + seeds = [0] + num_episodes_each_seed = 1 + + # 如果设置为 True,可以与 agent 对弈 + # main_config.env.agent_vs_human = True + main_config.env.agent_vs_human = True + + # 渲染模式 + main_config.env.render_mode = 'image_realtime_mode' + # main_config.env.render_mode = None + main_config.env.replay_path = './video' + + create_config.env_manager.type = 'base' + main_config.env.evaluator_env_num = 1 + main_config.env.n_evaluator_episode = 1 + + total_test_episodes = num_episodes_each_seed * len(seeds) + returns_mean_seeds = [] + returns_seeds = [] + + for seed in seeds: + returns_mean, returns = eval_muzero( + [main_config, create_config], + seed=seed, + num_episodes_each_seed=num_episodes_each_seed, + print_seed_details=True, + model_path=model_path + ) + returns_mean_seeds.append(returns_mean) + returns_seeds.append(returns) + + returns_mean_seeds = np.array(returns_mean_seeds) + returns_seeds = np.array(returns_seeds) + + print("=" * 20) + print(f"总共评估了 {len(seeds)} 个种子。每个种子评估了 {num_episodes_each_seed} 局。") + print(f"种子 {seeds} 的平均回报为 {returns_mean_seeds},回报为 {returns_seeds}。") + print("所有种子的平均奖励:", returns_mean_seeds.mean()) + print( + f'胜率: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes:.2%}, ' + f'和率: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes:.2%}, ' + f'负率: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes:.2%}' + ) + print("=" * 20) +