Line data Source code
1 : #include "wren/reflect/reflect.hpp"
2 :
3 : #include <cstdint>
4 : #include <spirv/unified1/spirv.hpp>
5 : #include <wren/utils/binary_reader.hpp>
6 :
7 : #include "wren/logging/log.hpp"
8 :
9 : namespace r = std::ranges;
10 : namespace rv = r::views;
11 :
12 : namespace wren::reflect {
13 :
14 : auto parse_string(utils::BinaryReader& reader)
15 : -> std::pair<std::string, uint32_t>;
16 :
17 0 : Reflect::Reflect(const std::span<const std::byte>& spirv) {
18 0 : utils::BinaryReader reader(spirv);
19 :
20 0 : magic_number_ = reader.read<uint32_t>();
21 0 : version_ = reader.read<uint32_t>();
22 0 : generator_ = reader.read<uint32_t>();
23 0 : bound_ = reader.read<uint32_t>();
24 0 : reader.read<uint32_t>();
25 :
26 : // Read instructions
27 0 : while (!reader.at_end()) {
28 0 : const auto op = reader.read<uint32_t>();
29 0 : const auto wordcount =
30 0 : (op >> 16) & 0xFFFF; // We already read the first word above
31 0 : const auto opcode = static_cast<spv::Op>(op & 0xFFFF);
32 :
33 0 : size_t skipped_words = 1;
34 :
35 0 : skipped_words += parse_op_code(reader, wordcount - skipped_words, opcode);
36 0 : for (size_t i = skipped_words; i < wordcount; ++i) {
37 0 : reader.read<uint32_t>();
38 0 : }
39 : }
40 :
41 0 : resolve();
42 0 : }
43 :
44 0 : void Reflect::resolve() {
45 0 : for (const auto& [id, annotations] : annotations_) {
46 0 : if (!names_.contains(id)) continue;
47 0 : const auto& name = names_.at(id);
48 0 : log::error("Finding annotation for: {}", name);
49 : }
50 0 : }
51 :
52 0 : auto parse_string(utils::BinaryReader& reader)
53 : -> std::pair<std::string, uint32_t> {
54 0 : std::string s;
55 :
56 0 : bool ended = false;
57 :
58 0 : uint32_t consumed = 0;
59 0 : while (true) {
60 0 : char c = reader.read<char>();
61 0 : s += c;
62 0 : if (c == '\0') ended = true;
63 0 : ++consumed;
64 :
65 0 : if (ended && s.size() % sizeof(uint32_t) == 0) break;
66 : }
67 :
68 0 : s.resize(strlen(s.data()));
69 :
70 0 : return {s, consumed / sizeof(uint32_t)};
71 0 : }
72 :
73 0 : auto Reflect::parse_op_code(utils::BinaryReader& reader,
74 : const uint32_t wordcount, const spv::Op& op)
75 : -> uint32_t {
76 0 : uint32_t words_consumed = 0;
77 :
78 0 : switch (op) {
79 : case spv::Op::OpCapability: {
80 0 : const auto cap [[maybe_unused]] =
81 0 : static_cast<spv::Capability>(reader.read<uint32_t>());
82 0 : ++words_consumed;
83 0 : break;
84 : }
85 : case spv::Op::OpDecorate: {
86 0 : Annotation annotation{};
87 :
88 0 : const auto& id = reader.read<uint32_t>();
89 0 : ++words_consumed;
90 :
91 0 : annotation.decoration =
92 0 : static_cast<spv::Decoration>(reader.read<uint32_t>());
93 0 : ++words_consumed;
94 :
95 0 : switch (annotation.decoration) {
96 : case spv::Decoration::DecorationLocation: {
97 0 : annotation.location = reader.read<uint32_t>();
98 0 : ++words_consumed;
99 0 : break;
100 : }
101 : case spv::Decoration::DecorationDescriptorSet: {
102 0 : annotation.descriptor_set = reader.read<uint32_t>();
103 0 : ++words_consumed;
104 0 : break;
105 : }
106 : case spv::Decoration::DecorationBinding: {
107 0 : annotation.binding = reader.read<uint32_t>();
108 0 : ++words_consumed;
109 0 : break;
110 : }
111 : default:
112 0 : break;
113 : }
114 :
115 0 : if (!annotations_.contains(id))
116 0 : annotations_.emplace(id, std::vector<Annotation>{});
117 0 : annotations_.at(id).push_back(annotation);
118 :
119 0 : break;
120 : }
121 : case spv::Op::OpEntryPoint: {
122 0 : EntryPoint entry;
123 :
124 0 : entry.execution_model =
125 0 : static_cast<spv::ExecutionModel>(reader.read<uint32_t>());
126 0 : ++words_consumed;
127 :
128 0 : entry.entry_point = reader.read<uint32_t>();
129 0 : ++words_consumed;
130 :
131 0 : const auto& [name, consumed] = parse_string(reader);
132 0 : entry.name = name;
133 0 : words_consumed += consumed;
134 :
135 0 : const uint32_t id_count = wordcount - words_consumed;
136 0 : entry.forward_references.reserve(id_count);
137 0 : for (size_t i = 0; i < id_count; ++i) {
138 0 : entry.forward_references.push_back(reader.read<uint32_t>());
139 0 : ++words_consumed;
140 0 : }
141 :
142 0 : entry_points_.push_back(entry);
143 : break;
144 0 : }
145 : case spv::Op::OpName: {
146 0 : const auto& result_id = reader.read<uint32_t>();
147 0 : ++words_consumed;
148 0 : const auto& [name, consumed] = parse_string(reader);
149 0 : words_consumed += consumed;
150 :
151 0 : names_.emplace(result_id, name);
152 : break;
153 0 : }
154 : case spv::Op::OpTypeVoid: {
155 : // const auto& result_id = reader.read<uint32_t>();
156 : // ++words_consumed;
157 : // types_.emplace(result_id, {});
158 0 : break;
159 : }
160 : case spv::Op::OpTypeInt:
161 : case spv::Op::OpTypeFloat: {
162 0 : const auto& result_id = reader.read<uint32_t>();
163 0 : ++words_consumed;
164 0 : const auto width = reader.read<uint32_t>();
165 0 : ++words_consumed;
166 :
167 0 : if (op == spv::Op::OpTypeInt) {
168 0 : const auto signedness [[maybe_unused]] = reader.read<uint32_t>();
169 0 : ++words_consumed;
170 0 : }
171 :
172 0 : types_.emplace(result_id, SimpleType{.width = width});
173 :
174 0 : break;
175 : }
176 : case spv::Op::OpTypeVector: {
177 0 : const auto& result_id = reader.read<uint32_t>();
178 0 : ++words_consumed;
179 :
180 0 : const auto& component_type = reader.read<uint32_t>();
181 0 : ++words_consumed;
182 :
183 0 : const auto& count = reader.read<uint32_t>();
184 0 : ++words_consumed;
185 :
186 0 : types_.emplace(result_id, types_.at(component_type).width * count);
187 :
188 0 : break;
189 : }
190 : case spv::Op::OpTypePointer: {
191 0 : const auto& result_id = reader.read<uint32_t>();
192 0 : ++words_consumed;
193 :
194 0 : const auto& _ = reader.read<uint32_t>(); // Storage class
195 0 : ++words_consumed;
196 :
197 0 : const auto& type = reader.read<uint32_t>();
198 0 : ++words_consumed;
199 :
200 0 : if (types_.contains(type))
201 0 : types_.emplace(result_id, types_.at(type).width);
202 0 : break;
203 : }
204 : case spv::Op::OpVariable: {
205 0 : Variable var;
206 0 : const auto& result_type = reader.read<uint32_t>();
207 0 : ++words_consumed;
208 0 : const auto& result_id = reader.read<uint32_t>();
209 0 : ++words_consumed;
210 :
211 0 : if (names_.contains(result_id)) {
212 0 : var.name = names_.at(result_id);
213 0 : }
214 :
215 0 : var.storage_class =
216 0 : static_cast<spv::StorageClass>(reader.read<uint32_t>());
217 0 : ++words_consumed;
218 :
219 0 : if (types_.contains(result_type))
220 0 : var.width = types_.at(result_type).width;
221 :
222 0 : if (!annotations_.contains(result_id)) {
223 : // Generic variable
224 0 : variables_.emplace(result_id, var);
225 0 : break;
226 : }
227 :
228 0 : const auto& get_annotation =
229 0 : [&](const spv::Decoration& dec) -> std::optional<Annotation> {
230 0 : auto res = r::find_if(annotations_.at(result_id),
231 0 : [dec](const auto& annotation) {
232 0 : return annotation.decoration == dec;
233 : });
234 0 : if (res == annotations_.at(result_id).end()) return std::nullopt;
235 0 : return *res;
236 0 : };
237 :
238 0 : const auto location = get_annotation(spv::Decoration::DecorationLocation);
239 0 : if (location.has_value()) {
240 0 : var.location = location.value().location.value();
241 0 : variables_.emplace(result_id, var);
242 0 : break;
243 : }
244 :
245 : const auto descriptor_set =
246 0 : get_annotation(spv::Decoration::DecorationDescriptorSet);
247 0 : if (descriptor_set.has_value()) {
248 0 : descriptor_sets_.emplace(result_id, std::vector<Binding>{});
249 0 : break;
250 : }
251 :
252 0 : const auto binding = get_annotation(spv::Decoration::DecorationBinding);
253 0 : if (binding.has_value()) {
254 0 : log::warn("Found binding: {}", binding->binding.value());
255 0 : break;
256 : }
257 :
258 : // TODO potentially an initializer
259 :
260 0 : break;
261 0 : }
262 : default:
263 0 : break;
264 : }
265 :
266 0 : if (words_consumed > 0) {
267 0 : if (words_consumed != wordcount)
268 0 : log::error("OpCode '{}' not finished: consumed: {} of word count: {}", op,
269 : words_consumed, wordcount);
270 0 : words_consumed += (wordcount - words_consumed);
271 0 : }
272 0 : return words_consumed;
273 0 : }
274 :
275 : } // namespace wren::reflect
|