#!/usr/bin/python3

#
# babackup_server
#
# Copyright (C) 2024 by John Heidemann <johnh@isi.edu>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
#


import argparse
import sys
import os
import os.path
import re
import shutil
import tempfile
import datetime
import subprocess
import logging
import pdb
# pdb.set_trace()

import yaml


class Program:
    def __init__(self):
        # if assertion_fails:
        #     raise Exception("Assertion failed")
        # or better sys.exit("Assertion failed")
        self.parse_args()
        if self.new_path is not None:
            self.configure_new_backup()
        else:
            self.check_backups()

    def verbose_log(self, s, status = 'info'):
        if self.verbose:
            print(s)
        logging.info(s)

    def parse_args(self):
        parser = argparse.ArgumentParser(description = 'backup things via rsync to a remote server from a client', epilog="""
babackup

        """)
        # see https://docs.python.org/3/library/argparse.html
        #  ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])

        #  parser.add_argument('--focus', help='focus on a given TARGET', choices=['us', 'nynj', 'coverage'], default='us')
        #  parser.add_argument('--output', '-o', help='output FILE')
        #  parser.add_argument('--duty-cycle', help='duty cycle (a float)', type=float)
        #  parser.add_argument('--type', '-t', choices=['pdf', 'png'], help='type of output (pdf or png)', default = 'pdf')
        #  parser.add_argument('--day', type=int, help='day to plot', default = None)
        parser.add_argument('--conf', '-c', help='use configuration FILE.yaml (default is ~/.config/babackup/server.yaml or /etc/babackup/server.yaml)')
        parser.add_argument('--name', '-N', help='use backup NAME, or define new backup NAME')
        parser.add_argument('--near-to-keep', help='how many near (daily) backups to keep', type=int, default=0)
        parser.add_argument('--far-to-keep', help='how many far (weekly) backups to keep', type=int, default=0)
        parser.add_argument('--new-path', help='location of a new backup')
        parser.add_argument('--new-pub-key', help='public ssh key for new backup')
        parser.add_argument('--new-mode', help='new backup is rrsync or ssh')
        parser.add_argument('--debug', '-d', help='debugging mode', action='store_true', default=False)
        parser.add_argument('--verbose', '-v', action='count')
        args = parser.parse_args()
        self.debug = args.debug
        self.verbose = args.verbose
        self.conf_path = args.conf
        self.name = args.name
        self.near_to_keep = args.near_to_keep
        self.far_to_keep = args.far_to_keep
        self.new_path = args.new_path
        self.new_pub_key = args.new_pub_key
        self.new_mode = args.new_mode
        self.DEFAULT_TO_KEEP = 10
        self.temp_dir = None
        return args


    def read_conf(self):
        """figure out what configuration file we're using, then read and return it"""
         # where?
        self.conf_dir = "/etc/babackup"
        if os.getuid() != 0:
            self.conf_dir = os.path.expanduser("~") + "/.config/babackup"
        # what?
        if self.conf_path is None:
            self.conf_path = self.conf_dir + "/server.yaml"
        # read it
        try:
            conf_stream = open(self.conf_path, 'r')
            self.conf = yaml.safe_load(conf_stream)
        except IOError:
            self.conf = dict()
        if self.conf.get('backups') is None:
            self.conf['backups'] = dict()
        # also set up logging
        logging_path = None
        logging_conf = self.conf.get('logging')
        if logging_conf is not None and logging_conf.get("filename") is not None:
            longging_path = self.conf['logging']['filename']
        if logging_path is None:
            logging_path = self.conf_dir + "server.log"
        logging.basicConfig(filename  = logging_path)


    def write_conf(self):
        with open(self.conf_path, 'w+') as conf_stream:
            yaml.dump(self.conf, conf_stream)


    def configure_new_backup(self):
        """configure a new backup
Update configuration files, generate keys, say what to do on the server, etc.
        """
        self.read_conf()
        if self.name is None:
            sys.exit("babackup: attempt to add new backup without specifying --name")
        if self.new_mode is None:
            sys.exit("babackup: attempt to add new backup without specifying --new-mode=MODE (rrsync or ssh)")
        mode = self.new_mode
        if self.new_path is None:
            sys.exit("babackup: attempt to add new backup without specifying --new-path=server/partial/or/full/path")
        if mode == 'rrsync' and self.new_pub_key is None:
            sys.exit("babackup_server: attempt to add new rrsync backup without specifying --new-pub-key='ssh-foo BASE64 keyname'")

        # and some sanity checking, since we're going to put stuff in authorized_keys
        if re.search(r"\s", self.new_path):
            sys.exit("babackup_server: rejecting --new-path that contains whitespace")
        if self.new_pub_key.find("\n") >= 0:
            sys.exit("babackup_server: rejecting --new-pub-key that contains newline")

        name = self.name
        backup = dict()
        backup["name"] = self.name
        backup["path"] = self.new_path
        backup["mode"] = mode
        if mode == 'rrsync':
            backup["pub-key"] = self.new_pub_key

        #
        # create the target directory
        #
        full_new_path = self.new_path
        if full_new_path[0] == '~':
            full_new_path = os.path.expanduser(full_new_path)
        if not os.path.isdir(full_new_path):
            if self.verbose:
                print(f"babackup_server: mkdir {full_new_path}")
            if not self.debug:
                os.mkdir(full_new_path)
                os.mkdir(full_new_path + "/current")
                os.mkdir(full_new_path + "/current/last")
                os.mkdir(full_new_path + "/current/last/data")

        #
        # change authorized_keys
        #
        if mode == 'rrsync':
            auth = 'command="/usr/local/bin/rrsync -wo ' + self.new_path + '/current",no-agent-forwarding,no-port-forwarding,no-pty,no-user-rc,no-X11-forwarding ' + self.new_pub_key
            if self.verbose:
                print(f"babackup_server: adding public key to ~/.ssh/authorized_keys, with rrsync\n\t{auth}")
            if not self.debug:
                auth_path = os.path.expanduser("~") + "/.ssh/authorized_keys"
                with open(auth_path, "a") as auth_stream:
                    auth_stream.write(auth + "\n")
        
        self.conf['backups'].append(backup)
        if self.debug:
            return
        self.write_conf()


        
    def check_backup(self, backup):
        """check one backup with configuration BACKUP"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        path = backup.get("path")
        if path is None:
            sys.exit(f"babackup_server: backup {name} has no path:")
        if path[0] == "~":
            path = os.path.expanduser(path)

        #
        # see if a this one finished a backup since last check
        #
        # 0. no begin: never started
        # 1. begin no end => in progress (or failed)
        # 2. begin and end, but end before begin => missed it and it's running again
        # 3. begin and end, but end after begin => good!
        #

        # no backup
        if not os.path.exists(f"{path}/current/begin"):
            self.verbose_log(f"babackup_server: {name} backup idle")
            return

        if not os.path.exists(f"{path}/current/end"):
            self.verbose_log(f"babackup_server: {name} backup is active")
            return

        begin_mtime = os.path.getmtime(f"{path}/current/begin")
        end_mtime = os.path.getmtime(f"{path}/current/end")
        if begin_mtime > end_mtime:
            self.verbose_log(f"babackup_server: {name} backup was complete but is active again")
            return

        #
        # commit!
        #
        # (Note that there is a test-to-use race going on here :-( )
        #
        self.verbose_log(f"babackup_server: {name} is rolling current to last")
        if not os.path.exists(f"{path}/new"):
            os.mkdir(f"{path}/new")
        os.rename(f"{path}/current", f"{path}/new/last")
        os.rename(f"{path}/new", f"{path}/current")

        #
        # move the old last into archive
        #
        # Note that we trust the file mtime rather than contents,
        # since the contents came from the user.
        #
        # Only do this if that backup looks good (has begin and end files).
        # Otherwise we let that last linger, eventually to be garbage collected
        # when current/last goes away
        #
        if os.path.isdir(f"{path}/current/last/last") and os.path.exists(f"{path}/current/last/last/begin") and os.path.exists(f"{path}/current/last/last/end"):
            self.verbose_log(f"babackup_server: {name} is moving last/last to archive")
            last_begin_mtime = os.path.getmtime(f"{path}/current/last/last/begin")
            last_isotime = datetime.datetime.fromtimestamp(last_begin_mtime, datetime.timezone.utc).isoformat(timespec = 'minutes')
            # get rid of : in time, to be more filename friendly
            last_isotime = last_isotime.replace(":", "_")
            if not os.path.isdir(f"{path}/archive"):
                os.mkdir(f"{path}/archive")
            if os.path.isdir(f"{path}/archive/{last_isotime}"):
                # xxx: we have two things with the same time, so the rename will fail.
                # So just skip it.
                return
            os.rename(f"{path}/current/last/last", f"{path}/archive/{last_isotime}")
            # we want to come back to this archive and check it later
            backup['archive_check'] = True
        
            
    def check_backup_archive(self, backup):
        """check the archive of one backup (configuration BACKUP) for outdated entries"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        path = backup.get("path")
        if path[0] == "~":
            path = os.path.expanduser(path)
        if path is None:
            sys.exit(f"babackup_server: backup {name} has no path:")
        if not os.path.isdir(f"{path}/archive"):
            self.verbose_log(f"babackup_server: {name} check has no archive ({path}/archive)")
            return

        #
        # walk the archives
        # finding things to delete (ending ~)
        # and to date check.
        #
        to_remove = []
        to_date_check = []
        part_to_timestamp = dict()
        timestamp_to_part = dict()

        iso_matcher = re.compile(r"^\d{4}-?\d{2}-?\d{2}([tT]\d{2}[-:_]?\d{2})?")
        for part in os.listdir(f"{path}/archive"):
            if part.endswith("~"):
                to_remove.append(part)
            elif iso_matcher.match(part):
                to_date_check.append(part)
                # undo our prior : removal
                clean_part = part.replace("_", ":")
#                # seems to require seconds, also, for fromisofromat with timezone to work
#                if len(clean_part) == 22:
#                    clean_part = clean_part[:16] + ":00" + clean_part[16:]
                part_to_timestamp[part] = datetime.datetime.fromisoformat(clean_part).timestamp()
                timestamp_to_part[part_to_timestamp[part]] = part
            else:
                self.verbose_log(f"babackup_server: {name} check finds suprising file {part} in {path}/archive", "warn")

        #
        # apply the priorization algorithm
        #
        if self.near_to_keep > 0:
            near_to_keep = self.near_to_keep
        else:
            near_to_keep = backup.get("near_to_keep", self.DEFAULT_TO_KEEP)
        if self.far_to_keep > 0:
            far_to_keep = self.far_to_keep
        else:
            far_to_keep = backup.get("far_to_keep", self.DEFAULT_TO_KEEP)

        if len(to_date_check) < near_to_keep + far_to_keep:
            self.verbose_log(f"babackup_server: not enough archives to require aging")
            return

        newest_to_oldest_timestamps = sorted(part_to_timestamp.values(), reverse=True)
        keeping_what = 'near'
        count_to_keep = near_to_keep
        distance_too_close = 23*60*60   # near is just less than 1 day
        last_timestamp_kept = None
        timestamps_to_keep = []
        timestamps_to_retire = []
        for timestamp in newest_to_oldest_timestamps:
            if count_to_keep > 0 and ((last_timestamp_kept is None) or (timestamp - last_timestamp_kept > distance_too_close)):
                # keep this timestamp
                self.verbose_log(f"babackup_server: keep {keeping_what} " + timestamp_to_part[timestamp])
                timestamps_to_keep.append(timestamp)
                last_timestamp_kept = timestamp
                count_to_keep -= 1
                if count_to_keep <= 0:
                    if keeping_what == 'near':
                        keeping_what = 'far'
                        count_to_keep = far_to_keep
                        distance_too_close = 23*60*60 * 7
                    elif keeping_what == 'far':
                    	keeping_what = 'no_more'
                    else:
                        sys.exit("babackup_server: internal error, got past no_more to keep")
            else:
                # too close, so drop it
                self.verbose_log(f"babackup_server: aging {keeping_what} " + timestamp_to_part[timestamp])
                timestamps_to_retire.append(timestamp)

        #
        # take action
        #
        if self.debug:
            return
        path_to_remove = []
        # first rename them (easy)
        for timestamp in timestamps_to_retire:
            part = timestamp_to_part[timestamp]
            self.verbose_log(f"babackup_server: {name} will retire {path}/archive/{part}")
            os.rename(f"{path}/archive/{part}", f"{path}/archive/{part}~")
            path_to_remove.append(f"{path}/archive/{part}~")
        # then remove them
        for old_path in path_to_remove:
            self.verbose_log(f"babackup_server: {name} is removing tree {old_path}")
            shutil.rmtree(old_path)
            
        
    def check_backups(self):
        """check all backups (or whatever was specified with -N)"""
        self.read_conf()
        for backup in self.conf['backups']:
            if self.name is None or self.name == backup['name']:
                self.check_backup(backup)
        #
        # Now that we've checked each, go back and look at aging.
        #
        for backup in self.conf['backups']:
            if backup.get("archive_check"):
                self.check_backup_archive(backup)


                
if __name__ == '__main__':
    Program()
    sys.exit(0)

