#!/usr/bin/python3

#
# playbook_tuner.py
#
# Copyright (C) 2021 by University of Southern California
# Written by ASM Rizvi<asmrizvi@usc.edu>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
#


import argparse
import operator
import sys

policyToDistribution = {}
foundPolicies = {}
noOfSites = 1
sites = []
limits = []


class Program:
    def __init__(self):
        self.arg = self.parse_args()

    def parse_args(self):
        parser = argparse.ArgumentParser(description = 'Given the anycast site list, capacities, and BGP playbook, this program suggests the best BGP configuration during a DDoS attack.', epilog="""

1. Build the playbook from a given file, 
2. Find out the sites that require reduction in load, 
3. Find out the policies that will reduce traffic from the overloaded sites, 
4. Find out the policy that will keep the maximum loaded site with a minimum load, 
5. If some other sites overloaded, pick a method that will reduce traffic from the 
previously overwhelemed sites, but give least possible load to the newly overloaded sites, 
6. If the previous overloaded sites are still overloaded, select other rules 
that will reduce more traffic, 
7. If no method works, escalate the problem to the operators

# Input: setup/playbook directory, and number of retries
# Output: suggested BGP change
# playbook structure:
# No. of sites
# site-1	site-2	site-3
# limit-1	limit-2	limit-3
# bgp_change_1	site-1-traffic	site-2-traffic	site-3-traffic
# bgp_change_2	site-1-traffic	site-2-traffic	site-3-traffic
# bgp_change_3	site-1-traffic	site-2-traffic	site-3-traffic

# Sample setup/playbook: given playbook.txt file
# Sample load distribution: 85000.2895	16000.1707	16978.5398
# Sample output: You have these sites: 3
Your playbook has 22 policies
Overloaded site: AMS
Suggested config: 1AMS,	Estimated load distribution: 41292.65	29494.75	41292.65
Other configs: Poison-Tier-1,	Estimated load distribution: 41292.65	29494.75	41292.65
Other configs: Poison-Tier-2,	Estimated load distribution: 41292.65	29494.75	41292.65

Command: 
Giving input about an anycast setup, playbook, and attack to find the best BGP config:

        playbook_tuner
        playbook_tuner --setup "playbook.txt"
        playbook_tuner --setup "playbook.txt" --retries 3

Output: Best BGP config.

        """)
        
        parser.add_argument('--setup', '-s', default='playbook.txt', help='Enter the setup file with sites, capacities, and BGP configurations.')
        parser.add_argument('--retries', '-r', default='3', help='Enter the setup file with sites, capacities, and BGP configurations.')
        
        args = parser.parse_args()

        return args
        
def readSetupFile(file):
    global noOfSites
    global sites
    global limits
    
    file = open(file, 'r')
    lines = file.readlines()
    count = 0
    
    for line in lines:
        line = line.strip()
        # print(line)
        if count == 0:
            noOfSites = int(line)
        elif count == 1:
            #print("LINE: " + line)
            sites = line.split("\t")
            #print(sites)
        elif count == 2:
            limits = line.split("\t")
        else:
            configs = line.strip().split("\t")
            config = configs[0]
            
            distribution = ""
            for i in range(1, len(configs)):
                if i == 1:
                    distribution = distribution + configs[i]
                else:
                    distribution = distribution + "\t" + configs[i]
            
            
            # distribution = configs[configs.index('\t') + 1]
            policyToDistribution[config] = distribution.strip()
        count = count + 1
    file.close()

def checkAttack(retries):
    global noOfSites
    global sites
    global limits
    global foundPolicies

    numberOfTries = 0
    for line in sys.stdin:
       
        if line == "" or "#" in line:
            continue
    
        # print ("LINE: " + line)
        x = line.strip().split("\t")
        siteToExcessLoad = {}
        attackFlag = 0
        total = 0.0
        totalLimit = 0.0
        
        for j in range(noOfSites):
            #print(sites)
            currentSite = sites[j]
            total = total + float(x[j])
            totalLimit = totalLimit + float(limits[j])
            if (float(x[j]) > float(limits[j])):
                excessLoad = float(x[j]) - float(limits[j])
                siteToExcessLoad[currentSite] = excessLoad
                attackFlag = 1
            
    
        if total > totalLimit:
            print("Total load exceeds total limit. It's probably best to absorb.")
            
        if attackFlag == 0:
            print("No overloaded site!")
            continue
        else:
            if numberOfTries >= int(retries):
                print("Send the problem to the operator.")
                sys.exit(0)
            else:
                if numberOfTries == 0:
                    foundPolicies = policyToDistribution 
                findSolution(siteToExcessLoad, total, totalLimit)
                numberOfTries = numberOfTries + 1 
       


def findSolution(siteToExcessLoad, total, totalLimit):
    global noOfSites
    global sites
    global limits
    global foundPolicies
    
    
    # print(siteToExcessLoad)
    #foundPolicies = policyToDistribution

    if len(foundPolicies) == 0:
        print("No routing policy can mitigate!")
    
    for key in siteToExcessLoad:
        print("Overloaded site: " + key)
        base = foundPolicies["Baseline"]
        #print(base)
        splitted = base.strip().split("\t")
        baselineValues = []
        for j in range(noOfSites):
            # print(splitted[j])
            baselineValues.append(int(splitted[j]))
        
        for policy in foundPolicies:
            distribution = foundPolicies[policy]
            if distribution == "":
                continue
                
            if policy == "Baseline":
                continue
            
            splitted = distribution.strip().split("\t")
            
                
            for j in range(noOfSites):
                if key == sites[j] and int(splitted[j]) >= baselineValues[j]:
                    foundPolicies[policy] = ""
                
    
    # print(foundPolicies)
    allLowest = {}
    estimatedPolicyToDistribution = {}
    breakLoop = 0
    for policy in foundPolicies:
        value = foundPolicies[policy]
        if value == "":
           continue
        
        l = []
        distribution = value.strip().split("\t")
            
        newDist = ""

        for j in range(noOfSites):
            tempDist = total * ((float)(distribution[j]) / 100.0)
            # print(tempDist)
            if j == 0:
                newDist = newDist + str(tempDist) 
            else: 
                newDist = newDist + "\t" + str(tempDist)
                
            l.append(float(limits[j]) - tempDist)
            
        estimatedPolicyToDistribution[policy] = newDist

        allLowest[policy] = min(l)
        
        sorted_x = sorted(allLowest.items(), key=operator.itemgetter(1), reverse=True)
            
    i = 0

    breakVal = 0
    
    for key in sorted_x:
        # print(key[0])
        vals = estimatedPolicyToDistribution[key[0]].split("\t")
        # print(vals)        
        for j in range(noOfSites):
            # print(str(vals[j]) +"\t" + str(limits[j]))
            if float(vals[j]) > float(limits[j]):
                if i == 0:
                    print("No method works! It's probably better to absorb.")
                    
                    
                breakVal = 1
                break
                        
        if breakVal == 1:
            break
        if i == 0:
            if key[0] == "Baseline":
                continue
            print("Suggested config: " + key[0] + ",\tEstimated load distribution: " + estimatedPolicyToDistribution[key[0]])
            
            foundPolicies["Baseline"] = policyToDistribution[key[0]]
        else:
            print("Other configs: " + key[0] + ",\tEstimated load distribution: " + estimatedPolicyToDistribution[key[0]])    
            
        i = i + 1
            
def main():

    global policyToDistribution
    global noOfSites
    global sites
    global limits
    
    args = Program()
    setupFile = args.arg.setup
    retries = args.arg.retries
    readSetupFile(setupFile)

    print("You have these sites: " + str(len(sites))) 
    print("Your playbook has " + str(len(policyToDistribution)) + " policies")

    checkAttack(retries)

    
if __name__ == "__main__":
    main()
    sys.exit(0)
