import lxml.etree as ET
import sys
import os
import hashlib
import json
from lxml import html



# -----------------------------------------------------------------------------
# Get standalone.xml file path
# -----------------------------------------------------------------------------

def get_standalone_filepath():
    jboss_standalone_path = "/opt/tandbergtv/cms/conf/jboss-workflow/standalone.xml" 
    jboss_standalone_52_path = "/opt/tandbergtv/watchpoint/jboss/standalone/configuration/standalone.xml"
    
    if len(sys.argv) > 1:
        file_path = sys.argv[1]
        return file_path

    if os.path.isfile(jboss_standalone_path):
        return jboss_standalone_path

    # For workflow 5.2
    if os.path.isfile(jboss_standalone_52_path):
        return jboss_standalone_52_path



# -----------------------------------------------------------------------------
# Parse XML file
# -----------------------------------------------------------------------------

def parse_xml(file):
    try:
        # Remove blanks for pretty print to work
        parser = ET.XMLParser(remove_blank_text=True)
        T = ET.parse(file, parser)
        return T
    except Exception as ex:
        print(str(ex))
        sys.exit(1)


# -----------------------------------------------------------------------------
# Calculate MD5 hash of a file
# -----------------------------------------------------------------------------

def file_md5(file):
    try:
        fp = open(file, 'rb')
        h = hashlib.md5(fp.read()).hexdigest()
        fp.close()
        return h
    except Exception as ex:
        print(str(ex))
        sys.exit(1)


# -----------------------------------------------------------------------------
# Save XML file
# -----------------------------------------------------------------------------

def save_xml(tree, file):
    try:
        new_file = file + ".new"
        bak_file = file + ".bak"

        # Write temp new file.
        # LXML pretty print limitations:
        #   * indent = 2 and is not configurable
        #   * if you don't remove blanks when you parse XML, pretty print will not work.
        tree.write(new_file, pretty_print=True, xml_declaration=True, encoding="UTF-8")

        # Back up old file?
        old_md5 = file_md5(file)
        new_md5 = file_md5(new_file)

        # Files are different. Backup and then replace old file.
        if old_md5 != new_md5:
            os.rename(file, bak_file)
            os.rename(new_file, file)
        # Files are the same, remove temp new file.
        else:
            os.remove(new_file)

    except Exception as ex:
        print("ERROR: Could not save file '" + file + "': " + str(ex))
        sys.exit(1)


# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------

def find_element(parent, tag_name):
    L = parent.xpath("*[local-name() = $name]", name=tag_name)
    if len(L) == 0:
        return None
    else:
        return L[0]


# -----------------------------------------------------------------------------

def get_all_xmlns(T):
    L = T.xpath("//*[namespace-uri()]")
    if len(L) == 0:
        print("WARNING: Could not find any element that contains namespace.")
        return
     
    namespaces = set()
    for el_ns in L:
        name = el_ns.tag
        if name[0] == "{":
            ns, _ = name[1:].split("}")
            namespaces.add(ns)
    return namespaces


# -----------------------------------------------------------------------------

def upgrade_element_namespace(element, ns):
    data = ET.tostring(element, pretty_print=True)
    e_html = html.fromstring(data)
    e_html.attrib['xmlns'] = ns
    parent = element.getparent()
    # remove and add updated element
    parent.remove(element)
    e_xml = ET.fromstring(html.tostring(e_html, encoding='utf-8', pretty_print=True))
    parent.append(e_xml)


# -----------------------------------------------------------------------------
# Upgrade web section
# -----------------------------------------------------------------------------

def upgrade_web(T):
    print("  Upgrading WEB")
    upgrade_web_ssl(T)
    upgrade_web_virtsrv(T)


# -----------------------------------------------------------------------------

def upgrade_web_ssl(T):
    L = T.xpath("//*[local-name() = $name]", name="ssl")
    if len(L) == 0:
        print("WARNING: Could not find 'connector/ssl' element.")
        return

    for ssl in L:
        ssl.attrib["protocol"] = "TLSv1.2"
        ssl.attrib["cipher-suite"] = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"

# -----------------------------------------------------------------------------

def upgrade_web_virtsrv(T):
    L = T.xpath("//*[local-name() = $name]", name="virtual-server")
    if len(L) == 0:
        print("WARNING: Could not find 'virtual-server' element.")
        return

    for el_vs in L:
        el_rewrite = find_element(el_vs, "rewrite")
        if el_rewrite == None:
            el_rewrite = ET.Element("rewrite")
            el_rewrite.attrib["pattern"] = ".*"
            el_rewrite.attrib["substitution"] = "-"
            el_rewrite.attrib["flags"] = "F"
            el_vs.append(el_rewrite)

            el_condition = ET.Element("condition")
            el_condition.attrib["test"] = "%{REQUEST_METHOD}"
            el_condition.attrib["pattern"] = "^(TRACE|PATCH|CONNECT|PROPFIND|MKCOL)$"
            el_condition.attrib["flags"] = "NC"
            el_rewrite.append(el_condition)


# -----------------------------------------------------------------------------
# Upgrade Cache section
# -----------------------------------------------------------------------------

def upgrade_cache(T):
    print("  Upgrading Cache")
    upgrade_cache_container(T)


# -----------------------------------------------------------------------------

def upgrade_cache_container(T):
    L1 = T.xpath("//*[local-name() = $name]", name="cache-container")
    L2 = T.xpath("//*[local-name() = $name]", name="local-cache")
    if len(L1) == 0 or len(L2) == 0 :
        print("WARNING: Could not find 'cache_container' element.")
        return
    
    L = L1 + L2
    for el_vs in L:
        el_vs.attrib["statistics-enabled"] = "true"
    

# -----------------------------------------------------------------------------
# Upgrade datasources section
# -----------------------------------------------------------------------------

# Remove unused datasources
def ds_remove_unused(el_ds):
    #  Remove "java:jboss/datasources/ExampleDS"
    if el_ds.attrib["jndi-name"] == "java:jboss/datasources/ExampleDS" and el_ds.attrib["pool-name"] == "ExampleDS":
        el_ds.getparent().remove(el_ds)
    return

# -----------------------------------------------------------------------------

# Replace EDB driver with PostgreSQL
def ds_upgrade_con(el_ds):
    if el_ds.attrib["jndi-name"] != "java:jboss/datasources/datasource":
        return

    el_con = find_element(el_ds, "connection-url")
    if el_con != None:
        el_con.text = "jdbc:postgresql://dbserver:5432/ttv"

    el_drv = find_element(el_ds, "driver")
    if el_drv != None:
        el_drv.text = "postgres.Driver"

# -----------------------------------------------------------------------------

# Set pool size
def ds_set_pool_size(el_ds):
    el_pool = find_element(el_ds, "pool")
    if el_pool == None:
        return

    # Set min pool size = 5
    el_min = find_element(el_pool, "min-pool-size")
    if el_min != None:
        el_min.text = "5"

    # Set max pool size = 250
    el_max = find_element(el_pool, "max-pool-size")
    if el_max != None:
        el_max.text = "250"

# -----------------------------------------------------------------------------

# Set idle timeout
def ds_set_timeout(el_ds):
    el_timeout = find_element(el_ds, "timeout")
    if el_timeout == None:
        el_timeout = ET.Element("timeout")
        el_ds.append(el_timeout)

    el_timeout_idle = find_element(el_timeout, "idle-timeout-minutes")
    if el_timeout_idle == None:
        el_timeout_idle = ET.Element("idle-timeout-minutes")
        el_timeout.append(el_timeout_idle)

    el_timeout_idle.text = "10"


# -----------------------------------------------------------------------------

def ds_upgrade_drivers(el_dss):
    el_drivers = find_element(el_dss, "drivers")
    if el_drivers == None:
        el_drivers = ET.Element("drivers")
        el_dss.append(el_drivers)

    # Remove EDB driver entry
    # NOTE: Don't use "//*[local-name() = $name]". It will select "datasources/datasource/driver" also !!!

    L = el_drivers.xpath("*[local-name() = $name]", name="driver")
    for el_drv in L:
        if el_drv.attrib["name"] == "com.edb.Driver" or el_drv.attrib["name"] == "postgress.Driver" or el_drv.attrib["name"] == "postgres.Driver":
            el_drivers.remove(el_drv)

    # Create PostgreSQL driver entry
    el_ds_class = ET.Element("xa-datasource-class")
    el_ds_class.text = "org.postgresql.xa.PGXADataSource"

    el_drv = ET.Element("driver", name="postgres.Driver", module="org.postgresql")
    el_drv.append(el_ds_class)

    el_drivers.append(el_drv)

# -----------------------------------------------------------------------------

def upgrade_datasources(T):
    print("  Upgrading datasources")

    # Upgrade "datasources/drivers" element
    L = T.xpath("//*[local-name() = $name]", name="datasources")
    if len(L) == 0:
        print("WARNING: Could not find 'datasources' element.")
        return

    ds_upgrade_drivers(L[0])

    # Upgrade "datasources/datasource" element
    L = T.xpath("//*[local-name() = $name]", name="datasource")
    if len(L) == 0:
        print("WARNING: Could not find 'datasource' element.")
        return

    for el_ds in L:
        # Remove unused datasources
        ds_remove_unused(el_ds)

        # Replace EDB driver with PostgreSQL
        ds_upgrade_con(el_ds)

        # Set pool size
        ds_set_pool_size(el_ds)

        # Set idle timeout
        ds_set_timeout(el_ds)
        

# -----------------------------------------------------------------------------
# Upgrade ejb section
# -----------------------------------------------------------------------------

def upgrade_ejb(T):
    print("  Upgrading ejb")
    upgrade_timer_service(T)


# -----------------------------------------------------------------------------

def upgrade_timer_service(T):
    L = T.xpath("//*[local-name() = $name]", name="timer-service")
    if len(L) == 0:
        print("WARNING: Could not find 'timer-service' element.")
        return

    el_timer_service = L[0]
    el_data_stores = find_element(el_timer_service, "data-stores")
    if el_data_stores != None:
        # timer_service is already up-to-date
        return

    el_timer_service.set("default-data-store", "default-file-store")
    for child in el_timer_service:
        el_timer_service.remove(child)

    # Create children again
    el_data_stores = ET.Element("data-stores")
    el_file_data_store = ET.Element("file-data-store")
    el_file_data_store.set("name", "default-file-store")
    el_file_data_store.set("path", "timer-service-data")
    el_file_data_store.set("relative-to", "jboss.server.data.dir")
    el_data_stores.append(el_file_data_store)
    el_timer_service.append(el_data_stores)
        

# -----------------------------------------------------------------------------
# Upgrade namespace section
# -----------------------------------------------------------------------------

def upgrade_namespace(T):
    print("  Upgrading namespace")

    namespaces = get_upgrade_required_namespaces(T)

    if namespaces == None or len(namespaces) == 0:
        print("Namespace upgrading is ignored")
        return
    else:
        upgrade_subsystem_namespaces(T, namespaces)


# -----------------------------------------------------------------------------

def get_upgrade_required_namespaces(T):
    script_dir = os.path.dirname(os.path.abspath(__file__))
    xmlns_version_file = os.path.join(script_dir, "workflow-standalone-xmlns-version.json")
    
    try:
        with open(xmlns_version_file) as json_file:
            data = json.load(json_file)
    except Exception as ex:
        print("Configuration file for upgrading 'jboss-workflow/standalone.xml' namespaces not found")
        return

    if('enabled' in data and not data['enabled']):
        return
    if('xmlns' not in data):
        return

    upgrade_required_namespaces = []
    xmlns_object = data['xmlns']
    namespaces = get_all_xmlns(T)
    for k, v in xmlns_object.items():
        ns_list = [ns for ns in namespaces if k in ns]
        if len(ns_list) == 0:
            continue
        ns = ns_list[0]
        required_ns = k + ':' + v
        if ns is not required_ns:
            # tuple (old_namespace, required_namespace)
            upgrade_ns_tup = tuple([ns, required_ns])
            upgrade_required_namespaces.append(upgrade_ns_tup)
    
    return upgrade_required_namespaces


# -----------------------------------------------------------------------------

def upgrade_subsystem_namespaces(T, namespaces):
    for ns_tup in namespaces:
        xpath = u'//*{%s}subsystem' % ns_tup[0]
        el_subsystem = T.find(xpath)
        
        if el_subsystem == None :
            print("WARNING: Could not find subsystem element with namespace", ns_tup[0])
            continue     
       
        upgrade_element_namespace(el_subsystem, ns_tup[1])


# -----------------------------------------------------------------------------
# Main function
# -----------------------------------------------------------------------------

def main(file):
    print("Upgrading " + file)

    T = parse_xml(file)

    upgrade_web(T)
    upgrade_cache(T)
    upgrade_datasources(T)
    upgrade_ejb(T)
    upgrade_namespace(T)

    save_xml(T, file)

    print("Done")

# -----------------------------------------------------------------------------

if __name__ == "__main__":

    file_path = get_standalone_filepath()
    
    main(file_path)

