diff -r adda406d3960 -r 13ee5f4ab612 xmlclass/xmlclass.py --- a/xmlclass/xmlclass.py Wed Aug 07 11:47:14 2013 +0200 +++ b/xmlclass/xmlclass.py Mon Aug 26 10:55:03 2013 +0200 @@ -28,7 +28,9 @@ from types import * from xml.dom import minidom from xml.sax.saxutils import escape, unescape, quoteattr +from lxml import etree from new import classobj +from collections import OrderedDict def CreateNode(name): node = minidom.Node() @@ -591,15 +593,21 @@ def GetElementInitialValue(factory, infos): infos["elmt_type"] = FindTypeInfos(factory, infos["elmt_type"]) - if infos["minOccurs"] == 0 and infos["maxOccurs"] == 1: - if infos.has_key("default"): - return infos["elmt_type"]["extract"](infos["default"], False) - else: - return None - elif infos["minOccurs"] == 1 and infos["maxOccurs"] == 1: - return infos["elmt_type"]["initial"]() + if infos["minOccurs"] == 1: + element_name = factory.etreeNamespaceFormat % infos["name"] + if infos["type"] == SIMPLETYPE: + def initial_value(): + value = etree.Element(element_name) + value.text = (infos["elmt_type"]["generate"](infos["elmt_type"]["initial"]())) + return value + else: + def initial_value(): + value = infos["elmt_type"]["initial"]() + DefaultElementClass.__setattr__(value, "tag", element_name) + return value + return [initial_value() for i in xrange(infos["minOccurs"])] else: - return [infos["elmt_type"]["initial"]() for i in xrange(infos["minOccurs"])] + return [] def HandleError(message, raise_exception): if raise_exception: @@ -691,6 +699,10 @@ if choices_dict.has_key(choice_name): raise ValueError("'%s' element defined two times in choice" % choice_name) choices_dict[choice_name] = infos + choices_xpath = "|".join(map(lambda x: "%s:%s" % (factory.TargetNamespace, x), choices_dict.keys())) + + def GetContentChoicesXPath(): + return choices_xpath def GetContentInitial(): content_name, infos = choices[0] @@ -828,6 +840,7 @@ return { "type": COMPLEXTYPE, + "choices_xpath": GetContentChoicesXPath, "initial": GetContentInitial, "check": CheckContent, "extract": ExtractContent, @@ -901,9 +914,11 @@ self.XMLClassDefinitions = {} self.DefinedNamespaces = {} + self.NSMAP = {} self.Namespaces = {} self.SchemaNamespace = None self.TargetNamespace = None + self.etreeNamespaceFormat = "%s" self.CurrentCompilations = [] @@ -914,6 +929,8 @@ else: self.ComputedClasses = {} self.ComputedClassesInfos = {} + self.ComputedClassesLookUp = {} + self.EquivalentClassesParent = {} self.AlreadyComputed = {} def GetQualifiedNameInfos(self, name, namespace=None, canbenone=False): @@ -1016,7 +1033,9 @@ attrs[name] = infos["extract"]["default"](attr) elif namespace == "xmlns": infos = self.GetQualifiedNameInfos("anyURI", self.SchemaNamespace) - self.DefinedNamespaces[infos["extract"](attr)] = name + value = infos["extract"](attr) + self.DefinedNamespaces[value] = name + self.NSMAP[name] = value else: raise ValueError("Invalid attribute \"%s\" for member \"%s\"!" % (qualified_name, node.nodeName)) for attr in valid_attrs: @@ -1063,20 +1082,44 @@ def ParseSchema(self): pass - + + def AddEquivalentClass(self, name, base): + equivalences = self.EquivalentClassesParent.setdefault(self.etreeNamespaceFormat % name, {}) + equivalences[self.etreeNamespaceFormat % base] = True + + def AddToLookupClass(self, name, parent, typeinfos): + lookup_name = self.etreeNamespaceFormat % name + if isinstance(typeinfos, (StringType, UnicodeType)): + self.AddEquivalentClass(name, typeinfos) + typeinfos = self.etreeNamespaceFormat % typeinfos + lookup_classes = self.ComputedClassesLookUp.get(lookup_name) + if lookup_classes is None: + self.ComputedClassesLookUp[lookup_name] = (typeinfos, parent) + elif isinstance(lookup_classes, DictType): + lookup_classes[self.etreeNamespaceFormat % parent + if parent is not None else None] = typeinfos + else: + lookup_classes = {self.etreeNamespaceFormat % lookup_classes[1] + if lookup_classes[1] is not None else None: lookup_classes[0]} + lookup_classes[self.etreeNamespaceFormat % parent + if parent is not None else None] = typeinfos + self.ComputedClassesLookUp[lookup_name] = lookup_classes + def ExtractTypeInfos(self, name, parent, typeinfos): if isinstance(typeinfos, (StringType, UnicodeType)): - namespace, name = DecomposeQualifiedName(typeinfos) - infos = self.GetQualifiedNameInfos(name, namespace) + namespace, type_name = DecomposeQualifiedName(typeinfos) + if namespace == self.TargetNamespace and name != "base": + self.AddToLookupClass(name, parent, type_name) + infos = self.GetQualifiedNameInfos(type_name, namespace) if infos["type"] == COMPLEXTYPE: - name, parent = self.SplitQualifiedName(name, namespace) - result = self.CreateClass(name, parent, infos) + type_name, parent = self.SplitQualifiedName(type_name, namespace) + result = self.CreateClass(type_name, parent, infos) if result is not None and not isinstance(result, (UnicodeType, StringType)): self.Namespaces[self.TargetNamespace][result["name"]] = result return result elif infos["type"] == ELEMENT and infos["elmt_type"]["type"] == COMPLEXTYPE: - name, parent = self.SplitQualifiedName(name, namespace) - result = self.CreateClass(name, parent, infos["elmt_type"]) + type_name, parent = self.SplitQualifiedName(type_name, namespace) + result = self.CreateClass(type_name, parent, infos["elmt_type"]) if result is not None and not isinstance(result, (UnicodeType, StringType)): self.Namespaces[self.TargetNamespace][result["name"]] = result return result @@ -1141,9 +1184,12 @@ bases = [] base_infos = classinfos.get("base", None) if base_infos is not None: + namespace, base_name = DecomposeQualifiedName(base_infos) + if namespace == self.TargetNamespace: + self.AddEquivalentClass(name, base_name) result = self.ExtractTypeInfos("base", name, base_infos) if result is None: - namespace, base_name = DecomposeQualifiedName(base_infos) + namespace, base_name = DecomposeQualifiedName(base_infos) if self.AlreadyComputed.get(base_name, False): self.ComputeAfter.append((name, parent, classinfos)) if self.TargetNamespace is not None: @@ -1164,7 +1210,7 @@ if classinfos["base"] is None: raise ValueError("No class found for base type") bases.append(classinfos["base"]) - bases.append(object) + bases.append(DefaultElementClass) bases = tuple(bases) classmembers = {"__doc__": classinfos.get("doc", ""), "IsBaseClass": baseclass} @@ -1177,11 +1223,8 @@ raise ValueError("\"%s\" type is not a simple type!" % attribute["attr_type"]) attrname = attribute["name"] if attribute["use"] == "optional": - classmembers[attrname] = None classmembers["add%s"%attrname] = generateAddMethod(attrname, self, attribute) classmembers["delete%s"%attrname] = generateDeleteMethod(attrname) - else: - classmembers[attrname] = infos["initial"]() classmembers["set%s"%attrname] = generateSetMethod(attrname) classmembers["get%s"%attrname] = generateGetMethod(attrname) else: @@ -1212,24 +1255,18 @@ if infos is not None: element["elmt_type"] = infos if element["maxOccurs"] == "unbounded" or element["maxOccurs"] > 1: - classmembers[elmtname] = [] classmembers["append%s" % elmtname] = generateAppendMethod(elmtname, element["maxOccurs"], self, element) classmembers["insert%s" % elmtname] = generateInsertMethod(elmtname, element["maxOccurs"], self, element) classmembers["remove%s" % elmtname] = generateRemoveMethod(elmtname, element["minOccurs"]) classmembers["count%s" % elmtname] = generateCountMethod(elmtname) else: if element["minOccurs"] == 0: - classmembers[elmtname] = None classmembers["add%s" % elmtname] = generateAddMethod(elmtname, self, element) classmembers["delete%s" % elmtname] = generateDeleteMethod(elmtname) - elif not isinstance(element["elmt_type"], (UnicodeType, StringType)): - classmembers[elmtname] = element["elmt_type"]["initial"]() - else: - classmembers[elmtname] = None classmembers["set%s" % elmtname] = generateSetMethod(elmtname) classmembers["get%s" % elmtname] = generateGetMethod(elmtname) - classmembers["__init__"] = generateInitMethod(self, classinfos) + classmembers["_init"] = generateInitMethod(self, classinfos) classmembers["getStructure"] = generateStructureMethod(classinfos) classmembers["loadXMLTree"] = generateLoadXMLTree(self, classinfos) classmembers["generateXMLText"] = generateGenerateXMLText(self, classinfos) @@ -1241,6 +1278,7 @@ classmembers["extraAttrs"] = {} class_definition = classobj(str(classname), bases, classmembers) + setattr(class_definition, "__getattr__", generateGetattrMethod(self, class_definition, classinfos)) setattr(class_definition, "__setattr__", generateSetattrMethod(self, class_definition, classinfos)) class_infos = {"type": COMPILEDCOMPLEXTYPE, "name": classname, @@ -1255,6 +1293,9 @@ self.ComputedClasses[classname] = class_definition self.ComputedClassesInfos[classname] = class_infos + self.AddToLookupClass(name, parent, class_definition) + self.AddToLookupClass(classname, None, class_definition) + return class_infos """ @@ -1306,6 +1347,46 @@ return instance return classExtractfunction +def generateGetattrMethod(factory, class_definition, classinfos): + attributes = dict([(attr["name"], attr) for attr in classinfos["attributes"] if attr["use"] != "prohibited"]) + optional_attributes = dict([(attr["name"], True) for attr in classinfos["attributes"] if attr["use"] == "optional"]) + elements = dict([(element["name"], element) for element in classinfos["elements"]]) + + def getattrMethod(self, name): + if attributes.has_key(name): + attribute_infos = attributes[name] + attribute_infos["attr_type"] = FindTypeInfos(factory, attribute_infos["attr_type"]) + value = self.get(name) + if value is not None: + return attribute_infos["attr_type"]["extract"](value, extract=False) + elif attribute_infos.has_key("fixed"): + return attribute_infos["attr_type"]["extract"](attribute_infos["fixed"], extract=False) + return attribute_infos["attr_type"]["initial"]() + + elif elements.has_key(name): + element_infos = elements[name] + element_infos["elmt_type"] = FindTypeInfos(factory, element_infos["elmt_type"]) + if name == "content": + content = self.xpath(element_infos["elmt_type"]["choices_xpath"](), namespaces=factory.NSMAP) + if element_infos["maxOccurs"] == "unbounded" or element_infos["maxOccurs"] > 1: + return content + elif len(content) > 0: + return content[0] + return None + else: + element_name = factory.etreeNamespaceFormat % name + if element_infos["maxOccurs"] == "unbounded" or element_infos["maxOccurs"] > 1: + return self.findall(element_name) + else: + return self.find(element_name) + + elif classinfos.has_key("base"): + return classinfos["base"].__getattr__(self, name) + + return DefaultElementClass.__getattribute__(self, name) + + return getattrMethod + """ Method that generate the method for loading an xml tree by following the attributes list defined @@ -1313,31 +1394,52 @@ def generateSetattrMethod(factory, class_definition, classinfos): attributes = dict([(attr["name"], attr) for attr in classinfos["attributes"] if attr["use"] != "prohibited"]) optional_attributes = dict([(attr["name"], True) for attr in classinfos["attributes"] if attr["use"] == "optional"]) - elements = dict([(element["name"], element) for element in classinfos["elements"]]) + elements = OrderedDict([(element["name"], element) for element in classinfos["elements"]]) def setattrMethod(self, name, value): if attributes.has_key(name): - attributes[name]["attr_type"] = FindTypeInfos(factory, attributes[name]["attr_type"]) - if value is None: - if optional_attributes.get(name, False): - return object.__setattr__(self, name, None) - else: - raise ValueError("Attribute '%s' isn't optional." % name) - elif attributes[name].has_key("fixed") and value != attributes[name]["fixed"]: - raise ValueError, "Value of attribute '%s' can only be '%s'."%(name, str(attributes[name]["fixed"])) - elif attributes[name]["attr_type"]["check"](value): - return object.__setattr__(self, name, value) - else: - raise ValueError("Invalid value for attribute '%s'." % (name)) + attribute_infos = attributes[name] + attribute_infos["attr_type"] = FindTypeInfos(factory, attribute_infos["attr_type"]) + if optional_attributes.get(name, False): + default = attribute_infos.get("default", None) + if value is None or value == default: + self.attrib.pop(name) + return + elif attribute_infos.has_key("fixed"): + return + return self.set(name, attribute_infos["attr_type"]["generate"](value)) + elif elements.has_key(name): - if CheckElementValue(factory, name, elements[name], value): - return object.__setattr__(self, name, value) - else: - raise ValueError("Invalid value for attribute '%s'." % (name)) + element_infos = elements[name] + element_infos["elmt_type"] = FindTypeInfos(factory, element_infos["elmt_type"]) + element_xpath = ("%s:%s" % (factory.TargetNamespace, name) + if name != "content" + else elements["content"]["elmt_type"]["choices_xpath"]()) + + for element in self.xpath(element_xpath, namespaces=factory.NSMAP): + self.remove(element) + + if value is not None: + previous_elements_xpath = "|".join(map( + lambda x: "%s:%s" % (factory.TargetNamespace, x) + if x != "content" + else elements["content"]["elmt_type"]["choices_xpath"](), + elements.keys()[elements.keys().index(name)])) + + insertion_point = len(self.xpath(previous_elements_xpath, namespaces=factory.NSMAP)) + + if not isinstance(value, ListType): + value = [value] + + for element in reversed(value): + self.insert(insertion_point, element) + elif classinfos.has_key("base"): return classinfos["base"].__setattr__(self, name, value) + elif class_definition.__dict__.has_key(name): - return object.__setattr__(self, name, value) + return DefaultElementClass.__setattr__(self, name, value) + else: raise AttributeError("'%s' can't have an attribute '%s'." % (self.__class__.__name__, name)) @@ -1725,18 +1827,21 @@ def initMethod(self): self.extraAttrs = {} if classinfos.has_key("base"): - classinfos["base"].__init__(self) + classinfos["base"]._init(self) for attribute in classinfos["attributes"]: attribute["attr_type"] = FindTypeInfos(factory, attribute["attr_type"]) - if attribute["use"] == "required": - setattr(self, attribute["name"], attribute["attr_type"]["initial"]()) - elif attribute["use"] == "optional": - if attribute.has_key("default"): - setattr(self, attribute["name"], attribute["attr_type"]["extract"](attribute["default"], False)) - else: - setattr(self, attribute["name"], None) + if attribute["use"] == "required" and self.get(attribute["name"]) is None: + self.set(attribute["name"], attribute["attr_type"]["generate"](attribute["attr_type"]["initial"]())) for element in classinfos["elements"]: - setattr(self, element["name"], GetElementInitialValue(factory, element)) + if element["name"] != "content": + element_name = ( + etree.QName(factory.NSMAP["xhtml"], "p") + if element["type"] == ANY + else factory.etreeNamespaceFormat % element["name"]) + if self.find(element_name) is None: + initial = GetElementInitialValue(factory, element) + if initial is not None: + map(self.append, initial) return initMethod def generateSetMethod(attr): @@ -1761,9 +1866,7 @@ extract = infos["elmt_type"]["extract"] else: raise ValueError("Invalid class attribute!") - if infos.has_key("default"): - setattr(self, attr, extract(infos["default"], False)) - else: + if not infos.has_key("default"): setattr(self, attr, initial()) return addMethod @@ -1777,10 +1880,10 @@ infos["elmt_type"] = FindTypeInfos(factory, infos["elmt_type"]) attr_list = getattr(self, attr) if maxOccurs == "unbounded" or len(attr_list) < maxOccurs: - if infos["elmt_type"]["check"](value): - attr_list.append(value) + if len(attr_list) == 0: + setattr(self, attr, [value]) else: - raise ValueError("\"%s\" value isn't valid!" % attr) + attr_list[-1].addnext(value) else: raise ValueError("There can't be more than %d values in \"%s\"!" % (maxOccurs, attr)) return appendMethod @@ -1790,10 +1893,12 @@ infos["elmt_type"] = FindTypeInfos(factory, infos["elmt_type"]) attr_list = getattr(self, attr) if maxOccurs == "unbounded" or len(attr_list) < maxOccurs: - if infos["elmt_type"]["check"](value): - attr_list.insert(index, value) + if len(attr_list) == 0: + setattr(self, attr, [value]) + elif index == 0: + attr_list[0].addprevious(value) else: - raise ValueError("\"%s\" value isn't valid!" % attr) + attr_list[min(index - 1, len(attr_list) - 1)].addnext(value) else: raise ValueError("There can't be more than %d values in \"%s\"!" % (maxOccurs, attr)) return insertMethod @@ -1805,24 +1910,24 @@ def generateSetChoiceByTypeMethod(factory, choice_types): choices = dict([(choice["name"], choice) for choice in choice_types]) - def setChoiceMethod(self, type): - if not choices.has_key(type): - raise ValueError("Unknown \"%s\" choice type for \"content\"!" % type) - choices[type]["elmt_type"] = FindTypeInfos(factory, choices[type]["elmt_type"]) - new_element = choices[type]["elmt_type"]["initial"]() - self.content = {"name": type, "value": new_element} - return new_element + def setChoiceMethod(self, content_type): + if not choices.has_key(content_type): + raise ValueError("Unknown \"%s\" choice type for \"content\"!" % content_type) + choices[content_type]["elmt_type"] = FindTypeInfos(factory, choices[content_type]["elmt_type"]) + new_content = choices[content_type]["elmt_type"]["initial"]() + self.content = new_content + return new_content return setChoiceMethod def generateAppendChoiceByTypeMethod(maxOccurs, factory, choice_types): choices = dict([(choice["name"], choice) for choice in choice_types]) - def appendChoiceMethod(self, type): - if not choices.has_key(type): - raise ValueError("Unknown \"%s\" choice type for \"content\"!" % type) - choices[type]["elmt_type"] = FindTypeInfos(factory, choices[type]["elmt_type"]) + def appendChoiceMethod(self, content_type): + if not choices.has_key(content_type): + raise ValueError("Unknown \"%s\" choice type for \"content\"!" % content_type) + choices[content_type]["elmt_type"] = FindTypeInfos(factory, choices[content_type]["elmt_type"]) if maxOccurs == "unbounded" or len(self.content) < maxOccurs: - new_element = choices[type]["elmt_type"]["initial"]() - self.content.append({"name": type, "value": new_element}) + new_element = choices[content_type]["elmt_type"]["initial"]() + self.appendcontent(new_element) return new_element else: raise ValueError("There can't be more than %d values in \"content\"!" % maxOccurs) @@ -1830,13 +1935,13 @@ def generateInsertChoiceByTypeMethod(maxOccurs, factory, choice_types): choices = dict([(choice["name"], choice) for choice in choice_types]) - def insertChoiceMethod(self, index, type): - if not choices.has_key(type): - raise ValueError("Unknown \"%s\" choice type for \"content\"!" % type) - choices[type]["elmt_type"] = FindTypeInfos(factory, choices[type]["elmt_type"]) + def insertChoiceMethod(self, index, content_type): + if not choices.has_key(content_type): + raise ValueError("Unknown \"%s\" choice type for \"content\"!" % content_type) + choices[type]["elmt_type"] = FindTypeInfos(factory, choices[content_type]["elmt_type"]) if maxOccurs == "unbounded" or len(self.content) < maxOccurs: - new_element = choices[type]["elmt_type"]["initial"]() - self.content.insert(index, {"name" : type, "value" : new_element}) + new_element = choices[content_type]["elmt_type"]["initial"]() + self.insertcontent(index, new_element) return new_element else: raise ValueError("There can't be more than %d values in \"content\"!" % maxOccurs) @@ -1846,7 +1951,7 @@ def removeMethod(self, index): attr_list = getattr(self, attr) if len(attr_list) > minOccurs: - getattr(self, attr).pop(index) + self.remove(attr_list[index]) else: raise ValueError("There can't be less than %d values in \"%s\"!" % (minOccurs, attr)) return removeMethod @@ -1857,16 +1962,102 @@ return countMethod """ -This function generate the classes from a class factory -""" -def GenerateClasses(factory): +This function generate a xml parser from a class factory +""" + +class DefaultElementClass(etree.ElementBase): + toto = True + + def getLocalTag(self): + return etree.QName(self.tag).localname + + def tostring(self): + return etree.tostring(self, pretty_print=True) + +class XMLElementClassLookUp(etree.PythonElementClassLookup): + + def __init__(self, classes, class_equivalence, *args, **kwargs): + etree.PythonElementClassLookup.__init__(self, *args, **kwargs) + self.LookUpClasses = classes + self.ClassEquivalence = class_equivalence + + def GetElementClass(self, element_tag, parent_tag=None, default=DefaultElementClass): + element_class = self.LookUpClasses.get(element_tag, (default, None)) + if not isinstance(element_class, DictType): + if isinstance(element_class[0], (StringType, UnicodeType)): + return self.GetElementClass(element_class[0], default=default) + return element_class[0] + + element_with_parent_class = element_class.get(parent_tag, default) + if isinstance(element_with_parent_class, (StringType, UnicodeType)): + return self.GetElementClass(element_with_parent_class, default=default) + elif element_with_parent_class == DefaultElementClass: + for equivalent_parent in self.ClassEquivalence.get(parent_tag, {}).keys(): + return self.GetElementClass(element_tag, equivalent_parent, default) + return element_with_parent_class + + def lookup(self, document, element): + parent = element.getparent() + return self.GetElementClass(element.tag, + parent.tag if parent is not None else None) + +class XMLClassParser(etree.XMLParser): + + def __init__(self, namespaces, default_namespace_format, base_class, *args, **kwargs): + etree.XMLParser.__init__(self, *args, **kwargs) + self.DefaultNamespaceFormat = default_namespace_format + self.NSMAP = namespaces + targetNamespace = etree.QName(default_namespace_format % "d").namespace + if targetNamespace is not None: + self.RootNSMAP = { + name if targetNamespace != uri else None: uri + for name, uri in namespaces.iteritems()} + else: + self.RootNSMAP = namespaces + self.BaseClass = base_class + + def set_element_class_lookup(self, class_lookup): + etree.XMLParser.set_element_class_lookup(self, class_lookup) + self.ClassLookup = class_lookup + + def CreateRoot(self): + if self.BaseClass is not None: + return self.makeelement( + self.DefaultNamespaceFormat % self.BaseClass[0], + nsmap=self.RootNSMAP) + return None + + def GetElementClass(self, element_tag, parent_tag=None): + return self.ClassLookup.GetElementClass( + self.DefaultNamespaceFormat % element_tag, + self.DefaultNamespaceFormat % parent_tag + if parent_tag is not None else parent_tag, + None) + + def CreateElement(self, element_tag, parent_tag=None): + new_element = self.GetElementClass(element_tag, parent_tag)() + DefaultElementClass.__setattr__(new_element, "tag", self.DefaultNamespaceFormat % element_tag) + return new_element + +def GenerateParser(factory, xsdstring): ComputedClasses = factory.CreateClasses() if factory.FileName is not None and len(ComputedClasses) == 1: - UpdateXMLClassGlobals(ComputedClasses[factory.FileName]) - return ComputedClasses[factory.FileName] + ComputedClasses = ComputedClasses[factory.FileName] + BaseClass = [(name, XSDclass) for name, XSDclass in ComputedClasses.items() if XSDclass.IsBaseClass] else: - UpdateXMLClassGlobals(ComputedClasses) - return ComputedClasses + BaseClass = [] + UpdateXMLClassGlobals(ComputedClasses) + + parser = XMLClassParser( + factory.NSMAP, + factory.etreeNamespaceFormat, + BaseClass[0] if len(BaseClass) == 1 else None, + schema = etree.XMLSchema(etree.fromstring(xsdstring)), + strip_cdata = False, remove_blank_text=True) + class_lookup = XMLElementClassLookUp(factory.ComputedClassesLookUp, factory.EquivalentClassesParent) + parser.set_element_class_lookup(class_lookup) + return parser def UpdateXMLClassGlobals(classes): globals().update(classes) +