# By: Riasat Ullah
# This file contains heartbeat check related database queries.

from objects.check import Check
from psycopg2 import errorcodes
from utils import constants, key_manager, times, var_names
from uuid import UUID
from validations import string_validator
import configuration as configs
import datetime
import psycopg2


def list_heartbeats(conn, timestamp, organization_id, user_id, is_healthy=None, service_ref_ids=None, search_words=None,
                    row_limit=None, row_offset=None, check_adv_perm=False):
    '''
    Gets the list of heartbeats given certain parameters.
    :param conn: db connection
    :param timestamp: timestamp this request is being made on
    :param organization_id: ID of the organization to filter by
    :param user_id: ID of the user_id making the request
    :param is_healthy: True if only passing heartbeat checks are requested; False otherwise
    :param service_ref_ids: (list) of concealed service reference IDs to filter by
    :param search_words: (str) keywords to filter by
    :param row_limit: (optional) number of members to fetch
    :param row_offset: (optional) number of members to skip ahead
    :param check_adv_perm: (boolean) True if advanced team permissions should be checked
    :return: (list) of dict of service details
    :errors: AssertionError, DatabaseError
    '''
    assert isinstance(timestamp, datetime.datetime)
    assert isinstance(organization_id, int)
    assert isinstance(user_id, int)

    check_start = timestamp - datetime.timedelta(days=30)
    query_params = {'timestamp': timestamp, 'org_id': organization_id, 'usr_id': user_id,
                    'comp_type_id': configs.service_component_type_id, 'chk_start': check_start,
                    'chk_type': constants.heartbeat}
    conditions = []

    if is_healthy is not None:
        if is_healthy:
            conditions.append(" is_healthy ")
        else:
            conditions.append(" not is_healthy ")

    if service_ref_ids is not None:
        assert isinstance(service_ref_ids, list)
        conditions.append('''
            chk.serviceid in (
                select serviceid from services
                where start_timestamp <= %(timestamp)s
                    and end_timestamp > %(timestamp)s
                    and service_ref_id = any(%(srv_ref_list)s)
            )
        ''')
        query_params['srv_ref_list'] = [key_manager.unmask_reference_key(x) for x in service_ref_ids]

    if search_words is not None:
        assert isinstance(search_words, str)
        conditions.append('''(
            (LOWER(chk.check_name || chk.ping_url || chk.ping_email) like '%%' || LOWER(%(search_words)s) || '%%')
            or (chk.description is not null and LOWER(chk.description) like '%%' || LOWER(%(search_words)s) || '%%')
        )''')
        query_params['search_words'] = search_words

    if check_adv_perm:
        conditions.append('''
            chk.serviceid not in (
                select component_id from components_user_cannot_view(
                    %(timestamp)s, %(org_id)s, %(usr_id)s, %(comp_type_id)s::smallint
                )
            )
        ''')

    limit_cond = ''
    if row_limit is not None:
        assert isinstance(row_limit, int)
        limit_cond += ' limit {0} '.format(str(row_limit))
    if row_offset is not None:
        assert isinstance(row_offset, int)
        limit_cond += ' offset {0} '.format(str(row_offset))

    query = '''
            with t1 as (
                select check_id, check_ref_id, check_type, check_name, chk.description, chk.is_enabled, ping_type,
                    ping_url, ping_email, ping_interval, grace_period, srv.serviceid, srv.service_name,
                    srv.service_ref_id, is_healthy, last_run_timestamp, next_run_timestamp
                from monitor_checks as chk
                join services as srv
                    on srv.serviceid = chk.serviceid
                        and srv.start_timestamp <= %(timestamp)s
                        and srv.end_timestamp > %(timestamp)s
                where chk.organization_id = %(org_id)s
                    and chk.start_timestamp <= %(timestamp)s
                    and chk.end_timestamp > %(timestamp)s
                    and chk.check_type = %(chk_type)s
                    {0}
            )
            , t2 as (
                select component_id as serviceid, json_agg(json_build_object(
                    'team_ref_id', team_ref_id,
                    'team_name', team_name
                )) as service_teams
                from team_components as tco
                join teams using (team_id)
                where tco.start_timestamp <= %(timestamp)s
                    and tco.end_timestamp > %(timestamp)s
                    and tco.component_id in (select serviceid from t1)
                    and tco.component_type_id = %(comp_type_id)s
                    and teams.organization_id = %(org_id)s
                    and teams.start_timestamp <= %(timestamp)s
                    and teams.end_timestamp > %(timestamp)s
                group by serviceid
            )
            , t3 as (
                select check_id, passed,
                    timezone(org.organization_timezone, timezone('UTC', scheduled_timestamp))::date as check_date
                from (
                    select chk.organization_id, mcl.check_id, mcl.scheduled_timestamp, passed
                    from monitor_check_logs as mcl
                    join monitor_checks as chk
                        on chk.check_id = mcl.check_id
                            and chk.start_timestamp <= mcl.scheduled_timestamp
                            and chk.end_timestamp > mcl.scheduled_timestamp
                    where mcl.check_id in (select check_id from t1)
                        and mcl.scheduled_timestamp >= %(chk_start)s
                        and mcl.scheduled_timestamp <= %(timestamp)s
                        and not mcl.passed
                ) as sub_table
                join organizations as org
                    on org.organization_id = sub_table.organization_id
                        and org.start_timestamp <= %(timestamp)s
                        and org.end_timestamp > %(timestamp)s
            )
            , t4 as (
                select check_id, json_object_agg(check_date, check_count) as periodic_health
                from (
                    select check_id, check_date, count(passed) as check_count
                    from t3
                    group by check_id, check_date
                )
                group by check_id
            )
            select check_ref_id, check_name, description, is_enabled, ping_type, ping_url, ping_email,
                ping_interval, grace_period, service_name, service_ref_id, service_teams, is_healthy,
                last_run_timestamp, next_run_timestamp, t4.periodic_health
            from t1
            left join t2 using(serviceid)
            left join t4 using(check_id)
            order by check_name asc;
            '''.format(' and ' + ' and '.join(conditions) if len(conditions) > 0 else '', limit_cond)
    try:
        result = conn.fetch(query, query_params)
        data = []
        for ref_id, chk_name, desc, is_enbl, png_type, png_url, png_email, png_interval, png_grace, \
                srv_name, srv_ref, srv_teams, is_up, last_run, next_run, periodic_health in result:

            periodic_health = dict() if periodic_health is None else periodic_health
            dt_range = [check_start + datetime.timedelta(days=i) for i in range(0, (timestamp - check_start).days + 1)]
            formatted_ph = []
            for dt_ in dt_range:
                dt_str_ = datetime.datetime.strftime(dt_, constants.date_hyphen_format)
                count = periodic_health[dt_str_] if dt_str_ in periodic_health else 0
                formatted_ph.append({var_names.period: dt_, var_names.count: count})

            data.append({
                var_names.check_ref_id: key_manager.conceal_reference_key(ref_id),
                var_names.check_name: chk_name,
                var_names.description: desc,
                var_names.is_enabled: is_enbl,
                var_names.ping_type: png_type,
                var_names.url: png_url,
                var_names.email: png_email,
                var_names.interval: png_interval,
                var_names.grace_period: png_grace,
                var_names.is_healthy: is_up,
                var_names.last_run: last_run,
                var_names.next_run: next_run,
                var_names.service: [srv_name, key_manager.conceal_reference_key(srv_ref)],
                var_names.teams: [[item[var_names.team_name],
                                   key_manager.conceal_reference_key(UUID(item[var_names.team_ref_id]))]
                                  for item in srv_teams] if srv_teams is not None else [],
                var_names.last_30_days: formatted_ph
            })

        return data
    except psycopg2.DatabaseError:
        raise


def get_heartbeat_details(conn, timestamp, organization_id, user_id, check_ref_id, check_adv_perm=False):
    '''
    Get the details of a heartbeat check.
    :param conn: db connection
    :param timestamp: timestamp when this request is being made
    :param organization_id: ID of the organization the check belongs to
    :param user_id: ID of the user who initiated the request
    :param check_ref_id: (concealed) reference ID of the check
    :param check_adv_perm: (boolean) True if advanced team permissions should be checked
    :return: (dict) of heartbeat check details
    '''
    assert isinstance(timestamp, datetime.datetime)
    assert isinstance(organization_id, int)
    assert isinstance(user_id, int)
    unmasked_check_ref = key_manager.unmask_reference_key(check_ref_id)

    query_params = {'timestamp': timestamp, 'org_id': organization_id,
                    'chk_ref': unmasked_check_ref, 'chk_type': constants.heartbeat}
    conditions = []
    if check_adv_perm:
        conditions.append('''
            chk.serviceid not in (
                select component_id from components_user_cannot_view(
                    %(timestamp)s, %(org_id)s, %(usr_id)s, %(comp_type_id)s::smallint
                )
            )
        ''')
        query_params['usr_id'] = user_id
        query_params['comp_type_id'] = configs.service_component_type_id

    query = '''
            with t1 as (
                select check_id, check_ref_id, check_type, check_name, chk.description, chk.is_enabled, ping_type,
                    ping_url, ping_email, ping_interval, grace_period, incident_title, incident_description,
                    urgency_level, ip_whitelist, email_whitelist, srv.service_name, srv.service_ref_id,
                    is_healthy, last_run_timestamp, next_run_timestamp
                from monitor_checks as chk
                join services as srv
                    on srv.serviceid = chk.serviceid
                        and srv.start_timestamp <= %(timestamp)s
                        and srv.end_timestamp > %(timestamp)s
                where chk.start_timestamp <= %(timestamp)s
                    and chk.end_timestamp > %(timestamp)s
                    and chk.organization_id = %(org_id)s
                    and check_ref_id = %(chk_ref)s
                    and check_type = %(chk_type)s
                    {0}
            )
            , t2 as (
                select check_id, array_agg(tag) as check_tags
                from monitor_check_incident_tags
                where start_timestamp <= %(timestamp)s
                    and end_timestamp > %(timestamp)s
                    and check_id in (select check_id from t1)
                group by check_id
            )
            select t1.*, t2.check_tags
            from t1
            left join t2 using(check_id);
            '''.format(' and ' + ' and '.join(conditions) if len(conditions) > 0 else '')
    try:
        result = conn.fetch(query, query_params)
        data = dict()
        for chk_id, chk_ref, chk_typ, chk_name, chk_desc, is_enbl, png_type, png_url, png_email, png_interval, \
            png_grace, inc_title, inc_desc, urg_lvl, ip_list, email_list, srv_name, srv_ref, is_up, \
                last_run, next_run, inc_tags in result:
            data = {
                var_names.check_ref_id: key_manager.conceal_reference_key(chk_ref),
                var_names.check_type: chk_typ,
                var_names.check_name: chk_name,
                var_names.description: chk_desc,
                var_names.is_enabled: is_enbl,
                var_names.ping_type: png_type,
                var_names.url: png_url,
                var_names.email: png_email,
                var_names.interval: png_interval,
                var_names.grace_period: png_grace,
                var_names.task_title: inc_title,
                var_names.text_msg: inc_desc,
                var_names.urgency_level: urg_lvl,
                var_names.tags: inc_tags,
                var_names.service: [srv_name, key_manager.conceal_reference_key(srv_ref)],
                var_names.ip_address: ip_list,
                var_names.email_from: email_list,
                var_names.is_healthy: is_up,
                var_names.last_run: last_run,
                var_names.next_run: next_run
            }
        return data
    except psycopg2.DatabaseError:
        raise


def get_heartbeat_metrics(conn, timestamp, organization_id, user_id, check_ref_id, start_time, end_time,
                          check_adv_perm=False):
    '''
    Get the metrics of a heartbeat check.
    :param conn: db connection
    :param timestamp: timestamp when this request is being made
    :param organization_id: ID of the organization the check belongs to
    :param user_id: ID of the user who initiated the request
    :param check_ref_id: (concealed) reference ID of the check
    :param start_time: (datetime.date or datetime.datetime) date/datetime to get the metrics from
    :param end_time: (datetime.date or datetime.datetime) date/datetime to get the metrics till
    :param check_adv_perm: (boolean) True if advanced team permissions should be checked
    :return: (dict) of heartbeat check details
    '''
    assert isinstance(timestamp, datetime.datetime)
    assert isinstance(organization_id, int)
    assert isinstance(user_id, int)
    assert isinstance(start_time, datetime.datetime)
    assert isinstance(end_time, datetime.datetime)
    unmasked_check_ref = key_manager.unmask_reference_key(check_ref_id)

    query_params = {'timestamp': timestamp, 'org_id': organization_id,
                    'chk_ref': unmasked_check_ref, 'chk_type': constants.heartbeat,
                    'start_time': start_time, 'end_time': end_time}
    conditions = []
    if check_adv_perm:
        conditions.append('''
            chk.serviceid not in (
                select component_id from components_user_cannot_view(
                    %(timestamp)s, %(org_id)s, %(usr_id)s, %(comp_type_id)s::smallint
                )
            )
        ''')
        query_params['usr_id'] = user_id
        query_params['comp_type_id'] = configs.service_component_type_id

    query = '''
            with t1 as (
                select check_id, check_ref_id, check_type, check_name, description, is_enabled, ping_type, ping_url,
                    ping_email, ping_interval, grace_period, is_healthy, last_run_timestamp, next_run_timestamp
                from monitor_checks as chk
                where chk.start_timestamp <= %(timestamp)s
                    and chk.end_timestamp > %(timestamp)s
                    and chk.organization_id = %(org_id)s
                    and check_ref_id = %(chk_ref)s
                    and check_type = %(chk_type)s
                    {0}
            )
            , t2 as (
                select check_id, json_agg(json_build_object(
                    'scheduled_timestamp', scheduled_timestamp,
                    'run_timestamp', run_timestamp,
                    'passed', passed,
                    'instance_id', instanceid
                ) order by run_timestamp) as check_logs
                from monitor_check_logs
                where check_id in (select check_id from t1)
                    and run_timestamp >= %(start_time)s
                    and run_timestamp < %(end_time)s
                group by check_id
            )
            select t1.*, t2.check_logs
            from t1
            left join t2 using(check_id);
            '''.format(' and ' + ' and '.join(conditions) if len(conditions) > 0 else '')
    try:
        result = conn.fetch(query, query_params)
        data = dict()
        for chk_id, chk_ref, chk_typ, chk_name, chk_desc, is_enbl, png_type, png_url, png_email, \
                png_interval, png_grace, is_up, last_run, next_run, chk_logs in result:
            fmt_logs = []
            if chk_logs is not None:
                for item in chk_logs:
                    fmt_logs.append({
                        var_names.scheduled_timestamp:
                            times.get_timestamp_from_string(item[var_names.scheduled_timestamp]),
                        var_names.run_timestamp: times.get_timestamp_from_string(item[var_names.run_timestamp]),
                        var_names.passed: item[var_names.passed],
                        var_names.instance_id: item[var_names.instance_id]
                    })
            data = {
                var_names.check_ref_id: key_manager.conceal_reference_key(chk_ref),
                var_names.check_type: chk_typ,
                var_names.check_name: chk_name,
                var_names.description: chk_desc,
                var_names.is_enabled: is_enbl,
                var_names.ping_type: png_type,
                var_names.url: png_url,
                var_names.email: png_email,
                var_names.interval: png_interval,
                var_names.grace_period: png_grace,
                var_names.is_healthy: is_up,
                var_names.last_run: last_run,
                var_names.next_run: next_run,
                var_names.events: fmt_logs
            }
        return data
    except psycopg2.DatabaseError:
        raise


def get_heartbeat(conn, timestamp, check_ref_id):
    '''
    Gets a heartbeat. This query is only used internally.
    :param conn: db connection
    :param timestamp: timestamp when this request is being made
    :param check_ref_id: (concealed) reference ID of the check
    :return: Check object
    :errors: AssertionError, DatabaseError
    '''
    assert isinstance(timestamp, datetime.datetime)
    unmasked_check_ref = key_manager.unmask_reference_key(check_ref_id)

    query = "select * from get_heartbeat_check(%s, %s);"
    query_params = (timestamp, unmasked_check_ref,)
    try:
        result = conn.fetch(query, query_params)
        for chk_id, org_id, chk_ref, chk_type, chk_name, chk_desc, is_enbl, png_type, png_url, png_email, \
            png_interval, png_grace, serv_id, inc_title, inc_desc, urg_lvl, ip_list, email_list, is_up, \
                last_run, next_run, chk_tags, chk_incs in result:

            return Check(chk_id, org_id, chk_ref, chk_type, chk_name, chk_desc, is_enbl, png_type,
                         png_url, png_email, png_interval, png_grace, serv_id, inc_title, inc_desc, urg_lvl,
                         tags=chk_tags, ip_address=ip_list, email_from=email_list, is_healthy=is_up,
                         last_run=last_run, next_run=next_run, incidents=chk_incs)
        return None
    except psycopg2.DatabaseError:
        raise


def get_upcoming_heartbeats(conn, timestamp, forward_lookout):
    '''
    Get heartbeats that are expected to run within the next given forward lookout.
    This query is only used internally to monitor checks.
    :param conn: db connection
    :param timestamp: timestamp when this request is being made
    :param forward_lookout: (int) number of minutes to get the data for
    :return: (dict) of Check objects
    :errors: AssertionError, DatabaseError
    '''
    assert isinstance(timestamp, datetime.datetime)
    assert isinstance(forward_lookout, int)

    query = "select * from upcoming_heartbeat_checks(%s, %s);"
    query_params = (timestamp, forward_lookout,)
    try:
        result = conn.fetch(query, query_params)
        data = dict()
        for chk_id, org_id, chk_ref, chk_type, chk_name, chk_desc, is_enbl, png_type, png_url, png_email, \
            png_interval, png_grace, serv_id, inc_title, inc_desc, urg_lvl, ip_list, email_list, is_up, \
                last_run, next_run, chk_tags, chk_incs in result:

            data[chk_id] = Check(chk_id, org_id, chk_ref, chk_type, chk_name, chk_desc, is_enbl, png_type,
                                 png_url, png_email, png_interval, png_grace, serv_id, inc_title, inc_desc, urg_lvl,
                                 tags=chk_tags, ip_address=ip_list, email_from=email_list, is_healthy=is_up,
                                 last_run=last_run, next_run=next_run, incidents=chk_incs)
        return data
    except psycopg2.DatabaseError:
        raise


def log_heartbeat(conn, timestamp, check_ref_id, ip_addr=None, sender_email=None):
    '''
    Log a heartbeat ping.
    :param conn: db connection
    :param timestamp: timestamp when the ping was received
    :param check_ref_id: (concealed) reference ID of the check
    :param ip_addr: IP address where the incoming ping generated from
    :param sender_email: email address of the sender
    :return: (tuple) -> organization ID, (list) of instance IDs that should be resolved
    '''
    assert isinstance(timestamp, datetime.datetime)
    unmasked_check_ref = key_manager.unmask_reference_key(check_ref_id)
    if ip_addr is not None:
        assert string_validator.is_valid_ip_address(ip_addr)
    if sender_email is not None:
        assert string_validator.is_email_address(sender_email)

    query = "select * from log_heartbeat(%s, %s, %s, %s);"
    query_params = (unmasked_check_ref, timestamp, ip_addr, sender_email,)
    try:
        result = conn.fetch(query, query_params)
        if len(result) > 0:
            return result[0][0], result[0][1]
        return None, None
    except psycopg2.IntegrityError as e:
        if e.pgcode == errorcodes.RESTRICT_VIOLATION:
            raise PermissionError
        elif e.pgcode == errorcodes.CHECK_VIOLATION:
            raise LookupError
        else:
            raise
    except psycopg2.DatabaseError:
        raise


def upload_heartbeat_logs(conn, query_params_list):
    '''
    Upload stored heartbeat logs.
    :param conn: db connection
    :param query_params_list: (list of tuples) of query parameters
    :return: (int) number of updates made
    :error: DatabaseError
    '''
    query = "select upload_heartbeat_logs(%s, %s, %s, %s, %s, %s);"
    try:
        conn.execute_batch(query, query_params_list)
        return len(query_params_list)
    except psycopg2.DatabaseError:
        raise
