#!/usr/bin/env python3
#+
# Copyright 2010 iXsystems, Inc.
# All rights reserved
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted providing that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# $FreeBSD$
#####################################################################

import copy
import getpass
import os
import platform
import re
import signal
import sys
import subprocess
import time
import traceback
import ipaddr
import sqlite3

# Only root can run it.
if os.geteuid() != 0:
    exit("This command must be run with root privileges.")

FREENAS_DATA_PATH = "/data"
FREENAS_DB = os.path.join(FREENAS_DATA_PATH, "freenas-v1.db")
IS_LINUX = platform.system().lower() == 'linux'

_ = lambda s: s

from middlewared.client import Client
from middlewared.utils.generate import random_string

from requests.packages.urllib3 import poolmanager
from requests.packages.urllib3.connectionpool import HTTPSConnectionPool

LAGGType = (
    ('failover', 'Failover'),
    ('lacp', 'LACP'),
    ('loadbalance', 'Load Balance'),
    ('roundrobin', 'Round Robin'),
    ('none', 'None'),
)


class NICChoices(object):
    """Populate a list of NIC choices"""
    def __init__(self, nolagg=False, novlan=False, noloopback=True, notap=True,
                 exclude_configured=True, include_vlan_parent=False,
                 exclude_unconfigured_vlan_parent=False,
                 with_alias=False, nobridge=True, noepair=True, include_lagg_parent=True):

        self.nolagg = nolagg
        self.novlan = novlan
        self.noloopback = noloopback
        self.notap = notap
        self.exclude_configured = exclude_configured
        self.include_vlan_parent = include_vlan_parent
        self.exclude_unconfigured_vlan_parent = exclude_unconfigured_vlan_parent
        self.with_alias = with_alias
        self.nobridge = nobridge
        self.noepair = noepair
        self.include_lagg_parent = include_lagg_parent

    def __iter__(self):
        pipe = os.popen("/sbin/ifconfig -l")
        self._NIClist = pipe.read().strip().split(' ')
        self._NIClist = [y for y in self._NIClist if y not in ('lo0', 'pfsync0', 'pflog0', 'ipfw0')]
        if self.noloopback is False:
            self._NIClist.append('lo0')

        # Remove internal interfaces for failover
        with Client() as c:
            if not c.call('system.is_freenas'):
                for iface in c.call('failover.internal_interfaces'):
                    if iface in self._NIClist:
                        self._NIClist.remove(iface)

        conn = sqlite3.connect('/data/freenas-v1.db')
        c = conn.cursor()
        # Remove interfaces that are parent devices of a lagg
        # Database queries are wrapped in try/except as this is run
        # before the database is created during syncdb and the queries
        # will fail
        if self.include_lagg_parent:
            try:
                c.execute("SELECT lagg_physnic FROM network_lagginterfacemembers")
            except sqlite3.OperationalError:
                pass
            else:
                for interface in c:
                    if interface[0] in self._NIClist:
                        self._NIClist.remove(interface[0])

        if self.nolagg:
            # vlan devices are not valid parents of laggs
            self._NIClist = [nic for nic in self._NIClist if not nic.startswith("lagg")]
            self._NIClist = [nic for nic in self._NIClist if not nic.startswith("vlan")]
        if self.novlan:
            self._NIClist = [nic for nic in self._NIClist if not nic.startswith("vlan")]
        else:
            # This removes devices that are parents of vlans.  We don't
            # remove these devices if we are adding a vlan since multiple
            # vlan devices may share the same parent.
            # The exception to this case is when we are getting the NIC
            # list for the GUI, in which case we want the vlan parents
            # as they may have a valid config on them.
            if not self.include_vlan_parent or self.exclude_unconfigured_vlan_parent:
                try:
                    c.execute("SELECT vlan_pint FROM network_vlan")
                except sqlite3.OperationalError:
                    pass
                else:
                    for interface in c:
                        if interface[0] in self._NIClist:
                            self._NIClist.remove(interface[0])

            if self.exclude_unconfigured_vlan_parent:
                # Add the configured VLAN parents back in
                try:
                    c.execute("SELECT vlan_pint FROM network_vlan "
                              "INNER JOIN network_interfaces ON "
                              "network_vlan.vlan_pint=network_interfaces.int_interface "
                              "WHERE network_interfaces.int_interface IS NOT NULL "
                              "AND ((network_interfaces.int_ipv4address != '' "
                              "AND network_interfaces.int_ipv4address IS NOT NULL) "
                              "OR network_interfaces.int_dhcp = 1)")
                except sqlite3.OperationalError:
                    pass
                else:
                    for interface in c:
                        if interface[0] not in self._NIClist:
                            self._NIClist.append(interface[0])

        if self.with_alias:
            try:
                sql = """
                    SELECT
                        int_interface

                    FROM
                        network_interfaces as ni

                    INNER JOIN
                        network_alias as na
                    ON
                        na.alias_interface_id = ni.id
                """
                c.execute(sql)

            except sqlite3.OperationalError:
                pass

            else:
                aliased_nics = [x[0] for x in c]
                niclist = copy.deepcopy(self._NIClist)
                for interface in niclist:
                    if interface not in aliased_nics:
                        self._NIClist.remove(interface)

        if self.exclude_configured:
            try:
                # Exclude any configured interfaces
                c.execute("SELECT int_interface FROM network_interfaces")
            except sqlite3.OperationalError:
                pass
            else:
                for interface in c:
                    if interface[0] in self._NIClist:
                        self._NIClist.remove(interface[0])

        if self.nobridge:
            self._NIClist = [nic for nic in self._NIClist if not nic.startswith("bridge")]

        if self.noepair:
            niclist = copy.deepcopy(self._NIClist)
            for nic in niclist:
                if nic.startswith('epair') or nic.startswith('vnet'):
                    self._NIClist.remove(nic)

        if self.notap:
            taplist = copy.deepcopy(self._NIClist)
            for nic in taplist:
                if nic.startswith('tap'):
                    self._NIClist.remove(nic)

        self.max_choices = len(self._NIClist)

        return iter((i, i) for i in self._NIClist)


class DjangoModelMeta(type):
    @property
    def objects(cls):
        return DjangoModelManager(cls)


class DjangoModel(metaclass=DjangoModelMeta):
    datastore = NotImplementedError

    def __init__(self, **kwargs):
        self.row = kwargs

    def __getattr__(self, item):
        try:
            return self.row[item]
        except KeyError:
            return None

    def __setattr__(self, key, value):
        if key in ['row']:
            self.__dict__[key] = value
        else:
            self.row[key] = value

    def save(self):
        data = self.row.copy()

        for k, v in data.items():
            if isinstance(v, DjangoModel):
                if v.id is None:
                    v.save()

                data[k] = v.id

        with Client() as c:
            if 'id' in self.row:
                c.call('datastore.update', self.datastore, self.row['id'], data)
            else:
                self.row['id'] = c.call('datastore.insert', self.datastore, data)

    def delete(self):
        with Client() as c:
            c.call('datastore.delete', self.datastore, self.row['id'])


class DjangoModelManager:
    def __init__(self, cls):
        self.cls = cls

    def all(self):
        return self.filter()

    def create(self, **kwargs):
        self.cls(**kwargs).save()

    def filter(self, **kwargs):
        with Client() as c:
            return DjangoModelResultSet([
                self.cls(**row)
                for row in c.call('datastore.query', self.cls.datastore, [[k, '=', v] for k, v in kwargs.items()])
            ])


class DjangoModelResultSet:
    def __init__(self, objects):
        self.objects = objects

    def exists(self):
        return bool(self.objects)

    def __bool__(self):
        return bool(self.objects)

    def __getitem__(self, item):
        return self.objects[item]

    def __iter__(self):
        return iter(self.objects)


def produce_django_model(ds):
    class Model(DjangoModel):
        datastore = ds

    return Model


LAGGInterfaceMembers = produce_django_model("network.lagginterfacemembers")
VLAN = produce_django_model("network.vlan")
GlobalConfiguration = produce_django_model("network.globalconfiguration")
StaticRoute = produce_django_model("network.staticroute")
Support = produce_django_model("system.support")


class Interfaces(DjangoModel):
    datastore = "network.interfaces"

    def __str__(self):
        if not self.int_name:
            return self.int_interface
        return str(self.int_name)


class LAGGInterface(DjangoModel):
    datastore = "network.lagginterface"

    def __init__(self, **kwargs):
        if 'lagg_interface' in kwargs and isinstance(kwargs['lagg_interface'], dict):
            kwargs['lagg_interface'] = Interfaces(**kwargs['lagg_interface'])

        super().__init__(**kwargs)


class FailoverIsEnabledException(Exception):
    pass


def ensure_failover_is_disabled():
    with Client() as c:
        if (
                not c.call('system.is_freenas') and
                c.call('failover.licensed') and
                not c.call('failover.config')['disabled']
        ):
            raise FailoverIsEnabledException()


# Custom class to avoid warning about unverified SSL, see #16474
class HTTPSConnectionPoolNoWarn(HTTPSConnectionPool):
    def _validate_conn(self, conn):
        """
        Called right before a request is made, after the socket is created
.
        """
        super(HTTPSConnectionPool, self)._validate_conn(conn)

        # Force connect early to allow us to validate the connection.
        if not getattr(conn, 'sock', None):  # AppEngine might not have  ` .sock`
            conn.connect()
poolmanager.pool_classes_by_scheme['https'] = HTTPSConnectionPoolNoWarn


def quad_to_cidr(quad):
    vals = { 0:True, 128:True, 192:True,
             224:True, 240:True, 248:True,
             252:True, 254:True, 255:True }

    count = 0
    octets = quad.split('.')
    for octet in octets:
        i = 7
        o = int(octet)
        if vals.get(o):
            while i >= 0:
                if o & (1 << i):
                    count += 1
                else:
                    return count
                i -= 1

        else:
            return False

    return count


def hex_to_cidr(_hex):
    h = _hex.replace("0x", "")
    h = h.replace("0X", "")
    h = int(h, 16)

    i = 31
    count = 0
    while h & (1 << i):
        count += 1
        i -= 1

    return count


def prompt(prompt_str, default_value=None):
    """Returns a string that forms a prompt, based on values passed in"""
    if default_value:
        default_str = ' [%s]' % (str(default_value), )
    else:
        default_str = ''
    return (prompt_str + default_str + ':')


def get_nic(choices=None):
    nic = False

    if choices is None:
        choices = NICChoices()

    while True:
        nics = []
        for i, c in enumerate(choices):
            nics.append(c[0])
            print("%d) %s" % (i + 1, nics[i]))

        _input = input(_("Select an interface (q to quit): "))
        if _input.isdigit() and int(_input) in range(1, len(nics)+1):
            nic = nics[int(_input) - 1]
            break
        elif _input.lower().startswith("q"):
            return False

    return nic


def get_lagg_proto():
    proto = False

    with Client() as c:
        is_freenas = c.call('system.is_freenas')
        failover_licensed = False
        if not is_freenas:
            failover_licensed = c.call('failover.licensed')

    while True:
        protos = { }
        for i, t in enumerate(LAGGType):
            protos[i] = t[0]
            print("%d) %s" % (i + 1, protos[i]))

        _input = input(_("Select a lagg protocol (q to quit): "))
        if _input.isdigit() and int(_input) in range(1, len(protos)+1):
            proto = protos[int(_input) - 1]
            break
        elif _input.lower().startswith("q"):
            return False

    return proto


def get_lagg_nics():
    nics = list(NICChoices(nolagg=True))
    if not nics:
        print("All interfaces are already allocated to LAGG or VLAN, cannot proceed.")
        return []

    group = []
    while nics:
        nic = get_nic(nics)
        if not nic:
            break
        group.append(nic)

        nics = [n for n in nics if n[0] != nic]

    return group


def configure_interface_stub(*args):
    retval = configure_interface()
    with Client() as c:
        c.call('service.restart', 'http')
    return retval


def configure_interface():
    while True:
        nics = []
        choices = NICChoices(include_vlan_parent=True, exclude_configured=False)
        for i, c in enumerate(choices):
            nics.append(c[0])
            print("%d) %s" % (i + 1, nics[i]))

        _input = input(_("Select an interface (q to quit): "))
        if _input.isdigit() and int(_input) in range(1, len(nics)+1):
            nic = nics[int(_input) - 1]
            break
        elif _input.lower().startswith("q"):
            return True

    iface = Interfaces.objects.filter(int_interface = nic)
    if iface:
        iface = iface[0]
    else:
        iface = Interfaces()

    while True and iface.id:
        _input = input(_("Delete interface? (y/n) ")).lower()
        if _input.startswith("y"):
            print(_("Deleting interface configuration:"), end=' ')
            iface.delete()
            print(_("Ok"))
            print(_("Restarting network:"), end=' ')
            try:
                with Client() as c:
                    c.call('service.start', 'network')
            except:
                print(_("Failed"))
            print(_("ok"))
            print(_("Restarting routing:"), end=' ')
            try:
                with Client() as c:
                    c.call('route.sync')
                    c.call('etc.generate', 'rc')
                    c.call('service.restart', 'routing')
            except:
                print(_("Failed"))
            print(_("ok"))
            return True
        elif _input.startswith('n'):
            break
        else:
            continue

    while True:
        _input = input(_("Remove the current settings of this interface? (This causes a momentary disconnection of the network.) (y/n) ")).lower()
        if _input.startswith("y"):
            print(_("Removing interface configuration:"), end=' ')
            try:
                if iface.id != None:
                    iface.int_ipv4address = ''
                    iface.int_ipv4address_b = ''
                    iface.int_v4netmaskbit = ''
                    iface.int_dhcp = False
                    iface.int_v6netmaskbit = ''
                    iface.int_ipv6address = ''
                    iface.int_ipv6auto = False
                    iface.int_vip = ''
                    iface.int_vhid = None
                    iface.int_pass = ''
                    iface.int_critical = False
                    iface.save()
            except Exception as err:
                print(_("Failed %s") % str(err))
                return False
            print(_("Ok"))
            print(_("Restarting network:"), end=' ')
            try:
                with Client() as c:
                    c.call('service.start', 'network')
            except:
                print(_("Failed"))
            print(_("ok"))
            print(_("Restarting routing:"), end=' ')
            try:
                with Client() as c:
                    c.call('route.sync')
                    c.call('etc.generate', 'rc')
                    c.call('service.restart', 'routing')
            except:
                print(_("Failed"))
            print(_("ok"))
            return True
        elif _input.startswith('n'):
            break
        else:
            continue

    need_restart = False
    while True:
        if not Interfaces.objects.filter(int_dhcp=True):
            _input = input(_("Configure interface for DHCP? (y/n) "))
            if _input.lower().startswith("y"):
                int_name_prompt = prompt(_("Interface name"), iface.int_name)
                int_name = input(int_name_prompt)
                if not int_name and iface.int_name:
                    int_name = iface.int_name
                iface.int_interface = nic
                iface.int_dhcp = True
                iface.int_ipv4address =  ''
                iface.int_ipv4address_b =  ''
                iface.int_v4netmaskbit = ''
                iface.int_v6netmaskbit = ''
                iface.int_ipv6address = ''
                iface.int_ipv6auto = False
                print(_("Saving interface configuration:"), end=' ')
                try:
                    iface.save()
                except Exception as err:
                    print(_("Failed %s") % str(err))
                    return False
                print(_("Ok"))
                need_restart = "DHCP"
                break
            elif _input.lower().startswith('n'):
                break
            else:
                continue
        else:
            break

    with Client() as c:
        is_freenas = c.call('system.is_freenas')
        failover_node = failover_licensed = False
        if not is_freenas:
            failover_licensed = c.call('failover.licensed')
            failover_node = c.call('failover.node')

    while True and need_restart != "DHCP":
        yes = input(_("Configure IPv4? (y/n) "))
        if yes.lower().startswith("y"):
            int_name_prompt = prompt(_("Interface name"), iface.int_name)
            int_name = input(int_name_prompt)
            if not int_name and iface.int_name:
                int_name = iface.int_name
            if not is_freenas and failover_licensed:
                if failover_node == 'A':
                    ip_label = "IPv4 Address (This TrueNAS Controller)"
                    ip_b_label = "IPv4 Address (TrueNAS Controller 2)"
                else:
                    ip_label = "IPv4 Address (TrueNAS Controller 1)"
                    ip_b_label = "IPv4 Address (This TrueNAS Controller)"
                ip_prompt = prompt(ip_label, iface.int_ipv4address)
                ip_b_prompt = prompt(ip_b_label, iface.int_ipv4address_b)
            else:
                ip_prompt = prompt("IPv4 Address", iface.int_ipv4address)

            mask_prompt = prompt("IPv4 Netmask", iface.int_v4netmaskbit)
            print(_("Several input formats are supported"))
            print(_("Example 1 CIDR Notation:"))
            print("    192.168.1.1/24")
            print(_("Example 2 IP and Netmask separate:"))
            print("    IP: 192.168.1.1")
            print("    Netmask: 255.255.255.0, /24 or 24")

            ip = ip_b = mask = None
            while True:
                ip = input(ip_prompt)

                if not ip and iface.int_ipv4address:
                    ip = iface.int_ipv4address

                if failover_licensed:
                    ip_b = input(ip_b_prompt)

                    if not ip_b and iface.int_ipv4address_b:
                        ip_b = iface.int_ipv4address_b

                try:
                    try:
                        ipaddr.IPNetwork(ip, version=4)
                    except ValueError:
                        print(_("Invalid value entered: %s") % ip)
                        continue
                    ip, mask = ip.split('/')

                    if failover_licensed:
                        try:
                            ipaddr.IPNetwork(ip_b, version=4)
                        except ValueError:
                            print(_("Invalid value entered: %s") % ip_b)
                            continue
                        ip_b, mask = ip_b.split('/')
                    break
                except:
                    mask = input(mask_prompt)
                    if mask.startswith("/"):
                        mask = mask.lstrip("/")
                    if not mask and iface.int_v4netmaskbit:
                        mask = iface.int_v4netmaskbit
                        break
                    elif re.match(r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9]"
                                   "[0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|"
                                   "[01]?[0-9][0-9]?)$", mask):
                        try:
                            ipaddr.IPAddress(mask, version=4)
                        except ValueError:
                            print(_("Invalid Netmask"))
                            continue
                        mask = quad_to_cidr(mask)
                        break
                    elif re.match("^(0[xX])?([0-9a-fA-F]){8}$", mask):
                        mask = hex_to_cidr(mask)
                        break
                    elif int(mask) > 0 and int(mask) < 33:
                        mask = int(mask, 10)
                        break
                    else:
                        print(_("""Enter netmask as a dotted quad, a hex number,
or CIDR prefix
Acceptable formats are 255.255.255.0,
0xffffff00,
/24, or 24"""))
                        continue

            iface.int_interface = nic
            iface.int_name = int_name
            iface.int_ipv4address = ip
            iface.int_ipv4address_b = ip_b or ''
            iface.int_v4netmaskbit = mask
            iface.int_dhcp = False
            iface.int_v6netmaskbit = ''
            iface.int_ipv6address = ''
            iface.int_ipv6auto = False
            print(_("Saving interface configuration:"), end=' ')
            try:
                iface.save()
            except:
                print(_("Failed"))
                return False
            print(_("Ok"))
            need_restart = True
            break
        elif yes.lower().startswith("n"):
            break

    while True:
        yes = input(_("Configure IPv6? (y/n) "))
        if yes.lower().startswith("y"):
            ip_prompt = prompt("IPv6 Address", iface.int_ipv6address)
            mask_prompt = prompt("IPv6 Prefixlen", iface.int_v6netmaskbit)

            ip = mask = None
            ip = input(ip_prompt)
            if not ip and iface.int_ipv6address:
                ip = iface.int_ipv6address

            try:
                try:
                    ipaddr.IPNetwork(ip, version=6)
                except ValueError:
                    print(_("Invalid value entered"))
                    continue
                ip, mask = ip.split('/')
            except:
                while True:
                    mask = input(mask_prompt)
                    if not mask and iface.int_v6netmaskbit:
                        mask = iface.int_v6netmaskbit
                        break
                    if mask.startswith("/"):
                        mask = mask.lstrip("/")
                    if int(mask) > 0 and int(mask) < 128:
                        mask = int(mask)
                    else:
                        print(_("Enter ipv6 prefixlen as number of bits, eg: 64"))
                        continue

            iface.int_interface = nic
            iface.int_ipv6address = ip
            iface.int_v6netmaskbit = mask

            print(_("Saving interface configuration:"), end=' ')
            try:
                iface.save()
            except:
                print(_("Failed"))
                return False
            print(_("Ok"))
            if not need_restart:
                need_restart = True
            break

        elif yes.lower().startswith("n"):
            break

    while True:
        if not failover_licensed:
            break
        yes = input(_("Configure failover settings? (y/n) "))
        if yes.lower().startswith("y"):

            while True:
                vip_prompt = prompt("Virtual IP", iface.int_vip)
                vip = input(vip_prompt)
                if not vip and iface.int_vip:
                    vip = iface.int_vip
                elif vip:
                    try:
                        ipaddr.IPAddress(vip)
                    except ValueError:
                        print('Invalid Virtual IP')
                        continue
                iface.int_vip = vip
                break

            while True:
                vhid_prompt = prompt("VHID", iface.int_vhid)
                vhid = input(vhid_prompt)
                if not vhid and iface.int_vhid:
                    vhid = iface.int_vhid
                elif vhid:
                    try:
                        vhid = int(vhid)
                    except ValueError:
                        print('VHID should be an integer')
                        continue
                    if vhid < 1 or vhid > 255:
                        print('VHID should be between 1 and 255')
                        continue
                iface.int_vhid = vhid
                break

            while True:
                intcrit = input(_("Mark interface as Critical for Failover? (y/n) "))
                if intcrit.lower().startswith("y"):
                    if iface.int_interface.startswith('lagg'):
                        with Client() as c:
                            data = c.call('datastore.query', 'network.lagginterface', [('lagg_interface', '=', iface.id)])
                            if data and data[0]['lagg_protocol'] == 'failover':
                                print('A lagg interface using the "Failover" protocol is not allowed to be marked critical for failover.')
                                continue
                    iface.int_critical = True
                elif intcrit.lower().startswith("n"):
                    iface.int_critical = False
                else:
                    print(_("Please select y or n"))
                    continue
                break

            while True:
                passwd = input("Password for the VIP (none to automatically generate one): ")
                if not passwd:
                    if iface.int_pass:
                        # means an interface is being updated
                        # and already has a generated password
                        # so just use that one
                        passwd = iface.int_pass
                        break
                    else:
                        # generate one
                        passwd = random_string(string_size=16)
                        break
                elif len(passwd) > 16:
                    print("Password must be between 1 and 16 characters long")
                    continue
                else:
                    break

            # finally, we need to set the db column
            # to what `passwd` is set to since we're
            # writing directly to the database
            iface.int_pass = passwd

            print(_("Saving interface configuration:"), end=' ')
            try:
                iface.save()
            except:
                print(_("Failed"))
                return False
            print(_("Ok"))
            if not need_restart:
                need_restart = True
            break

        elif yes.lower().startswith("n"):
            break

    if need_restart:
        print(_("Restarting network:"), end=' ')
        try:
            with Client() as c:
                c.call('service.start', 'network')
        except:
            print(_("Failed"))
        print(_("ok"))
    if need_restart == "DHCP":
        print(_("Restarting routing:"), end=' ')
        try:
            with Client() as c:
                c.call('route.sync')
                c.call('etc.generate', 'rc')
                c.call('service.restart', 'routing')
        except:
            print(_("Failed"))
        print(_("ok"))
        return True

def reset_root_pw(*args):
    print()
    print(_("Changing password for root"))
    print(_("This action will disable 2FA"))
    print()

    if not args:
        waituser = True
        prompt = lambda y: (getpass.getpass(), getpass.getpass(y))
        while True:
            p1, p2 = prompt(_('Retype password: '))
            if p1 == p2:
                break
            print()
            print(_('Passwords do not match. Try again.'))
            print()
    else:
        p1 = args[0]
        waituser = False

    if p1:
        try:
            with Client() as c:
                user_id = c.call("user.query", [["username", "=", "root"]], {"get": True})["id"]
                c.call("user.update", user_id, {"password": p1})
                c.call("auth.twofactor.update", {"enabled": False})
        except (FileNotFoundError, ConnectionRefusedError):
            from middlewared.plugins.account import crypted_password
            conn = sqlite3.connect(FREENAS_DB)
            c = conn.cursor()
            c.execute("UPDATE account_bsdusers SET bsdusr_unixhash = ? WHERE bsdusr_username = 'root'",
                      (crypted_password(p1),))
            c.execute("UPDATE system_twofactorauthentication SET enabled = 0")
            conn.commit()
            conn.close()
        print()
        print(_('Password successfully changed.'))
        print()
    else:
        print()
        print(_('Password change aborted.'))
        print()

    if waituser:
        print(_('Press enter to continue'))
        print()
        input("")


def reset_factory_defaults(*args):
    a = input(_('The configuration will be erased and reset to defaults. Are you sure?')
                  + '(yes/no): ')
    if a.lower().startswith('y'):
        try:
            with Client() as c:
                c.call('config.reset', {'reboot': False}, job=True)
            os.system("/sbin/shutdown -r now")
        except Exception as e:
            print("Failed to reset configuration: {}".format(str(e)))


def configure_lagg_interface(*args):

    menu = [
        [ _("Create Link Aggregation"), create_lagg_interface ],
        [ _("Delete Link Aggregation"), delete_lagg_interface ],
              ]
    menu_map = {}
    menu_max = 0
    for item in menu:
        menu_max = menu_max + 1
        menu_map[menu_max] = item

    while True:
        print()

        for index in menu_map:
            print("%d) %s" % (index, menu_map[index][0]))


        _input = input(_("Enter an option from 1-%d (enter q to quit): ") % (menu_max))
        if _input.isdigit() and int(_input) in range(1, menu_max + 1):
            ch = int(_input)
            if ch in menu_map:
                menu_map[ch][1]()
                break
        elif _input.lower().startswith("q"):
            return False
        continue

def create_lagg_interface():
    lagg_index = 0
    lagg_interfaces = LAGGInterface.objects.all()
    for li in lagg_interfaces:
        lagg_index = int(re.split('([0-9]+)$', li.lagg_interface.int_interface)[1]) + 1

    lagg_proto = get_lagg_proto()
    if not lagg_proto:
        return True

    lagg_nics = get_lagg_nics()
    if not lagg_nics:
        return True

    lagg_name = 'lagg%d' % lagg_index
    iface = Interfaces(int_interface = lagg_name, int_name = lagg_name,
        int_dhcp = False, int_ipv6auto = False)

    print(_("Saving interface configuration:"), end=' ')
    try:
        iface.save()
    except:
        print(_("Failed"))
        return False
    print(_("Ok"))

    lagg_iface = LAGGInterface(lagg_interface = iface,
                               lagg_protocol = lagg_proto)

    print(_("Saving Link Aggregation configuration:"), end=' ')
    try:
        lagg_iface.save()
    except:
        print(_("Failed"))
        return False
    print(_("Ok"))

    order = 0
    for nic in lagg_nics:
        lagg_iface_member = LAGGInterfaceMembers(
            lagg_interfacegroup=lagg_iface, lagg_ordernum=order, lagg_physnic=nic,
        )
        interface = Interfaces.objects.filter(int_interface=nic)
        if not interface.exists():
            Interfaces.objects.create(int_interface=nic, int_name=f'member of {lagg_name}')

        print(_("Saving Link Aggregation member configuration:"), end=' ')
        try:
            lagg_iface_member.save()
        except:
            print(_("Failed"))
            return False
        print(_("Ok"))

        order += 1

    print(_("Restarting network:"), end=' ')
    try:
        with Client() as c:
            c.call('service.start', 'network')
    except:
        print(_("Failed"))
        return False
    print(_("ok"))
    return True


def delete_lagg_interface():
    lagg_interfaces = LAGGInterface.objects.all()

    if not lagg_interfaces.exists():
        print()
        print("No lagg interfaces configured")
        print()
        print("Press enter to continue")
        input()
        return False

    lagg_map = {}
    while True:
        print()
        print("Select which lagg interface you would like to delete:")
        print()

        for idx, li in enumerate(lagg_interfaces):
            lagg_index = int(re.split('([0-9]+)$', li.lagg_interface.int_interface)[1]) + 1
            lagg_map[idx + 1] = li

            print("%d) lagg%s" % (idx + 1, re.split('([0-9]+)$', str(li.lagg_interface))[1]))

        print()

        _input = input(_("Select an interface (q to quit): "))
        if _input.isdigit() and int(_input) in range(1, idx + 2):
            ch = int(_input)
            if ch in lagg_map:
                lagg = lagg_map[ch]
                break
        elif _input.lower().startswith("q"):
            return False
        continue


    print(_("Deleting lagg interface:"), end=' ')
    try:
        lagg.delete()
    except:
        print(_("Failed"))
        return False
    print(_("Ok"))
    print(_("Restarting network:"), end=' ')
    try:
        with Client() as c:
            c.call('service.start', 'network')
    except:
        print(_("Failed"))
        return False
    print(_("ok"))


    return True


def configure_vlan(*args):

    menu = [
        [ _("Create VLAN Interface"), create_vlan ],
        [ _("Delete VLAN Interface"), delete_vlan ],
              ]
    menu_map = {}
    menu_max = 0
    for item in menu:
        menu_max = menu_max + 1
        menu_map[menu_max] = item

    while True:
        print()

        for index in menu_map:
            print("%d) %s" % (index, menu_map[index][0]))


        _input = input(_("Enter an option from 1-%d (enter q to quit): ") % (menu_max))
        if _input.isdigit() and int(_input) in range(1, menu_max + 1):
            ch = int(_input)
            if ch in menu_map:
                menu_map[ch][1]()
                break
        elif _input.lower().startswith("q"):
            return False
        continue


def create_vlan():

    vlan = VLAN()

    # Select parent interface
    while True:
        nics = []
        choices = NICChoices(novlan=True,exclude_configured=False)
        for i, c in enumerate(choices):
            nics.append(c[0])
            print("%2d) %s" % (i + 1, nics[i]))

        _input = input(_("Select a parent interface (q to quit): "))
        if _input.isdigit() and int(_input) in range(1, len(nics)+1):
            vlan_pint = nics[int(_input) - 1]
            break
        elif _input.lower().startswith("q"):
            return True
    # Get interface name
    while True:
        vlan_vint = input(_("Enter an interface name ")
                              + _("(vlanX) or a to abort: "))
        if vlan_vint == "a":
            return
        reg = re.search(r'vlan(\d+)', vlan_vint)
        if not reg:
            print(_("Interface name must be vlanX where X is a number"))
            continue
        vlan_vint = f'vlan{int(reg.group(1))}'
        break
    # Get vlan tag
    while True:
        vlan_tag = input(_("Enter a VLAN Tag or a to abort: "))
        if vlan_tag == "a":
            return
        if not re.match(r'\d+', vlan_tag):
            print(_("VLAN Tag must be an integer"))
            continue
        break
    # Get VLAN description
    vlan_description = input(_("Enter VLAN description: "))

    vlan.vlan_pint = vlan_pint
    vlan.vlan_vint = vlan_vint
    vlan.vlan_tag = vlan_tag
    vlan.vlan_description = vlan_description
    print(_("Saving VLAN interface:"), end=' ')
    try:
        vlan.save()
        for vlan_interface_name in [vlan_pint, vlan_vint]:
            qs = Interfaces.objects.filter(int_interface=vlan_interface_name)
            if not qs.exists():
                vlan_interface = Interfaces(
                    int_interface=vlan_interface_name,
                    int_name=vlan_interface_name,
                    int_dhcp=False,
                    int_ipv6auto=False,
                    int_options='up',
                )
                vlan_interface.save()
            else:
                vlan_interface = qs[0]
                if 'up' not in vlan_interface.int_options:
                    vlan_interface.int_options += ' up'
                    vlan_interface.save()
    except:
        print(_("Failed"))
        return False
    print(_("Ok"))
    print(_("Restarting network:"), end=' ')
    try:
        with Client() as c:
            c.call('service.start', 'network')
    except:
        print(_("Failed"))
        return False
    print(_("ok"))
    return True


def delete_vlan():
    vlan_interfaces = VLAN.objects.all()

    if not vlan_interfaces.exists():
        print()
        print("No VLAN interfaces configured")
        print()
        print("Press enter to continue")
        input()
        return False

    vlan_map = {}
    while True:
        print()
        print("Select which VLAN interface you would like to delete:")
        print()

        for idx, vi in enumerate(vlan_interfaces):
            vlan_map[idx + 1] = vi

            print("%d) %s" % (idx + 1, str(vi.vlan_vint)))

        print()

        _input = input(_("Select an interface (q to quit): "))
        if _input.isdigit() and int(_input) in range(1, idx + 2):
            ch = int(_input)
            if ch in vlan_map:
                vlan = vlan_map[ch]
                break
        elif _input.lower().startswith("q"):
            return False
        continue


    print(_("Deleting VLAN interface:"), end=' ')
    try:
        vlan.delete()
    except:
        print(_("Failed"))
        return False
    print(_("Ok"))

    return True


def configure_ipv4_default_route(gc):
    gwprompt = prompt(_("IPv4 Default Route"), gc.gc_ipv4gateway)

    gateway = input(gwprompt)
    if gateway:
        try:
            ipaddr.IPAddress(gateway, version=4)
        except ValueError:
            print(_("Invalid value entered"))
            return False
        gc.gc_ipv4gateway = gateway

        print(_("Saving IPv4 gateway:"), end=' ')
        try:
            gc.save()
        except:
            print(_("Failed"))
            return False
        print(_("Ok"))
        return True
    else:
        print(_("No default route entered."))
        return False

def configure_ipv6_default_route(gc):
    gwprompt = prompt(_("IPv6 Default Route"), gc.gc_ipv6gateway)

    gateway = input(gwprompt)
    if gateway:
        try:
            ipaddr.IPAddress(gateway, version=6)
        except ValueError:
            print(_("Invalid value entered"))
            return False
        gc.gc_ipv6gateway = gateway
        print(_("Saving IPv6 gateway:"), end=' ')
        try:
            gc.save()
        except:
            print(_("Failed"))
            return False
        print(_("Ok"))
        return True
    else:
        print(_("No default route entered."))
        return False

def configure_default_route(*args):
    gc = GlobalConfiguration.objects.all()
    gc = gc[0]
    need_save = False

    while True:
        yes = input(_("Configure IPv4 Default Route? (y/n)"))
        if yes.lower().startswith("y"):
            configure_ipv4_default_route(gc)
            need_save = True
            break
        elif yes.lower().startswith("n"):
            break

    while True:
        yes = input(_("Configure IPv6 Default Route? (y/n)"))
        if yes.lower().startswith("y"):
            configure_ipv6_default_route(gc)
            need_save = True
            break
        elif yes.lower().startswith("n"):
            break

    if need_save:
        print(_("Restarting routing:"), end=' ')
        try:
            with Client() as c:
                c.call('route.sync')
                c.call('etc.generate', 'rc')
                c.call('service.restart', 'routing')
        except:
            print(_("Failed"))
            return False
        print(_("ok"))
        return True
    else:
        print(_("Routing configuration unchanged."))
        return True


def configure_static_routes(*args):
    dest = input(_("Destination network: "))
    gateway = input(_("Gateway: "))
    desc = input(_("Description: "))

    try:
        ipaddr.IPNetwork(dest)
    except ValueError:
        print(_("Invalid destination network"))
        return False
    try:
        ipaddr.IPAddress(gateway)
    except ValueError:
        print(_("Invalid gateway"))
        return False

    sr = StaticRoute()
    sr.sr_destination = dest
    sr.sr_gateway = gateway
    if desc:
        sr.sr_description = desc
    print(_("Saving static route:"), end=' ')
    try:
        sr.save()
    except:
        print(_("Failed"))
        return False
    print(_("ok"))
    try:
        print(_("Restarting routing:"), end=' ')
        with Client() as c:
            c.call('route.sync')
            c.call('etc.generate', 'rc')
            c.call('service.restart', 'routing')
    except:
        print(_("Failed"))
        return False
    print(_("ok"))
    return True


def configure_dns(*args):
    ns1 = ns2 = ns3 = domain = None
    gc = GlobalConfiguration.objects.all()
    gc = gc[0]

    domain_prompt = prompt(_("DNS Domain"), gc.gc_domain)
    ns1_prompt = prompt(_("DNS Nameserver 1"), gc.gc_nameserver1)
    ns2_prompt = prompt(_("DNS Nameserver 2"), gc.gc_nameserver2)
    ns3_prompt = prompt(_("DNS Nameserver 3"), gc.gc_nameserver3)

    domain = input(domain_prompt)
    if domain:
        gc.gc_domain = domain

    need_save = False
    print(_("Enter nameserver IPs, an empty value ends input"))
    while True:
        ns1 = input(ns1_prompt)
        if ns1:
            try:
                ipaddr.IPAddress(ns1, version=4)
            except ValueError:
                print(_("Invalid nameserver"))
                return False
            gc.gc_nameserver1 = ns1
            need_save = True
            gc.gc_nameserver2 = ''
            gc.gc_nameserver3 = ''
        else:
            return False

        ns2 = input(ns2_prompt)
        if ns2:
            try:
                ipaddr.IPAddress(ns2, version=4)
            except ValueError:
                print(_("Invalid nameserver"))
                break
            gc.gc_nameserver2 = ns2
        else:
            break

        ns3 = input(ns3_prompt)
        if ns3:
            try:
                ipaddr.IPAddress(ns3, version=4)
            except ValueError:
                print(_("Invalid nameserver"))
                break
            gc.gc_nameserver3 = ns3
        break

    if need_save:
        print(_("Saving DNS configuration:"), end=' ')
        try:
            gc.save()
        except:
            print(_("Failed"))
            return False
        print(_("ok"))

        print(_("Reloading network config:"), end=' ')
        try:
            with Client() as c:
                c.call('service.reload', 'networkgeneral')
        except:
            print(_("Failed"))
            return False
        print(_("ok"))
        return True

def shell(*args):
    return os.system("/usr/bin/su -l root")


def automatic_ix_alert(*args):
    support = Support.objects.order_by('-id')[0]
    if support.enabled:
        text = _("enabled")
    else:
        text = _("disabled")
    print(_("Automatic support alerts to iXsystems: %s") % text)
    while True:
        if support.enabled:
            ret = input(_("Disable (y/n): "))
        else:
            ret = input(_("Enable (y/n): "))
        if ret.lower().startswith("y"):
            support.enabled = not support.enabled
            support.save()
            return True
        if ret.lower().startswith("n"):
            return False


def reboot(*args):
    while True:
        ret = input(_("Confirm reboot (y/n): "))
        if ret.lower().startswith("y"):
            os.system("/sbin/shutdown -r now")
            time.sleep(60)
            return False
        if ret.lower().startswith("n"):
            return False

def shutdown(*args):
    while True:
        ret = input(_("Confirm Shutdown (y/n): "))
        if ret.lower().startswith("y"):
            os.system(f"/sbin/shutdown -{'P' if IS_LINUX else 'p'} now")
            time.sleep(60)
            return False
        if ret.lower().startswith("n"):
            return False


def show_ip():
    with Client() as c:
        urls = c.call('system.general.get_ui_urls')

    if urls:
        print()
        print(_("The web user interface is at:"))
        print()
        for url in urls:
            print(url)
    else:
        print()
        print(_("The web interface could not be accessed."))
        print(_("Please check network configuration."))
    print()


def netcli_title():

    product = ''
    serial = ''
    version = ''

    try:
        with Client() as c:
            data = c.call('system.dmidecode_info')
            product = data['system-product-name']
            serial = data['system-serial-number']
            version = c.call('system.version') + ' | '
    except Exception:
        pass

    return f'{product} | {version}{serial}'


def main_menu():
    menu = [
        [ _("Configure Network Interfaces"), configure_interface_stub ],
        [ _("Configure Link Aggregation"), configure_lagg_interface ],
        [ _("Configure VLAN Interface"), configure_vlan ],
        [ _("Configure Default Route"), configure_default_route ],
        [ _("Configure Static Routes"), configure_static_routes ],
        [ _("Configure DNS"), configure_dns ],
        [ _("Reset Root Password"), reset_root_pw ],
        [ _("Reset Configuration to Defaults"), reset_factory_defaults ],
        [ _("Shell"), shell ],
        [ _("Reboot"), reboot ],
        [ _("Shut Down"), shutdown],
    ]

    failover_status = 'SINGLE'
    try:
        with Client() as c:
            support_available = c.call('support.is_available')
            is_freenas = c.call('system.is_freenas')
            if not is_freenas:
                failover_status = c.call('failover.status')
    except Exception:
        support_available = False
        is_freenas = True

    if not is_freenas and support_available:
        menu.insert(12, [ _("Toggle automatic support alerts to iXsystems"), automatic_ix_alert ])

    menu_map = {}
    menu_max = 0
    for item in menu:
        menu_max = menu_max + 1
        menu_map[menu_max] = item
        # If this was requested on the command line, then we just call it
        if len(sys.argv) > 1:
            if globals()[sys.argv[1]] == item[1]:
                item[1](*sys.argv[2:])
                sys.exit(0)

    while True:
        if not is_freenas and os.path.exists('/tmp/.failover_needop'):
            """
            Currently we can have multiple pools.
            Because one encrypted pool failed to import does not mean the node
            is supposed/allowed to unlock.
            """
            passphrase = (failover_status != 'MASTER')
            if passphrase:
                try:
                    with Client() as c:
                        passphrase &= (
                            c.call('failover.call_remote', 'failover.status') != 'MASTER'
                        )

                except Exception:
                    pass

            with open('/tmp/.failover_needop', 'r') as f:
                failed_pools = [p.strip() for p in f.read().strip().split('\n') if p.strip()]

            if passphrase and failed_pools:
                yes = input('Enter passphrase for pools which failed to decrypt ? (y/n)')
                if yes.lower().startswith('y'):
                    cache = []
                    for pool in failed_pools:
                        key = None
                        while not key:
                            try:
                                key = getpass.getpass(f'Please enter passphrase for "{pool}" pool: ')
                            except EOFError:
                                print('Please provide a valid key')
                            else:
                                if not key:
                                    print('Provided key is empty, please provide a valid key')
                        cache.append({'name': pool, 'passphrase': key})
                    print('Forcing current node to become MASTER')

                    with Client() as c:
                        c.call('failover.unlock', {'pools': cache})

                    time.sleep(5)
                    continue

        if not is_freenas:
            print(netcli_title())

        print()
        print(_("Console setup"))
        print("-------------")
        print()

        for index in menu_map:
            print("%d) %s" % (index, menu_map[index][0]))

        try:
            show_ip()
        except:
            pass

        try:
            ch = int(input(_("Enter an option from 1-%d: ") % (menu_max)))
        except ValueError:
            ch = None
        if ch in menu_map:
            f = menu_map[ch][1]
            try:
                if f not in [reset_root_pw, reset_factory_defaults, shell, reboot, shutdown]:
                    ensure_failover_is_disabled()
                f()
            except FailoverIsEnabledException:
                print("You can't perform this action unless failover is administratively disabled")
                print("Press Enter to return to console")
                input()


#
#	No signal handling here, it is assumed that this script
#	will be setup in /etc/ttys, so we just exit ;-).
#
if __name__ == '__main__':
    signal.signal(signal.SIGINT, signal.SIG_IGN)

    while True:
        try:
            console = True
            if os.getppid() == 1:
                try:
                    with Client() as c:
                        console = c.call('system.advanced.config')['consolemenu']
                except Exception:
                    pass
            if console:
                main_menu()
            else:
                show_ip()
                if IS_LINUX:
                    proc = subprocess.Popen(['/usr/bin/tty'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                    stdout, stderr = proc.communicate()
                    if proc.returncode:
                        print(f'Failed to determine active tty: {stderr.decode()}')
                        os.execv('/bin/login', ['/bin/login'])
                    else:
                        terminal = os.path.relpath(stdout.decode().strip(), '/dev')
                        os.execv('/sbin/agetty', ['/sbin/agetty', '--noclear', '--keep-baud', terminal])
                else:
                    os.execv('/usr/libexec/getty', ['/usr/libexec/getty', 'Pc'])
        except SystemExit as e:
            sys.exit(e.code)
        except Exception:
            traceback.print_exc()
            exit(1)
