Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache the number of elements in the action space #68

Open
wants to merge 3 commits into
base: 0.7
Choose a base branch
from

Conversation

thatguy11325
Copy link
Contributor

You probably dont need to dispatch to numpy everytime you call split to calculate the number of elements in the space. This PR caches the sizes (in a less than nice way imo) as an example. Before and after pictures below

Screenshot 2024-01-28 at 9 12 20 PM Screenshot 2024-01-28 at 9 12 52 PM

@jsuarez5341
Copy link
Contributor

This looks reasonable, waiting for hardware to test end to end. Any other optimization ideas for split? It's the main bottleneck right now. From before your patch:

🐡 python tests/test_extensions.py
0.00000032: Flatten time
0.00000294: Concatenate time
0.00001958: Split time
0.00000056: Unflatten time

@thatguy11325
Copy link
Contributor Author

thatguy11325 commented Feb 6, 2024

You could try to vectorize the generation of samps -> leaves a la what was done in evaluate? Though I'm unsure if that'll work if the sz's can vary.

@thatguy11325
Copy link
Contributor Author

thatguy11325 commented Feb 6, 2024

I think it'd look like

leaves = stacked_sample.reshape(len(flat_space), batch, *next(flat_space.values()).shape)

I assume since sz is the same across all flat spaces, shape will be the same too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants