from ipyleaflet import Map, basemaps, AntPath,  Heatmap, LayersControl, basemap_to_tiles, TileLayer, Marker, MarkerCluster, VectorTileLayer, SearchControl, AwesomeIcon, WKTLayer, WidgetControl, LayerGroup, SplitMapControl, DrawControl
from ipywidgets import HTML
from bs4 import BeautifulSoup
import requests

class CustomWKTLayer(WKTLayer):
    """Modified from the original WKT Class (https://ipyleaflet.readthedocs.io/en/latest/_modules/ipyleaflet/leaflet.html#WKTLayer)

    Layer created from a local WKT file or WKT string input.

    Attributes
    ----------
    path: string, default ""
      file path of local WKT file.
    wkt_string: string, default ""
      WKT string.
    """

    def __init__(self, annotation, **kwargs):
        super().__init__(**kwargs)
        self.data = self._get_data1(annotation)

    def _get_data1(self,annotation):
        try:
            from shapely import geometry, wkt
        except ImportError:
            raise RuntimeError("The WKTLayer needs shapely to be installed, please run `pip install shapely`")

        if self.path:
            with open(self.path) as f:
                parsed_wkt = wkt.load(f)
        elif self.wkt_string:
            parsed_wkt = wkt.loads(self.wkt_string)
        else:
            raise ValueError("Please provide either WKT file path or WKT string")

        geo = geometry.mapping(parsed_wkt)
        if geo["type"] == "GeometryCollection":
            features = [{"geometry": g, "properties": annotation, "type": "Feature"} for g in geo["geometries"]]
            feature_collection = {"type": "FeatureCollection", "features": features}
            return feature_collection
        else:
            feature = {"geometry": geo, "properties": annotation, "type": "Feature"}
            return feature

class Build_Map(object):

    """ Build a flatmap into jupyter notebbook using URL's from the SPARC research community. 

    Attributes
    ----------
    server_url: string, default ""
      Server URL containing the flatmap 
    tag: string, default 0
      Tag given to a model (0-11)
      0: 'vagus_test',
      1: 'whole-human',
      2: 'whole-rat',
      3: 'whole-pig',
      4: 'whole-rat',
      5: 'whole-rat',
      6: 'whole-rat',
      7: 'whole-rat',
      8: 'whole-mouse',
      9: 'whole-cat',
      10: 'whole-human',
      11: 'whole-rat'])  
    """
  
  def __init__(self, server_url, tag):
    self.server_url = server_url
    self.tag = tag

  def get_url_dict(self):
    try:
        req= requests.get(self.server_url)
        url_dict = req.json()
        return url_dict
    except NameError:
        raise RuntimeError("Please provide a flatmap URL")

  def get_model_id(self):
    '''
    Returns a list of model ids
    '''
    return [model_dict['id'] for model_dict in self.get_url_dict() if 'name' in model_dict.keys() and 'id' in model_dict.keys()]

  def get_model_name(self):
    '''
    Returns a list of model names
    '''
    return [model_dict['name'] for model_dict in self.get_url_dict() if 'name' in model_dict.keys() and 'id' in model_dict.keys()]

  def get_model_layer_url(self, model_ids, model_name):
    '''
    Returns a list of model wise layer urls for each model
    '''
    return [f'{self.server_url}/flatmap/{id}/layers' for id, name in zip(model_ids, model_name)]

  def get_model_image_layer(self, model_layer_url_list):
    '''
    Returns a list of key value pairs containing model id and model layers 
    '''
    model_image_layers = []
    for url in model_layer_url_list:
      req= requests.get(url)
      layer_url_dict = req.json()
      model_image_layers.append({layer_url_dict[0]['id'] : layer_url_dict[0]['image-layers']})
    return model_image_layers

  def get_tile_urls(self):

    '''
    Returns a list of urls for each tile for each layer
    '''

    #Obtain model ids and names
    model_ids = self.get_model_id()
    model_names = self.get_model_name()

    #Get model image layers
    model_layer_url_list = self.get_model_layer_url(model_ids, model_names)
    model_image_layers = self.get_model_image_layer(model_layer_url_list)

    #Get URLS for vector tile
    tile_urls = [f'{self.server_url}/flatmap/{model_ids[self.tag]}/tiles/{layer}/{{z}}/{{x}}/{{y}}' \
                 for layer in model_image_layers[self.tag][model_names[self.tag]]]
    vector_tile_url = f'{self.server_url}/flatmap/{model_ids[self.tag]}/mvtiles/{{z}}/{{x}}/{{y}}'
    req = requests.get(f'{self.server_url}/flatmap/{model_ids[self.tag]}',headers={'Accept':'json'})
    index_json = req.json()

    return tile_urls, vector_tile_url, index_json

  def get_annotations(self):
    '''
    Returns a list of model annotations
    '''
    model_ids = self.get_model_id()
    req = requests.get(f'{self.server_url}/flatmap/{model_ids[self.tag]}/annotations')
    annotations = req.json()
    return annotations

  def split_map(self, second_model_number):

    '''
    Splits current map into two maps for cross species comparisons
    '''

    left_layer_urls, _, _ =  self.get_tile_urls()
    left_layers = [TileLayer(url = url) for url in left_layer_urls]
    print(left_layers)

    second_model_object = Build_Map(self.server_url, second_model_number)
    right_layer_urls,_, _ =  second_model_object.get_tile_urls()
    right_layers = [TileLayer(url = url) for url in right_layer_urls]
    print(right_layers)

    control = SplitMapControl(left_layer=left_layers, right_layer=right_layers)

    map = Map(basemap=left_layers[0])
    for layer in left_layers[1:]:
      map.add(layer)
    map.add_control(control)
    return map

  def build_map_without_markers(self):

    '''
    Builds the flat map as a sequential layer addition. 
    '''

    tile_urls, vector_tile_url, index_json = self.get_tile_urls()
    
    #Construct the map Object
    map = Map(
      basemap=TileLayer(url=tile_urls[0], 
                        min_zoom=index_json['min-zoom']),
                        min_zoom=index_json['min-zoom'],
                        max_zoom=index_json['max-zoom'],
                        zoom=5,
                        scroll_wheel_zoom=True,
                        dragging=True,
                        attribution_control=False,
                        zoom_snap=False,
                        )

    #vector tiles styles
    vector_tiles_styles = dict(
        fill="true",
        weight=1,
        fillColor="#f2b648",
        color="#f2b648",
        fillOpacity=0.2,
        opacity=0.8,
    )

    #Create tile objects for the model with each layer added
    for tile_url in tile_urls[1:]:
      map.add(TileLayer(
          url=tile_url, 
          min_zoom=index_json['min-zoom']))

    # Add vector tile layer
    map.add_layer(VectorTileLayer(url=vector_tile_url, vector_tile_layer_styles = vector_tiles_styles))

    bounds = index_json['bounds']
    map.fit_bounds([[bounds[1],bounds[0]],[bounds[3],bounds[2]]])

    return map

    
class leaflet_addons(Build_Map):

    """ Methods for interacting with flatmaps using leaflet add-ons  - Markers, 
    custom drawings, layer_control, antpath, heatmap, search bar and hover capabilities.

    Attributes
    ----------
    build_map_object: A Build_Map object 
      Object of the class Build_Map 
    const_map: Flatmap
      The flatmap to be interacted with   

    """

    def __init__(self, build_map_obj, const_map):
        self.build_map_obj = build_map_obj 
        self.const_map = const_map

    def add_markers(self):
        annotations = self.build_map_obj.get_annotations()
        markers = []
        for key in annotations.keys():
            y,x = annotations[key]['centroid'] # Centroid coordinates are reversed
            if 'label' in annotations[key].keys():
                marker = Marker(location=[x,y],title=annotations[key]['label'])
            else:
                marker = Marker(location=[x,y])
            markers.append(marker)
        self.const_map.add_layer(MarkerCluster(markers=markers))

    def add_custom_drawings(self):
        
        draw_control = DrawControl()
        draw_control.polyline =  {
            "shapeOptions": {
                "color": "#6bc2e5",
                "weight": 8,
                "opacity": 1.0
            }
        }
        draw_control.polygon = {
            "shapeOptions": {
                "fillColor": "#6be5c3",
                "color": "#6be5c3",
                "fillOpacity": 1.0
            },
            "drawError": {
                "color": "#dd253b",
                "message": "Oups!"
            },
            "allowIntersection": False
        }
        draw_control.circle = {
            "shapeOptions": {
                "fillColor": "#efed69",
                "color": "#efed69",
                "fillOpacity": 1.0
            }
        }
        draw_control.rectangle = {
            "shapeOptions": {
                "fillColor": "#fca45d",
                "color": "#fca45d",
                "fillOpacity": 1.0
            }
        }

        self.const_map.add_control(draw_control)

    def add_layers_control(self):
        control = LayersControl(position='topright')
        self.const_map.add_control(control)

    def add_search_bar(self):
        #NOT WORKING
        search_marker = Marker(icon=AwesomeIcon(name="check", marker_color='green', icon_color='darkred'))

        self.const_map.add_control(SearchControl(
        position="topleft",
        url = 'https://nominatim.openstreetmap.org/search?format=json&q={s}',
        zoom=4,
        property_name='name',
        marker=search_marker
        ))

    def add_antpath(self, locations, use = 'polyline'):
        ant_path = AntPath(
        locations= locations,
        use = use,
        dash_array=[1, 10],
        delay=1000,
        color='#7590ba',
        pulse_color='#3f6fba'
        )

        self.const_map.add_layer(ant_path)

    def add_heat_map(self, locations, radius = 30):
        heatmap = Heatmap(
        locations= locations,
        radius=30)

        self.const_map.add_layer(heatmap)

    def hover(self): 

        html = HTML('Hover over regions')
        html.value = '<h3>Hover over regions</h3>'
        html.layout.margin = '0px 20px 20px 20px'
        control = WidgetControl(widget=html,position='bottomright')
        self.const_map.add_control(control)

        def update_html(event,feature,**kwargs):
            annotation = feature['properties']
            label = annotation['label']
            html.value=f"<h3>{label}</h3>"

        # Create WKT layers using bounding boxes
        wlayers = []
        #Get annotations dictionary.
        annotations = self.build_map_obj.get_annotations()
        for key in annotations.keys():
            if 'label' not in annotations[key].keys():
                continue        
            bounds = annotations[key]['bounds']
            s,w,n,e = bounds

            wlayer = CustomWKTLayer(
            annotations[key],
            wkt_string=f'POLYGON(({s} {w},{s} {e},{n} {e},{n} {w},{s} {w}))',
            style={'opacity':0,'fillOpacity':0}
            )
            wlayer.on_hover(update_html)
            wlayers.append(wlayer)        
        wlayer_group = LayerGroup(layers=wlayers)        
        self.const_map.add_layer(wlayer_group)

    def add_all_leaflet_functions(self):
        self.add_markers()
        self.add_custom_drawings()
        self.add_layers_control()
        self.add_search_bar()