Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4299,6 +4299,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
#define A(_x) ArgTypes.emplace_back(_x)
#define RRT(_y) A(GetResRetType(_y))
#define CBRT(_y) A(GetCBufferRetType(_y))
#define VEC2(_y) A(VectorType::get(_y, 2))
#define VEC4(_y) A(GetStructVectorType(4, _y))
#define VEC9(_y) A(VectorType::get(_y, 9))

Expand Down Expand Up @@ -7058,7 +7059,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::RayQuery_CandidateTriangleObjectPosition:
case OpCode::RayQuery_CommittedTriangleObjectPosition:
case OpCode::HitObject_TriangleObjectPosition:
// These return <9 x float> vectors directly
// These return native vectors directly
return cast<VectorType>(Ty)->getElementType();
case OpCode::MatVecMul:
case OpCode::MatVecMulAdd:
Expand Down
15 changes: 8 additions & 7 deletions utils/hct/hctdb_instrhelp.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def print_opfunc_table(self):
"u64": "A(pI64);",
"u8": "A(pI8);",
"v": "A(pV);",
"$vec2": "VEC2(pETy);",
"$vec4": "VEC4(pETy);",
"$vec9": "VEC9(pETy);",
"SamplePos": "A(pPos);",
Expand Down Expand Up @@ -686,7 +687,7 @@ def print_opfunc_oload_type(self):
# grouped by the set of overload parameter indices.
extended_dict = collections.OrderedDict()
struct_list = []
vec9_list = [] # For $vec9 operations that return native vectors
native_vec_list = [] # For vec operations that return native vectors
extended_list = []

for instr in self.db.get_dxil_ops():
Expand All @@ -709,9 +710,9 @@ def print_opfunc_oload_type(self):
continue

if ret_ty.startswith(vec_ty):
# $vec9 returns native <9 x float> vectors, not struct wrappers
if ret_ty == "$vec9":
vec9_list.append(instr.name)
# $vecX returns native vectors, not struct wrappers
if ret_ty in ["$vec2", "$vec9"]:
native_vec_list.append(instr.name)
else:
struct_list.append(instr.name)
continue
Expand Down Expand Up @@ -831,11 +832,11 @@ def print_opfunc_oload_type(self):
print(line)

# Generate code for $vec9 operations (native <9 x float> vectors)
if vec9_list:
if native_vec_list:
line = ""
for opcode in vec9_list:
for opcode in native_vec_list:
line = line + "case OpCode::{name}".format(name=opcode + ":\n")
line = line + " // These return <9 x float> vectors directly\n"
line = line + " // These return native vectors directly\n"
line = line + " return cast<VectorType>(Ty)->getElementType();"
print(line)

Expand Down