Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed imresized to PIL image resize #81

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ __pycache__/
# Video and data folders
kitti_data/
model_data_keras*/
nfs_data/
PredNet-20200617T064538Z-001/
PredNet-DO-results/

# Distribution / packaging
.Python
Expand Down
1 change: 1 addition & 0 deletions GitHub
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"GitHub","provenance":[],"mount_file_id":"1OxEIirVv0ZF6hQOA3Y7-EmIeju8uDxoH","authorship_tag":"ABX9TyOjl3RxuGE6iNEfNzsivZox"},"kernelspec":{"name":"python3","display_name":"Python 3"}},"cells":[{"cell_type":"code","metadata":{"id":"7EvCibyDJR3o","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364774354,"user_tz":-600,"elapsed":2571,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"8ef250aa-09d3-43f1-b847-1a9550d5890f"},"source":["!ls"],"execution_count":1,"outputs":[{"output_type":"stream","text":["drive sample_data\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"CYcAolnAJUxA","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364774355,"user_tz":-600,"elapsed":2555,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"90dbbdf3-40ec-4073-c243-3f4b25eb2dee"},"source":["%cd drive/My Drive/PredNet/"],"execution_count":2,"outputs":[{"output_type":"stream","text":["/content/drive/My Drive/PredNet\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"mqydiaXBJ-Jz","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364776352,"user_tz":-600,"elapsed":4516,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"1f1ca4b9-d28a-4798-bc3f-9af9c458ffd5"},"source":["!git clone https://github.com/edwardmfho/prednet.git"],"execution_count":3,"outputs":[{"output_type":"stream","text":["fatal: destination path 'prednet' already exists and is not an empty directory.\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"hTFj6t0lKPOK","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364778804,"user_tz":-600,"elapsed":6943,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"a39cc145-fc86-47d6-f668-278508228a30"},"source":["!ls"],"execution_count":4,"outputs":[{"output_type":"stream","text":["GitHub\tprednet\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"2QwNJDovKUx8","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364780918,"user_tz":-600,"elapsed":9029,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"16a8b471-46fa-4b5b-a871-34df78fd2610"},"source":["!python process_kitti.py"],"execution_count":5,"outputs":[{"output_type":"stream","text":["python3: can't open file 'process_kitti.py': [Errno 2] No such file or directory\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"aLJ57TQMKjKq","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364782668,"user_tz":-600,"elapsed":10761,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"236d3ab0-802b-499c-8d1b-c217195798f6"},"source":["!ls"],"execution_count":6,"outputs":[{"output_type":"stream","text":["GitHub\tprednet\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"vzoOi47XKlb0","colab_type":"code","colab":{},"executionInfo":{"status":"ok","timestamp":1592364784114,"user_tz":-600,"elapsed":12175,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}}},"source":["!cd prednet"],"execution_count":7,"outputs":[]},{"cell_type":"code","metadata":{"id":"5gpCQwH5Knqj","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364786035,"user_tz":-600,"elapsed":14066,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"a5183976-dbb4-4271-e252-dbafea011c15"},"source":["!pwd"],"execution_count":8,"outputs":[{"output_type":"stream","text":["/content/drive/My Drive/PredNet\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"Xq9xpG0QKqpn","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364788283,"user_tz":-600,"elapsed":16291,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"259febba-23df-4a01-8c22-5a0a4b75ae3e"},"source":["!cd /prednet"],"execution_count":9,"outputs":[{"output_type":"stream","text":["/bin/bash: line 0: cd: /prednet: No such file or directory\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"AbUCgYDLKsb3","colab_type":"code","colab":{},"executionInfo":{"status":"ok","timestamp":1592364790216,"user_tz":-600,"elapsed":18185,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}}},"source":["!cd prednet"],"execution_count":10,"outputs":[]},{"cell_type":"code","metadata":{"id":"JUz7aFENKztc","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364790217,"user_tz":-600,"elapsed":18168,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"c5c7d00e-7bb9-46fb-aee9-42e01f1e806c"},"source":["pwd"],"execution_count":11,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'/content/drive/My Drive/PredNet'"]},"metadata":{"tags":[]},"execution_count":11}]},{"cell_type":"code","metadata":{"id":"hTfEjVUJK1Tr","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1592364792318,"user_tz":-600,"elapsed":20250,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}},"outputId":"abb7232a-c628-4366-ccc5-0879ca80fb75"},"source":["!python /prednet/process_kitti.py"],"execution_count":12,"outputs":[{"output_type":"stream","text":["python3: can't open file '/prednet/process_kitti.py': [Errno 2] No such file or directory\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"pdNHFEBhK9_h","colab_type":"code","colab":{},"executionInfo":{"status":"ok","timestamp":1592364792319,"user_tz":-600,"elapsed":20243,"user":{"displayName":"Edward Ho","photoUrl":"","userId":"18197519276397766373"}}},"source":[""],"execution_count":12,"outputs":[]}]}
Empty file modified License.txt
100755 → 100644
Empty file.
Empty file modified README.md
100755 → 100644
Empty file.
Empty file modified data_utils.py
100755 → 100644
Empty file.
22 changes: 22 additions & 0 deletions download_20bn_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-00?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=42e35eacc0aa19106ac451f5e1507eb27af7009697eb7cb631e2f9c202390f04 --output 20bn-something-something-v2-00
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-01?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=f3587eff473816ec00c85aa79dc5605d8dbf5ff14b6f0c379d20ce7bbd656546 --output 20bn-something-something-v2-01
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-02?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=4efcfb2814977280ed6ec125af810b3a63ec844cc783f0d30dd977c2d092793d --output 20bn-something-something-v2-02
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-03?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=6c4f7ed6264e7d16be22746f2d05329d08fe7c2b00fcf481ea0b8ac7e43f2dc9 --output 20bn-something-something-v2-03
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-04?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=4d524b90f85c3cc265fd7df02a6c30b5d52511051f23b185e6cde67d4e9746af --output 20bn-something-something-v2-04
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-05?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=deb49d0e6b7c20211f6a10783549e487ca9b0c0116e41543b9ac84fd5f735e10 --output 20bn-something-something-v2-05
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-06?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=0cc3227629bd2506221d1938f12784714efcf4718c3e406715cb52c3d14c0cd8 --output 20bn-something-something-v2-06
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-07?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=9f5f9011317179238d101d5596632524c5dd771074e3ec745aabdf9f1e4e2db1 --output 20bn-something-something-v2-07
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-08?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141133Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=400d0ecbd013688f620e6162ad9f231b92b136c3e9d23756b190f890978d5f7d --output 20bn-something-something-v2-08
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-09?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=b50f40ec2bcfc06b8530c35b85b0871f113c2b1ef659027b5dc1057553e1bed1 --output 20bn-something-something-v2-09
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-10?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=c95ca1e2e40a9238c1d0e918f2be6f23d35aa3873985d7b9a2a42c829d9363e6 --output 20bn-something-something-v2-10
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-11?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=fdd3b45f0f8cde0e9ba41944e027f139a6fd8477fb4a6acede36ce1466359955 --output 20bn-something-something-v2-11
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-12?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=4949019d51e5c16823592fbccf1a6712c27b3e92f06ace0114a6f123d0c5ecec --output 20bn-something-something-v2-12
curl -c ttps://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-13?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=2db571736b025fb8f3a49d5947a4b336a09b89323eaf16906a99d03f66bd932c --output 20bn-something-something-v2-13
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-14?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=334d7b8f07d9c162cef6c4dc03ea9471d23d30101c89e4811f44d0152eb25a7d --output 20bn-something-something-v2-14
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-15?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=12455968e41d8a736d37d1ef16c84a2ec9215e6dfa2ab55614f34875b27a25bf --output 20bn-something-something-v2-15
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-16?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=bf4d2852e0159d61f6ee637251fa82c173bc32af7c4ca634bcdc0d6aa08f2a75 --output 20bn-something-something-v2-16
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-17?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=8f3b59c1ef1bda14cbe5c772ac8522c8e6122667c567a0fcaba5f3b29f046715 --output 20bn-something-something-v2-17
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-18?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=e057ac8f17d2d7486b80108e422bab5290e02360bb1dd3ff9ace26d35785e42d --output 20bn-something-something-v2-18
curl -c https://20bn-data-packages.s3.eu-west-1.amazonaws.com/something-something/v2/20bn-something-something-v2-19?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAJ2PXKOHYMBEGX4UA%2F20200624%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20200624T141411Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=bca3c648abfa7214925 --output 20bn-something-something-v2-19

cat 20bn-something-something-v2-?? | tar zx
Empty file modified download_data.sh
100755 → 100644
Empty file.
Empty file modified download_models.sh
100755 → 100644
Empty file.
Empty file modified environment.yml
100755 → 100644
Empty file.
Empty file modified keras_utils.py
100755 → 100644
Empty file.
Binary file added kitti_data - Shortcut.lnk
Binary file not shown.
Empty file modified kitti_evaluate.py
100755 → 100644
Empty file.
Empty file modified kitti_extrap_finetune.py
100755 → 100644
Empty file.
Empty file modified kitti_settings.py
100755 → 100644
Empty file.
7 changes: 7 additions & 0 deletions kitti_train.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@
history = model.fit_generator(train_generator, samples_per_epoch / batch_size, nb_epoch, callbacks=callbacks,
validation_data=val_generator, validation_steps=N_seq_val / batch_size)



if save_model:
json_string = model.to_json()
with open(json_file, "w") as f:
f.write(json_string)

loss_history = history.history["loss"]
numpy_loss_history = np.array(loss_history)
np.savetxt("loss_history_prednet.txt", numpy_loss_history, delimiter=",")

88 changes: 88 additions & 0 deletions kitti_train_prednet_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
'''
Train PredNet on KITTI sequences. (Geiger et al. 2013, http://www.cvlibs.net/datasets/kitti/)
'''

import os
import numpy as np
np.random.seed(123)
from six.moves import cPickle

from keras import backend as K
from keras.models import Model
from keras.layers import Input, Dense, Flatten
from keras.layers import LSTM
from keras.layers import TimeDistributed
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import Adam

from prednet-plus import PredNet
from data_utils import SequenceGenerator
from kitti_settings import *


save_model = True # if weights will be saved
weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights.hdf5') # where weights will be saved
json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json')

# Data files
train_file = os.path.join(DATA_DIR, 'X_train.hkl')
train_sources = os.path.join(DATA_DIR, 'sources_train.hkl')
val_file = os.path.join(DATA_DIR, 'X_val.hkl')
val_sources = os.path.join(DATA_DIR, 'sources_val.hkl')

# Training parameters
nb_epoch = 150
batch_size = 4
samples_per_epoch = 500
N_seq_val = 100 # number of sequences to use for validation

# Model parameters
n_channels, im_height, im_width = (3, 128, 160)
input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels)
stack_sizes = (n_channels, 48, 96, 192)
R_stack_sizes = stack_sizes
A_filt_sizes = (3, 3, 3)
Ahat_filt_sizes = (3, 3, 3, 3)
R_filt_sizes = (3, 3, 3, 3)
layer_loss_weights = np.array([1., 0., 0., 0.]) # weighting for each layer in final loss; "L_0" model: [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1]
layer_loss_weights = np.expand_dims(layer_loss_weights, 1)
nt = 10 # number of timesteps used for sequences in training
time_loss_weights = 1./ (nt - 1) * np.ones((nt,1)) # equally weight all timesteps except the first
time_loss_weights[0] = 0


prednet = PredNet(stack_sizes, R_stack_sizes,
A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
output_mode='error', return_sequences=True)

inputs = Input(shape=(nt,) + input_shape)
errors = prednet(inputs) # errors will be (batch_size, nt, nb_layers)
errors_by_time = TimeDistributed(Dense(1, trainable=False), weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors) # calculate weighted error by layer
errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt)
final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time
model = Model(inputs=inputs, outputs=final_errors)
model.compile(loss='mean_absolute_error', optimizer='adam')

train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True)
val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val)

lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001 # start with lr of 0.001 and then drop to 0.0001 after 75 epochs
callbacks = [LearningRateScheduler(lr_schedule)]
if save_model:
if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR)
callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True))

history = model.fit_generator(train_generator, samples_per_epoch / batch_size, nb_epoch, callbacks=callbacks,
validation_data=val_generator, validation_steps=N_seq_val / batch_size)



if save_model:
json_string = model.to_json()
with open(json_file, "w") as f:
f.write(json_string)

loss_history = history.history["loss"]
numpy_loss_history = np.array(loss_history)
np.savetxt("loss_history_prednet.txt", numpy_loss_history, delimiter=",")

10 changes: 10 additions & 0 deletions nfs_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Where KITTI data will be saved if you run process_kitti.py
# If you directly download the processed data, change to the path of the data.
DATA_DIR = './nfs_data/'

# Where model weights and config will be saved if you run kitti_train.py
# If you directly download the trained weights, change to appropriate path.
WEIGHTS_DIR = './model_data_keras2/'

# Where results (prediction plots and evaluation file) will be saved.
RESULTS_SAVE_DIR = './nfs_results/'
Loading