#!/usr/bin/env python3 import sys import logging from typing import List, Dict, Set import xml.etree.ElementTree as ET import re ############################################################################### def rename(name:str): return name ############################################################################### class registry: def __init__(self): self.types = {} self.extensions = {} self.features = {} self.types['API Constants'] = unscoped('API Constants') self.applied = set() def _serialise(self, name:str, queued:Set[str]): if name in queued: return [] result = [] obj = self.types[name] for d in obj.depends: if d == name: continue result += self._serialise(d, queued) assert name not in queued queued.add(name) result += [obj] return result def serialise(self, platform:str): required = [] for (_,f) in self.features.items(): required += f.apply(reg) required.append(f.name) for e in self.extensions: required += self.extensions[e].apply(self, platform) queued = set() result = [] for r in required: result += self._serialise(r, queued) return result ############################################################################### class type(object): """ The base class for all object defined in the Vulkan API. This includes (but is not limited to) types, like structures; and values, like constants. """ def __init__(self, name:str, depends:List[str] = []): self.name = name self.depends = [] + depends def depends(self): return self.depends def declare(self): return "" def define(self,reg): return "" ############################################################################### class aliastype(type): def __init__(self, name:str, target:str): super().__init__(name, depends=[target]) self.target = target def declare(self): return "using %(name)s = %(target)s;" % { "name": self.name, "target": self.target } class aliasvalue(type): def __init__(self, name:str, target:str): super().__init__(name, depends=[target]) self.target = target self.value = target def declare(self): return "constexpr auto %(name)s = %(target)s;" % { "name": self.name, "target": self.target } ##----------------------------------------------------------------------------- class placeholder(type): def __init__(self, name:str): super().__init__(name) ##----------------------------------------------------------------------------- class unscoped(type): def __init__(self, name:str): super().__init__(name) self.values = [] def declare(self): return "\n".join(t.declare() for t in self.values) def define(self,reg): return "\n".join(t.define(reg.types) for t in self.values) ############################################################################### class include(type): def __init__(self, node): assert node.tag == 'type' assert node.attrib['category'] == 'include' super().__init__(node.attrib['name']) self.directive = node.text def declare(self): return self.directive or "#include <%s>" % self.name class define(type): def __init__(self, node): assert node.tag == 'type' assert node.attrib['category'] == 'define' name = node.attrib.get('name') or node.find('name').text super().__init__(name) self.directive = "".join(node.itertext()) def declare(self): return self.directive class bitmask(type): def __init__(self,node): assert node.tag == 'type' assert node.attrib['category'] == 'bitmask' name = node.find('name').text type = node.find('type').text super().__init__(name,depends=[type]) self.type = type self.requires = node.attrib.get('requires') if self.requires: self.depends.append(self.requires) def declare(self): return "using %(name)s = %(type)s;" % { "name": self.name, "type": self.type } def define(self, reg:registry): return self.declare(); if not self.requires: return self.declare() return "using %(name)s = %(requires)s;" % { "name": self.name, "requires": self.requires } source = reg.types[self.requires] members = ["%(k)s = %(v)s" % {"k":k, "v":v.value} for (k,v) in source.values.items()] return """enum %(name)s : %(type)s { %(members)s }""" % { "name": self.name, "type": self.type, "members": ",\n".join(members) } class handle(type): def __init__(self, node): assert node.tag == 'type' assert node.attrib['category'] == 'handle' name = node.find('name').text type = node.find('type').text super().__init__(name, depends=[type]) self.type = type def declare(self): return "struct %(name)s_t; using %(name)s = %(name)s_t*;" % { "name": self.name, "type": self.type } class enum(type): def __init__(self,node): assert node.tag == 'type' assert node.attrib['category'] == 'enum' name = node.attrib['name'] super().__init__(name,depends=["VkEnum"]) self.values = {} def __setitem__(self, key:str, value): assert isinstance(value, constant) or isinstance(value, aliasvalue) self.values[key] = value def declare(self): return "" return "enum %(name)s : int32_t;" % { "name": self.name } def define(self,reg:registry): values = ("%(name)s = %(value)s" % { "name": k, "value": v.value } for (k,v) in self.values.items()) return "enum %(name)s : int32_t { %(values)s };" % { "name": self.name, "values": ", ".join(values) } class basetype(aliastype): """ Represents fundamental types that aliases of system provided types and used extensively by the base API. eg, VkBool32 """ def __init__(self, node): assert node.tag == 'type' assert node.attrib['category'] == 'basetype' super().__init__( node.find('name').text, node.find('type').text ) class funcpointer(type): def __init__(self,node): assert node.tag == 'type' assert node.attrib['category'] == 'funcpointer' name = node.find('name').text self.params = list(map(lambda x: x.text, node.findall('./type'))) self.text = "".join(node.itertext()) super().__init__(name, depends=['VkBool32']+self.params) def declare(self): return self.text class pod(type): def __init__(self,node): assert node.tag == 'type' assert node.attrib['category'] in ['struct', 'union'] super().__init__(node.attrib['name']) self._node = node self._category = node.attrib['category'] # sometimes there are enums hiding in the member fields being used as array sizes self.depends += list(e.text for e in node.findall('.//enum')) self.depends += list(t.text for t in node.findall('.//type')) self._members = [] for member in node.findall('./member'): type = member.find('type').text name = member.find('name').text comment = member.find('comment') if not comment is None: member.remove(comment) code = " ".join(member.itertext()) #code = member.iter() #code = filter(lambda x: x.tag != 'comment', code) #code = map(lambda x: x.itertext(), code) #code = map(lambda x: "".join(x), code) #code = "".join(code) self._members.append({'code': code, 'type': type, 'name': name}) def declare(self): return "%(category)s %(name)s;" % { 'category': self._category, 'name': rename(self.name) } def define(self,reg:registry): return "%(category)s %(name)s {\n%(members)s\n};" % { 'category': self._category, 'name': rename(self.name), 'members': "\n".join(m['code'] + ';' for m in self._members) } class struct(pod): def __init__(self,node): super().__init__(node) class union(pod): def __init__(self,node): super().__init__(node) class constant(type): def __init__(self,node,**kwargs): assert node.tag == 'enum' name = node.attrib['name'] super().__init__(name) if 'offset' in node.attrib: assert 'extends' in node.attrib number = int(kwargs['extnumber']) offset = int(node.attrib['offset']) self.value = 1000000000 + 1000 * number + offset if 'dir' in node.attrib: self.value *= -1 elif 'value' in node.attrib: self.value = node.attrib['value'] elif 'bitpos' in node.attrib: self.value = "1 << %s" % node.attrib['bitpos'] else: raise "Unknown constant value type" def declare(self): return "constexpr auto %(name)s = %(value)s;" % { "name": self.name, "value": self.value } class command(type): def __init__(self, node): assert node.tag == "command" proto = node.find('proto') name = proto.find('name').text super().__init__(name) self.result = proto.find('type').text self.depends += [self.result] self.params = [] for p in node.findall('./param'): self.depends.append(p.find('type').text) self.params.append("".join(p.itertext())) pass def declare(self): return "%(result)s %(name)s (%(params)s) noexcept;" % { 'name': rename(self.name), 'result': self.result, 'params': ", ".join(self.params) } class require(object): def __init__(self, root): self.values = [] self.depends = [] for node in root: if node.tag == 'enum': self.values.append(node) elif node.tag in ['command', 'type']: self.depends.append(node.attrib['name']) elif node.tag in ['comment']: pass else: raise "Unknown requires node" def apply(self,reg:registry,extnumber=None): required = [] required += self.depends for value in self.values: name = value.attrib['name'] if len(value.attrib) == 1: assert 'name' in value.attrib required.append(name) continue if not 'extends' in value.attrib: obj = constant(value) owner = reg.types['API Constants'] owner.values.append(obj) continue owner = reg.types[value.attrib['extends']] if 'alias' in value.attrib: owner[name] = aliasvalue(name, value.attrib['alias']) required.append(owner.name) elif value.tag == 'enum': owner[name] = constant(value,extnumber=extnumber or int(value.attrib.get('extnumber', '0'))) required.append(owner.name) elif value.tag == 'command': required.append(name) else: raise "Unknown type" return required class feature(type): def __init__(self, root): assert root.tag == 'feature' name = root.attrib['name'] super().__init__(name) self.requires = [] for node in root: if 'require' == node.tag: self.requires.append(require(node)) else: raise "Unhandled feature node" def define(self, reg:registry): return "#define %s" % self.name def apply(self,reg:registry): logging.info("Applying feature:", self.name, file=sys.stderr) result = [] for r in self.requires: result += r.apply(reg) return result class extension(type): def __init__(self, root): assert root.tag == 'extension' name = root.attrib['name'] super().__init__(name) if 'requires' in root.attrib: self.depends += root.attrib['requires'].split(',') self.number = int(root.attrib['number']) self.platform = root.attrib.get('platform') self.requires = [] for node in root: if node.tag == 'require': self.requires.append(require(node)) else: raise "Unknown extension node" def apply(self, reg:registry, platform:str): if self.name in reg.applied: return [] reg.applied.add(self.name) if self.platform and self.platform != platform: return [] required = [] for dep in self.depends: required = reg.extensions[dep].apply (reg, platform) logging.info("Applying extension:", self.name, file=sys.stderr) for node in self.requires: required += node.apply(reg,extnumber=self.number) return required ############################################################################### def ignore_node(types:Dict[str,type], root): pass parse_comment = ignore_node parse_vendorids = ignore_node parse_platforms = ignore_node parse_tags = ignore_node def parse_types(reg:registry, root): assert root.tag == 'types' for t in root.findall('type'): name = t.attrib.get ('name') or t.find('name').text assert name not in reg.types if 'alias' in t.attrib: name = t.attrib['name'] target = t.attrib['alias'] reg.types[name] = aliastype(name, target) continue category = t.attrib.get ('category') # if we don't have a category we should have a bare type that has a # dependency on something like a header. # # eg, 'Display' depends on 'X11/Xlib.h' if not category: reg.types[name] = placeholder (name) else: # Whitelist the known types so we don't accidentally instantiate # something whacky supported_categories = [ 'include', 'define', 'bitmask', 'basetype', 'handle', 'enum', 'funcpointer', 'struct', 'union' ] if category in supported_categories: obj = globals()[category] (t) reg.types[name] = obj else: raise 'unhandled type' if 'requires' in t.attrib: reg.types[name].depends.append(t.attrib['requires']) ##----------------------------------------------------------------------------- def parse_enums(reg:registry, root): assert root.tag == 'enums' ownername = root.attrib['name'] owner = reg.types[ownername] if ownername != 'API Constants' else reg.types for node in root.findall('./enum'): valuename = node.attrib.get('name') assert 'requires' not in node.attrib if 'alias' in node.attrib: owner[valuename] = aliasvalue(valuename,node.attrib['alias']) else: owner[valuename] = constant(node) ##----------------------------------------------------------------------------- def parse_commands(reg:registry, root): assert root.tag == 'commands' for node in root.findall('./command'): name = node.attrib.get('name') or node.find('./proto/name').text assert name not in reg.types if 'alias' in node.attrib: reg.types[name] = aliasvalue(name, node.attrib['alias']) continue reg.types[name] = command(node) ##----------------------------------------------------------------------------- def parse_feature(reg:registry, root): assert root.tag == 'feature' name = node.attrib['name'] assert name not in reg.features reg.features[name] = feature(root) reg.types[name] = reg.features[name] ##----------------------------------------------------------------------------- def parse_extensions(reg:registry, root): assert root.tag == 'extensions' for node in root.findall('./extension'): name = node.attrib['name'] assert name not in reg.extensions reg.extensions[name] = extension(node) ############################################################################### def enqueue_type(name:str, queued:Set[str], types:Dict[str,type]): if name in queued: return [] result = [] obj = types[name] for d in obj.depends: if d == name: continue result += enqueue_type(name=d, queued=queued, types=types) assert name not in queued queued.add(name) result += [obj] return result import argparse ##----------------------------------------------------------------------------- if __name__ == '__main__': logging.getLogger().setLevel(logging.WARNING) parser = argparse.ArgumentParser(description='Transform XML API specification into C++ headers') parser.add_argument('src', type=str, help='the path to the XML file to transform') parser.add_argument('dst', type=str, help='the output path for the result') args = parser.parse_args() src = open(args.src, 'r') dst = open(args.dst, 'w') tree = ET.parse(src) root = tree.getroot() reg = registry() types = {} for node in root: target = "parse_%s" % node.tag globals()[target](reg, node) reg.types['void*'] = placeholder('void*') reg.types['nullptr'] = placeholder('nullptr') reg.types['VkEnum'] = aliastype('VkEnum', 'int32_t') reg.types['VK_DEFINE_NON_DISPATCHABLE_HANDLE'] = aliastype("VK_DEFINE_NON_DISPATCHABLE_HANDLE", "uint64_t") reg.types['VK_DEFINE_HANDLE'] = aliastype("VK_DEFINE_HANDLE", "void*") reg.types['VK_NULL_HANDLE'] = aliasvalue("VK_NULL_HANDLE", "nullptr"); features = [feature(n) for n in root.findall('./feature')] features = dict((f.name,f) for f in features) #reg.extensions['VK_KHR_surface'].apply(reg, platform='xcb') extensions = ["VK_KHR_swapchain", "VK_EXT_debug_report", "VK_KHR_external_memory"] q = reg.serialise(platform='xcb') dst.write("#pragma once\n") dst.write('extern "C" {\n') for obj in q: dst.write(obj.declare()) dst.write('\n') dst.write(obj.define(reg)) dst.write('\n') dst.write('}\n') #write_types(dst, types)