diff --git a/tools/spec.py b/tools/spec.py index dce02ab..787f879 100644 --- a/tools/spec.py +++ b/tools/spec.py @@ -220,6 +220,9 @@ class bitmask(type): class handle(type): + parents: List[str] + type: str + def __init__(self, node): assert node.tag == 'type' assert node.attrib['category'] == 'handle' @@ -229,7 +232,11 @@ class handle(type): super().__init__(name, depends=[type]) self.type = type - self.parent = node.attrib.get('parent', None) + + parents = node.attrib.get('parent', None) + self.parents = parents.split(',') if parents else [] + + assert type def declare(self): return "struct %(name)s_t; using %(name)s = %(name)s_t*;" % { @@ -244,12 +251,18 @@ class handle(type): assert name assert reg - if not self.parent: - return False - if self.parent == name: + if self.name == name: return True - return reg.types[self.parent].has_parent(name, reg) + if not self.parents: + return False + if name in self.parents: + return True + + for p in self.parents: + if reg.types[p].has_parent(name, reg): + return True + return False class enum(type): @@ -432,7 +445,7 @@ class command(type): self.depends += p.depends def declare(self): - return "%(result)s %(name)s (%(params)s) noexcept;" % { + return 'extern "C" %(result)s %(name)s (%(params)s) noexcept;' % { 'name': rename(self.name), 'result': self.result, 'params': ", ".join(p.param for p in self.params) @@ -773,13 +786,45 @@ if __name__ == '__main__': with open(args.dst, 'w') as dst: dst.write("#pragma once\n") - dst.write('extern "C" {\n') for obj in q: dst.write(obj.declare()) dst.write('\n') dst.write(obj.define(reg)) dst.write('\n') - dst.write('}\n') + + + dst.write(""" + #include + + template + struct is_instance: + public std::false_type + {}; + + template + struct is_device: + public std::false_type + {}; + + + template + constexpr auto is_instance_v = is_instance::value; + + template + constexpr auto is_device_v = is_device::value; + """) + + for obj in q: + if not isinstance(obj,handle): + continue + + device_value = "true_type" if obj.has_parent("VkDevice", reg) else "false_type" + instance_value = "true_type" if obj.has_parent("VkInstance", reg) else "false_type" + + dst.write(f""" + template <> struct is_instance<{obj.name}>: public std::{instance_value} {{ }}; + template <> struct is_device<{obj.name}>: public std::{device_value} {{ }}; + """) with open(args.icd, 'w') as icd: commands = [i for i in q if isinstance(i, command)] @@ -798,7 +843,7 @@ if __name__ == '__main__': #define MAP_DEVICE_COMMANDS(FUNC) MAP0(FUNC,{",".join(i.name for i in device_commands)}) namespace cruft::vk::icd {{ - struct vendor; + class vendor; struct func {{ void *handle; @@ -828,7 +873,7 @@ if __name__ == '__main__': with open(args.dispatch, 'w') as dispatch: dispatch.write(""" - #include "vk.hpp" + #include "../vk.hpp" #include "vtable.hpp" #include @@ -840,20 +885,37 @@ if __name__ == '__main__': """) for obj in commands: - if obj.is_instance(reg): - table = "i_table" - elif obj.is_device(reg): - table = "d_table" + first_arg = reg.types[obj.params[0].type] + + if not isinstance(first_arg, handle): + dispatch.write(f""" + extern "C" {obj.result} {rename(obj.name)} ({", ".join(p.param for p in obj.params)}) noexcept {{ + unimplemented (); + }}""") + continue + + if first_arg.has_parent('VkDevice', reg): + table = "d_table"; + elif first_arg.has_parent('VkInstance', reg): + table = 'i_table' else: - raise Exception("unhandled command type") + raise Exception("Unknown param type") + dispatch.write(f""" extern "C" {obj.result} {rename(obj.name)} ({", ".join(p.param for p in obj.params)}) noexcept {{ - auto const entry = reinterpret_cast ({obj.params[0].name}); - auto const *table = reinterpret_cast (entry->table); + using first_arg_t = std::decay_t; + + if constexpr (is_instance_v) {{ + auto const entry = reinterpret_cast ({obj.params[0].name}); + auto const *table = reinterpret_cast (entry->table); - return (table->{obj.name})( - reinterpret_cast (entry->handle), - {", ".join(p.name for p in obj.params[1:])} - ); + return (table->{obj.name})( + reinterpret_cast (entry->handle) + {", ".join([''] + [p.name for p in obj.params[1:]])} + ); + }} else {{ + unimplemented (); + }} + }} """)