diff --git a/benchmark/evaluate_famous_models.py b/benchmark/evaluate_famous_models.py index 759d5d2..6f786be 100644 --- a/benchmark/evaluate_famous_models.py +++ b/benchmark/evaluate_famous_models.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import torch from torchvision import models diff --git a/benchmark/evaluate_rnn_models.py b/benchmark/evaluate_rnn_models.py index d211f91..8f669b9 100644 --- a/benchmark/evaluate_rnn_models.py +++ b/benchmark/evaluate_rnn_models.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import torch import torch.nn as nn diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..77a19dc 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index a91d9df..6b516fb 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import torch import torch.nn as nn @@ -54,6 +56,6 @@ def test_conv2d_random(self): flops, params = profile(net, inputs=(data,)) print(flops, params) - assert ( - flops == n * out_c * oh * ow // g * in_c * kh * kw - ), f"{flops} v.s. {n * out_c * oh * ow // g * in_c * kh * kw}" + assert flops == n * out_c * oh * ow // g * in_c * kh * kw, ( + f"{flops} v.s. {n * out_c * oh * ow // g * in_c * kh * kw}" + ) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 355ba5f..b05c081 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import torch import torch.nn as nn diff --git a/tests/test_relu.py b/tests/test_relu.py index a8293cf..fe4470c 100644 --- a/tests/test_relu.py +++ b/tests/test_relu.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import torch import torch.nn as nn diff --git a/tests/test_utils.py b/tests/test_utils.py index b7f754d..f9d8ea3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + from thop import utils diff --git a/thop/__init__.py b/thop/__init__.py index 314d897..708b75b 100644 --- a/thop/__init__.py +++ b/thop/__init__.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + __version__ = "2.0.13" diff --git a/thop/fx_profile.py b/thop/fx_profile.py index 48e8ab8..fe964a5 100644 --- a/thop/fx_profile.py +++ b/thop/fx_profile.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import logging from distutils.version import LooseVersion @@ -7,8 +9,7 @@ if LooseVersion(torch.__version__) < LooseVersion("1.8.0"): logging.warning( - f"torch.fx requires version higher than 1.8.0. " - f"But You are using an old version PyTorch {torch.__version__}. " + f"torch.fx requires version higher than 1.8.0. But You are using an old version PyTorch {torch.__version__}. " ) diff --git a/thop/profile.py b/thop/profile.py index f78d24f..0ffb911 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + from thop.rnn_hooks import * from thop.vision.basic_hooks import * diff --git a/thop/rnn_hooks.py b/thop/rnn_hooks.py index dbf4a01..aae8a98 100644 --- a/thop/rnn_hooks.py +++ b/thop/rnn_hooks.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import torch import torch.nn as nn from torch.nn.utils.rnn import PackedSequence diff --git a/thop/utils.py b/thop/utils.py index 845b54f..ee746ac 100644 --- a/thop/utils.py +++ b/thop/utils.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + from collections.abc import Iterable COLOR_RED = "91m" diff --git a/thop/vision/__init__.py b/thop/vision/__init__.py index e69de29..77a19dc 100644 --- a/thop/vision/__init__.py +++ b/thop/vision/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index 00068c7..06d8336 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import logging import torch.nn as nn diff --git a/thop/vision/calc_func.py b/thop/vision/calc_func.py index 9e534ac..7c66a1b 100644 --- a/thop/vision/calc_func.py +++ b/thop/vision/calc_func.py @@ -1,3 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + import warnings import numpy as np