Skip to content

Commit

Permalink
[no ci] map device location to cpu/cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
ajkdrag committed Apr 10, 2024
1 parent d48aaa0 commit 53e8356
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/ocrtoolkit/integrations/doctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ def __init__(self, model, path, device, **kwargs):
from doctr.models.detection.predictor import DetectionPredictor

super().__init__(model, path, device)
self.doctr_base_predictor = _OCRPredictor()
kwargs["mean"] = kwargs.get("mean", model.cfg["mean"])
kwargs["std"] = kwargs.get("std", model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 2)
input_shape = model.cfg["input_shape"][1:]

self.doctr_base_predictor = _OCRPredictor()
self.predictor = DetectionPredictor(PreProcessor(input_shape, **kwargs), model)

def _predict(self, images: List[np.ndarray], **kwargs) -> List[DetectionResults]:
Expand Down
3 changes: 3 additions & 0 deletions src/ocrtoolkit/integrations/gcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class GCVModel(DetectionModel):
def __init__(self, client, path):
super().__init__(model=client, path=path)

def _map_location(self):
pass

def _predict(self, images: List[np.ndarray], **kwargs) -> List[DetectionResults]:
l_results = []
for image in images:
Expand Down
4 changes: 4 additions & 0 deletions src/ocrtoolkit/wrappers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def __init__(self, model, path=None, device="cpu", **kwargs):
self.model = model
self.path = path
self.device = device
self._map_location()

def _map_location(self):
self.model = self.model.to(self.device)
logger.info(f"Loaded model from {self.path}, to {self.device}")

def preprocess(self, images: List[np.ndarray], **kwargs) -> List[np.ndarray]:
Expand Down

0 comments on commit 53e8356

Please sign in to comment.