Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add E2E inference script #5

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def config_eval_dataloader(args):
val_loader = DataLoader(val_dataset, num_workers=8, batch_size=1, shuffle=False, pin_memory=True, drop_last=False)

elif args.dataset == "example": # No annotations (annotation path) given
val_data_dirs = ["/path/to/example_dataset/FlowImages_gap1/",
"/path/to/example_dataset/JPEGImages/"]
val_seq = None
val_data_dirs = [f"{args.flow_output}",
f"{args.img_output}"]
val_seq = [args.name]
val_dataset = Example_eval_dataset(data_dirs=val_data_dirs, seqs=val_seq, ref_sam=ref_sam,
dataset=args.dataset, flow_gaps=flow_gaps, num_gridside=args.num_gridside)
val_loader = DataLoader(val_dataset, num_workers=8, batch_size=1, shuffle=False, pin_memory=True, drop_last=False)
Expand Down
21 changes: 19 additions & 2 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,24 @@ def eval(args, val_loader, flowsam):
choices=['dvs17', 'dvs17m', 'dvs16', 'ytvos', 'example'],
help="evaluation datasets",
)

parser.add_argument(
'--flow_output',
type=str,
default="output/flow/FlowImages_gap1/sample",
help="flow frame gaps, a string without spacing",
)
parser.add_argument(
'--img_output',
type=str,
default="output/images/sample",
help="flow frame gaps, a string without spacing",
)
parser.add_argument(
'--name',
type=str,
default="sample",
help="flow frame gaps, a string without spacing"
)
# Output configuration
parser.add_argument(
'--max_obj',
Expand All @@ -236,7 +253,7 @@ def eval(args, val_loader, flowsam):
'--save_path',
default=None,
help="path to save masks",
)
)

args = parser.parse_args()

Expand Down
1 change: 1 addition & 0 deletions flow/predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
sys.path.append('core')
sys.path.append('flow/core')

import os
import cv2
Expand Down
163 changes: 163 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
import glob as gb
import argparse
import cv2
import os
import requests
from PIL import Image

def extract_frames(video_path, output_folder):
# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

if video_path.endswith(".gif"):
with Image.open(video_path) as img:
while True:
try:
frame_count = img.tell()
frame_img = img.convert("RGB")
frame_img.save(os.path.join(output_folder, f"{frame_count+1:05d}.jpg"))
img.seek(frame_count + 1)
except EOFError:
break
else:
# Open the video file
cap = cv2.VideoCapture(video_path)

# Variable to keep track of frame count
frame_count = 0

# Read frames until there are no more
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break

# Save the frame as an image
frame_path = os.path.join(output_folder, f"{frame_count+1:05d}.jpg")
cv2.imwrite(frame_path, frame)

frame_count += 1

# Release the video capture object
cap.release()

def extract_flow(rgb_path, flow_output_path):
gap = [1, 2]
reverse = [0, 1]
batch_size = 4

folder = gb.glob(os.path.join(rgb_path, '*'))
for r in reverse:
for g in gap:
for f in folder:
print('===> Running {}, gap {}'.format(f, g))
mode = 'flow/raft-things.pth' # model
if r==1:
raw_outroot = flow_output_path + '/Flows_gap-{}/'.format(g) # where to raw flow
outroot = flow_output_path + '/FlowImages_gap-{}/'.format(g) # where to save the image flow
elif r==0:
raw_outroot = flow_output_path + '/Flows_gap{}/'.format(g) # where to raw flow
outroot = flow_output_path + '/FlowImages_gap{}/'.format(g) # where to save the image flow

os.system("python flow/predict.py "
"--gap {} --mode {} --path {} --batch_size {} "
"--outroot {} --reverse {} --raw_outroot {}".format(g, mode, f, batch_size, outroot, r, raw_outroot))

def create_video_from_images(input_folder, output_video_path, fps):
# Get the list of image files in the input folder
image_files = [os.path.join(input_folder, file) for file in os.listdir(input_folder) if file.endswith('.png')]

# Sort the image files by name
image_files.sort()

# Get the dimensions of the first image to set the video size
first_image = cv2.imread(image_files[0])
height, width, _ = first_image.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Use appropriate codec based on the output video format
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

# Iterate over each image and add it to the video
for image_file in image_files:
img = cv2.imread(image_file)
out.write(img)

# Release VideoWriter object
out.release()

def download_weight(filename, url):
# Check if the file exists locally
if not os.path.exists(filename):
print(f"File '{filename}' not found locally. Proceeding with download.")

# Download the file
response = requests.get(url)
with open(filename, 'wb') as f:
f.write(response.content)

print("File downloaded successfully!")
else:
print(f"File '{filename}' already exists locally. No need to download.")

def inference(args):
"""
User should change the configuration path to appropriate path.
Install segment-anything
"""
ckpt = 'frame_level_flowpsam_vitbvith_train_on_oclrsyn_dvs17m.pth'
rgb_encoder_ckpt_path = 'sam_vit_h_4b8939.pth'
download_weight(rgb_encoder_ckpt_path, 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth')
flow_encoder_ckpt_path = 'sam_vit_b_01ec64.pth'
download_weight(flow_encoder_ckpt_path, 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth')
# Split the path into directory and filename
directory, filename = os.path.split(args.video_output_path)
flow_path = os.path.join(args.flow_output_path, f'FlowImages_gap1')
os.system("python evaluation.py "
"--model flowpsam --ckpt {} --rgb_encoder_ckpt_path {} --flow_encoder_ckpt_path {} --flow_gaps 1,-1,2,-2 --num_gridside 20 "
"--dataset example --flow_output {} --img_output {} --name {} --max_obj 3 --save_path {}".format(
ckpt, rgb_encoder_ckpt_path, flow_encoder_ckpt_path, \
flow_path, directory, filename, args.flowsam_output_path)
)

"""
python inference.py --video_file_path sample.mp4 --video_output_path output/images/sample --extract_frames --flow_output_path output/flow --extract_flow --visualize_flow --run_flowsam --flowsam_output_path output --visualize_output
python inference.py --video_file_path siren.mp4 --video_output_path output/images/siren --extract_frames --flow_output_path output/flow --extract_flow --visualize_flow --run_flowsam --flowsam_output_path output --visualize_output
python inference.py --video_file_path highway.mp4 --video_output_path output/images/highway --extract_frames --flow_output_path output/flow --extract_flow --visualize_flow --run_flowsam --flowsam_output_path output --visualize_output
python inference.py --video_file_path bird.gif --video_output_path output/images/bird --extract_frames --flow_output_path output/flow --extract_flow --visualize_flow --run_flowsam --flowsam_output_path output --visualize_output
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--video_file_path', type=str, help="restore checkpoint")
parser.add_argument('--video_output_path', type=str, help="restore checkpoint")
parser.add_argument('--extract_frames', action='store_true', help='convert video file to image file folder')

parser.add_argument('--flow_output_path', type=str, help="restore checkpoint")
parser.add_argument('--extract_flow', action='store_true', help='whether to run flow ')
parser.add_argument('--visualize_flow', action='store_true', help='whether to run flow ')

parser.add_argument('--flowsam_output_path', type=str, help="restore checkpoint")
parser.add_argument('--run_flowsam', action='store_true', help='whether to run flow ')
parser.add_argument('--visualize_output', action='store_true', help='whether to run flow ')
args = parser.parse_args()

# Split the path into directory and filename
directory, filename = os.path.split(args.video_output_path)

if args.extract_frames:
extract_frames(args.video_file_path, args.video_output_path)

if args.extract_flow:
extract_flow(directory, args.flow_output_path)
# (Optional) For debug and visualization purpose
if args.visualize_flow:
flow_path = os.path.join(args.flow_output_path, f'FlowImages_gap-1/{filename}')
create_video_from_images(flow_path, f'{args.flow_output_path}/flow.mp4', fps=30)

if args.run_flowsam:
inference(args=args)

if args.visualize_output:
output_path = os.path.join(args.flowsam_output_path, f"nonhung/{filename}")
create_video_from_images(output_path, 'output.mp4', fps=30)