-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix tests for mps support #2005
base: feat/mps-support
Are you sure you want to change the base?
Conversation
Attempt fix ci: only cast reward from float64 to float32
Hello, (In theory, if pytorch supports MPS properly, you would only need to specify the device) |
hey 👋 yes that works, i have also tested A2C both on this branch, i'm still a beginner in this so i cant really say if all advanced use cases also work, but i think having the tests passing is a good indicator |
I think most issues are related to numpy v2, and should be fixed in #2041 too. |
@@ -81,7 +81,7 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: | |||
state = state.astype(np.int32) | |||
# The internal state is the binary representation of the | |||
# observed one | |||
return int(sum(state[i] * 2**i for i in range(len(state)))) | |||
return int(sum(int(state[i]) * 2**i for i in range(len(state)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should not be needed anymore (because of the cast)
Description
closes #914
When i started on the base branch
feat/mps-support
there were 45 failing tests that i now consider fixed, a few things to note:test_float64_action_space
tests entirely since float64 is not supportedtest_save_load[True-SAC]
only fails when running the full-suite or running all test_save_load tests (make pytest
orpython3 -m pytest -v -k 'test_save_load'
) if instead i run the the single breaking test (python3 -m pytest -v -k 'test_save_load[True-SAC]'
) then it passes 🤷♂️ i also run the test file in pycharm and it passes there too so i'm not sure what the issue is, i can add the stacktace of the failing test in a comment if neededHere the full list of fixed tests
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-DQN] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-DQN] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_envs.py::test_bit_flipping[kwargs1] - OverflowError: Python integer 128 out of bounds for int8
FAILED tests/test_envs.py::test_bit_flipping[kwargs2] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_envs.py::test_bit_flipping[kwargs3] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_her[True-SAC] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_her[True-TD3] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_her[True-DDPG] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_her[True-DQN] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_multiprocessing[True-TD3] - EOFError
FAILED tests/test_her.py::test_multiprocessing[True-DQN] - EOFError
FAILED tests/test_train_eval_mode.py::test_td3_train_with_batch_norm - AssertionError: assert ~tensor(True, device='mps:0')
FAILED tests/test_vec_normalize.py::test_get_original - AssertionError: assert dtype('float32') == dtype('float64')
FAILED tests/test_vec_normalize.py::test_get_original_dict - AssertionError: assert dtype('float32') == dtype('float64')
FAILED tests/test_her.py::test_save_load[True-SAC] - ValueError: Expected parameter scale (Tensor of shape (64, 4)) of distribution Normal(loc: torch.Size([64, 4]), scale: torch.Size([64, 4])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
Unsupported tests fixed by skipping
Motivation and Context
Types of changes
Checklist
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line