Skip to content

Commit

Permalink
Fix conversions for 32bit size_t (#24)
Browse files Browse the repository at this point in the history
* Fix conversions for 32bit size_t

* Fix missing include

* Add gs::VarUint->std::size_t safe conversion
  • Loading branch information
RichLogan authored Mar 22, 2024
1 parent fe50144 commit 57266d8
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 72 deletions.
16 changes: 15 additions & 1 deletion include/gs_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,29 @@
#define GS_TYPES_H

#include <cstdint>
#include <limits>
#include <vector>
#include <variant>
#include <stdexcept>
#include <string>
#include <optional>

namespace gs
{
// Primitive types
struct VarUint { std::uint64_t value; };
struct VarUint
{
std::uint64_t value;
operator std::size_t() const
{
if (value > std::numeric_limits<std::size_t>::max())
{
throw std::overflow_error("VarUInt too large to convert to std::size_t");
}
return static_cast<std::size_t>(value);
}
};

struct VarInt { std::int64_t value; };
struct Float16 { float value; };
typedef float Float32;
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ set_target_properties(gse
CXX_EXTENSIONS NO)
target_compile_options(gse PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>: -Wpedantic -Wextra -Wall>
$<$<CXX_COMPILER_ID:MSVC>: >)
$<$<CXX_COMPILER_ID:MSVC>: /W4 /WX>)

if(WIN32)
target_link_libraries(gse PRIVATE ws2_32)
Expand Down
33 changes: 29 additions & 4 deletions src/gs_api_internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <algorithm>
#include <cstddef>
#include <cstring>
#include <limits>
#include "gs_types.h"
#include "gs_api_internal.h"

Expand Down Expand Up @@ -1112,8 +1113,14 @@ int GSDeserializeObject(GS_Decoder_Context_Internal &context,
if (object.u.mesh1.num_vertices > 0)
{
// Allocate storage, storing the pointer for later deletion
const GS_VarUint vertices = object.u.mesh1.num_vertices;
if (vertices > std::numeric_limits<std::size_t>::max() / sizeof(GS_Loc1))
{
context.error = "Too many vertices to deserialize";
return -1;
}
std::uint8_t *p = new std::uint8_t[sizeof(GS_Loc1) *
object.u.mesh1.num_vertices];
static_cast<std::size_t>(vertices)];
context.allocations.push_back(p);

// Assign the pointer to the object
Expand All @@ -1130,8 +1137,14 @@ int GSDeserializeObject(GS_Decoder_Context_Internal &context,
if (object.u.mesh1.num_normals > 0)
{
// Allocate storage, storing the pointer for later deletion
const GS_VarUint normals = object.u.mesh1.num_normals;
if (normals > std::numeric_limits<std::size_t>::max() / sizeof(GS_Norm1))
{
context.error = "Too many normals to deserialize";
return -1;
}
std::uint8_t *p = new std::uint8_t[sizeof(GS_Norm1) *
object.u.mesh1.num_normals];
static_cast<std::size_t>(normals)];
context.allocations.push_back(p);

// Assign the pointer to the object
Expand All @@ -1148,8 +1161,14 @@ int GSDeserializeObject(GS_Decoder_Context_Internal &context,
if (object.u.mesh1.num_textures > 0)
{
// Allocate storage, storing the pointer for later deletion
const GS_VarUint textures = object.u.mesh1.num_textures;
if (textures > std::numeric_limits<std::size_t>::max() / sizeof(GS_TextureUV1))
{
context.error = "Too many textures to deserialize";
return -1;
}
std::uint8_t *p = new std::uint8_t[sizeof(GS_TextureUV1) *
object.u.mesh1.num_textures];
static_cast<std::size_t>(textures)];
context.allocations.push_back(p);

// Assign the pointer to the object
Expand All @@ -1166,8 +1185,14 @@ int GSDeserializeObject(GS_Decoder_Context_Internal &context,
if (object.u.mesh1.num_triangles > 0)
{
// Allocate storage, storing the pointer for later deletion
const GS_VarUint triangles = object.u.mesh1.num_triangles;
if (triangles > std::numeric_limits<std::size_t>::max() / sizeof(GS_VarUint))
{
context.error = "Too many triangles to deserialize";
return -1;
}
std::uint8_t *p = new std::uint8_t[sizeof(GS_VarUint) *
object.u.mesh1.num_triangles];
static_cast<std::size_t>(triangles)];
context.allocations.push_back(p);

// Assign the pointer to the object
Expand Down
99 changes: 50 additions & 49 deletions src/gs_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,15 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, GSObject &value)
*/
std::size_t Decoder::Decode(DataBuffer &data_buffer, Head1 &value)
{
VarUint length;
VarUint extracted_length;
std::size_t read_length;
std::size_t length_field;

// Read the object length
length_field = read_length = Deserialize(data_buffer, length);
length_field = read_length = Deserialize(data_buffer, extracted_length);

if (!length.value) throw DecoderException("Invalid object length");
if (!extracted_length.value) throw DecoderException("Invalid object length");
const std::size_t length = extracted_length;

// Read all of the required fields (evaluation order matters)
read_length += Deserialize(data_buffer, value.id);
Expand All @@ -232,7 +233,7 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Head1 &value)
read_length += Deserialize(data_buffer, value.rotation);

// Are optional elements present?
if ((read_length - length_field) < length.value)
if ((read_length - length_field) < length)
{
// Attempt to decode the HeadIPD1 object
GSObject object;
Expand All @@ -248,17 +249,17 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Head1 &value)
value.ipd = std::get<HeadIPD1>(object);

// Discard any octets not understood
if ((read_length - length_field) < length.value)
if ((read_length - length_field) < length)
{
data_buffer.AdvanceReadLength(length.value - read_length);
data_buffer.AdvanceReadLength(length - read_length);

// Update the read_length
read_length += length.value - (read_length - length_field);
read_length += length - (read_length - length_field);
}
}

// Did we read more octets than we should have?
if ((read_length - length_field) > length.value)
if ((read_length - length_field) > length)
{
throw DecoderException("Encoded object length error");
}
Expand Down Expand Up @@ -289,14 +290,15 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Head1 &value)
*/
std::size_t Decoder::Decode(DataBuffer &data_buffer, Hand1 &value)
{
VarUint length;
VarUint extracted_length;
std::size_t read_length;
std::size_t length_field;

// Read the object length
length_field = read_length = Deserialize(data_buffer, length);
length_field = read_length = Deserialize(data_buffer, extracted_length);

if (!length.value) throw DecoderException("Invalid object length");
if (!extracted_length.value) throw DecoderException("Invalid object length");
const std::size_t length = extracted_length;

// Read all of the required fields (evaluation order matters)
read_length += Deserialize(data_buffer, value.id);
Expand All @@ -306,17 +308,16 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Hand1 &value)
read_length += Deserialize(data_buffer, value.rotation);

// Discard any octets not understood
if ((read_length - length_field) < length.value)
if ((read_length - length_field) < length)
{
data_buffer.AdvanceReadLength(length.value -
(read_length - length_field));
data_buffer.AdvanceReadLength(length - (read_length - length_field));

// Update the read_length
read_length += length.value - (read_length - length_field);
read_length += length - (read_length - length_field);
}

// Did we read more octets than we should have?
if ((read_length - length_field) > length.value)
if ((read_length - length_field) > length)
{
throw DecoderException("Encoded object length error");
}
Expand Down Expand Up @@ -347,14 +348,15 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Hand1 &value)
*/
std::size_t Decoder::Decode(DataBuffer &data_buffer, Hand2 &value)
{
VarUint length;
VarUint extracted_length;
std::size_t read_length;
std::size_t length_field;

// Read the object length
length_field = read_length = Deserialize(data_buffer, length);
length_field = read_length = Deserialize(data_buffer, extracted_length);

if (!length.value) throw DecoderException("Invalid object length");
if (!extracted_length.value) throw DecoderException("Invalid object length");
const std::size_t length = extracted_length;

// Read all of the required fields (evaluation order matters)
read_length += Deserialize(data_buffer, value.id);
Expand All @@ -370,17 +372,16 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Hand2 &value)
read_length += Deserialize(data_buffer, value.pinky);

// Discard any octets not understood
if ((read_length - length_field) < length.value)
if ((read_length - length_field) < length)
{
data_buffer.AdvanceReadLength(length.value -
(read_length - length_field));
data_buffer.AdvanceReadLength(length - (read_length - length_field));

// Update the read_length
read_length += length.value - (read_length - length_field);
read_length += length - (read_length - length_field);
}

// Did we read more octets than we should have?
if ((read_length - length_field) > length.value)
if ((read_length - length_field) > length)
{
throw DecoderException("Encoded object length error");
}
Expand Down Expand Up @@ -411,14 +412,14 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Hand2 &value)
*/
std::size_t Decoder::Decode(DataBuffer &data_buffer, Mesh1 &value)
{
VarUint length;
VarUint extracted_length;
std::size_t read_length;
std::size_t length_field;

// Read the object length
length_field = read_length = Deserialize(data_buffer, length);

if (!length.value) throw DecoderException("Invalid object length");
length_field = read_length = Deserialize(data_buffer, extracted_length);
if (!extracted_length.value) throw DecoderException("Invalid object length");
const std::size_t length = extracted_length;

// Read all of the required fields (evaluation order matters)
read_length += Deserialize(data_buffer, value.id);
Expand All @@ -428,17 +429,16 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Mesh1 &value)
read_length += Deserialize(data_buffer, value.triangles);

// Discard any octets not understood
if ((read_length - length_field) < length.value)
if ((read_length - length_field) < length)
{
data_buffer.AdvanceReadLength(length.value -
(read_length - length_field));
data_buffer.AdvanceReadLength(length - (read_length - length_field));

// Update the read_length
read_length += length.value - (read_length - length_field);
read_length += length - (read_length - length_field);
}

// Did we read more octets than we should have?
if ((read_length - length_field) > length.value)
if ((read_length - length_field) > length)
{
throw DecoderException("Encoded object length error");
}
Expand Down Expand Up @@ -469,30 +469,30 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Mesh1 &value)
*/
std::size_t Decoder::Decode(DataBuffer &data_buffer, HeadIPD1 &value)
{
VarUint length;
VarUint extracted_length;
std::size_t read_length;
std::size_t length_field;

// Read the object length
length_field = read_length = Deserialize(data_buffer, length);
length_field = read_length = Deserialize(data_buffer, extracted_length);

if (!length.value) throw DecoderException("Invalid object length");
if (!extracted_length.value) throw DecoderException("Invalid object length");
const std::size_t length = extracted_length;

// Read all of the required fields (evaluation order matters)
read_length += Deserialize(data_buffer, value.ipd);

// Discard any octets not understood
if ((read_length - length_field) < length.value)
if ((read_length - length_field) < length)
{
data_buffer.AdvanceReadLength(length.value -
(read_length - length_field));
data_buffer.AdvanceReadLength(length - (read_length - length_field));

// Update the read_length
read_length += length.value - (read_length - length_field);
read_length += length - (read_length - length_field);
}

// Did we read more octets than we should have?
if ((read_length - length_field) > length.value)
if ((read_length - length_field) > length)
{
throw DecoderException("Encoded object length error");
}
Expand Down Expand Up @@ -549,14 +549,15 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, UnknownObject &value)
*/
std::size_t Decoder::Decode(DataBuffer &data_buffer, Object1 &value)
{
VarUint length;
VarUint extracted_length;
std::size_t read_length;
std::size_t length_field;

// Read the object length
length_field = read_length = Deserialize(data_buffer, length);
length_field = read_length = Deserialize(data_buffer, extracted_length);

if (!length.value) throw DecoderException("Invalid object length");
if (!extracted_length.value) throw DecoderException("Invalid object length");
const std::size_t length = extracted_length;

// Read all of the required fields (evaluation order matters)
read_length += Deserialize(data_buffer, value.id);
Expand All @@ -567,7 +568,7 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Object1 &value)
read_length += Deserialize(data_buffer, value.active);

// Are optional elements present?
if ((read_length - length_field) < length.value)
if ((read_length - length_field) < length)
{
// Attempt to decode the Parent object
ObjectID parent;
Expand All @@ -577,17 +578,17 @@ std::size_t Decoder::Decode(DataBuffer &data_buffer, Object1 &value)
value.parent = parent;

// Discard any octets not understood
if ((read_length - length_field) < length.value)
if ((read_length - length_field) < length)
{
data_buffer.AdvanceReadLength(length.value - read_length);
data_buffer.AdvanceReadLength(length - read_length);

// Update the read_length
read_length += length.value - (read_length - length_field);
read_length += length - (read_length - length_field);
}
}

// Did we read more octets than we should have?
if ((read_length - length_field) > length.value)
if ((read_length - length_field) > length)
{
throw DecoderException("Encoded object length error");
}
Expand Down
Loading

0 comments on commit 57266d8

Please sign in to comment.