Thanks for this outstanding work! May I ask why you utilized the class class TimeStep(torch.autograd.Function) with @staticmethod forward and backward (in cell.py)? In this class, some backward gradient is manually designed, so why not automatically differential is not applied here directly?
I tested in cell.py to replace y = TimeStep.apply(b, c, h1, h2, self.dt, self.geom.h) with y = _time_step(b, c, h1, h2, self.dt, self.geom.h) to avoid your manually designed gradients and the results are pretty similar. I am quite inserted and curious about why you coded like this?
Thanks a lot! It is really a great work!