Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter NALs globally #1403

Merged
merged 7 commits into from
Jan 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 82 additions & 48 deletions alvr/server/cpp/alvr_server/ClientConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,99 @@
#include "Utils.h"
#include "Settings.h"

static const uint8_t NAL_TYPE_SPS = 7;
static const char NAL_HEADER[] = {0x00, 0x00, 0x00, 0x01};

static const uint8_t H264_NAL_TYPE_SPS = 7;
static const uint8_t H265_NAL_TYPE_VPS = 32;

ClientConnection::ClientConnection() {
m_Statistics = std::make_shared<Statistics>();
static const uint8_t H264_NAL_TYPE_AUD = 9;
static const uint8_t H265_NAL_TYPE_AUD = 35;

ClientConnection::ClientConnection() {
m_Statistics = std::make_shared<Statistics>();
}

int findVPSSPS(const uint8_t *frameBuffer, int frameByteSize) {
int zeroes = 0;
int foundNals = 0;
for (int i = 0; i < frameByteSize; i++) {
if (frameBuffer[i] == 0) {
zeroes++;
} else if (frameBuffer[i] == 1) {
if (zeroes >= 2) {
foundNals++;
if (Settings::Instance().m_codec == ALVR_CODEC_H264 && foundNals >= 3) {
// Find end of SPS+PPS on H.264.
return i - 3;
} else if (Settings::Instance().m_codec == ALVR_CODEC_H265 && foundNals >= 4) {
// Find end of VPS+SPS+PPS on H.264.
return i - 3;
}
}
zeroes = 0;
} else {
zeroes = 0;
}
}
return -1;
/*
Sends the (VPS + )SPS + PPS video configuration headers from H.264 or H.265 stream as a sequence of NALs.
(VPS + )SPS + PPS have short size (8bytes + 28bytes in some environment), so we can
assume SPS + PPS is contained in first fragment.
*/
void sendHeaders(uint8_t **buf, int *len, int nalNum) {
uint8_t *b = *buf;
uint8_t *end = b + *len;

int headersLen = 0;
int foundHeaders = -1; // Offset by 1 header to find the length until the next header
while (b != end) {
if (b + sizeof(NAL_HEADER) <= end && memcmp(b, NAL_HEADER, sizeof(NAL_HEADER)) == 0) {
foundHeaders++;
if (foundHeaders == nalNum) {
break;
}
b += sizeof(NAL_HEADER);
headersLen += sizeof(NAL_HEADER);
}

b++;
headersLen++;
}
if (foundHeaders != nalNum) {
return;
}
InitializeDecoder((const unsigned char *)*buf, headersLen);

// move the cursor forward excluding config NALs
*buf = b;
*len -= headersLen;
}

void processH264Nals(uint8_t **buf, int *len) {
uint8_t *b = *buf;
int l = *len;
uint8_t nalType = b[4] & 0x1F;
nowrep marked this conversation as resolved.
Show resolved Hide resolved

if (nalType == H264_NAL_TYPE_AUD && l > sizeof(NAL_HEADER) * 2 + 2) {
b += sizeof(NAL_HEADER) + 2;
l -= sizeof(NAL_HEADER) + 2;
nalType = b[4] & 0x1F;
nowrep marked this conversation as resolved.
Show resolved Hide resolved
}
if (nalType == H264_NAL_TYPE_SPS) {
sendHeaders(&b, &l, 2); // 2 headers SPS and PPS
}
*buf = b;
*len = l;
Vixea marked this conversation as resolved.
Show resolved Hide resolved
}

void processH265Nals(uint8_t **buf, int *len) {
uint8_t *b = *buf;
int l = *len;
uint8_t nalType = (b[4] >> 1) & 0x3F;

if (nalType == H265_NAL_TYPE_AUD && l > sizeof(NAL_HEADER) * 2 + 3) {
b += sizeof(NAL_HEADER) + 3;
l -= sizeof(NAL_HEADER) + 3;
nalType = (b[4] >> 1) & 0x3F;
}
if (nalType == H265_NAL_TYPE_VPS) {
sendHeaders(&b, &l, 3); // 3 headers VPS, SPS and PPS
}
*buf = b;
*len = l;
}

void ClientConnection::SendVideo(uint8_t *buf, int len, uint64_t targetTimestampNs) {
// Report before the frame is packetized
ReportEncoded(targetTimestampNs);

uint8_t NALType;
if (Settings::Instance().m_codec == ALVR_CODEC_H264)
NALType = buf[4] & 0x1F;
else
NALType = (buf[4] >> 1) & 0x3F;

if ((Settings::Instance().m_codec == ALVR_CODEC_H264 && NALType == NAL_TYPE_SPS) ||
(Settings::Instance().m_codec == ALVR_CODEC_H265 && NALType == H265_NAL_TYPE_VPS)) {
// This frame contains (VPS + )SPS + PPS + IDR on NVENC H.264 (H.265) stream.
// (VPS + )SPS + PPS has short size (8bytes + 28bytes in some environment), so we can
// assume SPS + PPS is contained in first fragment.

int end = findVPSSPS(buf, len);
if (end == -1) {
// Invalid frame.
return;
}

InitializeDecoder((const unsigned char *)buf, end);
if (len < sizeof(NAL_HEADER)) {
return;
}

// move the cursor forward excluding config NALs
buf = &buf[end];
len = len - end;
int codec = Settings::Instance().m_codec;
if (codec == ALVR_CODEC_H264) {
processH264Nals(&buf, &len);
} else if (codec == ALVR_CODEC_H265) {
processH265Nals(&buf, &len);
}

VideoSend(targetTimestampNs, buf, len);
Expand Down
9 changes: 3 additions & 6 deletions alvr/server/cpp/platform/linux/CEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ void CEncoder::Run() {

fprintf(stderr, "CEncoder starting to read present packets");
present_packet frame_info;
std::vector<uint8_t> encoded_data;
while (not m_exiting) {
read_latest(client, (char *)&frame_info, sizeof(frame_info), m_exiting);

Expand All @@ -250,9 +249,8 @@ void CEncoder::Run() {

static_assert(sizeof(frame_info.pose) == sizeof(vr::HmdMatrix34_t&));

encoded_data.clear();
uint64_t pts;
if (!encode_pipeline->GetEncoded(encoded_data, &pts)) {
alvr::FramePacket packet;
if (!encode_pipeline->GetEncoded(packet)) {
Error("Failed to get encoded data!");
continue;
}
Expand All @@ -279,10 +277,9 @@ void CEncoder::Run() {
ReportPresent(pose->targetTimestampNs, present_offset);
ReportComposed(pose->targetTimestampNs, composed_offset);

m_listener->SendVideo(encoded_data.data(), encoded_data.size(), pts);
m_listener->SendVideo(packet.data, packet.size, packet.pts);

m_listener->GetStatistics()->EncodeOutput();

}
}
catch (std::exception &e) {
Expand Down
74 changes: 12 additions & 62 deletions alvr/server/cpp/platform/linux/EncodePipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,6 @@ extern "C" {
#include <libavcodec/avcodec.h>
}

namespace {

bool should_keep_nal_h264(const uint8_t * header_start)
{
uint8_t nal_type = (header_start[2] == 0 ? header_start[4] : header_start[3]) & 0x1F;
switch (nal_type)
{
case 6: // supplemental enhancement information
case 9: // access unit delimiter
return false;
default:
return true;
}
}

bool should_keep_nal_h265(const uint8_t * header_start)
{
uint8_t nal_type = ((header_start[2] == 0 ? header_start[4] : header_start[3]) >> 1) & 0x3F;
switch (nal_type)
{
case 35: // access unit delimiter
case 39: // supplemental enhancement information
return false;
default:
return true;
}
}

void filter_NAL(const uint8_t* input, size_t input_size, std::vector<uint8_t> &out)
{
if (input_size < 4)
return;
auto codec = Settings::Instance().m_codec;
std::array<uint8_t, 3> header = {{0, 0, 1}};
auto end = input + input_size;
auto header_start = input;
while (header_start != end)
{
auto next_header = std::search(header_start + 3, end, header.begin(), header.end());
if (next_header != end and next_header[-1] == 0)
{
next_header--;
}
if (codec == ALVR_CODEC_H264 and should_keep_nal_h264(header_start))
out.insert(out.end(), header_start, next_header);
if (codec == ALVR_CODEC_H265 and should_keep_nal_h265(header_start))
out.insert(out.end(), header_start, next_header);
header_start = next_header;
}
}

}

void alvr::EncodePipeline::SetBitrate(int64_t bitrate) {
encoder_ctx->bit_rate = bitrate;
encoder_ctx->rc_buffer_size = bitrate / Settings::Instance().m_refreshRate;
Expand Down Expand Up @@ -111,17 +58,20 @@ alvr::EncodePipeline::~EncodePipeline()
avcodec_free_context(&encoder_ctx);
}

bool alvr::EncodePipeline::GetEncoded(std::vector<uint8_t> &out, uint64_t *pts)
bool alvr::EncodePipeline::GetEncoded(FramePacket &packet)
{
AVPacket * enc_pkt = av_packet_alloc();
int err = avcodec_receive_packet(encoder_ctx, enc_pkt);
if (err == AVERROR(EAGAIN)) {
return false;
} else if (err) {
av_packet_free(&encoder_packet);
encoder_packet = av_packet_alloc();
int err = avcodec_receive_packet(encoder_ctx, encoder_packet);
if (err != 0) {
av_packet_free(&encoder_packet);
if (err == AVERROR(EAGAIN)) {
return false;
}
throw alvr::AvException("failed to encode", err);
}
filter_NAL(enc_pkt->data, enc_pkt->size, out);
*pts = enc_pkt->pts;
av_packet_free(&enc_pkt);
packet.data = encoder_packet->data;
packet.size = encoder_packet->size;
packet.pts = encoder_packet->pts;
return true;
}
10 changes: 9 additions & 1 deletion alvr/server/cpp/platform/linux/EncodePipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <vector>

extern "C" struct AVCodecContext;
extern "C" struct AVPacket;

class Renderer;

Expand All @@ -14,6 +15,12 @@ class VkFrame;
class VkFrameCtx;
class VkContext;

struct FramePacket {
uint8_t *data;
int size;
uint64_t pts;
};

class EncodePipeline
{
public:
Expand All @@ -25,13 +32,14 @@ class EncodePipeline
virtual ~EncodePipeline();

virtual void PushFrame(uint64_t targetTimestampNs, bool idr) = 0;
virtual bool GetEncoded(std::vector<uint8_t> & out, uint64_t *pts);
virtual bool GetEncoded(FramePacket &data);
virtual Timestamp GetTimestamp() { return timestamp; }

virtual void SetBitrate(int64_t bitrate);
static std::unique_ptr<EncodePipeline> Create(Renderer *render, VkContext &vk_ctx, VkFrame &input_frame, VkFrameCtx &vk_frame_ctx, uint32_t width, uint32_t height);
protected:
AVCodecContext *encoder_ctx = nullptr; //shall be initialized by child class
AVPacket *encoder_packet = NULL;
Timestamp timestamp = {};
};

Expand Down
20 changes: 8 additions & 12 deletions alvr/server/cpp/platform/linux/EncodePipelineAMF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,26 +456,27 @@ void EncodePipelineAMF::PushFrame(uint64_t targetTimestampNs, bool idr)
m_amfComponents.front()->SubmitInput(surface);
}

bool EncodePipelineAMF::GetEncoded(std::vector<uint8_t> &out, uint64_t *pts)
bool EncodePipelineAMF::GetEncoded(FramePacket &packet)
{
m_frameBuffer = NULL;
if (m_hasQueryTimeout) {
m_pipeline->Run();
} else {
uint32_t timeout = 4 * 1000; // 1 second
while (m_outBuffer.empty() && --timeout != 0) {
while (m_frameBuffer == NULL && --timeout != 0) {
std::this_thread::sleep_for(std::chrono::microseconds(250));
m_pipeline->Run();
}
}

if (m_outBuffer.empty()) {
if (m_frameBuffer == NULL) {
Error("Timed out waiting for encoder data");
return false;
}

out = m_outBuffer;
*pts = m_targetTimestampNs;
m_outBuffer.clear();
packet.data = reinterpret_cast<uint8_t *>(m_frameBuffer->GetNative());
packet.size = static_cast<int>(m_frameBuffer->GetSize());
packet.pts = m_targetTimestampNs;

uint64_t query;
VK_CHECK(vkGetQueryPoolResults(m_render->m_dev, m_queryPool, 0, 1, sizeof(uint64_t), &query, sizeof(uint64_t), VK_QUERY_RESULT_64_BIT));
Expand All @@ -499,12 +500,7 @@ void EncodePipelineAMF::SetBitrate(int64_t bitrate)

void EncodePipelineAMF::Receive(amf::AMFDataPtr data)
{
amf::AMFBufferPtr buffer(data); // query for buffer interface

char *p = reinterpret_cast<char*>(buffer->GetNative());
int length = static_cast<int>(buffer->GetSize());

m_outBuffer = std::vector<uint8_t>(p, p + length);
m_frameBuffer = amf::AMFBufferPtr(data); // query for buffer interface
}

void EncodePipelineAMF::ApplyFrameProperties(const amf::AMFSurfacePtr &surface, bool insertIDR)
Expand Down
4 changes: 2 additions & 2 deletions alvr/server/cpp/platform/linux/EncodePipelineAMF.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class EncodePipelineAMF : public EncodePipeline
~EncodePipelineAMF();

void PushFrame(uint64_t targetTimestampNs, bool idr) override;
bool GetEncoded(std::vector<uint8_t> &out, uint64_t *pts) override;
bool GetEncoded(FramePacket &packet) override;
void SetBitrate(int64_t bitrate) override;

private:
Expand Down Expand Up @@ -96,7 +96,7 @@ class EncodePipelineAMF : public EncodePipeline
int m_bitrateInMBits;

bool m_hasQueryTimeout = false;
std::vector<uint8_t> m_outBuffer;
amf::AMFBufferPtr m_frameBuffer;
uint64_t m_targetTimestampNs;
};

Expand Down