import lxml.etree as ET
import sys
import os
import hashlib

# -----------------------------------------------------------------------------
# Parse XML file
# -----------------------------------------------------------------------------

def parse_xml(file):
    try:
        # Remove blanks for pretty print to work
        parser = ET.XMLParser(remove_blank_text=True)
        T = ET.parse(file, parser)
        return T
    except Exception as ex:
        print(str(ex))
        sys.exit(1)


# -----------------------------------------------------------------------------
# Calculate MD5 hash of a file
# -----------------------------------------------------------------------------

def file_md5(file):
    try:
        fp = open(file, 'rb')
        h = hashlib.md5(fp.read()).hexdigest()
        fp.close()
        return h
    except Exception as ex:
        print(str(ex))
        sys.exit(1)


# -----------------------------------------------------------------------------
# Save XML file
# -----------------------------------------------------------------------------

def save_xml(tree, file):
    try:
        new_file = file + ".new"
        bak_file = file + ".bak"

        # Write temp new file.
        # LXML pretty print limitations:
        #   * indent = 2 and is not configurable
        #   * if you don't remove blanks when you parse XML, pretty print will not work.
        tree.write(new_file, pretty_print=True, xml_declaration=True, encoding="UTF-8")

        # Back up old file?
        old_md5 = file_md5(file)
        new_md5 = file_md5(new_file)

        # Files are different. Backup and then replace old file.
        if old_md5 != new_md5:
            os.rename(file, bak_file)
            os.rename(new_file, file)
        # Files are the same, remove temp new file.
        else:
            os.remove(new_file)

    except Exception as ex:
        print("ERROR: Could not save file '" + file + "': " + str(ex))
        sys.exit(1)


# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------

def find_element(parent, tag_name):
    L = parent.xpath("*[local-name() = $name]", name=tag_name)
    if len(L) == 0:
        return None
    else:
        return L[0]


# -----------------------------------------------------------------------------
# Upgrade web section
# -----------------------------------------------------------------------------

def upgrade_web(T):
    print("  Upgrading WEB")
    upgrade_web_ssl(T)
    upgrade_web_virtsrv(T)


# -----------------------------------------------------------------------------

def upgrade_web_ssl(T):
    L = T.xpath("//*[local-name() = $name]", name="ssl")
    if len(L) == 0:
        print("WARNING: Could not find 'connector/ssl' element.")
        return

    for ssl in L:
        ssl.attrib["protocol"] = "TLSv1.2"
        ssl.attrib["cipher-suite"] = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"

# -----------------------------------------------------------------------------

def upgrade_web_virtsrv(T):
    L = T.xpath("//*[local-name() = $name]", name="virtual-server")
    if len(L) == 0:
        print("WARNING: Could not find 'virtual-server' element.")
        return

    for el_vs in L:
        el_rewrite = find_element(el_vs, "rewrite")
        if el_rewrite == None:
            el_rewrite = ET.Element("rewrite")
            el_rewrite.attrib["pattern"] = ".*"
            el_rewrite.attrib["substitution"] = "-"
            el_rewrite.attrib["flags"] = "F"
            el_vs.append(el_rewrite)

            el_condition = ET.Element("condition")
            el_condition.attrib["test"] = "%{REQUEST_METHOD}"
            el_condition.attrib["pattern"] = "^(TRACE|PATCH|CONNECT|PROPFIND|MKCOL)$"
            el_condition.attrib["flags"] = "NC"
            el_rewrite.append(el_condition)


# -----------------------------------------------------------------------------
# Upgrade datasources section
# -----------------------------------------------------------------------------

# Remove unused datasources
def ds_remove_unused(el_ds):
    if el_ds.attrib["jndi-name"] == "java:jboss/datasources/datasource":
        el_ds.getparent().remove(el_ds)

    if el_ds.attrib["jndi-name"] == "java:/RcsDB":
        el_ds.getparent().remove(el_ds)

# -----------------------------------------------------------------------------

# Replace EDB driver with PostgreSQL
def ds_upgrade_con(el_ds):
    el_con = find_element(el_ds, "connection-url")
    if el_con != None:
        el_con.text = "jdbc:postgresql://dbserver:5432/ttv"

    el_drv = find_element(el_ds, "driver")
    if el_drv != None:
        el_drv.text = "postgres.Driver"

# -----------------------------------------------------------------------------

# Set pool size
def ds_set_pool_size(el_ds):
    el_pool = find_element(el_ds, "pool")
    if el_pool == None:
        return

    # Set min pool size = 5
    el_min = find_element(el_pool, "min-pool-size")
    if el_min != None:
        el_min.text = "5"

    # Set max pool size = 250
    el_max = find_element(el_pool, "max-pool-size")
    if el_max != None:
        el_max.text = "250"

# -----------------------------------------------------------------------------

# Set idle timeout
def ds_set_timeout(el_ds):
    el_timeout = find_element(el_ds, "timeout")
    if el_timeout == None:
        el_timeout = ET.Element("timeout")
        el_ds.append(el_timeout)

    el_timeout_idle = find_element(el_timeout, "idle-timeout-minutes")
    if el_timeout_idle == None:
        el_timeout_idle = ET.Element("idle-timeout-minutes")
        el_timeout.append(el_timeout_idle)

    el_timeout_idle.text = "10"


# -----------------------------------------------------------------------------

def ds_upgrade_drivers(el_dss):
    el_drivers = find_element(el_dss, "drivers")
    if el_drivers == None:
        el_drivers = ET.Element("drivers")
        el_dss.append(el_drivers)

    # Remove EDB driver entry
    # NOTE: Don't use "//*[local-name() = $name]". It will select "datasources/datasource/driver" also !!!

    L = el_drivers.xpath("*[local-name() = $name]", name="driver")
    for el_drv in L:
        if el_drv.attrib["name"] == "com.edb.Driver" or el_drv.attrib["name"] == "postgress.Driver" or el_drv.attrib["name"] == "postgres.Driver":
            el_drivers.remove(el_drv)

    # Create PostgreSQL driver entry
    el_ds_class = ET.Element("xa-datasource-class")
    el_ds_class.text = "org.postgresql.xa.PGXADataSource"

    el_drv = ET.Element("driver", name="postgres.Driver", module="org.postgresql")
    el_drv.append(el_ds_class)

    el_drivers.append(el_drv)

# -----------------------------------------------------------------------------

def upgrade_datasources(T):
    print("  Upgrading datasources")

    # Upgrade "datasources/drivers" element
    L = T.xpath("//*[local-name() = $name]", name="datasources")
    if len(L) == 0:
        print("WARNING: Could not find 'datasources' element.")
        return

    ds_upgrade_drivers(L[0])

    # Upgrade "datasources/datasource" element
    L = T.xpath("//*[local-name() = $name]", name="datasource")
    if len(L) == 0:
        print("WARNING: Could not find 'datasource' element.")
        return

    for el_ds in L:
        # Remove unused datasources
        ds_remove_unused(el_ds)

        # Replace EDB driver with PostgreSQL
        ds_upgrade_con(el_ds)

        # Set pool size
        ds_set_pool_size(el_ds)

        # Set idle timeout
        ds_set_timeout(el_ds)


# -----------------------------------------------------------------------------
# Upgrade EJB section
# -----------------------------------------------------------------------------
def upgrade_ejb(T):
    print("  Upgrading EJB")
    ejb_upgrade_pool_size(T)
    add_new_ejb(T)
    update_ejb(T)
# -----------------------------------------------------------------------------
# update EJB pool size
# -----------------------------------------------------------------------------   
def ejb_upgrade_pool_size(T):
    print("  Upgrading EJB")
    L = T.xpath("//*[local-name() = $name]", name="strict-max-pool")
    if len(L) == 0:
        print("WARNING: Could not find 'bean-instance-pools/strict-max-pool' element.")
        return

    for el in L:
        el.attrib["max-pool-size"] = "80"


# -----------------------------------------------------------------------------
# Add new EJBs
# -----------------------------------------------------------------------------
def add_new_ejb(T):
    print("  Adding new EJBs")
    
    L = T.xpath("//*[local-name() = $name]", name="bindings")
    if len(L) == 0:
        print("WARNING: Could not find 'bindings' element.")
        return
        
    if len(L) != 1:
        print("WARNING: There are more than one 'bindings' element.")
        return
        
    el_bindings = L[0]
    for el in el_bindings:
        if el.attrib["name"] == "java:global/cms/SiteTitleManager/local":
            print("The new EJB for SiteTitleManager is already there.")
            return
        
    el_lkp = ET.Element("lookup")
    el_lkp.attrib["name"] = "java:global/cms/SiteTitleManager/local"
    el_lkp.attrib["lookup"] = "java:global/cms/cms_sites_impl/SiteTitleManager!com.ericsson.cms.sites.core.ISiteTitleManager"
    
    el_bindings.append(el_lkp)

def change_namespace(T, elname, oldns, newns):
    L = T.xpath("//*[local-name() = $name]", name=elname)
    if len(L) == 0:
        print("WARNING: Could not find '"+elname+"' element.")
        return
    for el in L: 
        if el.tag.startswith(u'{%s}' % oldns):
            el_parent = el.getparent()
            el_new = ET.fromstring(ET.tostring(el).replace(oldns, newns))
            if el_parent == None:
                T._setroot(el_new)
            else:
                el_parent.replace(el, el_new)
            

# -----------------------------------------------------------------------------
# update jboss domain from 1.4 to 1.8
# -----------------------------------------------------------------------------
def update_server_jboss_domain(T):
    print("  Upgrading server jboss domain")
    change_namespace(T, 'server', u'urn:jboss:domain:1.4', 'urn:jboss:domain:1.8')

# -----------------------------------------------------------------------------
# add system properties
# -----------------------------------------------------------------------------
def add_system_properties(T):
    print("  Add system properties")
    el_root = T.getroot()
    el_system_properties = find_element(el_root, 'system-properties')
    if el_system_properties == None:
        el_system_properties = ET.Element("system-properties")
        el_property = ET.Element("property")
        el_property.attrib["name"] = "org.apache.tomcat.util.http.Parameters.MAX_COUNT"
        el_property.attrib["value"] = '5000'
        el_root.insert(el_root.index(find_element(el_root, 'extensions'))+1, el_system_properties)
        el_system_properties.append(el_property)

# -----------------------------------------------------------------------------
# update jboss domain logging from 1.2 to 1.5
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_logging(T):
    print("  Upgrading server jboss domain logging")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:logging:1.2', 'urn:jboss:domain:logging:1.5')
    
# -----------------------------------------------------------------------------
# Remove log4j cmsrappender
# -----------------------------------------------------------------------------
def remove_logger_cmsrappender(T):
    print("  Remove cmsrappender size-rotating-file-handler")
    L = T.xpath("//*[local-name() = $name]", name="size-rotating-file-handler")
    if len(L) == 0:
        print("WARNING: Could not find 'size-rotating-file-handler' element.")
        return
    
    for el in L: 
        if el.attrib['name'] == 'cmsrappender':
            el.getparent().remove(el)
    
    print("  Remove com.tandbergtv.tstv logger")
    L = T.xpath("//*[local-name() = $name]", name="logger")
    if len(L) == 0:
        print("WARNING: Could not find 'logger' element.")
        return
    
    for el in L: 
        if el.attrib['category'] == 'com.tandbergtv.tstv':
            el.getparent().remove(el)
    
# -----------------------------------------------------------------------------
# update subsystem jboss domain datasources from 1.1 to 1.2
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_datasources(T):
    print("  Upgrading subsustem jboss domain datasources")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:datasources:1.1', 'urn:jboss:domain:datasources:1.2')
            
# -----------------------------------------------------------------------------
# update subsystem jboss domain ee from 1.1 to 1.2
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_ee(T):
    print("  Upgrading subsystem jboss domain ee")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:ee:1.1', 'urn:jboss:domain:ee:1.2')

# -----------------------------------------------------------------------------
# add postgresql module
# -----------------------------------------------------------------------------
def add_postgresql_module(T):
    print("  Add postgresql module")
    global_modules = T.xpath("//*[local-name() = $name]", name="global-modules")[0]
    modules = global_modules.xpath("//*[local-name() = $name]", name="module")
    for el in modules:
        if el.attrib["name"] == "org.postgresql":
            print("The new postgresql module is already there.")
            return
    postgresql_el = ET.Element("module")
    postgresql_el.attrib["name"] = "org.postgresql"
    postgresql_el.attrib["slot"] = "main"
    global_modules.append(postgresql_el)
       
# -----------------------------------------------------------------------------
# update subsystem jboss domain ejb3 from 1.4 to 1.5
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_ejb3(T):
    print("  Upgrading subsystem jboss domain ejb3")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:ejb3:1.4', 'urn:jboss:domain:ejb3:1.5')
            
# -----------------------------------------------------------------------------
# update time service
# -----------------------------------------------------------------------------
def update_timer_service(T):
    print("  Upgrading time service")
    L = T.xpath("//*[local-name() = $name]", name="timer-service")
    if len(L) == 0:
        print("WARNING: Could not find 'timer-service' element.")
        return
    for el in L: 
        if el.attrib['thread-pool-name'] == 'default':
            if not hasattr(el, 'default-data-store'):
                el.attrib['default-data-store'] = 'default-file-store'
            el_data_store = find_element(el, 'data-store')
            if el_data_store != None:
                el.remove(el_data_store)
                el_data_stores = ET.Element("data-stores")
                el.append(el_data_stores)
                el_file_data_store = ET.Element('file-data-store')
                el_file_data_store.attrib['name'] = 'default-file-store'
                el_file_data_store.attrib['path'] = 'timer-service-data'
                el_file_data_store.attrib['relative-to'] = 'jboss.server.data.dir'
                el_data_stores.append(el_file_data_store)
            
# -----------------------------------------------------------------------------
# update subsystem jboss domain infinispan from 1.4 to 1.5
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_infinispan(T):
    print("  Upgrading subsystem jboss domain infinispan")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:infinispan:1.4', 'urn:jboss:domain:infinispan:1.5')

# -----------------------------------------------------------------------------
# update cache container, add statistics-enabled="true"
# -----------------------------------------------------------------------------         
def update_cache_container(T):
    print("  Upgrading cache container")
    L = T.xpath("//*[local-name() = $name]", name="cache-container")
    if len(L) == 0:
        print("WARNING: Could not find 'cache-container' element.")
        return
    for el in L: 
        if not hasattr(el, 'statistics-enabled'):
            el.attrib['statistics-enabled'] = 'true'
        local_caches = el.xpath("//*[local-name() = $name]", name="local-cache")
        for local_cache in local_caches:
            if not hasattr(local_cache, 'statistics-enabled'):
                local_cache.attrib['statistics-enabled'] = 'true'

# -----------------------------------------------------------------------------
# update subsystem jboss domain jacorb from 1.3 to 1.4
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_jacorb(T):
    print("  Upgrading subsystem jboss domain jacorb")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:jacorb:1.3', 'urn:jboss:domain:jacorb:1.4')
    
# -----------------------------------------------------------------------------
# update subsystem jboss domain jmx from 1.2 to 1.3
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_jmx(T):
    print("  Upgrading subsystem jboss domain jmx")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:jmx:1.2', 'urn:jboss:domain:jmx:1.3')

# -----------------------------------------------------------------------------
# update subsystem jboss domain mail from 1.1 to 1.2
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_mail(T):
    print("  Upgrading subsystem jboss domain mail")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:mail:1.1', 'urn:jboss:domain:mail:1.2')
                
# -----------------------------------------------------------------------------
# update mail-session
# -----------------------------------------------------------------------------
def update_mail_session(T):
    print("  Update mail session")
    mail_session = T.xpath("//*[local-name() = $name]", name="mail-session")[0]
    if mail_session != None and not hasattr(mail_session, 'name'):
        mail_session.attrib['name'] = 'java:jboss/mail/Default'

# -----------------------------------------------------------------------------
# update subsystem jboss domain messaging from 1.3 to 1.4
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_messaging(T):
    print("  Upgrading subsystem jboss domain messaging")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:messaging:1.3', 'urn:jboss:domain:messaging:1.4')
    
# -----------------------------------------------------------------------------
# update subsystem jboss domain naming from 1.3 to 1.4
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_naming(T):
    print("  Upgrading subsystem jboss domain naming")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:naming:1.3', 'urn:jboss:domain:naming:1.4')
    
# -----------------------------------------------------------------------------
# update subsystem jboss domain remoting from 1.1 to 1.2
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_remoting(T):
    print("  Upgrading subsystem jboss domain remoting")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:remoting:1.1', 'urn:jboss:domain:remoting:1.2')
    
# -----------------------------------------------------------------------------
# update subsystem jboss domain transactions from 1.3 to 1.5
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_transactions(T):
    print("  Upgrading subsystem jboss domain transactions")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:transactions:1.3', 'urn:jboss:domain:transactions:1.5')
    
# -----------------------------------------------------------------------------
# update subsystem jboss domain web from 1.4 to 2.2
# -----------------------------------------------------------------------------
def update_subsystem_jboss_domain_web(T):
    print("  Upgrading subsystem jboss domain web")
    change_namespace(T, 'subsystem', u'urn:jboss:domain:web:1.4', 'urn:jboss:domain:web:2.2')

# -----------------------------------------------------------------------------
# Upgrade the changes from cms5.2
# -----------------------------------------------------------------------------
def upgrade_jboss_domain(T):
    update_server_jboss_domain(T)
    update_subsystem_jboss_domain_logging(T)
    update_subsystem_jboss_domain_datasources(T)
    update_subsystem_jboss_domain_ee(T)
    update_subsystem_jboss_domain_ejb3(T)
    update_subsystem_jboss_domain_infinispan(T)
    update_subsystem_jboss_domain_jacorb(T)
    update_subsystem_jboss_domain_jmx(T)
    update_subsystem_jboss_domain_mail(T)
    update_subsystem_jboss_domain_messaging(T)
    update_subsystem_jboss_domain_naming(T)
    update_subsystem_jboss_domain_remoting(T)
    update_subsystem_jboss_domain_transactions(T)
    update_subsystem_jboss_domain_web(T)
# -----------------------------------------------------------------------------
# update EJB
# -----------------------------------------------------------------------------
def update_ejb(T):
    print("  Update EJB")
    update_partnerusermanager_name(T)

# -----------------------------------------------------------------------------
# Upgrade the partnerUserManager Name
# -----------------------------------------------------------------------------
def update_partnerusermanager_name(T):
    lookups = T.xpath("//*[local-name() = $name]", name="lookup")
    if not lookups:
        print("WARNING: Could not find 'lookup' element.")
    for lookup in lookups:
        if lookup.attrib["name"] == 'java:global/PartnerServices/PartnerUserManager':
            lookup.attrib["name"] = 'java:global/cms/PartnerUserManager/local'
            lookup.attrib["lookup"] = 'java:global/cms/cms_contentmgmt/PartnerUserManagerProxy'

# -----------------------------------------------------------------------------
# Main function
# -----------------------------------------------------------------------------

def main(file):
    print("Upgrading " + file)

    T = parse_xml(file)

    upgrade_datasources(T)
    upgrade_ejb(T)
    upgrade_web(T)
    upgrade_jboss_domain(T)
    add_system_properties(T)
    remove_logger_cmsrappender(T)
    add_postgresql_module(T)
    update_timer_service(T)
    update_cache_container(T)
    update_mail_session(T)

    save_xml(T, file)

    print("Done")

# -----------------------------------------------------------------------------

if __name__ == "__main__":

    file_path = "/opt/tandbergtv/cms/conf/jboss/standalone.xml"

    if len(sys.argv) > 1:
        file_path = sys.argv[1]

    main(file_path)

