From 1ed847f43804efaac09e67ebdd86783d801f7b62 Mon Sep 17 00:00:00 2001 From: Marius Hillenbrand Date: Wed, 8 Dec 2021 18:01:26 +0100 Subject: [PATCH] Fix endianness of string literals (#4622) * Fix endianness of string literals To get correct and consistent encoding and decoding of string literals on big-endian platforms, use spvtools::utils::MakeString and MakeVector (or wrapper functions) consistently for handling string literals. - add variant of MakeVector that encodes a string literal into an existing vector of words - add variants of MakeString - add a wrapper spvDecodeLiteralStringOperand in source/ - fix wrapper Operand::AsString to use MakeString (source/opt) - remove Operand::AsCString as broken and unused - add a variant of GetOperandAs for string literals (source/val) ... and apply those wrappers throughout the code. Fixes #149 * Extend round trip test for StringLiterals to flip word order In the encoding/decoding roundtrip tests for string literals, include a case that flips byte order in words after encoding and then checks for successful decoding. That is, on a little-endian host flip to big-endian byte order and then decode, and vice versa. * BinaryParseTest.InstructionWithStringOperand: also flip byte order Test binary parsing of string operands both with the host's and with the reversed byte order. --- source/binary.cpp | 35 ++++++----- source/binary.h | 7 +++ source/disassemble.cpp | 12 ++-- source/extensions.cpp | 4 +- source/link/linker.cpp | 20 +++--- source/name_mapper.cpp | 11 ++-- source/opt/aggressive_dead_code_elim_pass.cpp | 12 ++-- source/opt/amd_ext_to_khr.cpp | 6 +- source/opt/feature_manager.cpp | 3 +- source/opt/graphics_robust_access_pass.cpp | 10 +-- source/opt/inst_debug_printf_pass.cpp | 12 ++-- source/opt/instruction.h | 11 ++-- source/opt/ir_context.cpp | 11 ++-- source/opt/ir_context.h | 13 +--- .../opt/local_access_chain_convert_pass.cpp | 12 ++-- source/opt/local_single_block_elim_pass.cpp | 12 ++-- source/opt/local_single_store_elim_pass.cpp | 12 ++-- source/opt/module.cpp | 4 +- source/opt/remove_duplicates_pass.cpp | 5 +- source/opt/replace_invalid_opc.cpp | 5 +- source/opt/strip_debug_info_pass.cpp | 13 ++-- source/opt/strip_reflect_info_pass.cpp | 15 +++-- source/opt/type_manager.cpp | 9 +-- source/opt/upgrade_memory_model.cpp | 8 +-- source/text_handler.cpp | 11 +--- source/util/string_utils.h | 63 +++++++++++++++---- source/val/instruction.cpp | 10 +++ source/val/instruction.h | 3 + source/val/validate.cpp | 9 +-- source/val/validate_decorations.cpp | 6 +- source/val/validate_extensions.cpp | 11 ++-- source/val/validation_state.cpp | 6 +- test/binary_parse_test.cpp | 53 +++++++--------- test/binary_to_text.literal_test.cpp | 17 ++++- test/opt/instruction_test.cpp | 6 -- test/test_fixture.h | 19 +++++- 36 files changed, 249 insertions(+), 227 deletions(-) diff --git a/source/binary.cpp b/source/binary.cpp index 93d5da7aba..48a94f1eed 100644 --- a/source/binary.cpp +++ b/source/binary.cpp @@ -33,6 +33,7 @@ #include "source/operand.h" #include "source/spirv_constant.h" #include "source/spirv_endian.h" +#include "source/util/string_utils.h" spv_result_t spvBinaryHeaderGet(const spv_const_binary binary, const spv_endianness_t endian, @@ -62,6 +63,15 @@ spv_result_t spvBinaryHeaderGet(const spv_const_binary binary, return SPV_SUCCESS; } +std::string spvDecodeLiteralStringOperand(const spv_parsed_instruction_t& inst, + const uint16_t operand_index) { + assert(operand_index < inst.num_operands); + const spv_parsed_operand_t& operand = inst.operands[operand_index]; + + return spvtools::utils::MakeString(inst.words + operand.offset, + operand.num_words); +} + namespace { // A SPIR-V binary parser. A parser instance communicates detailed parse @@ -577,27 +587,18 @@ spv_result_t Parser::parseOperand(size_t inst_offset, case SPV_OPERAND_TYPE_LITERAL_STRING: case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { - convert_operand_endianness = false; - const char* string = - reinterpret_cast(_.words + _.word_index); - // Compute the length of the string, but make sure we don't run off the - // end of the input. - const size_t remaining_input_bytes = - sizeof(uint32_t) * (_.num_words - _.word_index); - const size_t string_num_content_bytes = - spv_strnlen_s(string, remaining_input_bytes); - // If there was no terminating null byte, then that's an end-of-input - // error. - if (string_num_content_bytes == remaining_input_bytes) + const size_t max_words = _.num_words - _.word_index; + std::string string = + spvtools::utils::MakeString(_.words + _.word_index, max_words, false); + + if (string.length() == max_words * 4) return exhaustedInputDiagnostic(inst_offset, opcode, type); - // Account for null in the word length, so add 1 for null, then add 3 to - // make sure we round up. The following is equivalent to: - // (string_num_content_bytes + 1 + 3) / 4 - const size_t string_num_words = string_num_content_bytes / 4 + 1; + // Make sure we can record the word count without overflow. // // This error can't currently be triggered because of validity // checks elsewhere. + const size_t string_num_words = string.length() / 4 + 1; if (string_num_words > std::numeric_limits::max()) { return diagnostic() << "Literal string is longer than " << std::numeric_limits::max() @@ -611,7 +612,7 @@ spv_result_t Parser::parseOperand(size_t inst_offset, // There is only one string literal argument to OpExtInstImport, // so it's sufficient to guard this just on the opcode. const spv_ext_inst_type_t ext_inst_type = - spvExtInstImportTypeGet(string); + spvExtInstImportTypeGet(string.c_str()); if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { return diagnostic() << "Invalid extended instruction import '" << string << "'"; diff --git a/source/binary.h b/source/binary.h index 66d24c7e41..eb3beacac3 100644 --- a/source/binary.h +++ b/source/binary.h @@ -15,6 +15,8 @@ #ifndef SOURCE_BINARY_H_ #define SOURCE_BINARY_H_ +#include + #include "source/spirv_definition.h" #include "spirv-tools/libspirv.h" @@ -33,4 +35,9 @@ spv_result_t spvBinaryHeaderGet(const spv_const_binary binary, // replacement for C11's strnlen_s which might not exist in all environments. size_t spv_strnlen_s(const char* str, size_t strsz); +// Decode the string literal operand with index operand_index from instruction +// inst. +std::string spvDecodeLiteralStringOperand(const spv_parsed_instruction_t& inst, + const uint16_t operand_index); + #endif // SOURCE_BINARY_H_ diff --git a/source/disassemble.cpp b/source/disassemble.cpp index c553988f38..250c2bf918 100644 --- a/source/disassemble.cpp +++ b/source/disassemble.cpp @@ -283,13 +283,11 @@ void Disassembler::EmitOperand(const spv_parsed_instruction_t& inst, case SPV_OPERAND_TYPE_LITERAL_STRING: { stream_ << "\""; SetGreen(); - // Strings are always little-endian, and null-terminated. - // Write out the characters, escaping as needed, and without copying - // the entire string. - auto c_str = reinterpret_cast(inst.words + operand.offset); - for (auto p = c_str; *p; ++p) { - if (*p == '"' || *p == '\\') stream_ << '\\'; - stream_ << *p; + + std::string str = spvDecodeLiteralStringOperand(inst, operand_index); + for (char const& c : str) { + if (c == '"' || c == '\\') stream_ << '\\'; + stream_ << c; } ResetColor(); stream_ << '"'; diff --git a/source/extensions.cpp b/source/extensions.cpp index a94db273e7..049a3ad10a 100644 --- a/source/extensions.cpp +++ b/source/extensions.cpp @@ -18,6 +18,7 @@ #include #include +#include "source/binary.h" #include "source/enum_string_mapping.h" namespace spvtools { @@ -30,8 +31,9 @@ std::string GetExtensionString(const spv_parsed_instruction_t* inst) { const auto& operand = inst->operands[0]; assert(operand.type == SPV_OPERAND_TYPE_LITERAL_STRING); assert(inst->num_words > operand.offset); + (void)operand; /* No unused variables in release builds. */ - return reinterpret_cast(inst->words + operand.offset); + return spvDecodeLiteralStringOperand(*inst, 0); } std::string ExtensionSetToString(const ExtensionSet& extensions) { diff --git a/source/link/linker.cpp b/source/link/linker.cpp index c4bfc763c7..21cada709c 100644 --- a/source/link/linker.cpp +++ b/source/link/linker.cpp @@ -37,6 +37,7 @@ #include "source/spirv_constant.h" #include "source/spirv_target_env.h" #include "source/util/make_unique.h" +#include "source/util/string_utils.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { @@ -282,16 +283,15 @@ spv_result_t MergeModules(const MessageConsumer& consumer, memory_model_inst->Clone(linked_context))); } while (false); - std::vector> entry_points; + std::vector> entry_points; for (const auto& module : input_modules) for (const auto& inst : module->entry_points()) { const uint32_t model = inst.GetSingleWordInOperand(0); - const char* const name = - reinterpret_cast(inst.GetInOperand(2).words.data()); + const std::string name = inst.GetInOperand(2).AsString(); const auto i = std::find_if( entry_points.begin(), entry_points.end(), - [model, name](const std::pair& v) { - return v.first == model && strcmp(name, v.second) == 0; + [model, name](const std::pair& v) { + return v.first == model && v.second == name; }); if (i != entry_points.end()) { spv_operand_desc desc = nullptr; @@ -334,11 +334,8 @@ spv_result_t MergeModules(const MessageConsumer& consumer, // OpModuleProcessed instruction about the linking step. if (linked_module->version() >= 0x10100) { const std::string processed_string("Linked by SPIR-V Tools Linker"); - const auto num_chars = processed_string.size(); - // Compute num words, accommodate the terminating null character. - const auto num_words = (num_chars + 1 + 3) / 4; - std::vector processed_words(num_words, 0u); - std::memcpy(processed_words.data(), processed_string.data(), num_chars); + std::vector processed_words = + spvtools::utils::MakeVector(processed_string); linked_module->AddDebug3Inst(std::unique_ptr( new Instruction(linked_context, SpvOpModuleProcessed, 0u, 0u, {{SPV_OPERAND_TYPE_LITERAL_STRING, processed_words}}))); @@ -414,8 +411,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer, const uint32_t type = decoration.GetSingleWordInOperand(3u); LinkageSymbolInfo symbol_info; - symbol_info.name = - reinterpret_cast(decoration.GetInOperand(2u).words.data()); + symbol_info.name = decoration.GetInOperand(2u).AsString(); symbol_info.id = id; symbol_info.type_id = 0u; diff --git a/source/name_mapper.cpp b/source/name_mapper.cpp index eb08f8fed3..3b31d33a81 100644 --- a/source/name_mapper.cpp +++ b/source/name_mapper.cpp @@ -22,10 +22,10 @@ #include #include -#include "spirv-tools/libspirv.h" - +#include "source/binary.h" #include "source/latest_version_spirv_header.h" #include "source/parsed_operand.h" +#include "spirv-tools/libspirv.h" namespace spvtools { namespace { @@ -172,7 +172,7 @@ spv_result_t FriendlyNameMapper::ParseInstruction( const auto result_id = inst.result_id; switch (inst.opcode) { case SpvOpName: - SaveName(inst.words[1], reinterpret_cast(inst.words + 2)); + SaveName(inst.words[1], spvDecodeLiteralStringOperand(inst, 1)); break; case SpvOpDecorate: // Decorations come after OpName. So OpName will take precedence over @@ -274,9 +274,8 @@ spv_result_t FriendlyNameMapper::ParseInstruction( SaveName(result_id, "Queue"); break; case SpvOpTypeOpaque: - SaveName(result_id, - std::string("Opaque_") + - Sanitize(reinterpret_cast(inst.words + 2))); + SaveName(result_id, std::string("Opaque_") + + Sanitize(spvDecodeLiteralStringOperand(inst, 1))); break; case SpvOpTypePipeStorage: SaveName(result_id, "PipeStorage"); diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index c2ed2f87cc..9827c535a6 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -27,6 +27,7 @@ #include "source/opt/iterator.h" #include "source/opt/reflect.h" #include "source/spirv_constant.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -146,8 +147,7 @@ void AggressiveDCEPass::AddStores(Function* func, uint32_t ptrId) { bool AggressiveDCEPass::AllExtensionsSupported() const { // If any extension not in allowlist, return false for (auto& ei : get_module()->extensions()) { - const char* extName = - reinterpret_cast(&ei.GetInOperand(0).words[0]); + const std::string extName = ei.GetInOperand(0).AsString(); if (extensions_allowlist_.find(extName) == extensions_allowlist_.end()) return false; } @@ -156,11 +156,9 @@ bool AggressiveDCEPass::AllExtensionsSupported() const { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12) && - 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100", - 32)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.") && + extension_name != "NonSemantic.Shader.DebugInfo.100") { return false; } } diff --git a/source/opt/amd_ext_to_khr.cpp b/source/opt/amd_ext_to_khr.cpp index d46d24379d..dd9bafda32 100644 --- a/source/opt/amd_ext_to_khr.cpp +++ b/source/opt/amd_ext_to_khr.cpp @@ -935,8 +935,7 @@ Pass::Status AmdExtensionToKhrPass::Process() { std::vector to_be_killed; for (Instruction& inst : context()->module()->extensions()) { if (inst.opcode() == SpvOpExtension) { - if (ext_to_remove.count(reinterpret_cast( - &(inst.GetInOperand(0).words[0]))) != 0) { + if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) { to_be_killed.push_back(&inst); } } @@ -944,8 +943,7 @@ Pass::Status AmdExtensionToKhrPass::Process() { for (Instruction& inst : context()->ext_inst_imports()) { if (inst.opcode() == SpvOpExtInstImport) { - if (ext_to_remove.count(reinterpret_cast( - &(inst.GetInOperand(0).words[0]))) != 0) { + if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) { to_be_killed.push_back(&inst); } } diff --git a/source/opt/feature_manager.cpp b/source/opt/feature_manager.cpp index 39a4a3486e..a590271679 100644 --- a/source/opt/feature_manager.cpp +++ b/source/opt/feature_manager.cpp @@ -39,8 +39,7 @@ void FeatureManager::AddExtension(Instruction* ext) { assert(ext->opcode() == SpvOpExtension && "Expecting an extension instruction."); - const std::string name = - reinterpret_cast(ext->GetInOperand(0u).words.data()); + const std::string name = ext->GetInOperand(0u).AsString(); Extension extension; if (GetExtensionFromString(name.c_str(), &extension)) { extensions_.Add(extension); diff --git a/source/opt/graphics_robust_access_pass.cpp b/source/opt/graphics_robust_access_pass.cpp index 1b28f9b529..336dcd83e4 100644 --- a/source/opt/graphics_robust_access_pass.cpp +++ b/source/opt/graphics_robust_access_pass.cpp @@ -559,21 +559,17 @@ uint32_t GraphicsRobustAccessPass::GetGlslInsts() { if (module_status_.glsl_insts_id == 0) { // This string serves double-duty as raw data for a string and for a vector // of 32-bit words - const char glsl[] = "GLSL.std.450\0\0\0\0"; - const size_t glsl_str_byte_len = 16; + const char glsl[] = "GLSL.std.450"; // Use an existing import if we can. for (auto& inst : context()->module()->ext_inst_imports()) { - const auto& name_words = inst.GetInOperand(0).words; - if (0 == std::strncmp(reinterpret_cast(name_words.data()), - glsl, glsl_str_byte_len)) { + if (inst.GetInOperand(0).AsString() == glsl) { module_status_.glsl_insts_id = inst.result_id(); } } if (module_status_.glsl_insts_id == 0) { // Make a new import instruction. module_status_.glsl_insts_id = TakeNextId(); - std::vector words(glsl_str_byte_len / sizeof(uint32_t)); - std::memcpy(words.data(), glsl, glsl_str_byte_len); + std::vector words = spvtools::utils::MakeVector(glsl); auto import_inst = MakeUnique( context(), SpvOpExtInstImport, 0, module_status_.glsl_insts_id, std::initializer_list{ diff --git a/source/opt/inst_debug_printf_pass.cpp b/source/opt/inst_debug_printf_pass.cpp index c0e6bc3f04..4218138f97 100644 --- a/source/opt/inst_debug_printf_pass.cpp +++ b/source/opt/inst_debug_printf_pass.cpp @@ -16,6 +16,7 @@ #include "inst_debug_printf_pass.h" +#include "source/util/string_utils.h" #include "spirv/unified1/NonSemanticDebugPrintf.h" namespace spvtools { @@ -231,10 +232,8 @@ Pass::Status InstDebugPrintfPass::ProcessImpl() { bool non_sem_set_seen = false; for (auto c_itr = context()->module()->ext_inst_import_begin(); c_itr != context()->module()->ext_inst_import_end(); ++c_itr) { - const char* set_name = - reinterpret_cast(&c_itr->GetInOperand(0).words[0]); - const char* non_sem_str = "NonSemantic."; - if (!strncmp(set_name, non_sem_str, strlen(non_sem_str))) { + const std::string set_name = c_itr->GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(set_name, "NonSemantic.")) { non_sem_set_seen = true; break; } @@ -242,9 +241,8 @@ Pass::Status InstDebugPrintfPass::ProcessImpl() { if (!non_sem_set_seen) { for (auto c_itr = context()->module()->extension_begin(); c_itr != context()->module()->extension_end(); ++c_itr) { - const char* ext_name = - reinterpret_cast(&c_itr->GetInOperand(0).words[0]); - if (!strcmp(ext_name, "SPV_KHR_non_semantic_info")) { + const std::string ext_name = c_itr->GetInOperand(0).AsString(); + if (ext_name == "SPV_KHR_non_semantic_info") { context()->KillInst(&*c_itr); break; } diff --git a/source/opt/instruction.h b/source/opt/instruction.h index ce568f6626..57ee70734b 100644 --- a/source/opt/instruction.h +++ b/source/opt/instruction.h @@ -24,6 +24,7 @@ #include "NonSemanticShaderDebugInfo100.h" #include "OpenCLDebugInfo100.h" +#include "source/binary.h" #include "source/common_debug_info.h" #include "source/latest_version_glsl_std_450_header.h" #include "source/latest_version_spirv_header.h" @@ -32,6 +33,7 @@ #include "source/opt/reflect.h" #include "source/util/ilist_node.h" #include "source/util/small_vector.h" +#include "source/util/string_utils.h" #include "spirv-tools/libspirv.h" const uint32_t kNoDebugScope = 0; @@ -85,15 +87,12 @@ struct Operand { spv_operand_type_t type; // Type of this logical operand. OperandData words; // Binary segments of this logical operand. - // Returns a string operand as a C-style string. - const char* AsCString() const { + // Returns a string operand as a std::string. + std::string AsString() const { assert(type == SPV_OPERAND_TYPE_LITERAL_STRING); - return reinterpret_cast(words.data()); + return spvtools::utils::MakeString(words); } - // Returns a string operand as a std::string. - std::string AsString() const { return AsCString(); } - // Returns a literal integer operand as a uint64_t uint64_t AsLiteralUint64() const { assert(type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER); diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp index 612a831add..fef0f7cefa 100644 --- a/source/opt/ir_context.cpp +++ b/source/opt/ir_context.cpp @@ -623,9 +623,8 @@ void IRContext::AddCombinatorsForCapability(uint32_t capability) { void IRContext::AddCombinatorsForExtension(Instruction* extension) { assert(extension->opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast(&extension->GetInOperand(0).words[0]); - if (!strcmp(extension_name, "GLSL.std.450")) { + const std::string extension_name = extension->GetInOperand(0).AsString(); + if (extension_name == "GLSL.std.450") { combinator_ops_[extension->result_id()] = {GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, @@ -944,11 +943,11 @@ void IRContext::EmitErrorMessage(std::string message, Instruction* inst) { uint32_t line_number = 0; uint32_t col_number = 0; - char* source = nullptr; + std::string source; if (line_inst != nullptr) { Instruction* file_name = get_def_use_mgr()->GetDef(line_inst->GetSingleWordInOperand(0)); - source = reinterpret_cast(&file_name->GetInOperand(0).words[0]); + source = file_name->GetInOperand(0).AsString(); // Get the line number and column number. line_number = line_inst->GetSingleWordInOperand(1); @@ -957,7 +956,7 @@ void IRContext::EmitErrorMessage(std::string message, Instruction* inst) { message += "\n " + inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - consumer()(SPV_MSG_ERROR, source, {line_number, col_number, 0}, + consumer()(SPV_MSG_ERROR, source.c_str(), {line_number, col_number, 0}, message.c_str()); } diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index 2ce16db4e4..7bef3054c0 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -43,6 +43,7 @@ #include "source/opt/type_manager.h" #include "source/opt/value_number_table.h" #include "source/util/make_unique.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -1032,11 +1033,7 @@ void IRContext::AddCapability(std::unique_ptr&& c) { } void IRContext::AddExtension(const std::string& ext_name) { - const auto num_chars = ext_name.size(); - // Compute num words, accommodate the terminating null character. - const auto num_words = (num_chars + 1 + 3) / 4; - std::vector ext_words(num_words, 0u); - std::memcpy(ext_words.data(), ext_name.data(), num_chars); + std::vector ext_words = spvtools::utils::MakeVector(ext_name); AddExtension(std::unique_ptr( new Instruction(this, SpvOpExtension, 0u, 0u, {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}}))); @@ -1053,11 +1050,7 @@ void IRContext::AddExtension(std::unique_ptr&& e) { } void IRContext::AddExtInstImport(const std::string& name) { - const auto num_chars = name.size(); - // Compute num words, accommodate the terminating null character. - const auto num_words = (num_chars + 1 + 3) / 4; - std::vector ext_words(num_words, 0u); - std::memcpy(ext_words.data(), name.data(), num_chars); + std::vector ext_words = spvtools::utils::MakeVector(name); AddExtInstImport(std::unique_ptr( new Instruction(this, SpvOpExtInstImport, 0u, TakeNextId(), {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}}))); diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index da9ba8cc01..d2059f5c21 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -19,6 +19,7 @@ #include "ir_builder.h" #include "ir_context.h" #include "iterator.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -328,8 +329,7 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const { return false; // If any extension not in allowlist, return false for (auto& ei : get_module()->extensions()) { - const char* extName = - reinterpret_cast(&ei.GetInOperand(0).words[0]); + const std::string extName = ei.GetInOperand(0).AsString(); if (extensions_allowlist_.find(extName) == extensions_allowlist_.end()) return false; } @@ -339,11 +339,9 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12) && - 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100", - 32)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.") && + extension_name != "NonSemantic.Shader.DebugInfo.100") { return false; } } diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp index 5fd4f658d4..f48c56aab0 100644 --- a/source/opt/local_single_block_elim_pass.cpp +++ b/source/opt/local_single_block_elim_pass.cpp @@ -19,6 +19,7 @@ #include #include "source/opt/iterator.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -183,8 +184,7 @@ void LocalSingleBlockLoadStoreElimPass::Initialize() { bool LocalSingleBlockLoadStoreElimPass::AllExtensionsSupported() const { // If any extension not in allowlist, return false for (auto& ei : get_module()->extensions()) { - const char* extName = - reinterpret_cast(&ei.GetInOperand(0).words[0]); + const std::string extName = ei.GetInOperand(0).AsString(); if (extensions_allowlist_.find(extName) == extensions_allowlist_.end()) return false; } @@ -194,11 +194,9 @@ bool LocalSingleBlockLoadStoreElimPass::AllExtensionsSupported() const { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12) && - 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100", - 32)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.") && + extension_name != "NonSemantic.Shader.DebugInfo.100") { return false; } } diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp index 051bcada76..123d03bf51 100644 --- a/source/opt/local_single_store_elim_pass.cpp +++ b/source/opt/local_single_store_elim_pass.cpp @@ -19,6 +19,7 @@ #include "source/cfa.h" #include "source/latest_version_glsl_std_450_header.h" #include "source/opt/iterator.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -48,8 +49,7 @@ bool LocalSingleStoreElimPass::LocalSingleStoreElim(Function* func) { bool LocalSingleStoreElimPass::AllExtensionsSupported() const { // If any extension not in allowlist, return false for (auto& ei : get_module()->extensions()) { - const char* extName = - reinterpret_cast(&ei.GetInOperand(0).words[0]); + const std::string extName = ei.GetInOperand(0).AsString(); if (extensions_allowlist_.find(extName) == extensions_allowlist_.end()) return false; } @@ -59,11 +59,9 @@ bool LocalSingleStoreElimPass::AllExtensionsSupported() const { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12) && - 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100", - 32)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.") && + extension_name != "NonSemantic.Shader.DebugInfo.100") { return false; } } diff --git a/source/opt/module.cpp b/source/opt/module.cpp index c3c705982c..5983abb126 100644 --- a/source/opt/module.cpp +++ b/source/opt/module.cpp @@ -260,9 +260,7 @@ bool Module::HasExplicitCapability(uint32_t cap) { uint32_t Module::GetExtInstImportId(const char* extstr) { for (auto& ei : ext_inst_imports_) - if (!strcmp(extstr, - reinterpret_cast(&(ei.GetInOperand(0).words[0])))) - return ei.result_id(); + if (!ei.GetInOperand(0).AsString().compare(extstr)) return ei.result_id(); return 0; } diff --git a/source/opt/remove_duplicates_pass.cpp b/source/opt/remove_duplicates_pass.cpp index 0e65cc8d14..1ed8e2a046 100644 --- a/source/opt/remove_duplicates_pass.cpp +++ b/source/opt/remove_duplicates_pass.cpp @@ -72,9 +72,8 @@ bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports() const { std::unordered_map ext_inst_imports; for (auto* i = &*context()->ext_inst_import_begin(); i;) { - auto res = ext_inst_imports.emplace( - reinterpret_cast(i->GetInOperand(0u).words.data()), - i->result_id()); + auto res = ext_inst_imports.emplace(i->GetInOperand(0u).AsString(), + i->result_id()); if (res.second) { // Never seen before, keep it. i = i->NextNode(); diff --git a/source/opt/replace_invalid_opc.cpp b/source/opt/replace_invalid_opc.cpp index e3b9d3e403..1dcd06f591 100644 --- a/source/opt/replace_invalid_opc.cpp +++ b/source/opt/replace_invalid_opc.cpp @@ -112,8 +112,7 @@ bool ReplaceInvalidOpcodePass::RewriteFunction(Function* function, } Instruction* file_name = context()->get_def_use_mgr()->GetDef(file_name_id); - const char* source = reinterpret_cast( - &file_name->GetInOperand(0).words[0]); + const std::string source = file_name->GetInOperand(0).AsString(); // Get the line number and column number. uint32_t line_number = @@ -121,7 +120,7 @@ bool ReplaceInvalidOpcodePass::RewriteFunction(Function* function, uint32_t col_number = last_line_dbg_inst->GetSingleWordInOperand(2); // Replace the instruction. - ReplaceInstruction(inst, source, line_number, col_number); + ReplaceInstruction(inst, source.c_str(), line_number, col_number); } } }, diff --git a/source/opt/strip_debug_info_pass.cpp b/source/opt/strip_debug_info_pass.cpp index c86ce57828..6a0ebf2482 100644 --- a/source/opt/strip_debug_info_pass.cpp +++ b/source/opt/strip_debug_info_pass.cpp @@ -14,6 +14,7 @@ #include "source/opt/strip_debug_info_pass.h" #include "source/opt/ir_context.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -21,9 +22,8 @@ namespace opt { Pass::Status StripDebugInfoPass::Process() { bool uses_non_semantic_info = false; for (auto& inst : context()->module()->extensions()) { - const char* ext_name = - reinterpret_cast(&inst.GetInOperand(0).words[0]); - if (0 == std::strcmp(ext_name, "SPV_KHR_non_semantic_info")) { + const std::string ext_name = inst.GetInOperand(0).AsString(); + if (ext_name == "SPV_KHR_non_semantic_info") { uses_non_semantic_info = true; } } @@ -46,9 +46,10 @@ Pass::Status StripDebugInfoPass::Process() { if (use->opcode() == SpvOpExtInst) { auto ext_inst_set = def_use->GetDef(use->GetSingleWordInOperand(0u)); - const char* extension_name = reinterpret_cast( - &ext_inst_set->GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12)) { + const std::string extension_name = + ext_inst_set->GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, + "NonSemantic.")) { // found a non-semantic use, return false as we cannot // remove this OpString return false; diff --git a/source/opt/strip_reflect_info_pass.cpp b/source/opt/strip_reflect_info_pass.cpp index 8b0f2db7ff..f9be960a83 100644 --- a/source/opt/strip_reflect_info_pass.cpp +++ b/source/opt/strip_reflect_info_pass.cpp @@ -19,6 +19,7 @@ #include "source/opt/instruction.h" #include "source/opt/ir_context.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -60,14 +61,13 @@ Pass::Status StripReflectInfoPass::Process() { } for (auto& inst : context()->module()->extensions()) { - const char* ext_name = - reinterpret_cast(&inst.GetInOperand(0).words[0]); - if (0 == std::strcmp(ext_name, "SPV_GOOGLE_hlsl_functionality1")) { + const std::string ext_name = inst.GetInOperand(0).AsString(); + if (ext_name == "SPV_GOOGLE_hlsl_functionality1") { to_remove.push_back(&inst); } else if (!other_uses_for_decorate_string && - 0 == std::strcmp(ext_name, "SPV_GOOGLE_decorate_string")) { + ext_name == "SPV_GOOGLE_decorate_string") { to_remove.push_back(&inst); - } else if (0 == std::strcmp(ext_name, "SPV_KHR_non_semantic_info")) { + } else if (ext_name == "SPV_KHR_non_semantic_info") { to_remove.push_back(&inst); } } @@ -84,9 +84,8 @@ Pass::Status StripReflectInfoPass::Process() { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.")) { non_semantic_sets.insert(inst.result_id()); to_remove.push_back(&inst); } diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp index 7935ad3313..6da4b57b49 100644 --- a/source/opt/type_manager.cpp +++ b/source/opt/type_manager.cpp @@ -23,6 +23,7 @@ #include "source/opt/log.h" #include "source/opt/reflect.h" #include "source/util/make_unique.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -349,11 +350,8 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { } case Type::kOpaque: { const Opaque* opaque = type->AsOpaque(); - size_t size = opaque->name().size(); // Convert to null-terminated packed UTF-8 string. - std::vector words(size / 4 + 1, 0); - char* dst = reinterpret_cast(words.data()); - strncpy(dst, opaque->name().c_str(), size); + std::vector words = spvtools::utils::MakeVector(opaque->name()); typeInst = MakeUnique( context(), SpvOpTypeOpaque, 0, id, std::initializer_list{ @@ -781,8 +779,7 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) { } } break; case SpvOpTypeOpaque: { - const uint32_t* data = inst.GetInOperand(0).words.data(); - type = new Opaque(reinterpret_cast(data)); + type = new Opaque(inst.GetInOperand(0).AsString()); } break; case SpvOpTypePointer: { uint32_t pointee_type_id = inst.GetSingleWordInOperand(1); diff --git a/source/opt/upgrade_memory_model.cpp b/source/opt/upgrade_memory_model.cpp index ab252059fa..9d6a5bceb4 100644 --- a/source/opt/upgrade_memory_model.cpp +++ b/source/opt/upgrade_memory_model.cpp @@ -20,6 +20,7 @@ #include "source/opt/ir_context.h" #include "source/spirv_constant.h" #include "source/util/make_unique.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -58,9 +59,7 @@ void UpgradeMemoryModel::UpgradeMemoryModelInstruction() { std::initializer_list{ {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityVulkanMemoryModelKHR}}})); const std::string extension = "SPV_KHR_vulkan_memory_model"; - std::vector words(extension.size() / 4 + 1, 0); - char* dst = reinterpret_cast(words.data()); - strncpy(dst, extension.c_str(), extension.size()); + std::vector words = spvtools::utils::MakeVector(extension); context()->AddExtension( MakeUnique(context(), SpvOpExtension, 0, 0, std::initializer_list{ @@ -85,8 +84,7 @@ void UpgradeMemoryModel::UpgradeInstructions() { if (ext_inst == GLSLstd450Modf || ext_inst == GLSLstd450Frexp) { auto import = get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); - if (reinterpret_cast(import->GetInOperand(0u).words.data()) == - std::string("GLSL.std.450")) { + if (import->GetInOperand(0u).AsString() == "GLSL.std.450") { UpgradeExtInst(inst); } } diff --git a/source/text_handler.cpp b/source/text_handler.cpp index 46b9845617..fe12a26e38 100644 --- a/source/text_handler.cpp +++ b/source/text_handler.cpp @@ -29,6 +29,7 @@ #include "source/util/bitutils.h" #include "source/util/hex_float.h" #include "source/util/parse_number.h" +#include "source/util/string_utils.h" namespace spvtools { namespace { @@ -307,14 +308,8 @@ spv_result_t AssemblyContext::binaryEncodeString(const char* value, << SPV_LIMIT_INSTRUCTION_WORD_COUNT_MAX << " words."; } - pInst->words.resize(newWordCount); - - // Make sure all the bytes in the last word are 0, in case we only - // write a partial word at the end. - pInst->words.back() = 0; - - char* dest = (char*)&pInst->words[oldWordCount]; - strncpy(dest, value, length + 1); + pInst->words.reserve(newWordCount); + spvtools::utils::AppendToVector(value, &pInst->words); return SPV_SUCCESS; } diff --git a/source/util/string_utils.h b/source/util/string_utils.h index 4282aa9496..03e20b3d63 100644 --- a/source/util/string_utils.h +++ b/source/util/string_utils.h @@ -16,6 +16,8 @@ #define SOURCE_UTIL_STRING_UTILS_H_ #include + +#include #include #include #include @@ -44,9 +46,10 @@ std::string CardinalToOrdinal(size_t cardinal); // string will be empty. std::pair SplitFlagArgs(const std::string& flag); -// Encodes a string as a sequence of words, using the SPIR-V encoding. -inline std::vector MakeVector(std::string input) { - std::vector result; +// Encodes a string as a sequence of words, using the SPIR-V encoding, appending +// to an existing vector. +inline void AppendToVector(const std::string& input, + std::vector* result) { uint32_t word = 0; size_t num_bytes = input.size(); // SPIR-V strings are null-terminated. The byte_index == num_bytes @@ -56,24 +59,36 @@ inline std::vector MakeVector(std::string input) { (byte_index < num_bytes ? uint8_t(input[byte_index]) : uint8_t(0)); word |= (new_byte << (8 * (byte_index % sizeof(uint32_t)))); if (3 == (byte_index % sizeof(uint32_t))) { - result.push_back(word); + result->push_back(word); word = 0; } } // Emit a trailing partial word. if ((num_bytes + 1) % sizeof(uint32_t)) { - result.push_back(word); + result->push_back(word); } +} + +// Encodes a string as a sequence of words, using the SPIR-V encoding. +inline std::vector MakeVector(const std::string& input) { + std::vector result; + AppendToVector(input, &result); return result; } -// Decode a string from a sequence of words, using the SPIR-V encoding. -template -inline std::string MakeString(const VectorType& words) { +// Decode a string from a sequence of words between first and last, using the +// SPIR-V encoding. Assert that a terminating 0-byte was found (unless +// assert_found_terminating_null is passed as false). +template +inline std::string MakeString(InputIt first, InputIt last, + bool assert_found_terminating_null = true) { std::string result; + constexpr size_t kCharsPerWord = sizeof(*first); + static_assert(kCharsPerWord == 4, "expect 4-byte word"); - for (uint32_t word : words) { - for (int byte_index = 0; byte_index < 4; byte_index++) { + for (InputIt pos = first; pos != last; ++pos) { + uint32_t word = *pos; + for (size_t byte_index = 0; byte_index < kCharsPerWord; byte_index++) { uint32_t extracted_word = (word >> (8 * byte_index)) & 0xFF; char c = static_cast(extracted_word); if (c == 0) { @@ -82,9 +97,33 @@ inline std::string MakeString(const VectorType& words) { result += c; } } - assert(false && "Did not find terminating null for the string."); + assert(!assert_found_terminating_null && + "Did not find terminating null for the string."); + (void)assert_found_terminating_null; /* No unused parameters in release + builds. */ return result; -} // namespace utils +} + +// Decode a string from a sequence of words in a vector, using the SPIR-V +// encoding. +template +inline std::string MakeString(const VectorType& words, + bool assert_found_terminating_null = true) { + return MakeString(words.cbegin(), words.cend(), + assert_found_terminating_null); +} + +// Decode a string from array words, consuming up to count words, using the +// SPIR-V encoding. +inline std::string MakeString(const uint32_t* words, size_t num_words, + bool assert_found_terminating_null = true) { + return MakeString(words, words + num_words, assert_found_terminating_null); +} + +// Check if str starts with prefix (only included since C++20) +inline bool starts_with(const std::string& str, const char* prefix) { + return 0 == str.compare(0, std::strlen(prefix), prefix); +} } // namespace utils } // namespace spvtools diff --git a/source/val/instruction.cpp b/source/val/instruction.cpp index b9155898ac..f16fcd7300 100644 --- a/source/val/instruction.cpp +++ b/source/val/instruction.cpp @@ -16,6 +16,9 @@ #include +#include "source/binary.h" +#include "source/util/string_utils.h" + namespace spvtools { namespace val { @@ -41,5 +44,12 @@ bool operator==(const Instruction& lhs, uint32_t rhs) { return lhs.id() == rhs; } +template <> +std::string Instruction::GetOperandAs(size_t index) const { + const spv_parsed_operand_t& o = operands_.at(index); + assert(o.offset + o.num_words <= inst_.num_words); + return spvtools::utils::MakeString(words_.data() + o.offset, o.num_words); +} + } // namespace val } // namespace spvtools diff --git a/source/val/instruction.h b/source/val/instruction.h index 617cb0660d..6d1f9f4f12 100644 --- a/source/val/instruction.h +++ b/source/val/instruction.h @@ -133,6 +133,9 @@ bool operator<(const Instruction& lhs, uint32_t rhs); bool operator==(const Instruction& lhs, const Instruction& rhs); bool operator==(const Instruction& lhs, uint32_t rhs); +template <> +std::string Instruction::GetOperandAs(size_t index) const; + } // namespace val } // namespace spvtools diff --git a/source/val/validate.cpp b/source/val/validate.cpp index 45b6a463e6..7655c960d3 100644 --- a/source/val/validate.cpp +++ b/source/val/validate.cpp @@ -219,9 +219,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState( if (inst->opcode() == SpvOpEntryPoint) { const auto entry_point = inst->GetOperandAs(1); const auto execution_model = inst->GetOperandAs(0); - const char* str = reinterpret_cast( - inst->words().data() + inst->operand(2).offset); - const std::string desc_name(str); + const std::string desc_name = inst->GetOperandAs(2); ValidationState_t::EntryPointDescription desc; desc.name = desc_name; @@ -237,9 +235,8 @@ spv_result_t ValidateBinaryUsingContextAndValidationState( for (const Instruction* check_inst : visited_entry_points) { const auto check_execution_model = check_inst->GetOperandAs(0); - const char* check_str = reinterpret_cast( - check_inst->words().data() + inst->operand(2).offset); - const std::string check_name(check_str); + const std::string check_name = + check_inst->GetOperandAs(2); if (desc_name == check_name && execution_model == check_execution_model) { diff --git a/source/val/validate_decorations.cpp b/source/val/validate_decorations.cpp index 3cdb471c70..50c0db93e9 100644 --- a/source/val/validate_decorations.cpp +++ b/source/val/validate_decorations.cpp @@ -21,11 +21,13 @@ #include #include +#include "source/binary.h" #include "source/diagnostic.h" #include "source/opcode.h" #include "source/spirv_constant.h" #include "source/spirv_target_env.h" #include "source/spirv_validator_options.h" +#include "source/util/string_utils.h" #include "source/val/validate_scopes.h" #include "source/val/validation_state.h" @@ -798,8 +800,8 @@ spv_result_t CheckDecorationsOfEntryPoints(ValidationState_t& vstate) { // targeted by an OpEntryPoint instruction for (auto& decoration : vstate.id_decorations(entry_point)) { if (SpvDecorationLinkageAttributes == decoration.dec_type()) { - const char* linkage_name = - reinterpret_cast(&decoration.params()[0]); + const std::string linkage_name = + spvtools::utils::MakeString(decoration.params()); return vstate.diag(SPV_ERROR_INVALID_BINARY, vstate.FindDef(entry_point)) << "The LinkageAttributes Decoration (Linkage name: " diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp index dccbe14908..479e9e4191 100644 --- a/source/val/validate_extensions.cpp +++ b/source/val/validate_extensions.cpp @@ -280,8 +280,7 @@ spv_result_t ValidateClspvReflectionKernel(ValidationState_t& _, return _.diag(SPV_ERROR_INVALID_ID, inst) << "Name must be an OpString"; } - const std::string name_str = reinterpret_cast( - name->words().data() + name->operands()[1].offset); + const std::string name_str = name->GetOperandAs(1); bool found = false; for (auto& desc : _.entry_point_descriptions(kernel_id)) { if (name_str == desc.name) { @@ -741,8 +740,7 @@ spv_result_t ValidateExtInstImport(ValidationState_t& _, const Instruction* inst) { const auto name_id = 1; if (!_.HasExtension(kSPV_KHR_non_semantic_info)) { - const std::string name(reinterpret_cast( - inst->words().data() + inst->operands()[name_id].offset)); + const std::string name = inst->GetOperandAs(name_id); if (name.find("NonSemantic.") == 0) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "NonSemantic extended instruction sets cannot be declared " @@ -774,7 +772,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { assert(import_inst); std::ostringstream ss; - ss << reinterpret_cast(import_inst->words().data() + 2); + ss << import_inst->GetOperandAs(1); ss << " "; ss << desc->name; @@ -3264,8 +3262,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { } } else if (ext_inst_type == SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION) { auto import_inst = _.FindDef(inst->GetOperandAs(2)); - const std::string name(reinterpret_cast( - import_inst->words().data() + import_inst->operands()[1].offset)); + const std::string name = import_inst->GetOperandAs(1); const std::string reflection = "NonSemantic.ClspvReflection."; char* end_ptr; auto version_string = name.substr(reflection.size()); diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 8d1a0d3f4e..9d708be4da 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -497,15 +497,13 @@ void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) { switch (inst->opcode()) { case SpvOpName: { const auto target = inst->GetOperandAs(0); - const auto* str = reinterpret_cast(inst->words().data() + - inst->operand(1).offset); + const std::string str = inst->GetOperandAs(1); AssignNameToId(target, str); break; } case SpvOpMemberName: { const auto target = inst->GetOperandAs(0); - const auto* str = reinterpret_cast(inst->words().data() + - inst->operand(2).offset); + const std::string str = inst->GetOperandAs(2); AssignNameToId(target, str); break; } diff --git a/test/binary_parse_test.cpp b/test/binary_parse_test.cpp index 9a13f22ca5..ece750c561 100644 --- a/test/binary_parse_test.cpp +++ b/test/binary_parse_test.cpp @@ -203,16 +203,7 @@ class BinaryParseTest : public spvtest::TextToBinaryTestBase<::testing::Test> { void Parse(const SpirvVector& words, spv_result_t expected_result, bool flip_words = false) { SpirvVector flipped_words(words); - SCOPED_TRACE(flip_words ? "Flipped Endianness" : "Normal Endianness"); - if (flip_words) { - std::transform(flipped_words.begin(), flipped_words.end(), - flipped_words.begin(), [](const uint32_t raw_word) { - return spvFixWord(raw_word, - I32_ENDIAN_HOST == I32_ENDIAN_BIG - ? SPV_ENDIANNESS_LITTLE - : SPV_ENDIANNESS_BIG); - }); - } + MaybeFlipWords(flip_words, flipped_words.begin(), flipped_words.end()); EXPECT_EQ(expected_result, spvBinaryParse(ScopedContext().context, &client_, flipped_words.data(), flipped_words.size(), @@ -486,27 +477,27 @@ TEST_F(BinaryParseTest, EarlyReturnWithTwoPassingCallbacks) { } TEST_F(BinaryParseTest, InstructionWithStringOperand) { - const std::string str = - "the future is already here, it's just not evenly distributed"; - const auto str_words = MakeVector(str); - const auto instruction = MakeInstruction(SpvOpName, {99}, str_words); - const auto words = Concatenate({ExpectedHeaderForBound(100), instruction}); - InSequence calls_expected_in_specific_order; - EXPECT_HEADER(100).WillOnce(Return(SPV_SUCCESS)); - const auto operands = std::vector{ - MakeSimpleOperand(1, SPV_OPERAND_TYPE_ID), - MakeLiteralStringOperand(2, static_cast(str_words.size()))}; - EXPECT_CALL(client_, - Instruction(ParsedInstruction(spv_parsed_instruction_t{ - instruction.data(), static_cast(instruction.size()), - SpvOpName, SPV_EXT_INST_TYPE_NONE, 0 /*type id*/, - 0 /* No result id for OpName*/, operands.data(), - static_cast(operands.size())}))) - .WillOnce(Return(SPV_SUCCESS)); - // Since we are actually checking the output, don't test the - // endian-swapped version. - Parse(words, SPV_SUCCESS, false); - EXPECT_EQ(nullptr, diagnostic_); + for (bool endian_swap : kSwapEndians) { + const std::string str = + "the future is already here, it's just not evenly distributed"; + const auto str_words = MakeVector(str); + const auto instruction = MakeInstruction(SpvOpName, {99}, str_words); + const auto words = Concatenate({ExpectedHeaderForBound(100), instruction}); + InSequence calls_expected_in_specific_order; + EXPECT_HEADER(100).WillOnce(Return(SPV_SUCCESS)); + const auto operands = std::vector{ + MakeSimpleOperand(1, SPV_OPERAND_TYPE_ID), + MakeLiteralStringOperand(2, static_cast(str_words.size()))}; + EXPECT_CALL(client_, Instruction(ParsedInstruction(spv_parsed_instruction_t{ + instruction.data(), + static_cast(instruction.size()), + SpvOpName, SPV_EXT_INST_TYPE_NONE, 0 /*type id*/, + 0 /* No result id for OpName*/, operands.data(), + static_cast(operands.size())}))) + .WillOnce(Return(SPV_SUCCESS)); + Parse(words, SPV_SUCCESS, endian_swap); + EXPECT_EQ(nullptr, diagnostic_); + } } // Checks for non-zero values for the result_id and ext_inst_type members diff --git a/test/binary_to_text.literal_test.cpp b/test/binary_to_text.literal_test.cpp index 02daac761b..5956984b13 100644 --- a/test/binary_to_text.literal_test.cpp +++ b/test/binary_to_text.literal_test.cpp @@ -27,8 +27,15 @@ using ::testing::Eq; using RoundTripLiteralsTest = spvtest::TextToBinaryTestBase<::testing::TestWithParam>; +static const bool kSwapEndians[] = {false, true}; + TEST_P(RoundTripLiteralsTest, Sample) { - EXPECT_THAT(EncodeAndDecodeSuccessfully(GetParam()), Eq(GetParam())); + for (bool endian_swap : kSwapEndians) { + EXPECT_THAT( + EncodeAndDecodeSuccessfully(GetParam(), SPV_BINARY_TO_TEXT_OPTION_NONE, + SPV_ENV_UNIVERSAL_1_0, endian_swap), + Eq(GetParam())); + } } // clang-format off @@ -58,8 +65,12 @@ using RoundTripSpecialCaseLiteralsTest = spvtest::TextToBinaryTestBase< // Test case where the generated disassembly is not the same as the // assembly passed in. TEST_P(RoundTripSpecialCaseLiteralsTest, Sample) { - EXPECT_THAT(EncodeAndDecodeSuccessfully(std::get<0>(GetParam())), - Eq(std::get<1>(GetParam()))); + for (bool endian_swap : kSwapEndians) { + EXPECT_THAT(EncodeAndDecodeSuccessfully(std::get<0>(GetParam()), + SPV_BINARY_TO_TEXT_OPTION_NONE, + SPV_ENV_UNIVERSAL_1_0, endian_swap), + Eq(std::get<1>(GetParam()))); + } } // clang-format off diff --git a/test/opt/instruction_test.cpp b/test/opt/instruction_test.cpp index c5b92efb2e..2a48134d90 100644 --- a/test/opt/instruction_test.cpp +++ b/test/opt/instruction_test.cpp @@ -62,12 +62,6 @@ TEST(InstructionTest, CreateWithOpcodeAndNoOperands) { EXPECT_EQ(inst.end(), inst.begin()); } -TEST(InstructionTest, OperandAsCString) { - Operand::OperandData abcde{0x64636261, 0x65}; - Operand operand(SPV_OPERAND_TYPE_LITERAL_STRING, std::move(abcde)); - EXPECT_STREQ("abcde", operand.AsCString()); -} - TEST(InstructionTest, OperandAsString) { Operand::OperandData abcde{0x64636261, 0x65}; Operand operand(SPV_OPERAND_TYPE_LITERAL_STRING, std::move(abcde)); diff --git a/test/test_fixture.h b/test/test_fixture.h index 0c5bfc9c3b..029fc8543b 100644 --- a/test/test_fixture.h +++ b/test/test_fixture.h @@ -15,6 +15,7 @@ #ifndef TEST_TEST_FIXTURE_H_ #define TEST_TEST_FIXTURE_H_ +#include #include #include @@ -91,12 +92,26 @@ class TextToBinaryTestBase : public T { return diagnostic->error; } + // Potentially flip the words in the binary representation to the other + // endianness + template + void MaybeFlipWords(bool flip_words, It begin, It end) { + SCOPED_TRACE(flip_words ? "Flipped Endianness" : "Normal Endianness"); + if (flip_words) { + std::transform(begin, end, begin, [](const uint32_t raw_word) { + return spvFixWord(raw_word, I32_ENDIAN_HOST == I32_ENDIAN_BIG + ? SPV_ENDIANNESS_LITTLE + : SPV_ENDIANNESS_BIG); + }); + } + } + // Encodes SPIR-V text into binary and then decodes the binary using // given options. Returns the decoded text. std::string EncodeAndDecodeSuccessfully( const std::string& txt, uint32_t disassemble_options = SPV_BINARY_TO_TEXT_OPTION_NONE, - spv_target_env env = SPV_ENV_UNIVERSAL_1_0) { + spv_target_env env = SPV_ENV_UNIVERSAL_1_0, bool flip_words = false) { DestroyBinary(); DestroyDiagnostic(); ScopedContext context(env); @@ -110,6 +125,8 @@ class TextToBinaryTestBase : public T { EXPECT_EQ(SPV_SUCCESS, error); if (!binary) return ""; + MaybeFlipWords(flip_words, binary->code, binary->code + binary->wordCount); + spv_text decoded_text; error = spvBinaryToText(context.context, binary->code, binary->wordCount, disassemble_options, &decoded_text, &diagnostic);