Should consider the performance benefit of implementing an adjoint calculation for the backward pass through the forward() method in WaveCell. This would potentially save us on memory during gradient computation because pytorch doesn't need to construct as large of a graph.
The approach is described here: https://pytorch.org/docs/stable/notes/extending.html