Skip to content

Fix JAX resource exhaustion/multi-threading #16

@brianreicher

Description

@brianreicher

JAX network testing running into two main issues, look to collapse one at a time:

  • Single GPU training is able to load the untrained network onto the GPU but unable to allocate resources to run convolutions
  • Multi-threaded training splits the first input dimension (batch size=16) into separate (4,4) dimensions, causing the rank of the input tensors to be expanded

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions