import pandas as pd
import numpy as np
import colorlover as cl
import dash_html_components as html


# TODO: Get these diretly from the google sheets
#master_csv = 'assets/XJETS-master - Sheet1.csv'
master_csv = 'assets/new_master_list.csv'
col_attr = 'assets/XJETS-master - col_name_attributes_new.csv'
data_availability = 'assets/data_downloads/data_info.csv'
citations_file = 'assets/all_refs.csv'


def generate_colorscale(values, colors):
    """
    Generate a color scale for int type categories. Each category will be at the center of a bin.
    The will generate a discrete color bar but may have unequal bin widths
    """
    max_val = np.max(values)
    min_val = np.min(values)
    if len(values) > len(colors):
        print('Length of values is greater than length of colors')
        return None

    nvals = len(values)
    if min_val != max_val:
        norm_values = [(i-min_val)/(max_val-min_val) for i in values]
    else:
        norm_values = [1]
    colorscale = []
    ticklocations = []

    for i in range(nvals):
        if i == 0:
            lb = norm_values[i]

        else:
            lb = prev_ub

        if i is not nvals-1:
            ub = (lb+norm_values[i+1])/2
        else:
            ub = 1

        prev_ub = ub
        colorscale.append([lb, colors[i]])
        colorscale.append([ub, colors[i]])
        ticklocations.append((lb+ub)*(max_val-min_val)/2)

    return colorscale, ticklocations


class datastore:
    def __init__(self):
        try:
            self.master_df = pd.read_csv(master_csv)
            self.orig_columns=self.master_df.columns
            self.total_sources=len(self.master_df.index)
            self.col_names = self.master_df.columns
            self.refs_df = pd.read_csv(citations_file).fillna('')
            #print(self.col_names)
            self.data_availability = pd.read_csv(data_availability)
            self.master_df = self.master_df.merge(
                self.data_availability, how='left', on='Name')
            self.col_attr_df = pd.read_csv(col_attr).set_index('attribute').T
            self.state = 'good'
            self.endpoint_df = self.master_df.copy(deep=True)
        except ValueError as verr:
            print(verr)
            self.state = 'bad'

    def get_orig_columns(self):
        return self.orig_columns

    def get_master_data(self):
        return self.master_df

    def _filter_super_luminal_sources(self):
        self.endpoint_df = self.endpoint_df[self.endpoint_df['bapp3'] > 1]

    def _filter_xray_upstream_offset(self):
        self.endpoint_df = self.endpoint_df[self.endpoint_df['X-ray first'] > 0]

    def _filter_xray_downstream_offset(self):
        self.endpoint_df = self.endpoint_df[self.endpoint_df['Radio first'] > 0]

    def _filter_xray_upstream_downstream_offsets(self):
        self.endpoint_df = self.endpoint_df[
            (self.endpoint_df['Radio first'] > 0) & (
                self.endpoint_df['X-ray first'] > 0)
        ]

    def get_filtered_data(self):
        return self.endpoint_df

    def reset_filters(self):
        self.endpoint_df = self.master_df.copy(deep=True)

    def get_valid_scales(self):
        return self.col_attr_df[self.col_attr_df['use_as_scale'] == 1]

    def get_valid_categories(self):
        return self.col_attr_df[self.col_attr_df['use_as_cat'] == 1]

    def get_valid_axes(self):
        return self.col_attr_df[self.col_attr_df['valid_axis'].astype('int') == 1]

    def get_col_attributes(self):
        return self.col_attr_df

    def get_filters(self):
        return [
            {'label': 'Super luminal sources',
                'value': '_filter_super_luminal_sources'},
            # {'label': 'Sources with upstream X-ray peak',
            #  'value': '_filter_xray_upstream_offset'},
            # {'label': 'Sources with downstream X-ray peak',
            #  'value': '_filter_xray_downstream_offset'},
            # {'label': 'Sources with both kind of offsets',
            #  'value': '_filter_xray_upstream_downstream_offsets'}
        ]

    def get_scaling_vars(self):
        scales_df = self.col_attr_df[self.col_attr_df['use_as_scale'].astype(
            'int') == 1]
        options = [
            {
                'label': row['pretty_name'],
                'value':i,
                'primaryText':row['pretty_name']
            } for i, row in scales_df.iterrows()
        ]
        options.insert(0,{
            'label': 'No scaling',
            'value': None,
            'primaryText': 'No scaling'
        })

        return options

    def get_marker_scaling_options(self, scale_using):
        scales = self.endpoint_df[scale_using].fillna(0.1)
        max_derired_marker_size = 40
        sizeref = 2*scales.max()/max_derired_marker_size**2

        return {
            'size': scales.tolist(),
            'sizemode': 'area',
            'sizemin': 4,
            'sizeref': sizeref,
            'color': None
        }

    def get_categorical_vars(self):
        cat_df = self.col_attr_df[self.col_attr_df['use_as_cat'].astype(
            'int') == 1]
        options = [
            {
                'label': row['pretty_name'],
                'value':i,
                'primaryText':row['pretty_name']
            } for i, row in cat_df.iterrows()
        ]
        options.insert(0,{'label': 'No Category', 'value': None,
                        'primaryText': 'No Category'})

        return options

    def get_cat_color(self, cat_using):
        cat_values = self.endpoint_df[cat_using].fillna(0)
        # print(cat_values)

        # Use the divergent+spectral scale.
        # Assume that, the no. of sequential categories is less than 11
        # if the variable is a continuous one interpolate the same

        if (cat_values % 1 == 0).all():
            # sequential cat
            unique_vals = np.sort(cat_values.unique())
            min_unique_val = np.min(unique_vals)
            n_unique_vals = len(unique_vals)

            if n_unique_vals < 1:
                return {
                    'color': None
                }
            elif n_unique_vals <= 3 and n_unique_vals > 0:
                colors = cl.scales['3']['qual']['Dark2']
            else:
                colors = cl.scales[str(n_unique_vals)]['div']['Spectral']

            colorscale, ticklocations = generate_colorscale(
                unique_vals, colors)

            # temporary hack
            # TODO: Fix the issue when there is only one unique value

            if(n_unique_vals == 1):
                ticklocations = [0]

            return {
                'color': [i-min_unique_val for i in cat_values.tolist()],
                'colorscale': colorscale,
                'colorbar': {
                    'title': self.col_attr_df.loc[cat_using]['pretty_name'],
                    'titleside': 'top',
                    'tickmode': 'array',
                    'tickvals': ticklocations,
                    'ticktext': [str(int(i)) for i in unique_vals],
                }
            }
        else:
            return {
                'color': None
            }

    def get_source_table(self, source):
        row = self.endpoint_df[self.endpoint_df['Name']
                               == source].iloc[0].fillna('NA')
        return html.Table(
            # [
            #    html.Tr(
            #        [
            #            html.Th(self.col_attr_df.loc[col]['pretty_name']) for col in self.col_attr_df.index
            #        ]
            #    )
            # ]
            # +
            [
                html.Tr(
                    [
                        html.Td(self.col_attr_df.loc[col]['pretty_name']),
                        html.Td(row[col])

                    ]
                ) for col in self.col_names
            ], className=''
        )

    def get_summary_data(self):
        return {
            'total_resolved_xray_knots': self.master_df['#resolved X_ray knots'].sum(),
            'total_knots_with_offsets': self.master_df['#knots with offsets'].sum(),
            'total_knots_without_offsets': self.master_df['#knots without offsets'].sum(),
            'total_resolved_xray_HS': self.master_df['#HS  Xrays'].sum(),
            'total_HS_with_offsets': self.master_df['#HS with offsets'].sum(),
            'total_HS_without_offsets': self.master_df['#HS without offsets'].sum(),
            'xray_first_offsets': self.master_df['X-ray first'].sum(),
            'radio_first_offsets': self.master_df['Radio first'].sum(),
            'total_xray_jet_terminates_at_a_bend': int(self.master_df['X-ray jet terminates before radio after a bend?'].sum()),
            'total_jets':self.total_sources
        }

    def get_formatted_refs(self,source):
        print('src',source)
        refs=self.refs_df.query(f'Name2=="{source}"')
        #print(refs)
        if refs.shape[0]==0:
            return ''
        ref_list=[]
        def get_apa_cite(r):
            num=''
            if r['number']!='':
                num=f'({int(r["number"])})'
            return f"{r['author']} ({r['year']}), {r['journal2']}, {r['volume']}{num}, p.{r['pages']}"


        for index,row in refs.iterrows():
            #print(row)
            ref_list.append(html.Li(
                [
                get_apa_cite(row),html.Br(),f'"{row["title"]}"',html.Br(),
                html.A('[ADS]',href=row['link'],target="_blank")
            ]
            ))

        
        print('ref_l',ref_list)

        return html.Ol(ref_list)


