Skip to content

Commit

Permalink
Merge pull request #5 from Giskard-AI/dataset-first
Browse files Browse the repository at this point in the history
[GSK-2355, GSK-2349, GSK-2373] DataLoader structure
  • Loading branch information
rabah-khalek authored Dec 13, 2023
2 parents c92941c + 6b6e604 commit 0f76a9f
Show file tree
Hide file tree
Showing 20 changed files with 1,347 additions and 549 deletions.
259 changes: 210 additions & 49 deletions examples/criteria1_partial_faces.ipynb

Large diffs are not rendered by default.

269 changes: 248 additions & 21 deletions examples/ex1_draw_landmarks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,31 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"from face_alignment import FaceAlignment, LandmarksType\n",
"\n",
"from loreal_poc.datasets.dataset_300W import Dataset300W\n",
"from loreal_poc.dataloaders.loaders import DataLoader300W\n",
"from loreal_poc.visualisation.draw import draw_marks\n",
"from loreal_poc.marks.facial_parts import FacialParts\n",
"\n",
"import torch\n",
"from loreal_poc.tests.performance import NMEMean, NMEs, Es, MEMean, MEStd, NMEMean, NMEStd\n",
"from loreal_poc.models.wrappers import FaceAlignmentWrapper"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"Benchmark"
Expand All @@ -28,27 +38,59 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"ds = Dataset300W(dir_path=\"300W/sample\")"
"ds = DataLoader300W(dir_path=\"300W/sample\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'cuda'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"device"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"L'Oreal"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
Expand All @@ -57,24 +99,202 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NMEMean:0.06233510979950631\n",
"NMEs:[0.07531859 0.08214728 0.0453659 0.05566845 0.05317534]\n",
"MEMean:12.172810987943965\n",
"MEStd:11.185152368023797\n",
"NMEMean:0.06233510979950631\n",
"NMEStd:0.013981365495877753\n",
"Es:[[12.64911064 13.98116651 17.64966983 27.96163264 26.88603439 27.2037963\n",
" 24.99632173 18.46327352 15.49091508 20.38655008 28.96824946 18.68070154\n",
" 21.00649238 23.34713518 18.78812417 8.91617704 3.15269091 9.84471132\n",
" 7.62599167 7.46495311 9.76452231 12.31256598 9.71432288 8.47259281\n",
" 4.12252787 3.64339539 10.99422539 4.90101469 9.26316129 10.76335747\n",
" 8.19016203 1.67062533 4.17529041 5.04662194 2.44197502 6.69321634\n",
" 5.4613892 1.78589613 5.436992 6.19588936 5.81028786 3.90036639\n",
" 7.06849355 4.17897894 2.98417158 7.79152681 7.53631647 6.44722553\n",
" 5.130614 5.64957945 4.64509042 6.76971801 8.64684104 5.60776471\n",
" 4.11845894 7.29749663 6.72116456 8.54999515 10.31351356 10.24740011\n",
" 6.52120334 9.54093418 6.00003675 7.30512594 4.84540535 6.80354687\n",
" 7.78893401 7.8107821 ]\n",
" [40.93702834 41.58638769 49.05876628 62.43859257 62.05374829 52.21021142\n",
" 40.26089464 31.43202795 13.07481751 26.67030898 35.6546853 34.86199806\n",
" 48.69191203 46.26268541 47.57961433 40.69200772 39.28106355 25.04677873\n",
" 11.40572808 20.12025094 26.74874025 36.21542351 9.46296359 8.62886041\n",
" 4.04043624 12.76876192 21.41182535 10.71742852 12.47921288 19.29799505\n",
" 6.64400271 4.95700363 3.40589621 1.85204887 2.10597863 5.56869078\n",
" 3.18053911 7.35603983 3.61900553 12.1698484 3.20499641 2.34261051\n",
" 2.19290515 6.41723507 8.13547866 8.35658411 6.76000621 6.76312627\n",
" 7.29803789 5.87968026 6.25955246 5.7234369 10.44651985 11.72808672\n",
" 9.14805362 6.07040773 3.57118748 1.53121684 2.69841435 1.9834407\n",
" 2.55976034 2.33400771 6.58163901 5.57154602 3.13813081 5.57154602\n",
" 9.35553163 4.60603865]\n",
" [13.45349085 10.30106645 11.01410137 6.26892989 10.08221444 9.62870552\n",
" 13.91076306 19.44744623 4.7848186 26.53777282 39.41954837 37.27968125\n",
" 37.12093385 50.05728626 33.94179668 36.12374413 58.6922866 39.26320014\n",
" 32.05138008 26.7678165 18.47045912 18.77988892 18.35038455 12.73079404\n",
" 23.28238014 28.49325661 26.7977367 11.40577244 21.35000037 19.43445613\n",
" 12.83553992 12.91981339 7.19348177 3.78898997 2.0586855 5.38131592\n",
" 16.76414522 4.6575879 7.59055736 13.538641 9.41119509 6.31696699\n",
" 13.41539563 0.61719122 13.49035782 9.9823623 4.35848334 0.95107991\n",
" 10.27132324 12.42952827 2.63886813 4.52813703 15.06713712 4.17462094\n",
" 20.30407747 15.20066433 5.88290481 12.20885064 8.82584715 2.53652518\n",
" 6.26509122 11.55412502 10.08891 9.17508496 12.85097179 1.33498464\n",
" 9.63094829 3.62571441]\n",
" [29.04032109 17.48195241 19.03426544 27.82319209 25.30004603 23.27956516\n",
" 29.08300156 30.99744725 19.59750456 11.88061699 2.42227331 4.46436076\n",
" 3.99741879 5.80222457 5.14764383 1.94772534 8.67356634 11.53738861\n",
" 8.59115976 9.078848 14.88060768 20.33959097 14.21931334 11.10721554\n",
" 6.36396771 0.58 0.9551513 6.12015114 10.19002596 11.39568032\n",
" 8.03452114 2.99413644 6.64243367 6.92942949 8.96173477 6.78312922\n",
" 4.09730899 2.86064765 3.59209187 5.20874284 0.60418623 4.28686925\n",
" 4.38583401 2.661242 5.17148586 5.26838391 7.96725925 5.48241015\n",
" 1.20838942 3.48755287 3.12345098 2.65753363 5.37117725 3.93241529\n",
" 10.19287815 5.94690516 6.1720273 4.32236972 3.22102002 5.90958653\n",
" 1.79587305 2.53842037 3.96310282 2.10230826 3.93678359 5.14862156\n",
" 5.61863124 5.34152038]\n",
" [ 5.68849541 5.26728858 7.04663189 7.67355993 7.86368571 6.80795799\n",
" 13.33560647 22.06216139 8.82135052 14.6929951 20.49964109 19.72341809\n",
" 21.020228 19.86263628 17.12977717 12.26901308 13.65741366 12.30100492\n",
" 9.94617615 12.59364764 23.29417878 35.61680328 16.02679372 7.04881408\n",
" 6.5965112 7.31660857 3.39014012 11.39469148 18.12316829 22.73563575\n",
" 17.73809587 11.91193099 2.93207793 4.28454957 8.64823248 16.77812796\n",
" 11.6476298 2.34872646 6.37660105 9.59164079 10.75729009 6.04159515\n",
" 8.56501932 14.04172187 14.55859461 16.31637438 14.81228146 14.85418251\n",
" 4.78193873 4.26139742 2.85773512 3.99231712 8.88610207 7.30510404\n",
" 14.35054745 12.45694826 4.71412558 5.38749265 7.33167682 6.02112631\n",
" 5.59434357 2.43333927 3.83328293 2.25857411 9.75475417 5.66155076\n",
" 3.83328293 3.94273255]]\n"
]
}
],
"source": [
"prediction = model.predict(ds)\n",
"marks = ds.all_marks\n",
"for metric in [NMEMean, NMEs, MEMean, MEStd, NMEMean, NMEStd, Es]:\n",
" print(f\"{metric.__name__}:{metric.get(prediction, marks)}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"left eye\n",
"NMEMean:0.03064837061887285\n",
"NMEs:[0.03639644 0.02607136 0.02883627 0.02288638 0.0390514 ]\n",
"MEMean:6.205209482416183\n",
"MEStd:3.7323675345101868\n",
"NMEMean:0.03064837061887285\n",
"NMEStd:0.006134046369089966\n",
"Es:[[ nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" 5.4613892 1.78589613 5.436992 6.19588936 5.81028786 3.90036639\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan]\n",
" [ nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" 3.18053911 7.35603983 3.61900553 12.1698484 3.20499641 2.34261051\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan]\n",
" [ nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" 16.76414522 4.6575879 7.59055736 13.538641 9.41119509 6.31696699\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan]\n",
" [ nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" 4.09730899 2.86064765 3.59209187 5.20874284 0.60418623 4.28686925\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan]\n",
" [ nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" 11.6476298 2.34872646 6.37660105 9.59164079 10.75729009 6.04159515\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan]]\n"
]
}
],
"source": [
"prediction = model.predict(ds, facial_part=FacialParts.left_eye)\n",
"marks = ds.all_marks\n",
"\n",
"print(FacialParts.left_eye.name)\n",
"for metric in [NMEMean, NMEs, MEMean, MEStd, NMEMean, NMEStd, Es]:\n",
" print(f\"{metric.__name__}:{metric.get(prediction, marks)}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"chosen_idx = 4\n",
"image = ds.all_images[chosen_idx]\n",
"ground_truth_landmarks = ds.all_marks[chosen_idx, :, :]\n",
"predictions = model.predict(ds).prediction\n",
"loreal_landmarks = predictions[chosen_idx]"
"predictions = model.predict(ds)\n",
"image, ground_truth_landmarks, meta = ds[chosen_idx]\n",
"loreal_landmarks = predictions.prediction[chosen_idx]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 8,
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
Expand All @@ -85,19 +305,26 @@
"<PIL.Image.Image image mode=RGB size=650x454>"
]
},
"execution_count": 5,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_marks(image, [ds.all_marks[chosen_idx, :, :], loreal_landmarks], [\"green\", \"red\"])"
"draw_marks(image, [ground_truth_landmarks, loreal_landmarks], [\"green\", \"red\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -111,9 +338,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
"nbformat_minor": 4
}
Loading

0 comments on commit 0f76a9f

Please sign in to comment.