Skip to content

Commit

Permalink
Added Prototype.
Browse files Browse the repository at this point in the history
  • Loading branch information
ATATC committed Sep 30, 2024
1 parent 603d0ba commit d75c6fd
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
22 changes: 13 additions & 9 deletions leads_jarvis/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from torch.nn.functional import pad as _pad
from torchvision.transforms.functional import resize as _resize

from leads_jarvis.prototype import Prototype
from leads_jarvis.types import Device as _Device


def _delta_theta(a: dict[str, _Any], b: dict[str, _Any], c: dict[str, _Any]) -> float:
lat_a, lon_a, lat_b, lon_b, lat_c, lon_c = (a["latitude"], a["longitude"], b["latitude"], b["longitude"],
Expand Down Expand Up @@ -39,26 +42,27 @@ def transform_batch(x: _Tensor, img_size: int = 224) -> _Tensor:
return _stack(transformed_tensors)


class BatchDataset(_CSVDataset):
def __init__(self, file: str, batch_size: int, channels: tuple[str, ...], device: str = "cpu") -> None:
super().__init__(file, batch_size)
class BatchDataset(_CSVDataset, Prototype):
def __init__(self, file: str, batch_size: int, channels: tuple[str, ...], device: _Device = "cpu") -> None:
_CSVDataset.__init__(self, file, batch_size)
Prototype.__init__(self, device)
self._channels: tuple[str, ...] = channels
self._device: str = device

@_override
def __iter__(self) -> _Generator[tuple[_Tensor, _Tensor], None, None]:
batch = []
for i in super().__iter__():
if len(batch) >= self._chunk_size:
yield (_tensor(_Preprocessor(batch).to_tensor(self._channels), _float, self._device),
_tensor((i["throttle"], i["brake"], _delta_theta(*batch[-2:], i)), _float, self._device))
yield (_tensor(_Preprocessor(batch).to_tensor(self._channels), dtype=_float, device=self._device),
_tensor((i["throttle"], i["brake"], _delta_theta(*batch[-2:], i)), dtype=_float,
device=self._device))
batch.clear()
batch.append(i)


class OnlineDataset(BatchDataset, _Callback):
def __init__(self, server_address: str, server_port: int, batch_size: int, channels: tuple[str, ...],
device: str = "cpu") -> None:
device: _Device = "cpu") -> None:
BatchDataset.__init__(self, server_address, batch_size, channels, device)
_Callback.__init__(self)
self._address: str = server_address
Expand Down Expand Up @@ -94,6 +98,6 @@ def __iter__(self) -> _Generator[tuple[_Tensor, _Tensor], None, None]:
b = self._batch.copy()
n = b[self._chunk_size]
b = b[:self._chunk_size]
yield (_tensor(_Preprocessor(b).to_tensor(self._channels), _float, self._device),
_tensor((n["throttle"], n["brake"], _delta_theta(*b[-2:], n)), _float, self._device))
yield (_tensor(_Preprocessor(b).to_tensor(self._channels), dtype=_float, device=self._device),
_tensor((n["throttle"], n["brake"], _delta_theta(*b[-2:], n)), dtype=_float, device=self._device))
self._batch.clear()
12 changes: 12 additions & 0 deletions leads_jarvis/prototype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Self as _Self

from leads_jarvis.types import Device as _Device


class Prototype(object):
def __init__(self, device: _Device) -> None:
self._device: _Device = device

def to(self, device: _Device) -> _Self:
self._device = device
return self
11 changes: 6 additions & 5 deletions leads_jarvis/trainer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from typing import Literal as _Literal

from leads import L as _L
from torch import save as _save
from torch.nn import Module as _Module
from torch.optim import Optimizer as _Optimizer

from leads_jarvis.dataset import BatchDataset
from leads_jarvis.prototype import Prototype
from leads_jarvis.types import Device as _Device


class Trainer(object):
class Trainer(Prototype):
def __init__(self, dataset: BatchDataset, network: _Module, criterion: _Module, optimizer: _Optimizer,
weights_file: str, device: _Literal["cpu", "cuda"] = "cpu") -> None:
weights_file: str, device: _Device = "cpu") -> None:
super().__init__(device)
self._dataset: BatchDataset = dataset
self._network: _Module = network
self._criterion: _Module = criterion
self._optimizer: _Optimizer = optimizer
self._weights_file: str = weights_file
self._device: _Literal["cpu", "cuda"] = device
self._device: _Device = device

def initialize(self) -> None:
self._dataset.load()
Expand Down
3 changes: 3 additions & 0 deletions leads_jarvis/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Literal as _Literal

type Device = _Literal["cpu", "cuda"]

0 comments on commit d75c6fd

Please sign in to comment.