Line data Source code
1 : #include "shader.hpp"
2 :
3 : #include <slang-com-ptr.h>
4 : #include <slang-cpp-types.h>
5 : #include <slang.h>
6 : #include <spirv_reflect.h>
7 : #include <vulkan/vulkan_core.h>
8 :
9 : #include <algorithm>
10 : #include <cstdint>
11 : #include <vulkan/vulkan_enums.hpp>
12 : #include <vulkan/vulkan_handles.hpp>
13 : #include <wren/math/vector.hpp>
14 : #include <wren/utils/enums.hpp>
15 : #include <wren/utils/filesystem.hpp>
16 : #include <wren/utils/string_reader.hpp>
17 : #include <wren/vk/result.hpp>
18 :
19 : #include "vulkan/vulkan_structs.hpp"
20 : #include "wren/logging/log.hpp"
21 : #include "wren/utils/result.hpp"
22 :
23 : // NOLINTNEXTLINE
24 : BOOST_DESCRIBE_ENUM(
25 : SpvReflectResult, SPV_REFLECT_RESULT_SUCCESS, SPV_REFLECT_RESULT_NOT_READY,
26 : SPV_REFLECT_RESULT_ERROR_PARSE_FAILED,
27 : SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED,
28 : SPV_REFLECT_RESULT_ERROR_RANGE_EXCEEDED,
29 : SPV_REFLECT_RESULT_ERROR_NULL_POINTER,
30 : SPV_REFLECT_RESULT_ERROR_INTERNAL_ERROR,
31 : SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH,
32 : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND,
33 : SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_CODE_SIZE,
34 : SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_MAGIC_NUMBER,
35 : SPV_REFLECT_RESULT_ERROR_SPIRV_UNEXPECTED_EOF,
36 : SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE,
37 : SPV_REFLECT_RESULT_ERROR_SPIRV_SET_NUMBER_OVERFLOW,
38 : SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_STORAGE_CLASS,
39 : SPV_REFLECT_RESULT_ERROR_SPIRV_RECURSION,
40 : SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_INSTRUCTION,
41 : SPV_REFLECT_RESULT_ERROR_SPIRV_UNEXPECTED_BLOCK_DATA,
42 : SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_BLOCK_MEMBER_REFERENCE,
43 : SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ENTRY_POINT,
44 : SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_EXECUTION_MODE,
45 : SPV_REFLECT_RESULT_ERROR_SPIRV_MAX_RECURSIVE_EXCEEDED)
46 :
47 : namespace wren::vk {
48 :
49 0 : ShaderModule::ShaderModule(reflect::spirv_t spirv,
50 : const ::vk::ShaderModule& module)
51 0 : : spirv_(std::move(spirv)),
52 0 : module_(module),
53 0 : reflection_(std::make_shared<spv_reflect::ShaderModule>(this->spirv_)) {
54 0 : if (reflection_->GetResult() != SPV_REFLECT_RESULT_SUCCESS) {
55 0 : log::error("Shader reflection failed: {}", reflection_->GetResult());
56 0 : }
57 0 : }
58 :
59 0 : auto ShaderModule::get_vertex_input_bindings() const
60 : -> std::vector<::vk::VertexInputBindingDescription> {
61 0 : uint32_t count = 0;
62 0 : reflection_->EnumerateInputVariables(&count, nullptr);
63 0 : if (count == 0) return {};
64 :
65 0 : std::vector<SpvReflectInterfaceVariable*> input_variables(count);
66 0 : reflection_->EnumerateInputVariables(&count, input_variables.data());
67 :
68 0 : uint32_t offset = 0;
69 0 : std::vector<::vk::VertexInputBindingDescription> bindings;
70 0 : bindings.reserve(input_variables.size());
71 0 : for (const auto& input : input_variables) {
72 0 : if (static_cast<uint32_t>(input->built_in) != UINT32_MAX) continue;
73 :
74 0 : const auto width = input->numeric.scalar.width / 8;
75 0 : const auto count = input->numeric.vector.component_count;
76 0 : offset += width * count;
77 : }
78 :
79 0 : if (offset == 0) {
80 0 : return {};
81 : }
82 :
83 0 : return {{0, offset}};
84 0 : }
85 :
86 0 : auto ShaderModule::get_vertex_input_attributes() const
87 : -> std::vector<::vk::VertexInputAttributeDescription> {
88 0 : const auto& input_variables = [&] {
89 0 : uint32_t count = 0;
90 0 : reflection_->EnumerateInputVariables(&count, nullptr);
91 0 : std::vector<SpvReflectInterfaceVariable*> input_variables(count);
92 0 : reflection_->EnumerateInputVariables(&count, input_variables.data());
93 :
94 0 : std::ranges::sort(input_variables, [](const auto& a, const auto& b) {
95 0 : return a->location < b->location;
96 : });
97 :
98 0 : return input_variables;
99 0 : }();
100 :
101 0 : uint32_t offset = 0;
102 0 : std::vector<::vk::VertexInputAttributeDescription> attrs;
103 0 : for (const auto& input : input_variables) {
104 0 : if (input->location == UINT32_MAX) continue;
105 0 : attrs.emplace_back(input->location, 0,
106 0 : static_cast<::vk::Format>(input->format), offset);
107 0 : offset += (input->numeric.scalar.width / 8) *
108 0 : input->numeric.vector.component_count;
109 : }
110 :
111 0 : return attrs;
112 0 : }
113 :
114 0 : auto ShaderModule::get_descriptor_set_layout_bindings() const
115 : -> std::vector<::vk::DescriptorSetLayoutBinding> {
116 0 : uint32_t count = 0;
117 0 : reflection_->EnumerateDescriptorSets(&count, nullptr);
118 0 : std::vector<SpvReflectDescriptorSet*> spv_sets(count);
119 0 : reflection_->EnumerateDescriptorSets(&count, spv_sets.data());
120 :
121 0 : std::vector<::vk::DescriptorSetLayoutBinding> layouts;
122 0 : for (SpvReflectDescriptorSet* set : spv_sets) {
123 0 : std::span<SpvReflectDescriptorBinding*> bindings(set->bindings,
124 0 : set->binding_count);
125 :
126 0 : for (SpvReflectDescriptorBinding* binding : bindings) {
127 0 : layouts.emplace_back(
128 0 : binding->binding,
129 0 : static_cast<::vk::DescriptorType>(binding->descriptor_type),
130 0 : binding->count, ::vk::ShaderStageFlagBits::eVertex);
131 : }
132 : }
133 :
134 0 : return layouts;
135 0 : }
136 :
137 0 : auto Shader::create(const ::vk::Device& device,
138 : const std::filesystem::path& shader_path) -> expected<Ptr> {
139 0 : Slang::ComPtr<slang::IGlobalSession> slang_session;
140 0 : slang::createGlobalSession(slang_session.writeRef());
141 0 : if (slang_session == nullptr) {
142 0 : log::error("Failed to create slang global session");
143 0 : }
144 :
145 0 : slang::TargetDesc target_desc{
146 : .format = SLANG_SPIRV,
147 0 : .profile = slang_session->findProfile("spriv_1_5"),
148 : };
149 :
150 0 : slang::SessionDesc session_desc{
151 : .targets = &target_desc,
152 : .targetCount = 1,
153 : .compilerOptionEntryCount = 0,
154 : };
155 :
156 0 : const auto& output_diagnostics = [](slang::IBlob* diag) {
157 0 : if (diag == nullptr) return;
158 0 : log::error("{}", static_cast<const char*>(diag->getBufferPointer()));
159 0 : };
160 :
161 0 : Slang::ComPtr<slang::ISession> session;
162 0 : slang_session->createSession(session_desc, session.writeRef());
163 0 : if (slang_session == nullptr) {
164 0 : log::error("Failed to create slang session");
165 0 : }
166 :
167 : const auto compile_shader =
168 0 : [&](const std::string& entry_point_name) -> expected<ShaderModule> {
169 0 : slang::IModule* slang_module = nullptr;
170 : {
171 0 : Slang::ComPtr<slang::IBlob> diagnostic_blob;
172 0 : slang_module = session->loadModule(shader_path.string().c_str(),
173 0 : diagnostic_blob.writeRef());
174 0 : if (slang_module == nullptr) {
175 0 : log::error("Failed to load shader module:\n");
176 0 : output_diagnostics(diagnostic_blob.get());
177 0 : }
178 0 : }
179 :
180 0 : Slang::ComPtr<slang::IEntryPoint> entry_point;
181 0 : slang_module->findEntryPointByName(entry_point_name.c_str(),
182 0 : entry_point.writeRef());
183 0 : if (entry_point == nullptr)
184 0 : log::error("Failed to find entry point: {}", entry_point_name);
185 :
186 0 : std::array<slang::IComponentType*, 2> component_type = {slang_module,
187 0 : entry_point};
188 :
189 0 : Slang::ComPtr<slang::IComponentType> composed_program;
190 : {
191 0 : Slang::ComPtr<slang::IBlob> diagnostics_blob;
192 0 : SlangResult result = session->createCompositeComponentType(
193 0 : component_type.data(), component_type.size(),
194 0 : composed_program.writeRef());
195 0 : if (result != SLANG_OK) {
196 0 : output_diagnostics(diagnostics_blob.get());
197 0 : }
198 0 : }
199 :
200 0 : Slang::ComPtr<slang::IBlob> spirv_code;
201 : {
202 0 : Slang::ComPtr<slang::IBlob> diagnostics_blob;
203 0 : SlangResult result = composed_program->getEntryPointCode(
204 0 : 0, 0, spirv_code.writeRef(), diagnostics_blob.writeRef());
205 0 : if (result != SLANG_OK) {
206 0 : output_diagnostics(diagnostics_blob.get());
207 0 : }
208 0 : }
209 :
210 0 : ::vk::ShaderModuleCreateInfo create_info(
211 0 : {}, spirv_code->getBufferSize(),
212 0 : static_cast<const uint32_t*>(spirv_code->getBufferPointer()));
213 0 : VK_TRY_RESULT(module, device.createShaderModule(create_info));
214 :
215 0 : std::vector<uint32_t> spirv;
216 : {
217 0 : size_t count = spirv_code->getBufferSize() / sizeof(uint32_t);
218 0 : spirv.resize(count);
219 0 : std::memcpy(spirv.data(), spirv_code->getBufferPointer(),
220 0 : count * sizeof(uint32_t));
221 : }
222 :
223 0 : return ShaderModule{spirv, module};
224 0 : };
225 :
226 0 : const auto shader = std::make_shared<Shader>();
227 :
228 0 : TRY_RESULT(const auto& v_module, compile_shader("vertexMain"));
229 0 : shader->vertex_shader(v_module);
230 :
231 0 : TRY_RESULT(const auto& f_module, compile_shader("fragmentMain"));
232 0 : shader->fragment_shader(f_module);
233 :
234 0 : return shader;
235 0 : }
236 :
237 0 : auto Shader::create_graphics_pipeline(const ::vk::Device& device,
238 : const ::vk::RenderPass& render_pass,
239 : const math::Vec2f& size, bool depth)
240 : -> expected<void> {
241 0 : ::vk::Result res = ::vk::Result::eSuccess;
242 :
243 : // Descriptor Sets
244 : const auto bindings =
245 0 : vertex_shader_module_.get_descriptor_set_layout_bindings();
246 0 : ::vk::DescriptorSetLayoutCreateInfo dl_create_info(
247 0 : ::vk::DescriptorSetLayoutCreateFlagBits::ePushDescriptorKHR, bindings);
248 :
249 0 : VK_TIE_ERR_PROP(descriptor_layout_,
250 : device.createDescriptorSetLayout(dl_create_info));
251 :
252 0 : ::vk::PipelineLayoutCreateInfo layout_create({}, descriptor_layout_);
253 0 : std::tie(res, pipeline_layout_) = device.createPipelineLayout(layout_create);
254 0 : if (res != ::vk::Result::eSuccess)
255 0 : return std::unexpected(make_error_code(res));
256 :
257 : // Dynamic states
258 0 : std::array dynamic_states = {::vk::DynamicState::eViewport,
259 : ::vk::DynamicState::eScissor};
260 0 : ::vk::PipelineDynamicStateCreateInfo dynamic_state({}, dynamic_states);
261 :
262 : // Input binding/attributes
263 0 : const auto input_bindings = vertex_shader_module_.get_vertex_input_bindings();
264 : const auto input_attributes =
265 0 : vertex_shader_module_.get_vertex_input_attributes();
266 :
267 0 : ::vk::PipelineVertexInputStateCreateInfo vertex_input_info{
268 0 : {}, input_bindings, input_attributes};
269 :
270 0 : ::vk::PipelineInputAssemblyStateCreateInfo input_assembly(
271 0 : {}, ::vk::PrimitiveTopology::eTriangleList, false);
272 :
273 : // Viewport
274 0 : ::vk::Viewport viewport{
275 0 : 0, 0, static_cast<float>(size.x()), static_cast<float>(size.y()), 1, 0};
276 0 : ::vk::Rect2D scissor{
277 0 : {0, 0},
278 0 : {static_cast<uint32_t>(size.x()), static_cast<uint32_t>(size.y())}};
279 0 : ::vk::PipelineViewportStateCreateInfo viewport_state{{}, viewport, scissor};
280 :
281 0 : ::vk::PipelineRasterizationStateCreateInfo rasterization(
282 0 : {}, false, false, ::vk::PolygonMode::eFill, ::vk::CullModeFlagBits::eNone,
283 : ::vk::FrontFace::eCounterClockwise, false, {}, {}, {}, 1.0f);
284 :
285 0 : ::vk::PipelineMultisampleStateCreateInfo multisample{
286 0 : {}, ::vk::SampleCountFlagBits::e1, false};
287 :
288 : // Colour blending
289 0 : ::vk::PipelineColorBlendAttachmentState colour_blend_attachment{
290 : true,
291 : ::vk::BlendFactor::eSrcAlpha,
292 : ::vk::BlendFactor::eOneMinusSrcAlpha,
293 : ::vk::BlendOp::eAdd,
294 : ::vk::BlendFactor::eOne,
295 : ::vk::BlendFactor::eZero,
296 : ::vk::BlendOp::eAdd};
297 0 : colour_blend_attachment.setColorWriteMask(
298 0 : ::vk::ColorComponentFlagBits::eR | ::vk::ColorComponentFlagBits::eG |
299 0 : ::vk::ColorComponentFlagBits::eB | ::vk::ColorComponentFlagBits::eA);
300 0 : ::vk::PipelineColorBlendStateCreateInfo colour_blend(
301 0 : {}, false, ::vk::LogicOp::eCopy, colour_blend_attachment,
302 0 : {0.0, 0.0, 0.0, 0.0});
303 :
304 : // Depth / Stencil
305 0 : ::vk::PipelineDepthStencilStateCreateInfo depth_state(
306 0 : {}, depth, depth, ::vk::CompareOp::eGreaterOrEqual);
307 :
308 : // Stages
309 0 : const auto v_stage_create_info = ::vk::PipelineShaderStageCreateInfo(
310 0 : {}, ::vk::ShaderStageFlagBits::eVertex, vertex_shader_module_.module(),
311 : "main");
312 0 : const auto f_stage_create_info = ::vk::PipelineShaderStageCreateInfo(
313 0 : {}, ::vk::ShaderStageFlagBits::eFragment,
314 0 : fragment_shader_module_.module(), "main");
315 0 : std::array shader_stages = {v_stage_create_info, f_stage_create_info};
316 :
317 0 : auto create_info = ::vk::GraphicsPipelineCreateInfo(
318 0 : {}, shader_stages, &vertex_input_info, &input_assembly, {},
319 : &viewport_state, &rasterization, &multisample, &depth_state,
320 0 : &colour_blend, &dynamic_state, pipeline_layout_, render_pass);
321 :
322 0 : std::tie(res, pipeline_) = device.createGraphicsPipeline({}, create_info);
323 0 : if (res != ::vk::Result::eSuccess)
324 0 : return std::unexpected(make_error_code(res));
325 :
326 0 : return {};
327 0 : }
328 :
329 0 : auto Shader::read_wren_shader_file(const std::filesystem::path& path)
330 : -> expected<std::map<ShaderType, std::string>> {
331 0 : const auto shader_file = utils::fs::read_file_to_string(path);
332 :
333 0 : utils::StringReader reader(shader_file);
334 :
335 0 : std::map<ShaderType, std::string> shaders;
336 0 : while (!reader.at_end()) {
337 0 : reader.skip_to_text_end("##type ");
338 0 : const auto shader_type = reader.read_to_end_line();
339 :
340 0 : const auto shader_content = reader.read_to_text_start("##type ");
341 :
342 0 : shaders.emplace(
343 0 : utils::string_to_enum<ShaderType>(shader_type, true).value(),
344 : shader_content);
345 : }
346 :
347 0 : return shaders;
348 0 : }
349 :
350 : } // namespace wren::vk
|