diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index ea7390b4a4..ad38cb935c 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -48,6 +48,8 @@ # on-policy distillation workflows "on_policy_distill_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillWorkflow", "on_policy_distill_math_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillMathWorkflow", + # custom workflows + "sudoku_workflow": "trinity.common.workflows.sudoku_workflow.SudokuWorkflow", }, ) diff --git a/trinity/common/workflows/sudoku_generator.py b/trinity/common/workflows/sudoku_generator.py new file mode 100644 index 0000000000..761268d6dc --- /dev/null +++ b/trinity/common/workflows/sudoku_generator.py @@ -0,0 +1,74 @@ +import random + + +class SudokuGenerator: + """ + Lightweight Sudoku generator. + + This generator avoids relying on a single canonical solution by applying + randomized transformations to a solved grid before removing values to + create a puzzle. The difficulty is controlled by the number of removed + cells (holes). + + """ + + BASE_SOLUTION = [ + [5, 3, 4, 6, 7, 8, 9, 1, 2], + [6, 7, 2, 1, 9, 5, 3, 4, 8], + [1, 9, 8, 3, 4, 2, 5, 6, 7], + [8, 5, 9, 7, 6, 1, 4, 2, 3], + [4, 2, 6, 8, 5, 3, 7, 9, 1], + [7, 1, 3, 9, 2, 4, 8, 5, 6], + [9, 6, 1, 5, 3, 7, 2, 8, 4], + [2, 8, 7, 4, 1, 9, 6, 3, 5], + [3, 4, 5, 2, 8, 6, 1, 7, 9], + ] + + def _shuffle_solution(self, board): + """ + Randomize a solved Sudoku grid while preserving validity. + + This follows common Sudoku generation techniques: + - permuting numbers + - shuffling rows + - shuffling columns + """ + board = [row[:] for row in board] + + # Shuffle numbers 1–9 + numbers = list(range(1, 10)) + shuffled_numbers = numbers[:] + random.shuffle(shuffled_numbers) + mapping = dict(zip(numbers, shuffled_numbers)) + board = [[mapping[v] for v in row] for row in board] + + # Shuffle rows + random.shuffle(board) + + # Shuffle columns + board = list(map(list, zip(*board))) + random.shuffle(board) + board = list(map(list, zip(*board))) + + return board + + def generate(self, holes=40): + """ + Generate a Sudoku puzzle. + + Args: + holes (int): Number of empty cells (0s) in the puzzle. + Larger values correspond to higher difficulty. + + Returns: + tuple: (puzzle, solution) + """ + solution = self._shuffle_solution(self.BASE_SOLUTION) + puzzle = [row[:] for row in solution] + + for _ in range(holes): + r = random.randint(0, 8) + c = random.randint(0, 8) + puzzle[r][c] = 0 + + return puzzle, solution diff --git a/trinity/common/workflows/sudoku_judge.py b/trinity/common/workflows/sudoku_judge.py new file mode 100644 index 0000000000..9fee423710 --- /dev/null +++ b/trinity/common/workflows/sudoku_judge.py @@ -0,0 +1,43 @@ +class SudokuJudge: + """ + Judge Sudoku board state. + - Checks row validity + - Checks column validity + - Checks 3x3 block validity + """ + + @staticmethod + def is_valid(board): + # Check rows + for row in board: + nums = [v for v in row if v != 0] + if len(nums) != len(set(nums)): + return False + + # Check columns + for col in range(9): + nums = [] + for row in range(9): + v = board[row][col] + if v != 0: + nums.append(v) + if len(nums) != len(set(nums)): + return False + + # Check 3x3 sub-grids + for br in range(0, 9, 3): + for bc in range(0, 9, 3): + nums = [] + for r in range(br, br + 3): + for c in range(bc, bc + 3): + v = board[r][c] + if v != 0: + nums.append(v) + if len(nums) != len(set(nums)): + return False + + return True + + @staticmethod + def is_solved(board, solution): + return board == solution diff --git a/trinity/common/workflows/sudoku_workflow.py b/trinity/common/workflows/sudoku_workflow.py new file mode 100644 index 0000000000..dcc1df28dd --- /dev/null +++ b/trinity/common/workflows/sudoku_workflow.py @@ -0,0 +1,182 @@ +import re + +from trinity.common.experience import Experience +from trinity.common.workflows.workflow import Workflow + +from .sudoku_generator import SudokuGenerator +from .sudoku_judge import SudokuJudge + + +class SudokuWorkflow(Workflow): + """ + Multi-step Sudoku solving workflow. + + This workflow follows a FrozenLake-style agentic interaction pattern: + - Maintains an internal environment state (Sudoku board) + - Interacts with the model step by step + - Provides explicit rules, task description, and strict output format + - Gives feedback on invalid or ineffective actions + - Terminates on success or failure + """ + + can_reset = True + + def __init__(self, task, model, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + + # Initialize puzzle + if "puzzle" in task.raw_task and "solution" in task.raw_task: + self.board = [row[:] for row in task.raw_task["puzzle"]] + self.solution = [row[:] for row in task.raw_task["solution"]] + else: + generator = SudokuGenerator() + self.board, self.solution = generator.generate() + + self.judge = SudokuJudge() + self.max_steps = 20 + + # State tracking (FrozenLake-style) + self.current_step = 0 + self.last_board = None + self.last_action = None + + def reset(self, task): + """Reset the workflow state for a new task.""" + self.board = [row[:] for row in task.raw_task["puzzle"]] + self.solution = [row[:] for row in task.raw_task["solution"]] + self.current_step = 0 + self.last_board = None + self.last_action = None + + def _build_prompt(self): + """ + Build a detailed, step-aware prompt inspired by the Frozen Lake example. + """ + prompt = ( + "You are playing a Sudoku game.\n\n" + "Game Rules:\n" + "- The board is a 9x9 grid.\n" + "- A value of 0 represents an empty cell.\n" + "- Each row must contain the numbers 1 through 9 exactly once.\n" + "- Each column must contain the numbers 1 through 9 exactly once.\n" + "- Each 3x3 sub-grid must contain the numbers 1 through 9 exactly once.\n" + "- You may only place numbers in empty cells.\n\n" + "Task:\n" + "- At each step, output ONE valid move to progress toward solving the puzzle.\n\n" + "Output Format (STRICT):\n" + "```row col value```\n\n" + "Example:\n" + "```0 2 4```\n\n" + f"Current Step: {self.current_step}\n" + f"Remaining Steps: {self.max_steps - self.current_step}\n\n" + f"Current Board:\n{self.board}\n" + ) + + if self.last_board is not None and self.board == self.last_board: + prompt += ( + "\nYour last response was invalid or had no effect. " + "Please recheck the Sudoku rules and the required output format." + ) + + return prompt + + def parse_action(self, text): + """ + Parse model output. + + Expected format: + ```row col value``` + """ + matches = re.findall(r"```(.*?)```", text, re.DOTALL) + if not matches: + return None + + try: + parts = matches[-1].strip().split() + if len(parts) != 3: + return None + + r, c, v = map(int, parts) + if not (0 <= r <= 8 and 0 <= c <= 8 and 1 <= v <= 9): + return None + + return r, c, v + except ValueError: + return None + + def apply_move(self, r, c, v): + """Apply a move to the board if the cell is empty.""" + if self.board[r][c] == 0: + self.board[r][c] = v + + def run(self): + """ + Execute the Sudoku workflow step by step. + """ + experiences = [] + + for _ in range(self.max_steps): + prompt = self._build_prompt() + + responses = self.model.chat([{"role": "user", "content": prompt}]) + resp = responses[0] + + self.last_board = [row[:] for row in self.board] + + action = self.parse_action(resp.response_text) + if action is None: + reward = -1.0 + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=reward, + logprobs=resp.logprobs, + ) + ) + break + + r, c, v = action + self.apply_move(r, c, v) + + # Invalid or ineffective action + if self.board == self.last_board or not self.judge.is_valid(self.board): + reward = -1.0 + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=reward, + logprobs=resp.logprobs, + ) + ) + break + + # Solved + if self.judge.is_solved(self.board, self.solution): + reward = 1.0 + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=reward, + logprobs=resp.logprobs, + ) + ) + break + + # Intermediate step + reward = 0.0 + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=reward, + logprobs=resp.logprobs, + ) + ) + + self.last_action = action + self.current_step += 1 + + return experiences