# By: Riasat Ullah
# This file contains the following classes - Bill, BillItem, Charge, DiscountItem.

from utils import constants, var_names
import configuration
import datetime
import json


class Bill(object):

    def __init__(self, organization_id, year, month, billing_date, billing_start, billing_end, billing_items,
                 currency, card_id, vat_country, vat_percent, discount_details=None, additional_discount_items=None,
                 credits_available=None, is_tax_exempted=False, exchange_rates=None):
        '''
        This class represents a single bill for an organization for a given phone country code.
        Only outgoing usages are billed.
        :param organization_id: (int) organization ID
        :param year: (int) billing year
        :param month: (int) billing month
        :param billing_date: (datetime.date) the date this bill is being created on
        :param billing_start: (datetime.date) the start of this bill's billing period
        :param billing_end: (datetime.date) the end of this bill's billing period
        :param billing_items: (list) of BillingItem objects
        :param currency: (str) the currency this bill is to be charged in
        :param card_id: (int) ID of the card given by TaskCall
        :param vat_country: 2 letters ISO code of the country the VAT is for
        :param vat_percent: (float) the percentage of the VAT (0 where there is none)
        :param discount_details: (dict) -> {discount_id: , discount_percent: , discount_type: , reason: }
        :param additional_discount_items: (list) of DiscountItem objects (from volume discounts, promos, etc.)
        :param credits_available: (dict) -> {} {credit id: id, ...}
        :param is_tax_exempted: (boolean) True if organization is exempted from taxation; False otherwise
        :param exchange_rates: (dict) of exchange rates
        '''
        self.organization_id = organization_id
        self.year = year
        self.month = month
        self.billing_date = billing_date
        self.billing_start = billing_start
        self.billing_end = billing_end
        self.billing_items = billing_items
        self.currency = currency
        self.card_id = card_id
        self.vat_country = vat_country
        self.vat_percent = vat_percent
        self.discount_details = discount_details
        self.additional_discount_items = additional_discount_items
        self.credits_available = credits_available
        self.is_tax_exempted = is_tax_exempted
        self.exchange_rates = exchange_rates

        self.validate_billing_items()

    def validate_billing_items(self):
        '''
        Make sure each of the BillingItem details match the Bill details
        '''
        for item in self.billing_items:
            assert isinstance(item, BillItem)
            if item.organization_id != self.organization_id and item.currency != self.currency:
                raise AssertionError('Billing item organization does not match that in the bill. Organization id - '
                                     + str(self.organization_id))

    def subtotal(self):
        '''
        Calculate the total before accounting for any deductions.
        :return: (float) subtotal
        '''
        sub_total = 0
        for item in self.billing_items:
            sub_total += item.amount
        return round(sub_total, 2)

    def has_account_wide_discount(self):
        '''
        Checks if an organization has account wide discount or not.
        :return: True if they do; False otherwise
        '''
        if self.discount_details is not None and\
                self.discount_details[var_names.discount_type] == constants.all_discount_type:
            return True
        else:
            return False

    def general_discount_items(self):
        '''
        Get the DiscountItem objects of the general discount that will be applied to the bill.
        :return: (list) DiscountItem objects
        '''
        sub_total, discount, discount_items = 0, 0, []
        if self.discount_details is not None:
            for item in self.billing_items:
                if self.discount_details[var_names.discount_type] == constants.all_discount_type or\
                        item.subscription_id in configuration.discount_eligible_user_plan_subscription_ids:
                    sub_total += item.amount

            if sub_total > 0:
                amount = self.discount_details[var_names.discount]
                perc = self.discount_details[var_names.discount_percent]

                # For fixed amounts
                if amount is not None and amount > 0:
                    if amount > sub_total:
                        discount = sub_total
                    else:
                        discount = sub_total - amount

                elif perc is not None and perc <= 100:
                    discount = round(sub_total * perc / 100, 2)

                if discount > 0:
                    discount_items.append(
                        DiscountItem(self.organization_id, self.discount_details[var_names.discount_id],
                                     self.discount_details[var_names.reason], self.currency, discount)
                    )

        return discount_items

    def general_discount(self):
        '''
        Calculate the applicable general discount.
        :return: (float) general discount
        '''
        disc_items = self.general_discount_items()
        disc = 0
        for item in disc_items:
            disc += item.discount
        return round(disc, 2)

    def additional_discount(self):
        '''
        Get the sum of discounts that are non-general like volume discounts, promos, etc. Additional discounts only
        apply when an account does not already have an account wide (ALL) discount.
        :return: (float) additional discount
        '''
        disc = 0
        if self.additional_discount_items is not None and not self.has_account_wide_discount():
            for item in self.additional_discount_items:
                disc += item.discount
        return round(disc, 2)

    def credits_applied_details(self):
        '''
        The remaining balance of credits remaining.
        :return: (dict) -> {credit_id: amount, ...}
        '''
        applicable_credits = dict()
        if self.credits_available is not None:
            subtotal_left = self.subtotal_after_discounts()
            for item in self.credits_available:
                cred_remaining = item[var_names.credit_amount] - item[var_names.credit_used]
                if cred_remaining > 0:
                    cred_curr = item[var_names.credit_currency]

                    # Credits are only issued in USD, but here we are using the credit currency for cleanliness.
                    if self.currency != cred_curr:
                        if self.exchange_rates is None:
                            raise RuntimeError(
                                'Exchange rates not provided. Cannot calculate credits in billing. Organization ID - ' +
                                str(self.organization_id)
                            )
                        ex_rate = self.exchange_rates[(cred_curr, self.currency)]
                        reverse_ex_rate = 1/ex_rate

                        # Convert the subtotal to the credit currency.
                        subtotal_left = round(subtotal_left * reverse_ex_rate, 2)

                    if cred_remaining >= subtotal_left:
                        applicable_credits[item[var_names.credit_id]] = subtotal_left
                        break
                    else:
                        applicable_credits[item[var_names.credit_id]] = cred_remaining
                        subtotal_left = subtotal_left - cred_remaining

            if len(applicable_credits) > 0:
                return applicable_credits
        return None

    def credits(self):
        '''
        The total of all credits applied
        :return: (float) total amount of credits applied
        '''
        applied_credits = self.credits_applied_details()
        if applied_credits is None:
            return 0
        else:
            return round(sum(list(applied_credits.values())), 2)

    def subtotal_after_discounts(self):
        '''
        Get the subtotal of the bill after only discounts are applied.
        :return: (float) discounted subtotal
        '''
        disc_sub = round(self.subtotal() - self.general_discount(), 2)

        # Apply the additional discounts only if there is still a payable after the general discount.
        if disc_sub > 0:
            disc_sub = disc_sub - self.additional_discount()

        # Subtotals can never be negative. If discounts are more than the subtotal, then just set it to zero.
        if disc_sub < 0:
            disc_sub = 0
        return disc_sub

    def total_discount(self):
        '''
        Total discount applied on the account. This is not the same as the sum of all the discounts.
        If discounts go over the subtotal available then they are reduced to match the subtotal.
        :return: (float) total discount
        '''
        return self.subtotal() - self.subtotal_after_discounts()

    def subtotal_after_deductions(self):
        '''
        Get the subtotal of the bill after both discounts and credits are applied.
        :return: (float) reduced subtotal
        '''
        reduced_sub = self.subtotal_after_discounts() - self.credits()
        if reduced_sub < 0:
            reduced_sub = 0
        return reduced_sub

    def vat(self, sub_total):
        '''
        Calculate the VAT payable given a subtotal.
        :param sub_total: (float) subtotal
        :return: (float) VAT amount
        '''
        return round(sub_total * (self.vat_percent / 100), 2) if not self.is_tax_exempted else 0

    def total(self):
        '''
        Calculate the total to charge after accounting for all discounts.
        :return: (float) total
        '''
        reduced_subtotal = self.subtotal_after_deductions()
        vat_payable = self.vat(reduced_subtotal)
        total = reduced_subtotal + vat_payable
        return round(total, 2)

    def get_billing_items_json(self):
        '''
        Get the json of the BillingItem objects of this Bill.
        :return: (json) of the BillingItem objects
        '''
        all_items = []
        for item in self.billing_items:
            all_items.append(item.to_dict())
        return json.dumps(all_items)

    def db_format_discounts(self):
        '''
        Get the applied discount details in the format expected by the billings query.
        :return: (json) [{credit_id: , credit_used: }, ...]  |  None if no credits were applied
        '''
        disc_items = self.general_discount_items()

        # Additional discount does not apply when there is an account wide discount.
        if self.additional_discount_items is not None and self.additional_discount() > 0:
            disc_items = disc_items + self.additional_discount_items

        if len(disc_items) == 0:
            return None
        else:
            fmt_disc = []
            for dsc_obj in disc_items:
                fmt_disc.append(dsc_obj.to_dict())
            return json.dumps(fmt_disc)

    def db_format_credits(self):
        '''
        Get the applied credits details in the format expected by the billings query.
        :return: (json) [{credit_id: , credit_used: }, ...]  |  None if no credits were applied
        '''
        cred_det = self.credits_applied_details()
        if cred_det is None:
            return None
        else:
            fmt_crd = []
            for cr_id in cred_det:
                fmt_crd.append({var_names.credit_id: cr_id, var_names.credit_used: cred_det[cr_id]})
            return json.dumps(fmt_crd)

    def db_query_params(self, timestamp, payment_confirmation=None):
        '''
        Get the query parameters for creating Bill entries in the database
        :param timestamp: timestamp when this query is being requested to be generated
        :param payment_confirmation: the payment confirmation ID from the vendor
        :return: (tuple) of query params (18 in total)
        '''
        assert isinstance(timestamp, datetime.datetime)
        sub_total = self.subtotal()
        discount_applied = self.total_discount()
        credits_applied = self.credits()
        reduced_sub_total = self.subtotal_after_deductions()
        vat_payable = self.vat(reduced_sub_total)
        total = self.total()
        bill_status = constants.open_state if total > 0 else constants.closed

        # The billing end that will be reported should be the day before the
        # billing end used for querying the database. The billing end used for
        # queries needed to be a day more than what will be reported as time
        # based queries will generate results exclusive of the date. So, by
        # increasing it by a day, we are able to retrieve the correct data.
        # reportable_billing_end = self.billing_end - datetime.timedelta(days=1)
        reportable_billing_end = self.billing_end

        query_params = (bill_status, self.billing_start, reportable_billing_end, timestamp.date(),
                        self.year, self.month, self.organization_id, self.card_id,
                        self.currency, sub_total, discount_applied, credits_applied,
                        reduced_sub_total, self.vat_country, self.vat_percent, vat_payable,
                        total, self.get_billing_items_json(), self.db_format_discounts(), self.db_format_credits(),
                        constants.handler_stripe, payment_confirmation,)

        return query_params


class BillItem(object):
    '''
    Represents the most granular billing item. 'Monthly' and 'Excess' usages of the same
    usage types must be treated as separate BillingItem objects.
    '''
    def __init__(self, organization_id, subscription_id, description, quantity, currency, unit_price, amount):
        self.organization_id = organization_id
        self.subscription_id = subscription_id
        self.description = description
        self.quantity = quantity
        self.currency = currency
        self.unit_price = float(unit_price)
        self.amount = float(amount)

    def to_dict(self):
        '''
        Get the dict of a BillingItem object
        :return: (dict) of BillingItem
        '''
        data = {
            var_names.organization_id: self.organization_id,
            var_names.subscription_id: self.subscription_id,
            var_names.item_description: self.description,
            var_names.item_quantity: self.quantity,
            var_names.billing_currency: self.currency,
            var_names.subscription_fee: self.unit_price,
            var_names.item_total: self.amount
        }
        return data


class Charge(object):

    def __init__(self, charge_id, organization_id, card_id, handler, card_token, currency, amount, side,
                 payment_confirmation=None):

        # card_token is a dictionary

        self.charge_id = charge_id
        self.organization_id = organization_id
        self.card_id = card_id
        self.handler = handler
        self.card_token = card_token
        self.currency = currency
        self.amount = amount
        self.side = side
        self.payment_confirmation = payment_confirmation

    def db_query_params(self, timestamp):
        '''
        Get the query for creating a charge in the database.
        :param timestamp: timestamp when this query is being requested to be created on
        :return: (str) database query
        :errors: RuntimeError
        '''
        if self.payment_confirmation is None:
            raise RuntimeError('No payment confirmation was found')
        query_params = (self.charge_id, timestamp.date(), constants.handler_stripe, self.payment_confirmation,)
        return query_params


class DiscountItem(object):
    '''
    Represents a discount item.
    '''
    def __init__(self, organization_id, discount_id, description, currency, discount):
        self.organization_id = organization_id
        self.discount_id = discount_id
        self.description = description
        self.currency = currency
        self.discount = float(discount)

    def to_dict(self):
        '''
        Get the dict of a BillingItem object
        :return: (dict) of BillingItem
        '''
        data = {
            var_names.organization_id: self.organization_id,
            var_names.discount_id: self.discount_id,
            var_names.description: self.description,
            var_names.billing_currency: self.currency,
            var_names.discount: self.discount
        }
        return data
