diff --git a/conda_env.yml b/conda_env.yml index e638c7b..e70d284 100644 --- a/conda_env.yml +++ b/conda_env.yml @@ -1,96 +1,244 @@ name: video_features channels: - pytorch + - nvidia - conda-forge - defaults dependencies: - - _libgcc_mutex=0.1=main + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - alsa-lib=1.2.8=h166bdaf_0 - antlr-python-runtime=4.9.3=pyhd8ed1ab_1 - - attrs=22.1.0=pyh71513ae_1 - - av=8.0.2=py38he20a9df_1 - - backports=1.0=py_2 - - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - aom=3.5.0=h27087fc_0 + - attr=2.5.1=h166bdaf_1 + - av=10.0.0=py311h84dbf73_3 - blas=1.0=mkl - - bzip2=1.0.8=h516909a_3 - - ca-certificates=2022.6.15=ha878542_0 - - certifi=2022.6.15=py38h578d9bd_0 - - cffi=1.14.6=py38h400218f_0 - - cudatoolkit=11.0.221=h6bb024c_0 - - ffmpeg=4.3.1=h167e202_0 - - freetype=2.10.2=h5ab3b9f_0 - - ftfy=6.1.1=pyhd8ed1ab_0 - - gettext=0.19.8.1=h5e8e0c9_1 - - gmp=6.2.0=he1b5a44_2 - - gnutls=3.6.13=h79a8f9a_0 - - iniconfig=1.1.1=pyh9f0ad1d_0 - - intel-openmp=2020.2=254 - - jpeg=9b=h024ee3a_2 - - lame=3.100=h14c3975_1001 - - lcms2=2.11=h396b838_0 - - ld_impl_linux-64=2.33.1=h53a641e_7 - - libedit=3.1.20191231=h14c3975_1 - - libffi=3.3=he6710b0_2 - - libflac=1.3.3=he1b5a44_0 - - libgcc-ng=9.1.0=hdf63c60_0 - - libgfortran-ng=7.3.0=hdf63c60_0 - - libiconv=1.16=h516909a_0 - - libllvm10=10.0.1=he513fc3_3 - - libogg=1.3.2=h516909a_1002 - - libpng=1.6.37=hbc83047_0 - - libsndfile=1.0.29=he1b5a44_0 - - libstdcxx-ng=9.1.0=hdf63c60_0 - - libtiff=4.1.0=h2733197_1 - - libuv=1.40.0=h7b6447c_0 - - libvorbis=1.3.7=he1b5a44_0 - - llvmlite=0.36.0=py38h612dafd_4 - - lz4-c=1.9.2=he6710b0_1 - - mkl=2020.2=256 - - mkl-service=2.3.0=py38he904b0f_0 - - mkl_fft=1.2.0=py38h23d657b_0 - - mkl_random=1.1.1=py38h0573a6f_0 - - ncurses=6.2=he6710b0_1 - - nettle=3.4.1=h1bed415_1002 - - ninja=1.10.1=py38hfd86e86_0 - - numba=0.53.1=py38ha9443f7_0 - - numpy=1.19.1=py38hbc911f0_0 - - numpy-base=1.19.1=py38hfa32c7d_0 - - olefile=0.46=py_0 - - omegaconf=2.1.1=py38h578d9bd_1 - - openh264=2.1.1=h8b12597_0 - - openssl=1.1.1q=h7f8727e_0 - - packaging=21.3=pyhd8ed1ab_0 - - pillow=7.2.0=py38hb39fc2d_0 - - pip=20.2.2=py38_0 - - pluggy=1.0.0=py38h578d9bd_3 - - py=1.11.0=pyh6c4a22f_0 - - pycparser=2.21=pyhd8ed1ab_0 - - pyparsing=3.0.9=pyhd8ed1ab_0 - - pysoundfile=0.10.3.post1=pyhd3deb0d_0 - - pytest=7.1.2=py38h578d9bd_0 - - python=3.8.5=h7579374_1 - - python_abi=3.8=1_cp38 - - pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0 - - pyyaml=5.3.1=py38h8df0ef7_1 - - readline=8.0=h7b6447c_0 - - regex=2022.3.15=py38h7f8727e_0 - - resampy=0.2.2=py_0 - - scipy=1.5.2=py38h0b6359f_0 - - setuptools=49.6.0=py38_0 - - six=1.15.0=py_0 - - sqlite=3.33.0=h62c20be_0 - - tbb=2020.2=hc9558a2_0 - - tk=8.6.10=hbc83047_0 + - brotli-python=1.0.9=py311h6a678d5_7 + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.25.0=hd590300_0 + - ca-certificates=2023.12.12=h06a4308_0 + - cairo=1.16.0=ha61ee94_1014 + - certifi=2023.11.17=pyhd8ed1ab_0 + - cffi=1.16.0=py311h5eee18b_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - cryptography=41.0.7=py311hdda0065_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.3.101=0 + - cuda-runtime=12.1.0=0 + - dbus=1.13.6=h5008d03_3 + - exceptiongroup=1.2.0=pyhd8ed1ab_2 + - expat=2.5.0=hcb278e6_1 + - ffmpeg=5.1.2=gpl_h8dda1f0_106 + - fftw=3.3.10=nompi_hc118613_108 + - filelock=3.13.1=py311h06a4308_0 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_1 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - freeglut=3.2.2=h9c3ff4c_1 + - freetype=2.12.1=h4a9f257_0 + - ftfy=6.1.3=pyhd8ed1ab_0 + - gettext=0.21.1=h27087fc_0 + - giflib=5.2.1=h5eee18b_3 + - glib=2.78.3=hfc55251_0 + - glib-tools=2.78.3=hfc55251_0 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py311hc9b5ff0_0 + - gnutls=3.7.9=hb077bed_0 + - graphite2=1.3.13=h58526e2_1001 + - gst-plugins-base=1.22.0=h4243ec0_2 + - gstreamer=1.22.0=h25f0c4b_2 + - gstreamer-orc=0.4.34=hd590300_0 + - harfbuzz=6.0.0=h8e241bc_0 + - hdf5=1.14.0=nompi_hb72d44e_103 + - icu=70.1=h27087fc_0 + - idna=3.4=py311h06a4308_0 + - iniconfig=2.0.0=pyhd8ed1ab_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - jack=1.9.22=h11f4161_0 + - jasper=2.0.33=h0ff4b12_1 + - jinja2=3.1.2=py311h06a4308_0 + - jpeg=9e=h5eee18b_1 + - keyutils=1.6.1=h166bdaf_0 + - krb5=1.20.1=h81ceb04_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libaec=1.1.2=h59595ed_1 + - libblas=3.9.0=1_h86c2bf4_netlib + - libcap=2.67=he9d0100_0 + - libcblas=3.9.0=5_h92ddd45_netlib + - libclang=15.0.7=default_hb11cfb5_4 + - libclang13=15.0.7=default_ha2b6cf4_4 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.8.1.2=0 + - libcups=2.3.3=h36d4200_3 + - libcurand=10.3.4.107=0 + - libcurl=8.1.2=h409715c_0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdb=6.2.32=h9c3ff4c_0 + - libdeflate=1.17=h5eee18b_1 + - libdrm=2.4.114=h166bdaf_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=hd590300_2 + - libevent=2.1.10=h28343ad_4 + - libexpat=2.5.0=hcb278e6_1 + - libffi=3.4.4=h6a678d5_0 + - libflac=1.4.3=h59595ed_0 + - libgcc-ng=13.2.0=h807b86a_3 + - libgcrypt=1.10.3=hd590300_0 + - libgfortran-ng=13.2.0=h69a702a_3 + - libgfortran5=13.2.0=ha4646dd_3 + - libglib=2.78.3=h783c2da_0 + - libglu=9.0.0=he1b5a44_1001 + - libgomp=13.2.0=h807b86a_3 + - libgpg-error=1.47=h71f35ed_0 + - libiconv=1.17=hd590300_2 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - liblapack=3.9.0=5_h92ddd45_netlib + - liblapacke=3.9.0=5_h92ddd45_netlib + - libllvm14=14.0.6=hcd5def8_4 + - libllvm15=15.0.7=hadd5161_1 + - libnghttp2=1.58.0=h47da74e_0 + - libnpp=12.0.2.50=0 + - libnsl=2.0.1=hd590300_0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libogg=1.3.4=h7f98852_1 + - libopencv=4.7.0=py311h7a0761e_1 + - libopus=1.3.1=h7f98852_1 + - libpciaccess=0.17=h166bdaf_0 + - libpng=1.6.39=h5eee18b_0 + - libpq=15.3=hbcd7760_1 + - libprotobuf=3.21.12=hfc55251_2 + - libsndfile=1.2.2=hc60ed4a_1 + - libsqlite=3.44.2=h2797004_0 + - libssh2=1.11.0=h0841786_0 + - libstdcxx-ng=13.2.0=h7e041cc_3 + - libsystemd0=253=h8c4010b_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libtool=2.4.7=h27087fc_0 + - libudev1=253=h0b41bf4_1 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=2.38.1=h0b41bf4_0 + - libva=2.18.0=h0b41bf4_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libvpx=1.11.0=h9c3ff4c_3 + - libwebp=1.3.2=h11a3e52_0 + - libwebp-base=1.3.2=h5eee18b_0 + - libxcb=1.13=h7f98852_1004 + - libxkbcommon=1.5.0=h79f4944_1 + - libxml2=2.10.3=hca2bb57_4 + - libzlib=1.2.13=hd590300_5 + - llvm-openmp=14.0.6=h9e868ea_0 + - llvmlite=0.41.1=py311ha6695c7_0 + - lz4-c=1.9.4=h6a678d5_0 + - markupsafe=2.1.3=py311h5eee18b_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py311h5eee18b_1 + - mkl_fft=1.3.8=py311h5eee18b_0 + - mkl_random=1.2.4=py311hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpg123=1.32.4=h59595ed_0 + - mpmath=1.3.0=py311h06a4308_0 + - mysql-common=8.0.33=hf1915f5_6 + - mysql-libs=8.0.33=hca2cd23_6 + - ncurses=6.4=h6a678d5_0 + - nettle=3.9.1=h7ab15ed_0 + - networkx=3.1=py311h06a4308_0 + - nspr=4.35=h27087fc_0 + - nss=3.96=h1d7d5a4_0 + - numba=0.58.1=py311h96b013e_0 + - numpy=1.26.3=py311h08b1b3b_0 + - numpy-base=1.26.3=py311hf175353_0 + - omegaconf=2.3.0=pyhd8ed1ab_0 + - opencv=4.7.0=py311h38be061_1 + - openh264=2.3.1=hcb278e6_2 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.1.4=hd590300_0 + - p11-kit=0.24.1=hc5aa10d_0 + - packaging=23.2=pyhd8ed1ab_0 + - pcre2=10.42=hcad00b1_0 + - pillow=10.0.1=py311ha6cbd5a_0 + - pip=23.3.1=py311h06a4308_0 + - pixman=0.43.0=h59595ed_0 + - pluggy=1.3.0=pyhd8ed1ab_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - pulseaudio=16.1=hcb278e6_3 + - pulseaudio-client=16.1=h5195f5e_3 + - pulseaudio-daemon=16.1=ha8d29e2_3 + - py-opencv=4.7.0=py311h781c19f_1 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=23.2.0=py311h06a4308_0 + - pysocks=1.7.1=py311h06a4308_0 + - pysoundfile=0.12.1=pyhd8ed1ab_0 + - pytest=7.4.4=pyhd8ed1ab_0 + - python=3.11.6=hab00c5b_0_cpython + - python_abi=3.11=4_cp311 + - pytorch=2.1.2=py3.11_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.1=py311h5eee18b_0 + - qt-main=5.15.8=h5d23da1_6 + - readline=8.2=h5eee18b_0 + - regex=2023.12.25=py311h459d7ec_0 + - requests=2.31.0=py311h06a4308_0 + - resampy=0.4.2=pyhd8ed1ab_0 + - scipy=1.11.4=py311h64a7726_0 + - setuptools=68.2.2=py311h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - svt-av1=1.4.1=hcb278e6_0 + - sympy=1.12=pypyh9d50eac_103 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.13=noxft_h4845f30_101 - tomli=2.0.1=pyhd8ed1ab_0 - - torchaudio=0.7.2=py38 - - torchvision=0.8.2=py38_cu110 - - tqdm=4.49.0=py_0 - - typing_extensions=4.0.1=pyha770c72_0 - - wcwidth=0.2.5=pyh9f0ad1d_2 - - wheel=0.35.1=py_0 - - x264=1!152.20180806=h14c3975_0 - - xz=5.2.5=h7b6447c_0 - - yaml=0.2.5=h516909a_0 - - zlib=1.2.11=h7b6447c_3 - - zstd=1.4.5=h9ceee32_0 - - pip: - - opencv-python==4.4.0.44 + - torchaudio=2.1.2=py311_cu121 + - torchtriton=2.1.0=py311 + - torchvision=0.16.2=py311_cu121 + - tqdm=4.66.1=pyhd8ed1ab_0 + - typing_extensions=4.9.0=py311h06a4308_1 + - tzdata=2023d=h04d1e81_0 + - urllib3=1.26.18=py311h06a4308_0 + - wcwidth=0.2.13=pyhd8ed1ab_0 + - wheel=0.41.2=py311h06a4308_0 + - x264=1!164.3095=h166bdaf_2 + - x265=3.5=h924138e_3 + - xcb-util=0.4.0=h516909a_0 + - xcb-util-image=0.4.0=h166bdaf_0 + - xcb-util-keysyms=0.4.0=h516909a_0 + - xcb-util-renderutil=0.3.9=h166bdaf_0 + - xcb-util-wm=0.4.1=h516909a_0 + - xkeyboard-config=2.38=h0b41bf4_0 + - xorg-fixesproto=5.0=h7f98852_1002 + - xorg-inputproto=2.3.2=h7f98852_1002 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.1.1=hd590300_0 + - xorg-libsm=1.2.4=h7391055_0 + - xorg-libx11=1.8.4=h0b41bf4_0 + - xorg-libxau=1.0.11=hd590300_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxfixes=5.0.3=h7f98852_1004 + - xorg-libxi=1.7.10=h7f98852_0 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.4.5=h5eee18b_0 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.13=hd590300_5 + - zstd=1.5.5=hc292b87_0 diff --git a/docs/meta/install_conda.md b/docs/meta/install_conda.md new file mode 100644 index 0000000..fa2af2c --- /dev/null +++ b/docs/meta/install_conda.md @@ -0,0 +1,10 @@ +Just steps to install conda and create a new environment from scratch. +```bash +conda create -n video_features +conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia +conda install -c conda-forge omegaconf scipy tqdm pytest opencv +# +CLIP +conda install -c conda-forge ftfy regex +# vggish +conda install -c conda-forge resampy pysoundfile +``` diff --git a/docs/models/i3d.md b/docs/models/i3d.md index 26df519..1ecac49 100644 --- a/docs/models/i3d.md +++ b/docs/models/i3d.md @@ -61,12 +61,12 @@ Activate the environment conda activate video_features ``` -and extract features from `./sample/v_ZNVhz7ctTq0.mp4` video and show the predicted classes +and extract features from `./sample/v_GGSY1Qvo990.mp4` video and show the predicted classes ```bash python main.py \ feature_type=i3d \ device="cuda:0" \ - video_paths="[./sample/v_ZNVhz7ctTq0.mp4]" \ + video_paths="[./sample/v_GGSY1Qvo990.mp4]" \ show_pred=true ``` diff --git a/docs/models/vggish.md b/docs/models/vggish.md index 61d6c34..19ef588 100644 --- a/docs/models/vggish.md +++ b/docs/models/vggish.md @@ -74,15 +74,21 @@ python main.py \ --- -## Difference between Tensorflow and PyTorch implementations +## Difference between TensorFlow and PyTorch implementations +VGGish was originally implemented in TensorFlow. +We use the PyTorch implementation by +[harritaylor/torchvggish](https://github.com/harritaylor/torchvggish/tree/f70241ba) +The difference in values between the PyTorch and Tensorflow implementation is negligible. +However, after updating the versions of the dependencies, the values are slightly different. +If you wish to use the old implementation, you can use the conda environment at the `b21f330` commit or earlier. +The following table shows the difference in values. ``` python main.py \ feature_type=vggish \ - on_extraction=save_numpy \ - file_with_video_paths=./sample/sample_video_paths.txt + video_paths="[./sample/v_GGSY1Qvo990.mp4]" -TF (./sample/v_GGSY1Qvo990.mp4): +Original (./sample/v_GGSY1Qvo990.mp4): [[0. 0.04247099 0.09079538 ... 0. 0.18485409 0. ] [0. 0. 0. ... 0. 0.5720243 0.5475726 ] [0. 0.00705254 0.15173683 ... 0. 0.33540994 0.10572422] @@ -92,7 +98,7 @@ TF (./sample/v_GGSY1Qvo990.mp4): [0. 0.31638345 0. ... 0. 0. 0. ]] max: 2.31246495; mean: 0.13741589; min: 0.00000000 -PyTorch (./sample/v_GGSY1Qvo990.mp4): +b21f330 and ealier (./sample/v_GGSY1Qvo990.mp4): [[0. 0.04247095 0.09079528 ... 0. 0.18485469 0. ] [0. 0. 0. ... 0. 0.5720252 0.5475726 ] [0. 0.0070536 0.1517372 ... 0. 0.33541012 0.10572463] @@ -102,21 +108,16 @@ PyTorch (./sample/v_GGSY1Qvo990.mp4): [0. 0.31638315 0. ... 0. 0. 0. ]] max: 2.31246495; mean: 0.13741589; min: 0.00000000 -(PyTorch - TensorFlow).abs() -tensor([[0.0000e+00, 4.4703e-08, 1.0431e-07, ..., 0.0000e+00, 5.9605e-07, - 0.0000e+00], - [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 8.9407e-07, - 0.0000e+00], - [0.0000e+00, 1.0580e-06, 3.7253e-07, ..., 0.0000e+00, 1.7881e-07, - 4.1723e-07], - ..., - [0.0000e+00, 0.0000e+00, 8.6427e-07, ..., 0.0000e+00, 2.3097e-07, - 0.0000e+00], - [0.0000e+00, 1.4454e-06, 8.0466e-07, ..., 0.0000e+00, 0.0000e+00, - 0.0000e+00], - [0.0000e+00, 2.9802e-07, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00, - 0.0000e+00]]) -max: 4.0531e-06; mean: 2.2185e-07; min: 0.00000000 +Current (./sample/v_GGSY1Qvo990.mp4): +[[0. 0.0752698 0.12985817 ... 0. 0.18340725 0.00647891] + [0. 0. 0. ... 0. 0.5479691 0.6105871 ] + [0. 0.03563304 0.1507446 ... 0. 0.20983526 0.15856776] + ... + [0. 0. 0.3077196 ... 0. 0.08271158 0.03223182] + [0. 0.15476668 0.25240228 ... 0. 0. 0. ] + [0. 0.3711498 0. ... 0. 0. 0. ]] +max: 2.41924119; mean: 0.13830526; min: 0.00000000 + ``` --- diff --git a/models/r21d/extract_r21d.py b/models/r21d/extract_r21d.py index b630e72..680907f 100644 --- a/models/r21d/extract_r21d.py +++ b/models/r21d/extract_r21d.py @@ -2,11 +2,14 @@ import numpy as np import torch + import torchvision +from torchvision.io.video import read_video +import torchvision.models as models + from models._base.base_extractor import BaseExtractor from models.transforms import (CenterCrop, Normalize, Resize, ToFloatTensorInZeroOne) -from torchvision.io.video import read_video from utils.io import reencode_video_with_diff_fps from utils.utils import form_slices, show_predictions_on_dataset @@ -47,12 +50,6 @@ def __init__(self, args) -> None: self.step_size = self.model_def['step_size'] if self.stack_size is None: self.stack_size = self.model_def['stack_size'] - self.transforms = torchvision.transforms.Compose([ - ToFloatTensorInZeroOne(), - Resize((128, 171)), - Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), - CenterCrop((112, 112)) - ]) self.show_pred = args.show_pred self.output_feat_keys = [self.feature_type] self.name2module = self.load_model() @@ -102,8 +99,16 @@ def load_model(self) -> Dict[str, torch.nn.Module]: Returns: Dict[str, torch.nn.Module]: model-agnostic dict holding modules for extraction and show_pred """ + self.transforms = torchvision.transforms.Compose([ + ToFloatTensorInZeroOne(), + Resize((128, 171)), + Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), + CenterCrop((112, 112)), + ]) + if self.model_name == 'r2plus1d_18_16_kinetics': - model = torchvision.models.video.r2plus1d_18(pretrained=True) + weights_key = 'DEFAULT' + model = models.get_model('r2plus1d_18', weights=weights_key) else: model = torch.hub.load( self.model_def['repo'], @@ -111,6 +116,7 @@ def load_model(self) -> Dict[str, torch.nn.Module]: num_classes=self.model_def['num_classes'], pretrained=True, ) + model = model.to(self.device) model.eval() # save the pre-trained classifier for show_preds and replace it in the net with identity diff --git a/models/raft/raft_src/corr.py b/models/raft/raft_src/corr.py index e30e3f3..e914aee 100644 --- a/models/raft/raft_src/corr.py +++ b/models/raft/raft_src/corr.py @@ -36,7 +36,7 @@ def __call__(self, coords): corr = self.corr_pyramid[i] dx = torch.linspace(-r, r, 2 * r + 1) dy = torch.linspace(-r, r, 2 * r + 1) - delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1).to(coords.device) centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) diff --git a/models/raft/raft_src/utils/utils.py b/models/raft/raft_src/utils/utils.py index e6f2467..2bc475b 100644 --- a/models/raft/raft_src/utils/utils.py +++ b/models/raft/raft_src/utils/utils.py @@ -73,7 +73,7 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False): def coords_grid(batch, ht, wd): - coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd), indexing='ij') coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) diff --git a/models/resnet/extract_resnet.py b/models/resnet/extract_resnet.py index c81b5ef..25a684c 100644 --- a/models/resnet/extract_resnet.py +++ b/models/resnet/extract_resnet.py @@ -24,16 +24,8 @@ def __init__(self, args: omegaconf.DictConfig) -> None: extraction_total=args.extraction_total, show_pred=args.show_pred, ) - self.transforms = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) self.name2module = self.load_model() - def load_model(self) -> Dict[str, torch.nn.Module]: """Defines the models, loads checkpoints, sends them to the device. @@ -43,12 +35,14 @@ def load_model(self) -> Dict[str, torch.nn.Module]: Returns: Dict[str, torch.nn.Module]: model-agnostic dict holding modules for extraction and show_pred """ - try: - model = getattr(models, self.model_name) - except AttributeError: - raise NotImplementedError(f'Model {self.model_name} not found.') + # TODO: could be 'DEFAULT' to unify with other models from tv + weights_key = 'IMAGENET1K_V1' + model = models.get_model(self.model_name, weights=weights_key) + self.transforms = transforms.Compose([ + transforms.ToPILImage(), + models.get_model_weights(self.model_name)[weights_key].transforms(), + ]) - model = model(pretrained=True) model = model.to(self.device) model.eval() # save the pre-trained classifier for show_preds and replace it in the net with identity diff --git a/models/vggish/vggish_src/vggish_input.py b/models/vggish/vggish_src/vggish_input.py index 21e97c0..ec38e41 100644 --- a/models/vggish/vggish_src/vggish_input.py +++ b/models/vggish/vggish_src/vggish_input.py @@ -61,18 +61,13 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): # Frame features into examples. features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS - example_window_length = int(round( - vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) - example_hop_length = int(round( - vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) - log_mel_examples = mel_features.frame( - log_mel, - window_length=example_window_length, - hop_length=example_hop_length) + example_window_length = int(round(vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round(vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame(log_mel, window_length=example_window_length, + hop_length=example_hop_length) if return_tensor: - log_mel_examples = torch.tensor( - log_mel_examples, requires_grad=True)[:, None, :, :].float() + log_mel_examples = torch.tensor(log_mel_examples, requires_grad=True)[:, None, :, :].float() return log_mel_examples