diff --git a/ostream.cpp b/ostream.cpp index 66c12b2..a623fed 100644 --- a/ostream.cpp +++ b/ostream.cpp @@ -41,7 +41,6 @@ operator<< (std::ostream &os, VkExtent3D val) //----------------------------------------------------------------------------- -#if 0 std::ostream& operator<< (std::ostream &os, VkQueueFlags val) { @@ -55,7 +54,6 @@ operator<< (std::ostream &os, VkQueueFlags val) return os << "[ " << util::make_infix (util::make_view (names, names + cursor)) << " ]"; } -#endif /////////////////////////////////////////////////////////////////////////////// diff --git a/ostream.hpp b/ostream.hpp index 0dc5e1b..ec8c521 100644 --- a/ostream.hpp +++ b/ostream.hpp @@ -27,6 +27,7 @@ /////////////////////////////////////////////////////////////////////////////// +std::ostream& operator<< (std::ostream&, VkQueueFlags); std::ostream& operator<< (std::ostream&, VkExtent2D); std::ostream& operator<< (std::ostream&, VkExtent3D); std::ostream& operator<< (std::ostream&, VkPhysicalDeviceType); diff --git a/tools/spec.py b/tools/spec.py index 4fb9256..3c8b4cc 100755 --- a/tools/spec.py +++ b/tools/spec.py @@ -3,6 +3,34 @@ import logging import xml.etree.ElementTree as ET +import re + +############################################################################### +def camel_to_snake(name): + name = re.sub('([a-z])([A-Z])', r'\1_\2', name) + return name.lower() + + +def remove_namespace(name): + name = re.sub('^VK_', '', name) + name = re.sub('^[vV][kK]', '', name) + return name + + +############################################################################### +def rename(name): + return name + name = remove_namespace(name) + name = camel_to_snake(name) + return name + + +##----------------------------------------------------------------------------- +def rename_enum(type, value): + return value + value = rename(value) + value = re.sub("^%s_" % type, '', value) + return value ############################################################################### @@ -16,7 +44,7 @@ class type(object): def declare(self): return "" - def define(self): + def define(self,types): return "" @@ -32,8 +60,11 @@ class basetype(type): def depends(self): return [self._type] - def define(self): - return "using %s = %s;" % (self.name, self._type) + def define(self, types): + return "using %(name)s = %(type)s;" % { + 'name': rename(self.name), + 'type': self._type + } ##----------------------------------------------------------------------------- @@ -45,8 +76,9 @@ class handle(type): super().__init__(node.find('name').text) def declare(self): - return "typedef struct object_%(name)s* %(name)s;" % { 'name': self.name } - return "using %(name)s = struct object_%(name)s*;" % { 'name': self.name } + return "using %(name)s = struct _%(name)s*;" % { + 'name': rename(self.name) + } ##----------------------------------------------------------------------------- @@ -58,8 +90,11 @@ class constant(type): self._value = node.attrib['value'] - def define(self): - return "constexpr auto %s = %s;" % (self.name, self._value) + def define(self, types): + return "constexpr auto %(name)s = %(value)s;" % { + 'name': self.name, + 'value': self._value + } ##----------------------------------------------------------------------------- @@ -75,17 +110,34 @@ class bitmask(type): self._requires = node.attrib.get('requires', None) self._type = node.find('type').text - def declare(self): - return "" - def depends(self): if self._requires: return [self._type, self._requires] else: return [self._type] - def define(self): - return "using %s = %s;" % (self.name, self._type) + def declare(self): + return "" + + def define(self, types): + if not self._requires: + return "using %(name)s = %(type)s;" % { + 'name': rename(self.name), + 'type': rename(self._type) + } + + return "using %(name)s = %(requires)s;" % { + 'name': rename(self.name), + 'requires': rename(self._requires) + } + + members = types[self._requires].values + + return "enum class %(name)s : %(type)s { %(members)s };" % { + 'name': rename(self.name), + 'type': rename(self._type), + 'members': "\n".join("%(name)s = %(value)s," % x for x in members) + } ##----------------------------------------------------------------------------- @@ -112,24 +164,38 @@ class bitflag(type): def add(self, name, value=None): self.values.append({'name': name, 'value': value}) + def values(self): + return self.values + def depends(self): if self._depends: - return [self._depends] + return ['VkFlags',self._depends] else: - return [] + return ['VkFlags'] def declare(self): return "" - def define(self): + def define(self,types): values = [] for v in self.values: if 'value' in v: - values.append("%(name)s = %(value)s" % v) + values.append( + "%(name)s = %(value)s" % { + 'name': rename(v['name']), + 'value': v['value'] + } + ) else: - values.append(v['name']) + values.append(rename(v['name'])) - return "enum %s { %s };" % (self.name, ", ".join(values)) + return """ + enum %(name)s : %(vkflags)s { %(members)s }; + """ % { + 'name': rename(self.name), + 'vkflags': rename('VkFlags'), + 'members': ", ".join(values) + } ##----------------------------------------------------------------------------- @@ -154,18 +220,26 @@ class enum(type): def declare(self): return "" - def define(self): + def define(self,types): values = [] for v in self.values: if v['value']: - values.append("%(name)s = %(value)s" % v) + values.append( + "%(name)s = %(value)s" % { + 'name': rename(v['name']), + 'value': v['value'] + } + ) else: - values.append(v['name']) + values.append(rename(v['name'])) - return "enum %s { %s };" % ( - self.name, - ", ".join(values) - ) + attribute = '[[nodiscard]]' if self.name == 'VkResult' else '' + + return "enum %(attribute)s %(name)s : int32_t { %(values)s };" % ({ + 'name': rename(self.name), + 'attribute': attribute, + 'values': ", ".join(values), + }) ##----------------------------------------------------------------------------- @@ -185,7 +259,7 @@ class funcpointer(type): def declare(self): return self.text - def define(self): + def define(self,types): return ""; @@ -196,6 +270,7 @@ class pod(type): assert(node.attrib['category'] in ['struct', 'union']) super().__init__(node.attrib['name']) + self._node = node self._category = node.attrib['category'] # sometimes there are enums hiding in the member fields being used as array sizes @@ -209,7 +284,7 @@ class pod(type): 'name': x.find('name').text, # we must include a space separator otherwise we get run-on # types/names in the member definitions. - 'code': " ".join(x.itertext()) + 'code': " ".join(rename(c) for c in x.itertext()) }, node.findall('./member')) ) @@ -218,12 +293,15 @@ class pod(type): return self._depends def declare(self): - return "%s %s;" % (self._category, self.name) + return "%(category)s %(name)s;" % { + 'category': self._category, + 'name': rename(self.name) + } - def define(self): + def define(self,types): return "%(category)s %(name)s {\n%(members)s\n};" % { 'category': self._category, - 'name': self.name, + 'name': rename(self.name), 'members': "\n".join(m['code'] + ';' for m in self._members) } @@ -305,7 +383,7 @@ class command(type): def declare(self): return "%(result)s %(name)s (%(params)s) noexcept;" % { - 'name': self.name, + 'name': rename(self.name), 'result': self._result, 'params': ", ".join(self._params) } @@ -335,30 +413,31 @@ def parse_extension(types, node): ############################################################################### -def write_type(dst, types, t): +def write_type(dst, all, pending, t): logging.info("writing: %s", t.name) for d in t.depends(): - if d in types: - write_type(dst, types, types[d]) + if d in pending: + write_type(dst, all, pending, pending[d]) - if t.name in types: + if t.name in pending: dst.write(t.declare()) dst.write('\n') - dst.write(t.define()) + dst.write(t.define(all)) dst.write('\n') - del types[t.name] + del pending[t.name] ##----------------------------------------------------------------------------- +import copy def write_types(dst, types): + all = types + pending = copy.deepcopy(types) keys = list(types.keys()) for k in keys: if k in types: - write_type(dst, types, types[k]) - - + write_type(dst, all, pending, types[k]) ##----------------------------------------------------------------------------- @@ -366,6 +445,11 @@ def write_root(dst, node): dst.write (""" #ifndef __VK_HPP #define __VK_HPP + #include + + template + struct enable_bitops : public std::false_type { }; + #include #include @@ -405,7 +489,7 @@ def write_root(dst, node): // extract the value for VK_NULL_HANDLE from the XML we'll just // hard code it here. #define VK_NULL_HANDLE nullptr - + // TODO: make this correspond to a required version #define VK_VERSION_1_0 """) @@ -430,7 +514,30 @@ def write_root(dst, node): #if defined(__cplusplus) } #endif - #endif + """) + + for x in types.values(): + if isinstance(x,bitflag): + dst.write(""" + template <> + struct enable_bitops<%(name)s> : + public std::true_type + { }; + """ % { 'name': x.name } + ) + + dst.write(""" + template + std::enable_if_t::value, T> + operator| (T a, T b) + { + return T ( + static_cast> (a) | + static_cast> (b) + ); + } + + #endif """)