# message_storage.py
#
# Copyright 2024 Christopher Talbot
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gi.repository import Adw
from gi.repository import Gtk
import gettext
import json
import time
import os
import hashlib
import sqlite3
from datetime import datetime, timezone

import gtk_meshtastic_client.utils as utils
import gtk_meshtastic_client.message_row as message_row
import gtk_meshtastic_client.chat_page as chat_page
import gtk_meshtastic_client.node_row as node_row
import gtk_meshtastic_client.channels_page as channels_page


"""
Sometimes ack messages come in much later and there's no way it actually
was delivered. To make sure this doesn't happen, make sure the ack came
within 30 seconds of the date on the row.
"""
delivery_timeout = 30

"""
I want a way to ID the channel based on the ID and the AES Key. However,
I don't want to store the AES key in the database, that is a liability.
Since a hash is a one way function (it destroys the data), I concatenate
the channel name and base64 AES Key and then take the sha256 hash of the
result for the ID. This way, the only way to recover the AES key is to
brute force the key anyways.

As an example, the default LongFast will be:
$ID_$Key Base64, or LongFast_AQ== , then the hash is:
399df22f29297e5dd9b23d8ddc229bb6e64f2d05365b5a478cf6527dc302652c

The ID ends up being 64-bits long for the int, which is probably ok.
"""

def create_channel_id(title, psk_base64):
    string_to_hash = str(title) + "_" + str(psk_base64)
    h = hashlib.new('sha256')
    h.update(string_to_hash.encode('utf-8'))

    channel_hash = h.hexdigest()
    return channel_hash

def find_node_by_id(num, list_box, list_box_children):
    for x in range(list_box_children):
        row_to_test = list_box.get_row_at_index(x)
        if num == row_to_test.num:
            row_to_change = row_to_test
            return row_to_change

class MessageStorage():

    database_name = 'invalid'
    version = 0
    load_limit = 20

    def update_favorite_node(self, node_row):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()

        res = cur.execute("SELECT * FROM favorite_nodes WHERE num = ? ", (node_row.num, ))
        result = res.fetchone()
        if result is None:
            return
        else:
            num = result[0]

            if hasattr(node_row, 'longName'):
                longName = node_row.longName
            else:
                longName = result[1]

            if hasattr(node_row, 'shortName'):
                shortName = node_row.shortName
            else:
                shortName = result[2]

            if hasattr(node_row, 'publicKey'):
                publicKey = node_row.publicKey
            else:
                publicKey = result[3]

            if hasattr(node_row, 'lastHeard') and node_row.lastHeard > 0:
                lastHeard = node_row.lastHeard
            else:
                lastHeard = result[4]

            cur.execute("UPDATE favorite_nodes SET shortName = ? , longName = ? , publicKey = ? , lastHeard = ? WHERE num = ? ", (shortName, longName, publicKey, lastHeard, num, ))

        con.commit()
        cur.close()
        con.close()

    def retrieve_favorite_nodes(self, list_row, list_row_children):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        res = cur.execute("SELECT * FROM favorite_nodes")
        rows = res.fetchall()
        for row in rows:
            row_to_change = find_node_by_id(row[0], list_row, list_row_children)
            if row_to_change:
                row_to_change.toggle_favorite_node(False)

                num = row[0]

                if hasattr(row_to_change, 'longName'):
                    longName = row_to_change.longName
                else:
                    longName = row[1]

                if hasattr(row_to_change, 'shortName'):
                    shortName = row_to_change.shortName
                else:
                    shortName = row[2]

                if hasattr(row_to_change, 'publicKey'):
                    publicKey = row_to_change.publicKey
                else:
                    publicKey = row[3]

                if hasattr(row_to_change, 'lastHeard') and row_to_change.lastHeard > 0:
                    lastHeard = row_to_change.lastHeard
                else:
                    lastHeard = row[4]

                cur.execute("UPDATE favorite_nodes SET shortName = ? , longName = ? , publicKey = ? , lastHeard = ? WHERE num = ? ", (shortName, longName, publicKey, lastHeard, num, ))

            else:
                app = Gtk.Application.get_default()
                win = Gtk.Application.get_active_window(app)
                new_node = node_row.NodeRow()
                new_node.add_favorite_node_attributes(row[0], row[1], row[2], row[3], row[4])
                new_node.toggle_favorite_node(False)
                win.nearby_nodes_page_bin.add_node_to_nearby_nodes(new_node)

        con.commit()
        cur.close()
        con.close()

    def remove_node_from_database(self, num):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()

        cur.execute("DELETE FROM favorite_nodes WHERE num = ? ", (num, ))
        con.commit()
        cur.close()
        con.close()

    def add_node_to_database(self, num, longName, shortName, publicKey, lastHeard):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        fav_node_attribute = []

        fav_node_attribute.append(num)

        if longName:
            fav_node_attribute.append(longName)
        else:
            fav_node_attribute.append("")

        if shortName:
            fav_node_attribute.append(shortName)
        else:
            fav_node_attribute.append("")

        if publicKey:
            fav_node_attribute.append(publicKey)
        else:
            fav_node_attribute.append("MA==")

        if lastHeard:
            fav_node_attribute.append(lastHeard)
        else:
            fav_node_attribute.append(0)

        #favorite_nodes(num, longName, shortName, publicKey, lastHeard)
        cur.execute("INSERT INTO favorite_nodes VALUES(?, ?, ?, ?, ?)", fav_node_attribute)
        con.commit()
        cur.close()
        con.close()

    def delete_message(self, msg_id, timestamp):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        for row in cur.execute("SELECT * FROM text_messages WHERE msg_id = ?  AND time = ? ", (msg_id, timestamp, )):
            channel_id = row[9]

        cur.execute("DELETE FROM text_messages WHERE msg_id = ?  AND time = ? ", (msg_id, timestamp, ))
        con.commit()
        cur.close()
        con.close()
        return channel_id

    def delete_dm(self, dm_id):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        cur.execute("DELETE FROM direct_message_list WHERE dm_id = ? ", (dm_id, ))
        con.commit()
        cur.close()
        con.close()

    def clear_history(self, channel_id):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        cur.execute("DELETE FROM text_messages WHERE channel_title = ? ", (channel_id, ))
        con.commit()
        cur.close()
        con.close()

    def get_message_count(self, channel_id):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        res = cur.execute("SELECT COUNT(*) FROM text_messages WHERE channel_title = ?", (channel_id, ))
        count_list = res.fetchone()
        con.commit()
        cur.close()
        con.close()
        return count_list[0]

    def load_messages(self, nav_page, channel_id, offset):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        #Table text_messages(msg_id, from_id, from_short_name, from_long_name, to_id, direction, text, time, delivered, channel_title)
        for row in cur.execute("SELECT * FROM text_messages WHERE channel_title = ? ORDER BY time DESC LIMIT ? OFFSET ?", (channel_id, self.load_limit, offset, )):
            chat_row = message_row.MessageRow()
            chat_row.set_message_id(row[0])
            chat_row.set_short_name(row[2])
            chat_row.set_long_name(row[3])
            chat_row.set_message_direction(row[5])
            chat_row.set_message_content(row[6])
            chat_row.set_date(row[7])
            chat_row.update_delivery(row[8])

            nav_page.add_message(chat_row)

        con.commit()
        cur.close()
        con.close()
        return

    def update_to_delivered_in_database(self, request_id):
        """ For extra assurance, just make sure that the request_id is an int """
        if not isinstance(request_id, int):
            self.logger.warning("request_id not int!")
            return

        """ To make sure we don't alias with another message, make sure the time is within the delivery timeout period """
        unix_timestamp = int(datetime.now(timezone.utc).timestamp()) - delivery_timeout

        con = sqlite3.connect(self.database_name)
        cur = con.cursor()

        cur.execute("UPDATE text_messages SET delivered = true WHERE msg_id = ? AND time > ? ", (request_id, unix_timestamp, ))

        for row in cur.execute("SELECT * FROM text_messages WHERE msg_id = ? AND time > ? ", (request_id, unix_timestamp, )):
            channel_id = row[9]

        con.commit()
        cur.close()
        con.close()

        if 'channel_id' in locals():
            return channel_id

    def add_sent_message_to_database(self, packet, interface, from_id, short_name, long_name, unix_timestamp, to_address, message_text, channel_id):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        new_message = []

        #msg_id
        new_message.append(getattr(packet,"id"))
        #from_id
        # We only do the int version since we find nodes by this int id
        new_message.append(from_id)
        #from_short_name
        new_message.append(short_name)
        #from_short_name
        new_message.append(long_name)
        #to_id
        new_message.append(to_address)
        #direction
        new_message.append(utils.MsgDirection.Out.value)
        #Text
        new_message.append(message_text)
        #time
        new_message.append(unix_timestamp)
        #delivered
        new_message.append(False)
        #channel_Title
        new_message.append(channel_id)

        cur.execute("INSERT INTO text_messages VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", new_message)
        con.commit()
        cur.close()
        con.close()

    def add_received_message_to_database(self, packet, interface, short_name, long_name, channel_id):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        new_message = []

        #msg_id
        new_message.append(packet["id"])
        #from_id
        # We only do the int version since we find nodes by this int id
        new_message.append(packet["from"])
        #from_short_name
        new_message.append(short_name)
        #from_short_name
        new_message.append(long_name)
        #to_id
        new_message.append(packet["to"])
        #direction
        new_message.append(utils.MsgDirection.In.value)
        #Text
        if packet.get("decoded", {}).get("text"):
            new_message.append(packet["decoded"]["text"])
        else:
            new_message.append("no_text")
        #time
        if 'rxTime' in packet:
            if packet["rxTime"] != 0:
                new_message.append(packet["rxTime"])
            else:
                new_message.append(int(datetime.now(timezone.utc).timestamp()))
        else:
            new_message.append(int(datetime.now(timezone.utc).timestamp()))
        #delivered
        new_message.append(False)
        #channel_Title
        new_message.append(channel_id)


        cur.execute("INSERT INTO text_messages VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", new_message)
        con.commit()
        cur.close()
        con.close()

    def add_direct_message(self, num, longName, shortName, publicKey):
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        new_dm = []

        new_dm.append(num)

        if longName:
            new_dm.append(longName)
        else:
            new_dm.append("")

        if shortName:
            new_dm.append(shortName)
        else:
            new_dm.append("")

        if publicKey:
            new_dm.append(publicKey)
        else:
            new_dm.append("MA==")

        cur.execute("INSERT INTO direct_message_list VALUES(?, ?, ?, ?)", new_dm)
        con.commit()
        cur.close()
        con.close()

    def update_direct_message(self, node_to_check):
        app = Gtk.Application.get_default()
        win = Gtk.Application.get_active_window(app)
        con = sqlite3.connect(self.database_name)
        cur = con.cursor()

        """
        See if this node has a DM. if not, don't bother trying to update
        """
        res = cur.execute("SELECT * FROM direct_message_list WHERE dm_id= ? ", (node_to_check.num, ))
        result = res.fetchone()
        if result is None:
            return
        else:
            long_name = result[1]
            short_name = result[2]
            publicKey = result[3]

        found_dm = win.channels_page_bin.find_dm_by_id(node_to_check.num)
        if not found_dm:
            return

        if hasattr(node_to_check, 'longName'):
            if node_to_check.longName != long_name:
                cur.execute("UPDATE direct_message_list SET longName = ? WHERE dm_id = ?", (node_to_check.longName, node_to_check.num, ))
                long_name = node_to_check.longName
                found_dm.set_channel_title(long_name, "")

        if hasattr(node_to_check, 'shortName'):
            if node_to_check.shortName != short_name:
                cur.execute("UPDATE direct_message_list SET shortName = ? WHERE dm_id = ?", (node_to_check.shortName, node_to_check.num, ))
                short_name = node_to_check.shortName
                found_dm.set_index(short_name)

        if hasattr(node_to_check, 'publicKey'):
            if node_to_check.publicKey != publicKey:
                cur.execute("UPDATE direct_message_list SET publicKey = ? WHERE dm_id = ?", (node_to_check.publicKey, node_to_check.num, ))
                publicKey = node_to_check.publicKey
                found_dm.set_public_key(publicKey)

        con.commit()
        cur.close()
        con.close()

    def populate_direct_messages(self, interface):
        app = Gtk.Application.get_default()
        win = Gtk.Application.get_active_window(app)

        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        #Table direct_message_list(dm_id, longName, shortName)
        res = cur.execute("SELECT * FROM direct_message_list")
        rows = res.fetchall()
        for row in rows:
            dm_id = row[0]
            long_name = row[1]
            short_name = row[2]
            publicKey = row[3]

            """
            The node may have had updates, or you may have found the name
            for a node. If so, we should update
            """
            node_to_check = win.nearby_nodes_page_bin.find_node_by_id(dm_id)
            if node_to_check:
                if hasattr(node_to_check, 'longName'):
                    if node_to_check.longName != long_name:
                        cur.execute("UPDATE direct_message_list SET longName = ? WHERE dm_id = ?", (node_to_check.longName, dm_id, ))
                        long_name = node_to_check.longName

                if hasattr(node_to_check, 'shortName'):
                    if node_to_check.shortName != short_name:
                        cur.execute("UPDATE direct_message_list SET shortName = ? WHERE dm_id = ?", (node_to_check.shortName, dm_id, ))
                        short_name = node_to_check.shortName

                if hasattr(node_to_check, 'publicKey'):
                    if node_to_check.publicKey != publicKey:
                        cur.execute("UPDATE direct_message_list SET publicKey = ? WHERE dm_id = ?", (node_to_check.publicKey, dm_id, ))
                        publicKey = node_to_check.publicKey

            node_row.create_direct_message(dm_id, long_name, short_name, publicKey, False, False)

        con.commit()
        cur.close()
        con.close()

    def __init__(self, interface):
        app = Gtk.Application.get_default()

        self.logger = app.logger

        node_id = utils.idToHex(interface.myInfo.my_node_num)
        self.database_name = utils.meshtastic_config_dir + '/' + node_id + '_messages_v' + str(self.version) + '.db'

        #if the database doesn't exist, create new table
        if not os.path.exists(self.database_name):
            self.logger.debug("Creating new database")
            con = sqlite3.connect(self.database_name)
            cur = con.cursor()
            cur.execute("CREATE TABLE text_messages(msg_id, from_id, from_short_name, from_long_name, to_id, direction, text, time, delivered, channel_title)")
            con.commit()
            cur.close()
            con.close()

        con = sqlite3.connect(self.database_name)
        cur = con.cursor()
        res = cur.execute("SELECT name FROM sqlite_master WHERE name='direct_message_list'")
        result = res.fetchone()

        if result is None:
            cur.execute("CREATE TABLE direct_message_list(dm_id, longName, shortName, publicKey)")

        try:
            cur.execute("SELECT publicKey FROM direct_message_list")
        except:
            cur.execute("ALTER TABLE direct_message_list ADD publicKey")
            cur.execute("UPDATE direct_message_list SET publicKey = 'MA=='")

        res = cur.execute("SELECT name FROM sqlite_master WHERE name='favorite_nodes'")
        result = res.fetchone()

        if result is None:
            cur.execute("CREATE TABLE favorite_nodes(num, longName, shortName, publicKey, lastHeard)")

        con.commit()
        cur.close()
        con.close()
