LCOV - code coverage report
Current view: top level - wren_vk/src - shader.cpp (source / functions) Coverage Total Hit
Test: Wren Engine Coverage Lines: 0.0 % 195 0
Test Date: 1980-01-01 00:00:00 Functions: 0.0 % 11 0

            Line data    Source code
       1              : #include "shader.hpp"
       2              : 
       3              : #include <slang-com-ptr.h>
       4              : #include <slang-cpp-types.h>
       5              : #include <slang.h>
       6              : #include <spirv_reflect.h>
       7              : #include <vulkan/vulkan_core.h>
       8              : 
       9              : #include <algorithm>
      10              : #include <cstdint>
      11              : #include <vulkan/vulkan_enums.hpp>
      12              : #include <vulkan/vulkan_handles.hpp>
      13              : #include <wren/math/vector.hpp>
      14              : #include <wren/utils/enums.hpp>
      15              : #include <wren/utils/filesystem.hpp>
      16              : #include <wren/utils/string_reader.hpp>
      17              : #include <wren/vk/result.hpp>
      18              : 
      19              : #include "vulkan/vulkan_structs.hpp"
      20              : #include "wren/logging/log.hpp"
      21              : #include "wren/utils/result.hpp"
      22              : 
      23              : // NOLINTNEXTLINE
      24              : BOOST_DESCRIBE_ENUM(
      25              :     SpvReflectResult, SPV_REFLECT_RESULT_SUCCESS, SPV_REFLECT_RESULT_NOT_READY,
      26              :     SPV_REFLECT_RESULT_ERROR_PARSE_FAILED,
      27              :     SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED,
      28              :     SPV_REFLECT_RESULT_ERROR_RANGE_EXCEEDED,
      29              :     SPV_REFLECT_RESULT_ERROR_NULL_POINTER,
      30              :     SPV_REFLECT_RESULT_ERROR_INTERNAL_ERROR,
      31              :     SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH,
      32              :     SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND,
      33              :     SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_CODE_SIZE,
      34              :     SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_MAGIC_NUMBER,
      35              :     SPV_REFLECT_RESULT_ERROR_SPIRV_UNEXPECTED_EOF,
      36              :     SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE,
      37              :     SPV_REFLECT_RESULT_ERROR_SPIRV_SET_NUMBER_OVERFLOW,
      38              :     SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_STORAGE_CLASS,
      39              :     SPV_REFLECT_RESULT_ERROR_SPIRV_RECURSION,
      40              :     SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_INSTRUCTION,
      41              :     SPV_REFLECT_RESULT_ERROR_SPIRV_UNEXPECTED_BLOCK_DATA,
      42              :     SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_BLOCK_MEMBER_REFERENCE,
      43              :     SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ENTRY_POINT,
      44              :     SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_EXECUTION_MODE,
      45              :     SPV_REFLECT_RESULT_ERROR_SPIRV_MAX_RECURSIVE_EXCEEDED)
      46              : 
      47              : namespace wren::vk {
      48              : 
      49            0 : ShaderModule::ShaderModule(reflect::spirv_t spirv,
      50              :                            const ::vk::ShaderModule& module)
      51            0 :     : spirv_(std::move(spirv)),
      52            0 :       module_(module),
      53            0 :       reflection_(std::make_shared<spv_reflect::ShaderModule>(this->spirv_)) {
      54            0 :   if (reflection_->GetResult() != SPV_REFLECT_RESULT_SUCCESS) {
      55            0 :     log::error("Shader reflection failed: {}", reflection_->GetResult());
      56            0 :   }
      57            0 : }
      58              : 
      59            0 : auto ShaderModule::get_vertex_input_bindings() const
      60              :     -> std::vector<::vk::VertexInputBindingDescription> {
      61            0 :   uint32_t count = 0;
      62            0 :   reflection_->EnumerateInputVariables(&count, nullptr);
      63            0 :   if (count == 0) return {};
      64              : 
      65            0 :   std::vector<SpvReflectInterfaceVariable*> input_variables(count);
      66            0 :   reflection_->EnumerateInputVariables(&count, input_variables.data());
      67              : 
      68            0 :   uint32_t offset = 0;
      69            0 :   std::vector<::vk::VertexInputBindingDescription> bindings;
      70            0 :   bindings.reserve(input_variables.size());
      71            0 :   for (const auto& input : input_variables) {
      72            0 :     if (static_cast<uint32_t>(input->built_in) != UINT32_MAX) continue;
      73              : 
      74            0 :     const auto width = input->numeric.scalar.width / 8;
      75            0 :     const auto count = input->numeric.vector.component_count;
      76            0 :     offset += width * count;
      77              :   }
      78              : 
      79            0 :   if (offset == 0) {
      80            0 :     return {};
      81              :   }
      82              : 
      83            0 :   return {{0, offset}};
      84            0 : }
      85              : 
      86            0 : auto ShaderModule::get_vertex_input_attributes() const
      87              :     -> std::vector<::vk::VertexInputAttributeDescription> {
      88            0 :   const auto& input_variables = [&] {
      89            0 :     uint32_t count = 0;
      90            0 :     reflection_->EnumerateInputVariables(&count, nullptr);
      91            0 :     std::vector<SpvReflectInterfaceVariable*> input_variables(count);
      92            0 :     reflection_->EnumerateInputVariables(&count, input_variables.data());
      93              : 
      94            0 :     std::ranges::sort(input_variables, [](const auto& a, const auto& b) {
      95            0 :       return a->location < b->location;
      96              :     });
      97              : 
      98            0 :     return input_variables;
      99            0 :   }();
     100              : 
     101            0 :   uint32_t offset = 0;
     102            0 :   std::vector<::vk::VertexInputAttributeDescription> attrs;
     103            0 :   for (const auto& input : input_variables) {
     104            0 :     if (input->location == UINT32_MAX) continue;
     105            0 :     attrs.emplace_back(input->location, 0,
     106            0 :                        static_cast<::vk::Format>(input->format), offset);
     107            0 :     offset += (input->numeric.scalar.width / 8) *
     108            0 :               input->numeric.vector.component_count;
     109              :   }
     110              : 
     111            0 :   return attrs;
     112            0 : }
     113              : 
     114            0 : auto ShaderModule::get_descriptor_set_layout_bindings() const
     115              :     -> std::vector<::vk::DescriptorSetLayoutBinding> {
     116            0 :   uint32_t count = 0;
     117            0 :   reflection_->EnumerateDescriptorSets(&count, nullptr);
     118            0 :   std::vector<SpvReflectDescriptorSet*> spv_sets(count);
     119            0 :   reflection_->EnumerateDescriptorSets(&count, spv_sets.data());
     120              : 
     121            0 :   std::vector<::vk::DescriptorSetLayoutBinding> layouts;
     122            0 :   for (SpvReflectDescriptorSet* set : spv_sets) {
     123            0 :     std::span<SpvReflectDescriptorBinding*> bindings(set->bindings,
     124            0 :                                                      set->binding_count);
     125              : 
     126            0 :     for (SpvReflectDescriptorBinding* binding : bindings) {
     127            0 :       layouts.emplace_back(
     128            0 :           binding->binding,
     129            0 :           static_cast<::vk::DescriptorType>(binding->descriptor_type),
     130            0 :           binding->count, ::vk::ShaderStageFlagBits::eVertex);
     131              :     }
     132              :   }
     133              : 
     134            0 :   return layouts;
     135            0 : }
     136              : 
     137            0 : auto Shader::create(const ::vk::Device& device,
     138              :                     const std::filesystem::path& shader_path) -> expected<Ptr> {
     139            0 :   Slang::ComPtr<slang::IGlobalSession> slang_session;
     140            0 :   slang::createGlobalSession(slang_session.writeRef());
     141            0 :   if (slang_session == nullptr) {
     142            0 :     log::error("Failed to create slang global session");
     143            0 :   }
     144              : 
     145            0 :   slang::TargetDesc target_desc{
     146              :       .format = SLANG_SPIRV,
     147            0 :       .profile = slang_session->findProfile("spriv_1_5"),
     148              :   };
     149              : 
     150            0 :   slang::SessionDesc session_desc{
     151              :       .targets = &target_desc,
     152              :       .targetCount = 1,
     153              :       .compilerOptionEntryCount = 0,
     154              :   };
     155              : 
     156            0 :   const auto& output_diagnostics = [](slang::IBlob* diag) {
     157            0 :     if (diag == nullptr) return;
     158            0 :     log::error("{}", static_cast<const char*>(diag->getBufferPointer()));
     159            0 :   };
     160              : 
     161            0 :   Slang::ComPtr<slang::ISession> session;
     162            0 :   slang_session->createSession(session_desc, session.writeRef());
     163            0 :   if (slang_session == nullptr) {
     164            0 :     log::error("Failed to create slang session");
     165            0 :   }
     166              : 
     167              :   const auto compile_shader =
     168            0 :       [&](const std::string& entry_point_name) -> expected<ShaderModule> {
     169            0 :     slang::IModule* slang_module = nullptr;
     170              :     {
     171            0 :       Slang::ComPtr<slang::IBlob> diagnostic_blob;
     172            0 :       slang_module = session->loadModule(shader_path.string().c_str(),
     173            0 :                                          diagnostic_blob.writeRef());
     174            0 :       if (slang_module == nullptr) {
     175            0 :         log::error("Failed to load shader module:\n");
     176            0 :         output_diagnostics(diagnostic_blob.get());
     177            0 :       }
     178            0 :     }
     179              : 
     180            0 :     Slang::ComPtr<slang::IEntryPoint> entry_point;
     181            0 :     slang_module->findEntryPointByName(entry_point_name.c_str(),
     182            0 :                                        entry_point.writeRef());
     183            0 :     if (entry_point == nullptr)
     184            0 :       log::error("Failed to find entry point: {}", entry_point_name);
     185              : 
     186            0 :     std::array<slang::IComponentType*, 2> component_type = {slang_module,
     187            0 :                                                             entry_point};
     188              : 
     189            0 :     Slang::ComPtr<slang::IComponentType> composed_program;
     190              :     {
     191            0 :       Slang::ComPtr<slang::IBlob> diagnostics_blob;
     192            0 :       SlangResult result = session->createCompositeComponentType(
     193            0 :           component_type.data(), component_type.size(),
     194            0 :           composed_program.writeRef());
     195            0 :       if (result != SLANG_OK) {
     196            0 :         output_diagnostics(diagnostics_blob.get());
     197            0 :       }
     198            0 :     }
     199              : 
     200            0 :     Slang::ComPtr<slang::IBlob> spirv_code;
     201              :     {
     202            0 :       Slang::ComPtr<slang::IBlob> diagnostics_blob;
     203            0 :       SlangResult result = composed_program->getEntryPointCode(
     204            0 :           0, 0, spirv_code.writeRef(), diagnostics_blob.writeRef());
     205            0 :       if (result != SLANG_OK) {
     206            0 :         output_diagnostics(diagnostics_blob.get());
     207            0 :       }
     208            0 :     }
     209              : 
     210            0 :     ::vk::ShaderModuleCreateInfo create_info(
     211            0 :         {}, spirv_code->getBufferSize(),
     212            0 :         static_cast<const uint32_t*>(spirv_code->getBufferPointer()));
     213            0 :     VK_TRY_RESULT(module, device.createShaderModule(create_info));
     214              : 
     215            0 :     std::vector<uint32_t> spirv;
     216              :     {
     217            0 :       size_t count = spirv_code->getBufferSize() / sizeof(uint32_t);
     218            0 :       spirv.resize(count);
     219            0 :       std::memcpy(spirv.data(), spirv_code->getBufferPointer(),
     220            0 :                   count * sizeof(uint32_t));
     221              :     }
     222              : 
     223            0 :     return ShaderModule{spirv, module};
     224            0 :   };
     225              : 
     226            0 :   const auto shader = std::make_shared<Shader>();
     227              : 
     228            0 :   TRY_RESULT(const auto& v_module, compile_shader("vertexMain"));
     229            0 :   shader->vertex_shader(v_module);
     230              : 
     231            0 :   TRY_RESULT(const auto& f_module, compile_shader("fragmentMain"));
     232            0 :   shader->fragment_shader(f_module);
     233              : 
     234            0 :   return shader;
     235            0 : }
     236              : 
     237            0 : auto Shader::create_graphics_pipeline(const ::vk::Device& device,
     238              :                                       const ::vk::RenderPass& render_pass,
     239              :                                       const math::Vec2f& size, bool depth)
     240              :     -> expected<void> {
     241            0 :   ::vk::Result res = ::vk::Result::eSuccess;
     242              : 
     243              :   // Descriptor Sets
     244              :   const auto bindings =
     245            0 :       vertex_shader_module_.get_descriptor_set_layout_bindings();
     246            0 :   ::vk::DescriptorSetLayoutCreateInfo dl_create_info(
     247            0 :       ::vk::DescriptorSetLayoutCreateFlagBits::ePushDescriptorKHR, bindings);
     248              : 
     249            0 :   VK_TIE_ERR_PROP(descriptor_layout_,
     250              :                   device.createDescriptorSetLayout(dl_create_info));
     251              : 
     252            0 :   ::vk::PipelineLayoutCreateInfo layout_create({}, descriptor_layout_);
     253            0 :   std::tie(res, pipeline_layout_) = device.createPipelineLayout(layout_create);
     254            0 :   if (res != ::vk::Result::eSuccess)
     255            0 :     return std::unexpected(make_error_code(res));
     256              : 
     257              :   // Dynamic states
     258            0 :   std::array dynamic_states = {::vk::DynamicState::eViewport,
     259              :                                ::vk::DynamicState::eScissor};
     260            0 :   ::vk::PipelineDynamicStateCreateInfo dynamic_state({}, dynamic_states);
     261              : 
     262              :   // Input binding/attributes
     263            0 :   const auto input_bindings = vertex_shader_module_.get_vertex_input_bindings();
     264              :   const auto input_attributes =
     265            0 :       vertex_shader_module_.get_vertex_input_attributes();
     266              : 
     267            0 :   ::vk::PipelineVertexInputStateCreateInfo vertex_input_info{
     268            0 :       {}, input_bindings, input_attributes};
     269              : 
     270            0 :   ::vk::PipelineInputAssemblyStateCreateInfo input_assembly(
     271            0 :       {}, ::vk::PrimitiveTopology::eTriangleList, false);
     272              : 
     273              :   // Viewport
     274            0 :   ::vk::Viewport viewport{
     275            0 :       0, 0, static_cast<float>(size.x()), static_cast<float>(size.y()), 1, 0};
     276            0 :   ::vk::Rect2D scissor{
     277            0 :       {0, 0},
     278            0 :       {static_cast<uint32_t>(size.x()), static_cast<uint32_t>(size.y())}};
     279            0 :   ::vk::PipelineViewportStateCreateInfo viewport_state{{}, viewport, scissor};
     280              : 
     281            0 :   ::vk::PipelineRasterizationStateCreateInfo rasterization(
     282            0 :       {}, false, false, ::vk::PolygonMode::eFill, ::vk::CullModeFlagBits::eNone,
     283              :       ::vk::FrontFace::eCounterClockwise, false, {}, {}, {}, 1.0f);
     284              : 
     285            0 :   ::vk::PipelineMultisampleStateCreateInfo multisample{
     286            0 :       {}, ::vk::SampleCountFlagBits::e1, false};
     287              : 
     288              :   // Colour blending
     289            0 :   ::vk::PipelineColorBlendAttachmentState colour_blend_attachment{
     290              :       true,
     291              :       ::vk::BlendFactor::eSrcAlpha,
     292              :       ::vk::BlendFactor::eOneMinusSrcAlpha,
     293              :       ::vk::BlendOp::eAdd,
     294              :       ::vk::BlendFactor::eOne,
     295              :       ::vk::BlendFactor::eZero,
     296              :       ::vk::BlendOp::eAdd};
     297            0 :   colour_blend_attachment.setColorWriteMask(
     298            0 :       ::vk::ColorComponentFlagBits::eR | ::vk::ColorComponentFlagBits::eG |
     299            0 :       ::vk::ColorComponentFlagBits::eB | ::vk::ColorComponentFlagBits::eA);
     300            0 :   ::vk::PipelineColorBlendStateCreateInfo colour_blend(
     301            0 :       {}, false, ::vk::LogicOp::eCopy, colour_blend_attachment,
     302            0 :       {0.0, 0.0, 0.0, 0.0});
     303              : 
     304              :   // Depth / Stencil
     305            0 :   ::vk::PipelineDepthStencilStateCreateInfo depth_state(
     306            0 :       {}, depth, depth, ::vk::CompareOp::eGreaterOrEqual);
     307              : 
     308              :   // Stages
     309            0 :   const auto v_stage_create_info = ::vk::PipelineShaderStageCreateInfo(
     310            0 :       {}, ::vk::ShaderStageFlagBits::eVertex, vertex_shader_module_.module(),
     311              :       "main");
     312            0 :   const auto f_stage_create_info = ::vk::PipelineShaderStageCreateInfo(
     313            0 :       {}, ::vk::ShaderStageFlagBits::eFragment,
     314            0 :       fragment_shader_module_.module(), "main");
     315            0 :   std::array shader_stages = {v_stage_create_info, f_stage_create_info};
     316              : 
     317            0 :   auto create_info = ::vk::GraphicsPipelineCreateInfo(
     318            0 :       {}, shader_stages, &vertex_input_info, &input_assembly, {},
     319              :       &viewport_state, &rasterization, &multisample, &depth_state,
     320            0 :       &colour_blend, &dynamic_state, pipeline_layout_, render_pass);
     321              : 
     322            0 :   std::tie(res, pipeline_) = device.createGraphicsPipeline({}, create_info);
     323            0 :   if (res != ::vk::Result::eSuccess)
     324            0 :     return std::unexpected(make_error_code(res));
     325              : 
     326            0 :   return {};
     327            0 : }
     328              : 
     329            0 : auto Shader::read_wren_shader_file(const std::filesystem::path& path)
     330              :     -> expected<std::map<ShaderType, std::string>> {
     331            0 :   const auto shader_file = utils::fs::read_file_to_string(path);
     332              : 
     333            0 :   utils::StringReader reader(shader_file);
     334              : 
     335            0 :   std::map<ShaderType, std::string> shaders;
     336            0 :   while (!reader.at_end()) {
     337            0 :     reader.skip_to_text_end("##type ");
     338            0 :     const auto shader_type = reader.read_to_end_line();
     339              : 
     340            0 :     const auto shader_content = reader.read_to_text_start("##type ");
     341              : 
     342            0 :     shaders.emplace(
     343            0 :         utils::string_to_enum<ShaderType>(shader_type, true).value(),
     344              :         shader_content);
     345              :   }
     346              : 
     347            0 :   return shaders;
     348            0 : }
     349              : 
     350              : }  // namespace wren::vk
        

Generated by: LCOV version 2.3.2-1