#!/usr/bin/python3

#
# babackup
#
# Copyright (C) 2024 by John Heidemann <johnh@isi.edu>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
#


import argparse
import sys
import os
import os.path
import subprocess
import tempfile
import datetime
import time
import logging
import re
import hashlib
import atexit
import base64
# from systemd import journal
import pdb
# pdb.set_trace()

import yaml


class Program:
    __version__ = '1.3'
    
    def __init__(self):
        # if assertion_fails:
        #     raise Exception("Assertion failed")
        # or better sys.exit("Assertion failed")
        self.parse_args()
        if self.new_path is not None:
            self.configure_new_backup()
        else:
            self.run_backups()

    def verbose_log(self, s, status = 'info'):
        """our common logging function, both to stdout and the log"""
        if self.verbose:
            print(s)
        logging.info(s)


    def configure_logging(self):
        """set up formal logging, as given in the config"""
        logging_path = None
        logging_conf = self.conf.get('logging')
        if logging_conf is not None and logging_conf.get("filename") is not None:
            logging_path = self.conf['logging']['filename']
        if logging_path is None:
            logging_path = self.conf_dir + "/client.log"
        logging.basicConfig(filename  = logging_path, level = 'INFO', format="%(asctime)s: %(message)s")


    def parse_args(self):
        parser = argparse.ArgumentParser(description = 'backup things via rsync to a remote server from a client', epilog="""
babackup

Primary use case:

        babackup

will read the configuration file and backup each memorized tree.

To select a specific tree, use -N name.

To begin backing up a new tree, do:

        babackup --new-path=/local/path/to/tree/root --new-server=user@server:server/full/or/user/path/to/destination

which will generate a new ssh key and show the command to run on the server
to set it up.

Babackup avoids concurrent runs on backups with the name name.
However, it will override apparently stale runs, or force it with -f -f.

        """)
        # see https://docs.python.org/2/library/argparse.html
        #  ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])

        #  parser.add_argument('--focus', help='focus on a given TARGET', choices=['us', 'nynj', 'coverage'], default='us')
        #  parser.add_argument('--output', '-o', help='output FILE')
        #  parser.add_argument('--duty-cycle', help='duty cycle (a float)', type=float)
        #  parser.add_argument('--type', '-t', choices=['pdf', 'png'], help='type of output (pdf or png)', default = 'pdf')
        #  parser.add_argument('--day', type=int, help='day to plot', default = None)
        parser.add_argument('--name', '-N', help='use backup NAME, or define new backup NAME')
        parser.add_argument('--conf', '-c', help='use configuration FILE.yaml (default: ~/.config/babackup/client.yaml or /etc/babackup/client.yaml)')
        parser.add_argument('--new-path', help='create a new backup configuration for PATH')
        parser.add_argument('--new-mode', help='a new backup should use mode rrsync or ssh', default='unspecified')
        parser.add_argument('--new-server', '-s', help='a new backup will go to USER@SERVER:PATH')
        parser.add_argument('--new-interval', help='the desired minimum backup interval, in seconds (in minutes, hours or days with m, h d)')
        parser.add_argument('--new-exclude-from', help='a new backup will exclude from FILE')
        parser.add_argument('--new-keyfile', help='a new backup will use this ssh keyfile (default: generate one)')
        parser.add_argument('--new-pass', help='the passphrase for the keyfile (default: generate one)')
        parser.add_argument('--force', '-f', help='force a backup run, ignoring any minimum interval', action='store_true', action=count, default=0)
        parser.add_argument('--debug', '-d', help='debugging mode', action='store_true', default=False)
        parser.add_argument('--verbose', '-v', action='count', default=0)
        args = parser.parse_args()
        self.debug = args.debug
        self.verbose = args.verbose
        self.conf_path = args.conf
        self.new_path = args.new_path
        self.new_userserverpath = args.new_server
        self.new_interval = args.new_interval
        self.name = args.name
        self.force = args.force
        self.new_exclude_from = args.new_exclude_from
        self.new_keyfile = args.new_keyfile
        self.new_pass = args.new_pass
        self.new_mode = args.new_mode
        self.temp_dir = None
        return args

    
    def read_remember(self):
        """read our other config file remembering when we backed stuff up"""
        #
        # and load our memory
        #
        self.remember_path = re.sub(r'.yaml$', r'_remember.yaml', self.conf_path)
        try:
            remember_stream = open(self.remember_path, 'r')
            self.remember = yaml.safe_load(remember_stream)
        except IOError:
            self.remember = dict()
        if self.remember.get('backup_times') is None:
            self.remember['backup_times'] = dict()
        self.remember_changed = False


    def write_remember(self):
        """write a changed remember file in the conf_dir"""
        if not self.remember_changed:
            return
        if not os.path.isdir(self.conf_dir):
            os.makedirs(self.conf_dir, mode=0o755)
        with open(self.remember_path, 'w+') as remember_stream:
            yaml.dump(self.remember, remember_stream)
        self.remember_changed = False


    def read_conf(self):
        """figure out what configuration file we're using, then read and return it
Also sets up logging."""
        # where?
        self.conf_dir = "/etc/babackup"
        if os.getuid() != 0:
            self.conf_dir = os.path.expanduser("~") + "/.config/babackup"
        # what?
        if self.conf_path is None:
            self.conf_path = self.conf_dir + "/client.yaml"
        # read it
        try:
            conf_stream = open(self.conf_path, 'r')
            self.conf = yaml.safe_load(conf_stream)
        except IOError:
            self.conf = dict()
        if self.conf.get('backups') is None:
            self.conf['backups'] = dict()
        self.read_remember()
        # also set up logging
        logging_path = None
        logging_conf = self.conf.get('logging')
        if logging_conf is not None and logging_conf.get("filename") is not None:
            longging_path = self.conf['logging']['filename']
        if logging_path is None:
            logging_path = self.conf_dir + "client.log"
        logging.basicConfig(filename  = logging_path)


    def write_conf(self):
        """store a presumably changed client.conf"""
        if not os.path.isdir(self.conf_dir):
            os.makedirs(self.conf_dir, mode=0o755)
        # use opener so we can set the file mode to restrictive
        with open(self.conf_path, 'w+',opener=lambda path, flags:os.open(path, os.O_WRONLY|os.O_CREAT, mode=0o600)) as conf_stream:
            yaml.dump(self.conf, conf_stream)


    def create_ssh_keyfile(self, path, pass):
        """generate a new ssh keyfile at PATH, using an optional PASS, then return the new random BPASS"""

        #
        # generate a pass, if necessary
        #
        if pass is None or pass == '':
            with open("/dev/urandom", 'rb') as ur_stream:
                ur = ur_stream.read(24)
            pass = base64.b64encode(ur)

        # now the key
        result = subprocess.run(['ssh-keygen', '-t', 'ed25519', '-C', keyfilename, '-N', pass, '-f', keyfile_path])
        # remember what key used
        self.new_keyfile = keyfile_path
        if result.returncode != 0:
            sys.exit("babackup: ssh-keygen failed with code " + str(result.returncode))
        #
        # finally compute the bpass
        # (which is just the pass-phrase, base64 encoded so it's not in cleartext in the conf)
        #
        bpass = str.encode(base64.b64encode(str.encode(pass)))
        return bpass
    

    def configure_new_backup(self):
        """configure a new backup
Update configuration files, generate keys, say what to do on the server, etc.
        """
        self.read_conf()

        if self.new_path is None:
            sys.exit("babackup: attempt to add new backup without specifying --new-path=/client/new/path")

        if self.name is None:
            m = hashlib.sha256()
            m.update(self.new_path)
            self.name = m.hexdigest()[0:16]
        name = self.name
        if len(list(filter(lambda backup: backup.get('name') == name, self.conf['backups']))) > 0:
            sys.exit(f"babackup: attempt to add new backup named {name} that already exists")

        mode = self.new_mode
        if not (mode == 'rrsync' or mode == 'ssh'):
            sys.exit("babackup: must select --new-mode=ssh or --new-mode=rrsync for a new backup")

        # our new baby
        backup = dict()
        backup["name"] = name
        backup["path"] = self.new_path
        backup["mode"] = mode

        if mode == 'rrsync':
            # because we're going to be using rrsync, only the server knows the servers_side_path
            m = re.match(r'^([^:]*:)(.*)$', self.new_userserverpath)
            userserver_only = m.group(1)
            server_side_path = m.group(2)
            if server_side_path[0] != '/':
                server_side_path = "~/" + server_side_path
            backup["userserverpath"] = userserver_only
            backup["server_side_path"] = server_side_path
        elif mode == 'ssh':
            server_side_path = backup["userserverpath"] = self.new_userserverpath
        else:
            sys.exit(f"babackup: interal error, bad mode {mode}")
            
        if self.new_exclude_from is not None:
            backup["exclude_from"] = self.new_exclude_from

        if self.new_interval is not None:
            m = re.match(r'^(\d+)([a-z]?)$', self.new_interval)
            if result is None:
                sys.exit(f"babackup: cannot parse --new-interval={self.new_interval}")
            v = int(m.group(1))
            if m.group(2) == 's':
                pass
            elif m.group(2) == 'm':
                v *= 60
            elif m.group(2) == 'h':
                v *= 60*60
            elif m.group(2) == 'd':
                v *= 24*60*60
            else:
                sys.exit(f"babackup: cannot parse scale on --new-interval={self.new_interval} (I accpet s/m/h/d)")
            backup["interval"] = v

        public_key_info = ''
        if self.new_keyfile is None and mode == 'rrsync':
            # no keyfile, so make one
            # ssh-keygen -t ed25519 -C 'for babackup-{name}' -N '' -f path
            now_isodate = datetime.date.now(tz = datetime.timezone.utc).isoformat()
            keyfilename = f"babackup-{name}-{now_isodate}"
            keyfile_path = os.path.expanduser("~") + f"/.ssh/{keyfilename}"
            if os.path.exists(keyfile_path):
                self.verbose_log(f"babackup: reusing existing ssh key in {keyfile_path}")
            else:
                self.verbose_log(f"babackup: generating new public key to {keyfile_path}")
                bpass = self.create_ssh_keyfile(keyfile_path)
                backup['bpass'] = bpass
            self.new_keyfile = keyfile_path
        if self.new_keyfile is not None:
            backup['keyfile'] = self.new_keyfile

            # and read the public side
            with open(f"{self.new_keyfile}.pub", "r") as keyfile_stream:
                lines = keyfile_stream.readlines()
                if len(lines) != 1:
                    sys.exit(f"babackup: cannot parse public key in {self.new_keyfile}")
                public_key_info = " --new-pub-key='" + lines[0].rstrip() + "'"

        #
        # inform the user what to do on the server
        #
        print("\nto complete configuration of babackup, run this command on the server:")
        print(f"babackup_server --name={name} --new-mode={mode} --new-path={server_side_path}{public_key_info}\n")

        self.conf['backups'].append(backup)
        if self.debug:
            return
        self.write_conf()
        
        
    def write_temp_file(self, filename, contents):
        """write CONTENTS to FILENAME in a (possibly new) temp directory"""
        if self.temp_dir is None:
            self.temp_dir = tempfile.TemporaryDirectory()
        temp_file_path = self.temp_dir.name + "/" + filename
        with open(temp_file_path , "w+") as tf:
            tf.write(contents)
            tf.close()
        # let IOErrors propagate
        return temp_file_path
    
        
    def run_rsync(self, args):
        """run rsync, successfully, with ARGLINE"""
        args.insert(0, '/usr/bin/rsync')
        self.verbose_log(" ".join(args))
        if self.debug:
            return
        result = subprocess.run(args)
        if result.returncode != 0:
            sys.exit("rsync with " + " ".join(args) + " failed with code " + str(result.returncode))

    def backup_run_completes(self, sentinel_path):
        """remove a sentinel file to indicate we're no longer running"""
        if os.path.exists(sentinel_path):
            os.unlink(sentinel_path)

    def backup_already_running(self, backup):
        """check if a BACKUP is already running"""
        sanitized_name = backup['name']
        santizied_name = re.sub(r'[^\w\s-]', '', sanitized_name).strip("-_ \t\n")
        # sigh race
        sentinel_path = f"{self.conf_dir}/on.{sanitized_name}"
        if os.path.exists(sentinel_path):
            if time.time() - os.path.getmtime(sentinel_path) < 24*60*60:
                return True
            else:
                self.verbose_log(f"babackup: overriding old sentinel {sentinel_path}")
        with open(sentinel_path, "a+") as sen_stream:
            sen_stream.write(datetime.datetime.now(tz = datetime.timezone.utc).isoformat() + "\n")
        atexit.register(lambda: self.backup_run_complets(sentinel_path))
        return False

    def run_backup(self, backup):
        """run one backups with configuration BACKUP"""

        name = backup.get("name")
        if name is None:
            sys.exit("babackup: backup is missing 'name:'")
        if backup.get('path') is None:
            sys.exit(f"babackup: backup {name} is missing path")
        mode = backup.get('mode', "ssh")

        #
        # are we already running it?
        #
        if self.backup_already_running(backup):
            if self.force <= 1:
                self.verbose_log(f"babackup: skipping backup {name} that is in progress")
                return
            self.verbose_log(f"babackup: overriding skip of backup {name} that is in progress")

        #
        # do we need to run it?
        #
#        pdb.set_trace()
        now_timestamp = datetime.datetime.now(datetime.timezone.utc).timestamp()
        if self.force > 0:
            self.verbose_log(f"babackup: forcing backing up {name} before")
            else:
                pass
        elif self.remember.get('backup_times') is None or self.remember['backup_times'].get(name) is None:
            self.verbose_log(f"babackup: no record of backing up {name} before")
            else:
                pass
        elif now_timestamp - float(self.remember['backup_times'][name]) > backup.get("new_interval", 23*60*60):
            self.verbose_log(f"babackup: need to do fresh backup of {name}")
            else:
                pass
        else:
            self.verbose_log(f"babackup: no need need backup of {name}, prior is recent")
            return

        #
        # figure out ssh
        #
        rsync_rsh = None
        userserver_sep = ''
        server_path = '.'
        if backup.get('userserverpath') is not None:
            (userserver, server_path) = backup.get('userserverpath').split(":")
            if userserver != 'localhost':
                userserver_sep = userserver + ":"
                #(user, server) = userserver.split("@")
                rsync_rsh = "ssh"
                if backup.get('keyfile'):
                    rsync_rsh += f" -i {backup['keyfile']} "
                os.environ['RSYNC_RSH'] = rsync_rsh
            else:
                server_path = os.path.expanduser(server_path)
        if rsync_rsh is not None:
            self.verbose_log(f"RSYNC_RSH={rsync_rsh}")

        #
        # we're going to begin
        #
        begin_time = datetime.datetime.now(datetime.timezone.utc)
        begin_file_path = self.write_temp_file('begin', begin_time.isoformat(timespec='seconds') + "\n")
        self.run_rsync([begin_file_path, f"{userserver_sep}{server_path}/."])

	#
        # first, create the "data" dir (so we know it exists)
        #
        os.mkdir(f"{self.temp_dir.name}/data")
        self.run_rsync([f"{self.temp_dir.name}/data", f"{userserver_sep}{server_path}/."])

        #
        # now the real work
        # 
        rsync_args = ['-aHb', '-x']
        if backup.get('exclude_from') is not None:
            rsync_args.append("--exclude-from=" + backup['exclude_from'])
        # link-dest is actually a relative path to the actual destination
        if mode  == 'rrsync':
            # as a special case, if we're running under rrsync, we have to anchor the link-dest.  rrsync unanchors it for us on the server side
            rsync_args.append("--link-dest=/last/data")
        elif mode == 'ssh':
            # not under rrsync, so link dest is above the user destination
            rsync_args.append("--link-dest=../last/data")
        else:
            sys.exit(f"babackup: unknown mode {mode}")
        self.run_rsync(['--delete'] + rsync_args + [backup['path'] + "/.", f"{userserver_sep}{server_path}/data"])

        #
        # and we're done
        #
        end_time = datetime.datetime.now(datetime.timezone.utc)
        end_file_path = self.write_temp_file('end', begin_time.isoformat(timespec='seconds') + "\n")
        self.run_rsync([end_file_path, f"{userserver_sep}{server_path}/."])

        self.remember['backup_times'][name] = str(begin_time.timestamp())
        self.remember_changed = True
        self.write_remember()  # sync after every backup, in case of failure

        
    def run_backups(self):
        """run all backups (or whatever was specified with -N)"""
        self.read_conf()
        for backup in self.conf['backups']:
            if (self.name is None or self.name == backup['name']) and backup.get("enabled", True):
                self.run_backup(backup)
        self.write_remember()



if __name__ == '__main__':
    Program()
    sys.exit(0)

