# By: Riasat Ullah
# This file contains all billing related functions.

from analytics.live_call_analytics import LiveCallAnalytics
from dbqueries import db_accounts, db_billings, db_live_call_routing
from billings.bill import Bill, BillItem, DiscountItem
from utils import constants, helpers, var_names
import calendar
import configuration as configs
import datetime
import math


class BillManager(object):

    def __init__(self, conn, timestamp, billing_year, billing_month, with_billable_accounts=None,
                 with_billing_info=None, with_subscriptions=None, with_organization_cards=None,
                 with_discounts=None, with_exchange_rates=None, with_vat_rates=None, perform_duplication_check=True,
                 call_routing_logs=None, with_active_credits=None, with_tax_exemptions=None,
                 with_external_communication_details=None, with_monitoring_checks=None):
        '''
        Constructor
        :param conn: (DBConn object) fb connection
        :param timestamp: (datetime.datetime) timestamp when this request is being made
        :param billing_year: (int) the year this bill is for
        :param billing_month: (int) the month this bill is for
        :param with_billable_accounts: (dict of list of dict) -> {org id: [ {...}, {...} ], org id 2: ...}
        :param with_billing_info: (dict) -> {org id: {iso code: .., currency: ..}, ...}
        :param with_subscriptions: (dict of dict) -> {sub id: {...}, ...}
        :param with_organization_cards: (dict) -> {org id: card id, ...}
        :param with_discounts: (dict) of tuples -> {org id: (discount id, discount percent), ...}
        :param with_exchange_rates: (dict) -> { (from curr, to curr): rate, ... }
        :param with_vat_rates: (dict) -> { iso code: VAT rate, ... }
        :param perform_duplication_check: (boolean) if True runs the avoid_duplication() function
        :param call_routing_logs: (dict of list of lists) call routing logs from the billing period -> {org ID: [[], ]}
        :param with_active_credits: (list of dict) -> [{type: , name: , amount: , used: , ...}, ...]
        :param with_tax_exemptions: (list or org IDs) -> [org_id, ...]
        :param with_external_communication_details: (dict of dict) ->
            { org_id: {'SEND EXTERNAL EMAIL': ..., 'SEND EXTERNAL SMS': {ISO code 1: , ..} .. } ..}
        :param with_monitoring_checks: (dict) -> {org_id: count, ...}
        '''
        self.conn = conn
        self.timestamp = timestamp
        self.billing_year = billing_year
        self.billing_month = billing_month
        self.billing_start, self.billing_end = self.get_billing_period()

        self.billable_accounts = db_billings.get_billable_accounts(self.conn, self.billing_start, self.billing_end)\
            if with_billable_accounts is None else with_billable_accounts

        self.organization_ids = list(self.billable_accounts.keys())

        if perform_duplication_check:
            self.avoid_duplication()

        self.billing_info = db_billings.get_basic_billing_info(
            self.conn, self.billing_start, self.billing_end, self.organization_ids
        ) if with_billing_info is None else with_billing_info

        self.subscriptions = db_accounts.get_subscriptions(self.conn, self.timestamp)\
            if with_subscriptions is None else with_subscriptions

        self.organization_cards = db_billings.get_billable_cards(
            self.conn, self.billing_start, self.billing_end, self.organization_ids)\
            if with_organization_cards is None else with_organization_cards

        self.discounts = db_billings.get_discounts(
            self.conn, self.timestamp, self.organization_ids,
            discount_type=[constants.all_discount_type, constants.user_plan_discount_type]
        ) if with_discounts is None else with_discounts

        self.exchange_rates = db_billings.get_exchange_rates(self.conn, self.billing_year, self.billing_month)\
            if with_exchange_rates is None else with_exchange_rates

        self.vat_rates = db_billings.get_vat_rates(self.conn, self.timestamp)\
            if with_vat_rates is None else with_vat_rates

        self.call_routing_logs = db_live_call_routing.get_live_call_routing_analytics_details(
            self.conn, self.billing_start, self.billing_end, org_id=self.organization_ids)\
            if call_routing_logs is None else call_routing_logs

        self.live_call_country_rates = db_live_call_routing.get_live_call_routing_country_specific_rates(
            self.conn, self.billing_start)

        self.active_credits = db_billings.get_organization_credits(
            self.conn, self.timestamp, self.organization_ids, active=True, split_by_org=True
        ) if with_active_credits is None else with_active_credits

        self.tax_exemptions = db_billings.get_tax_exempted_organizations(self.conn, self.timestamp)\
            if with_tax_exemptions is None else with_tax_exemptions

        self.external_communication_details = db_billings.get_external_communication_details(
            self.conn, self.billing_start, self.billing_end, self.organization_ids
        ) if with_external_communication_details is None else with_external_communication_details

        self.monitoring_checks = db_billings.get_monitor_checks_run(
            self.conn, self.billing_start, self.billing_end, self.organization_ids
        ) if with_monitoring_checks is None else with_monitoring_checks

    def avoid_duplication(self):
        '''
        Gets the list of organization ids that have already been billed for the given year and month
        and removes them from the list of billable organizations to avoid duplication.
        '''
        billed_org_ids = db_billings.get_billed_organization_ids(self.conn, self.billing_year, self.billing_month)
        duped_ids = set(self.organization_ids).intersection(set(billed_org_ids))
        if len(duped_ids) > 0:
            non_duped_ids = set(self.organization_ids).difference(set(billed_org_ids))
            self.organization_ids = list(non_duped_ids)

            for org_id in duped_ids:
                if org_id in self.billable_accounts:
                    del self.billable_accounts[org_id]

    def get_billing_period(self):
        '''
        Get the billing period based on the billing year and month.
        '''
        month_days = calendar.monthrange(self.billing_year, self.billing_month)[1]
        period_start = datetime.date(self.billing_year, self.billing_month, 1)
        period_end = period_start.replace(day=month_days)
        return period_start, period_end

    def get_fx_adjusted_value(self, bill_currency, item_currency, item_value):
        '''
        Get exchange rate adjusted value of an item.
        :param bill_currency: currency to be billed in
        :param item_currency: currency of the item
        :param item_value: value of the item
        :return: (float) exchange rate adjusted value
        '''
        if item_currency != bill_currency:
            ex_rate = self.exchange_rates[(item_currency, bill_currency)]
            item_value = round(item_value * ex_rate, 2)
        return item_value

    def get_description_and_fee(self, item, adjusted=False):
        item_description, sub_fee = item[var_names.description], item[var_names.subscription_fee]
        if item[var_names.is_trial]:
            item_description += ' - Free Trial'
            unit_price = 0
        else:
            # Pro-rate fee to half where subscription was only used for half the month
            for_h1, for_h2 = self.first_and_second_half_of_month_billing_eligibility(item)
            if for_h1 and for_h2:
                item_description += ' - Month usage (Full)'
                unit_price = sub_fee
            else:
                unit_price = sub_fee / 2
                if for_h1:
                    item_description += ' - Month usage (First half)'
                if for_h2:
                    item_description += ' - Month usage (Second half)'

            if adjusted:
                item_description += ' - Adjusted'

        return item_description, unit_price

    def first_and_second_half_of_month_billing_eligibility(self, item):
        for_h1, for_h2 = True, True
        month_half_date = datetime.date(self.billing_year, self.billing_month, 15)
        if self.is_eligible_for_half_month_billing(item):
            if item[var_names.end_period] <= month_half_date:
                for_h1, for_h2 = True, False
            else:
                for_h1, for_h2 = False, True
        return for_h1, for_h2

    @staticmethod
    def is_eligible_for_half_month_billing(item):
        if (item[var_names.end_period] - item[var_names.start_period]).days < configs.half_bill_period:
            return True
        return False

    def create_user_plan_billing_items(self, org_id):
        '''
        Creates billing items for user plans only. All billing items must be converted
        to the billing currency that the organization chose to be billed in.
        :param org_id: ID of the organization getting billed
        :return: (list of BillItem) -> [BillItem 1, BillItem 2, ...]
        :errors: RuntimeError
        '''
        if self.billable_accounts is None:
            raise RuntimeError('Billable accounts have not been provided or retrieved.')
        if org_id not in self.billable_accounts:
            raise RuntimeError('Organization ID was not found in billable accounts - ' + str(org_id))
        if len(self.billable_accounts[org_id]) == 0:
            raise RuntimeError('Organization account subscriptions are missing. ID - ' + str(org_id))

        bill_items = []
        user_plans = []
        bill_currency = self.billing_info[org_id][var_names.billing_currency]
        for item in self.billable_accounts[org_id][var_names.subscriptions]:
            if item[var_names.subscription_type] == constants.base_subscription_type:
                user_plans.append(item)

        user_plans = helpers.sorted_list_of_dict(user_plans, var_names.count, descending=True)
        h1, h2 = [], []
        last_u1, max_f1 = 0, 0
        last_u2, max_f2 = 0, 0
        for item in user_plans:
            h1_count, h1_fee, h1_allocation, is_h1_adj = 0, 0, 0, False
            h2_count, h2_fee, h2_allocation, is_h2_adj = 0, 0, 0, False
            for_h1, for_h2 = self.first_and_second_half_of_month_billing_eligibility(item)
            if for_h1:
                last_u1, max_f1, h1_count, h1_fee, is_h1_adj =\
                    self.calculate_adjusted_allocations(item, last_u1, max_f1)
                h1_allocation = h1_count * h1_fee

            if for_h2:
                last_u2, max_f2, h2_count, h2_fee, is_h2_adj =\
                    self.calculate_adjusted_allocations(item, last_u2, max_f2)
                h2_allocation = h2_count * h2_fee

            h1.append(h1_allocation)
            h2.append(h2_allocation)

            sub_id = item[var_names.subscription_id]
            sub_currency = self.subscriptions[sub_id][var_names.subscription_currency]
            if h1_count == h2_count and h1_fee == h2_fee:
                desc = self.get_description_and_fee(item, adjusted=is_h1_adj)[0]
                unit_price = round(self.get_fx_adjusted_value(bill_currency, sub_currency, h1_fee * 2), 2)
                item_total = round(unit_price * h1_count, 2)
                bill_items.append(BillItem(org_id, sub_id, desc, h1_count, bill_currency, unit_price, item_total))
            else:
                if for_h1:
                    desc = self.get_description_and_fee(item, adjusted=is_h1_adj)[0]
                    unit_price = round(self.get_fx_adjusted_value(bill_currency, sub_currency, h1_fee), 2)
                    item_total = round(unit_price * h1_count, 2)
                    bill_items.append(BillItem(org_id, sub_id, desc, h1_count, bill_currency, unit_price, item_total))
                if for_h2:
                    desc = self.get_description_and_fee(item, adjusted=is_h2_adj)[0]
                    unit_price = round(self.get_fx_adjusted_value(bill_currency, sub_currency, h2_fee), 2)
                    item_total = round(unit_price * h2_count, 2)
                    bill_items.append(BillItem(org_id, sub_id, desc, h2_count, bill_currency, unit_price, item_total))
        return bill_items

    @staticmethod
    def calculate_adjusted_allocations(item, last_u, max_f):
        '''
        Calculates billable users and fee for a subscription after making adjustments as needed
        to avoid double charging when multiple subscriptions are present in the same billing period.
        :param item: (dict) subscription item
        :param last_u: (int) user count of the last subscription processed
        :param max_f: (double) maximum fee charged for user plans in this billing period up until now
        :return: (tuple) -> (updated last user count, updated max fee, billable user count, billable fee, if adjusted)
        '''
        if item[var_names.is_trial]:
            billable_user_count = item[var_names.count]
            billable_fee = 0
            is_adjusted = False
        else:
            is_adjusted = True
            if item[var_names.subscription_fee] > max_f:
                billable_fee = item[var_names.subscription_fee] - max_f
                max_f = item[var_names.subscription_fee]
                if billable_fee == item[var_names.subscription_fee] and is_adjusted:
                    is_adjusted = False
            else:
                billable_fee = 0

            if item[var_names.count] > last_u:
                billable_user_count = item[var_names.count] - last_u
                last_u = billable_user_count
            else:
                if billable_fee > 0:
                    billable_user_count = item[var_names.count]
                    last_u = billable_user_count
                else:
                    billable_user_count = 0

            if billable_user_count == item[var_names.count] and is_adjusted:
                is_adjusted = False

            billable_fee = billable_fee / 2
        return last_u, max_f, billable_user_count, billable_fee, is_adjusted

    def create_stakeholder_billing_items(self, org_id):
        '''
        Create billing items for stakeholder add-on subscription.
        :param org_id: ID of the organization getting billed
        :return: (list of BillItem) -> [BillItem 1, BillItem 2, ...]
        '''
        bill_items = []
        stakeholder_count = self.billable_accounts[org_id][var_names.stakeholder_count]
        if stakeholder_count is not None and stakeholder_count > 0:
            add_on_details = self.subscriptions[configs.stakeholder_add_on_id]
            bill_currency = self.billing_info[org_id][var_names.billing_currency]
            unit_price = round(self.get_fx_adjusted_value(
                bill_currency, add_on_details[var_names.subscription_currency],
                add_on_details[var_names.subscription_fee]
            ), 2)
            item_total = round(unit_price * stakeholder_count, 2)

            bill_items.append(
                BillItem(org_id, configs.stakeholder_add_on_id, add_on_details[var_names.description],
                         stakeholder_count, bill_currency, unit_price, item_total)
            )
        return bill_items

    def create_live_call_routing_billing_items(self, org_id):
        '''
        Create the live call routing billing items of an organization.
        :param org_id: ID of the organization getting billed
        :return: (list of BillItem) -> [BillItem 1, BillItem 2, ...]
        '''
        bill_items = []
        disc_items = []
        desc_incoming_calls_domestic = 'Live call routing - Call minutes (US/Canada)'
        desc_incoming_calls_international = 'Live call routing - Call minutes (International)'
        bill_currency = self.billing_info[org_id][var_names.billing_currency]

        # Billing for live call routing add-on (for plans that do not have the advanced features)
        phone_number_counts = self.billable_accounts[org_id][var_names.live_call_routing_count]
        sub_id = configs.live_call_routing_add_on_id
        lcr_add_on_users_breakdown = [y[var_names.count] for y in [
            x for x in self.billable_accounts[org_id][var_names.subscriptions] if x[var_names.subscription_id] == sub_id
        ]]
        lcr_add_on_users = max(lcr_add_on_users_breakdown) if len(lcr_add_on_users_breakdown) > 0 else 0
        if lcr_add_on_users > 0:
            add_on = self.subscriptions[sub_id]
            add_on_fee = self.get_fx_adjusted_value(
                bill_currency, add_on[var_names.subscription_currency], add_on[var_names.subscription_fee])
            total_fee = round(lcr_add_on_users * add_on_fee, 2)
            bill_items.append(BillItem(org_id, sub_id, add_on[var_names.description],
                                       lcr_add_on_users, bill_currency, add_on_fee, total_fee))

        # Billing for maintaining phone numbers
        if phone_number_counts is not None and len(phone_number_counts) > 0:
            for item in phone_number_counts:
                item_iso = item[var_names.iso_country_code]
                item_phone_type = item[var_names.phone_type]
                item_count = item[var_names.count]
                item_iso_rates = self.live_call_country_rates[item_iso]

                sub_id = configs.tier_1_phone_numbers_supplement_id
                supp = self.subscriptions[sub_id]
                item_desc = supp[var_names.description] + ' - ' + item_iso

                if item_phone_type == constants.local:
                    number_monthly_cost = item_iso_rates[var_names.local_number_rate]
                    item_desc += ' Local'
                elif item_phone_type == constants.mobile:
                    number_monthly_cost = item_iso_rates[var_names.mobile_number_rate]
                    item_desc += ' Mobile'
                elif item_phone_type == constants.toll_free:
                    number_monthly_cost = item_iso_rates[var_names.toll_free_number_rate]
                    item_desc += ' Toll Free'
                else:
                    number_monthly_cost = supp[var_names.subscription_fee]

                supp_fee = self.get_fx_adjusted_value(bill_currency, constants.usd_curr, number_monthly_cost)
                total_fee = round(item_count * supp_fee, 2)
                bill_items.append(BillItem(org_id, sub_id, item_desc, item_count, bill_currency, supp_fee, total_fee))

        # Billing for incoming calls
        if org_id in self.call_routing_logs:
            analytics_maker = LiveCallAnalytics(self.conn, self.timestamp, org_id, self.billing_start, self.billing_end,
                                                call_logs=self.call_routing_logs[org_id], with_mappings=False,
                                                with_volume_discount=True)
            summary = analytics_maker.org_usage_summary()

            dom_minutes = summary[var_names.domestic_minutes]
            dom_cost = self.get_fx_adjusted_value(
                bill_currency, constants.usd_curr, summary[var_names.domestic_cost])
            dom_unit_price = round(float(dom_cost/dom_minutes), 2) if dom_minutes > 0 else 0
            intl_minutes = summary[var_names.international_minutes]
            intl_cost = self.get_fx_adjusted_value(
                bill_currency, constants.usd_curr, summary[var_names.international_cost])
            intl_unit_price = round(float(intl_cost/intl_minutes), 2) if intl_minutes > 0 else 0
            recording_minutes = summary[var_names.recording_minutes]
            vol_disc_item = summary[var_names.volume_discount]

            if dom_cost > 0:
                bill_items.append(BillItem(org_id, configs.live_call_routing_add_on_id, desc_incoming_calls_domestic,
                                           dom_minutes, bill_currency, dom_unit_price, dom_cost))
            if intl_cost > 0:
                bill_items.append(BillItem(org_id, configs.live_call_routing_add_on_id,
                                           desc_incoming_calls_international, intl_minutes, bill_currency,
                                           intl_unit_price, intl_cost))
            if recording_minutes > 0:
                sub_id = configs.call_recording_supplement_id
                supp = self.subscriptions[sub_id]
                supp_fee = self.get_fx_adjusted_value(
                    bill_currency, supp[var_names.subscription_currency], supp[var_names.subscription_fee])
                total_fee = round(recording_minutes * supp_fee, 2)
                bill_items.append(BillItem(org_id, sub_id, supp[var_names.description],
                                           recording_minutes, bill_currency, supp_fee, total_fee))

            if vol_disc_item is not None:
                disc_items.append(DiscountItem(org_id, vol_disc_item[0], vol_disc_item[1],
                                               bill_currency, vol_disc_item[2]))

        return bill_items, disc_items

    def create_external_communication_billing_items(self, org_id):
        '''
        Create the external communication billing items of an organization.
        :param org_id: ID of the organization getting billed
        :return: (list of BillItem) -> [BillItem 1, BillItem 2, ...]
        '''
        bill_items = []
        if org_id in self.external_communication_details:
            email_count = self.external_communication_details[org_id][constants.send_external_email_event]
            sms_breakdown = self.external_communication_details[org_id][constants.send_external_sms_event]
            bill_currency = self.billing_info[org_id][var_names.billing_currency]

            if email_count > configs.free_external_email_count:
                billable_email_count = email_count - configs.free_external_email_count
                sub_id = configs.external_sms_supplement_id
                supp = self.subscriptions[sub_id]
                supp_fee = self.get_fx_adjusted_value(bill_currency, supp[var_names.subscription_currency],
                                                      supp[var_names.subscription_fee])
                total_fee = round(billable_email_count * supp_fee, 2)
                bill_items.append(BillItem(org_id, configs.external_email_supplement_id, supp[var_names.description],
                                           billable_email_count, bill_currency, supp_fee, total_fee))

            for iso_code in sms_breakdown:
                item_count = sms_breakdown[iso_code]
                sub_id = configs.external_sms_supplement_id
                supp = self.subscriptions[sub_id]
                item_desc = supp[var_names.description] + ' - ' + iso_code
                iso_sms_rate = self.live_call_country_rates[iso_code][var_names.text_rate]
                if iso_sms_rate is None:
                    iso_sms_rate = supp[var_names.subscription_fee]

                supp_fee = self.get_fx_adjusted_value(bill_currency, constants.usd_curr, iso_sms_rate)
                total_fee = round(item_count * supp_fee, 2)
                bill_items.append(BillItem(org_id, sub_id, item_desc, item_count, bill_currency, supp_fee, total_fee))

        return bill_items

    def create_monitoring_checks_billing_items(self, org_id):
        '''
        Create billing items for monitoring checks.
        :param org_id: ID of the organization getting billed
        :return: (list of BillItem) -> [BillItem 1, BillItem 2, ...]
        '''
        bill_items = []
        if org_id in self.monitoring_checks:
            checks_run = self.monitoring_checks[org_id]
            bill_currency = self.billing_info[org_id][var_names.billing_currency]
            billable_checks_count = checks_run - configs.free_monitor_checks_count
            sub_id = configs.checks_supplement_id
            supp = self.subscriptions[sub_id]
            supp_fee = self.get_fx_adjusted_value(
                bill_currency, supp[var_names.subscription_currency],
                supp[var_names.subscription_fee] if checks_run > configs.free_monitor_checks_count else 0
            )
            # Checks are billed for every 10,000 request
            billable_checks_multiplier = math.ceil(billable_checks_count/10000)
            total_checks_multiplier = math.ceil(checks_run/10000)
            total_fee = round(billable_checks_multiplier * supp_fee, 2)

            bill_items.append(BillItem(org_id, configs.checks_supplement_id, supp[var_names.description],
                                       total_checks_multiplier, bill_currency, supp_fee, total_fee))

        return bill_items

    def create_sso_add_on_billing_items(self, org_id):
        '''
        Create billing items for SSO add-on subscription.
        :param org_id: ID of the organization getting billed
        :return: (list of BillItem) -> [BillItem 1, BillItem 2, ...]
        '''
        bill_items = []
        bill_currency = self.billing_info[org_id][var_names.billing_currency]

        sub_id = configs.sso_add_on_id
        sso_users_breakdown = [y[var_names.count] for y in [
            x for x in self.billable_accounts[org_id][var_names.subscriptions] if x[var_names.subscription_id] == sub_id
        ]]
        sso_users = max(sso_users_breakdown) if len(sso_users_breakdown) > 0 else 0
        if sso_users > 0:
            add_on_details = self.subscriptions[sub_id]
            add_on_currency = add_on_details[var_names.subscription_currency]
            add_on_fee = self.get_fx_adjusted_value(bill_currency, add_on_currency,
                                                    add_on_details[var_names.subscription_fee])
            total_fee = round(sso_users * add_on_fee, 2)
            bill_items.append(BillItem(org_id, sub_id, add_on_details[var_names.description],
                                       sso_users, bill_currency, add_on_fee, total_fee))
        return bill_items

    def create_bills(self, skip_organizations=None):
        '''
        Creates the bills for all the organizations that are billable in this BillManager.
        Every organization will only have 1 Bill for a billing period.
        :return: (list) of Bill objects
        '''
        bills = []
        for org_id in self.billable_accounts:

            if skip_organizations is not None and org_id in skip_organizations:
                continue

            lcr_bill_items, lcr_disc_items = self.create_live_call_routing_billing_items(org_id)

            # Get BillItem(s) from the organization's alerting subscription first (every account must have
            # an alerting subscription plan). Then add BillItem(s) from other subscriptions if there are any.
            bill_items = self.create_user_plan_billing_items(org_id) + self.create_stakeholder_billing_items(org_id) +\
                lcr_bill_items + self.create_sso_add_on_billing_items(org_id) +\
                self.create_external_communication_billing_items(org_id) +\
                self.create_monitoring_checks_billing_items(org_id)

            org_country = self.billing_info[org_id][var_names.iso_country_code]
            org_country_vat = self.vat_rates[org_country] if org_country in self.vat_rates else 0
            billing_currency = self.billing_info[org_id][var_names.billing_currency]
            card_id = self.organization_cards[org_id] if org_id in self.organization_cards else None

            org_bill = Bill(
                org_id, self.billing_year, self.billing_month, self.timestamp.date(), self.billing_start,
                self.billing_end, bill_items, billing_currency, card_id, org_country, org_country_vat,
                discount_details=self.discounts[org_id] if org_id in self.discounts else None,
                additional_discount_items=lcr_disc_items,
                credits_available=self.active_credits[org_id] if org_id in self.active_credits else None,
                is_tax_exempted=True if self.tax_exemptions is not None and org_id in self.tax_exemptions else False,
                exchange_rates=self.exchange_rates
            )
            has_non_zero_dollar_bill = True if org_bill.total() > 0 else False
            if org_id not in self.organization_cards and has_non_zero_dollar_bill:
                raise RuntimeError('Non zero bill found for organization without a payment card - ' + str(org_id))

            bills.append(org_bill)

        return bills
