Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pong-misc: TypeError: select cases must have the same shapes, got [(30, 40), ()]. #76

Open
HelgeS opened this issue May 15, 2024 · 0 comments

Comments

@HelgeS
Copy link

HelgeS commented May 15, 2024

When running the Pong-misc environment, the following error is raised from move_paddles.

I tried both the example notebook and gymnax-blines to ensure it's not an usage error.

Below is the stack trace and the gymnax-blines configuration I have used.

$ python train.py -config agents/Pong-misc/ppo.yaml

PPO:   0%|                                                                                                                                                                                                                                                        | 0/18751 [00:00<?, ?it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/helge/Sandbox/pt/gymnax-blines/train.py", line 76, in <module>
    main(
  File "/home/helge/Sandbox/pt/gymnax-blines/train.py", line 24, in main
    log_steps, log_return, network_ckpt = train_fn(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 271, in train_ppo
    train_state, obs, state, batch, rng_step = get_transition(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 252, in get_transition
    next_obs, next_state, reward, done, _ = rollout_manager.batch_step(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 138, in batch_step
    return jax.vmap(self.env.step, in_axes=(0, 0, 0, None))(
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/environment.py", line 45, in step
    obs_st, state_st, reward, done, info = self.step_env(key, state, action, params)
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/misc/pong.py", line 75, in step_env
    state = move_paddles(
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/misc/pong.py", line 356, in move_paddles
    new_center_p2 = jax.lax.select(use_ai_policy, new_center_ai, new_center_self)
TypeError: select cases must have the same shapes, got [(30, 40), ()].

Configuration (copied from CartPole-v1):

train_config:
  train_type: "PPO"
  num_train_steps: 150000
  evaluate_every_epochs: 1000

  env_name: "Pong-misc"
  env_kwargs: {}
  env_params: {}
  num_test_rollouts: 164
  
  num_train_envs: 8  # Number of parallel env workers
  max_grad_norm: 0.5  # Global norm to clip gradients by
  gamma: 0.99  # Discount factor
  n_steps: 32 # "GAE n-steps"
  n_minibatch: 4 # "Number of PPO minibatches"
  lr_begin: 5e-04  # Start PPO learning rate
  lr_end: 5e-04 #  End PPO learning rate
  lr_warmup: 0.05 # Prop epochs until warmup is completed 
  epoch_ppo: 4  # "Number of PPO epochs on a single batch"
  clip_eps: 0.2 # "Clipping range"
  gae_lambda: 0.95 # "GAE lambda"
  entropy_coeff: 0.01 # "Entropy loss coefficient"
  critic_coeff: 0.5  # "Value loss coefficient"

  network_name: "Categorical-MLP"
  network_config:
    num_hidden_units: 64
    num_hidden_layers: 2

log_config:
  time_to_track: ["num_steps"]
  what_to_track: ["return"]
  verbose: false
  print_every_k_updates: 1
  overwrite: 1
  model_type: "jax"

device_config:
  num_devices: 1
  device_type: "gpu"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant