Skip to content

Commit

Permalink
view metrics in example
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Dec 12, 2024
1 parent 2e7af6d commit 17adc54
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 31 deletions.
105 changes: 78 additions & 27 deletions docs/examples/baseline_boxes.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion milliontrees/common/metrics/all_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _compute_element_wise(self, y_pred, y_true):

gt_boxes = gt[self.geometry_name]
pred_boxes = target_boxes[target_scores > self.score_threshold]
det_accuracy = torch.mean(torch.stack([ self._accuracy(gt_boxes,pred_boxes,iou_thr) for iou_thr in np.arange(0.5,0.51,0.05)]))
det_accuracy = torch.mean(torch.stack([ self._accuracy(gt_boxes,pred_boxes,iou_thr) for iou_thr in np.arange(0.4,0.41,0.05)]))
batch_results.append(det_accuracy)

return torch.tensor(batch_results)
Expand Down
2 changes: 1 addition & 1 deletion milliontrees/datasets/TreePoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self,
self._input_lookup = df.groupby('filename').apply(lambda x: x.index.values).to_dict()

# Point labels
self._y_array = df[["x", "y"]].values.astype(float)
self._y_array = df[["x", "y"]].values.astype(int)

# Labels -> just 'Tree'
self._n_classes = 1
Expand Down
3 changes: 1 addition & 2 deletions weights/DeepForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main():

## Evaluate the model
box_dataset = get_dataset("TreeBoxes", root_dir="/orange/ewhite/DeepForest/MillionTrees/")
box_test_data = box_dataset.get_subset("test")
box_test_data = box_dataset.get_subset("test",frac=0.1)

test_loader = get_eval_loader("standard", box_test_data, batch_size=16)

Expand All @@ -75,7 +75,6 @@ def main():
metadata, images, targets = batch
for image_metadata, image, image_targets in zip(metadata,images, targets):
basename = box_dataset._filename_id_to_code[int(image_metadata[0])]
#image_path = os.path.join(box_dataset._data_dir._str, "images",basename)
# Deepforest likes 0-255 data, channels first
channels_first = image.permute(1, 2, 0).numpy() * 255
pred = m.predict_image(channels_first)
Expand Down

0 comments on commit 17adc54

Please sign in to comment.