import torch
import torch.nn as nn
import math
from models.backbone.common import *
from utils.general import make_divisible
from module.nanodet_utils import generate_anchors, dist2bbox
import torch.nn.functional as F


class YoloV6Detect(nn.Module):
    '''Efficient Decoupled Head
    With hardware-aware degisn, the decoupled head is optimized with
    hybridchannels methods.
    '''
    def __init__(self, cfg):
        super().__init__()

        width_mul = cfg.Model.width_multiple
        # channels_list_backbone = cfg.Model.Backbone.out_channels
        channels_list_neck = cfg.Model.Neck.out_channels

        if isinstance(cfg.Model.anchors, (list, tuple)):
            num_anchors = len(cfg.Model.anchors)
        else:
            num_anchors = cfg.Model.anchors
        # num_repeats = [(max(round(i * depth_mul), 1) if i > 1 else i) for i in (num_repeat_backbone + num_repeat_neck)]
        channels_list = [make_divisible(i * width_mul, 8) for i in (channels_list_neck)]

        self.nc = cfg.Dataset.nc
        self.no = self.nc + 5  # number of outputs per anchor
        self.nl = cfg.Model.Neck.num_outs# number of detection layers
        # self.n_anchors = num_anchors
        self.num_keypoints = cfg.Dataset.np
        self.na = num_anchors
        # self.anchors = cfg.Model.anchors
        # self.register_buffer('anchors', torch.tensor(cfg.Model.anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)

        self.prune = False
        self.use_l1 = False
        self.export = False

         # Define criteria
        self.grids = [torch.zeros(1)] * len(cfg.Model.Head.in_channels)

        self.grid = [torch.zeros(1)] * self.nl 
        self.prior_prob = 1e-2
        self.inplace = cfg.Model.inplace
        # stride = [8, 16, 32]  # strides computed during build
        # self.stride = torch.tensor(stride)
        self.reg_max = cfg.Loss.reg_max
        self.use_dfl = cfg.Loss.use_dfl
        self.stride = torch.Tensor(cfg.Model.Head.strides)
        self.proj_conv = nn.Conv2d(self.reg_max + 1, 1, 1, bias=False)
        self.grid_cell_offset = cfg.Loss.grid_cell_offset
        self.grid_cell_size = cfg.Loss.grid_cell_size

        # Init decouple head
        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()
        self.cls_preds = nn.ModuleList()
        self.reg_preds = nn.ModuleList()
        # self.obj_preds = nn.ModuleList()
        self.stems = nn.ModuleList()
        head_layers = tal_build_effidehead_layer(channels_list, num_anchors, self.nc, reg_max=self.reg_max)  # detection layer
        assert head_layers is not None

        # Efficient decoupled head layers
        for i in range(self.nl):
            idx = i*5
            self.stems.append(head_layers[idx])
            self.cls_convs.append(head_layers[idx+1])
            self.reg_convs.append(head_layers[idx+2])
            self.cls_preds.append(head_layers[idx+3])
            self.reg_preds.append(head_layers[idx+4])
            # self.obj_preds.append(head_layers[idx+5])

    def initialize_biases(self):
        for conv in self.cls_preds:
            b = conv.bias.view(-1, )
            b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
            conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
            w = conv.weight
            w.data.fill_(0.)
            conv.weight = torch.nn.Parameter(w, requires_grad=True)

        for conv in self.reg_preds:
            b = conv.bias.view(-1, )
            b.data.fill_(1.0)
            conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
            w = conv.weight
            w.data.fill_(0.)
            conv.weight = torch.nn.Parameter(w, requires_grad=True)
        
        self.proj = nn.Parameter(torch.linspace(0, self.reg_max, self.reg_max + 1), requires_grad=False)
        self.proj_conv.weight = nn.Parameter(self.proj.view([1, self.reg_max + 1, 1, 1]).clone().detach(),
                                                   requires_grad=False)

    def get_output_and_grid(self, output, k, stride, dtype):
        grid = self.grids[k]

        batch_size = output.shape[0]
        hsize, wsize = output.shape[-2:]
        if grid.shape[2:4] != output.shape[2:4]:
            yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
            grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
            self.grids[k] = grid

        output = output.view(batch_size, self.na, self.no, hsize, wsize)
        output = output.permute(0, 1, 3, 4, 2).reshape(
            batch_size, self.na * hsize * wsize, -1
        )
        grid = grid.view(1, -1, 2)
        output[..., :2] = (output[..., :2] + grid) * stride
        output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
   
        return output, grid

    def forward(self, x):
        if self.training:
            cls_score_list = []
            reg_distri_list = []

            for i in range(self.nl):
                x[i] = self.stems[i](x[i])
                cls_x = x[i]
                reg_x = x[i]
                cls_feat = self.cls_convs[i](cls_x)
                cls_output = self.cls_preds[i](cls_feat)
                reg_feat = self.reg_convs[i](reg_x)
                reg_output = self.reg_preds[i](reg_feat)

                # cls_output = torch.sigmoid(cls_output)
                cls_score_list.append(cls_output.flatten(2).permute((0, 2, 1)))
                reg_distri_list.append(reg_output.flatten(2).permute((0, 2, 1)))
            
            cls_score_list = torch.cat(cls_score_list, axis=1)
            reg_distri_list = torch.cat(reg_distri_list, axis=1)

            return x, cls_score_list, reg_distri_list
        elif self.export:
            # cls_score_list = []
            # reg_distri_list = []
            z = []

            for i in range(self.nl):
                b, _, h, w = x[i].shape
                l = h * w
                x[i] = self.stems[i](x[i])
                cls_x = x[i]
                reg_x = x[i]
                cls_feat = self.cls_convs[i](cls_x)
                cls_output = self.cls_preds[i](cls_feat)
                reg_feat = self.reg_convs[i](reg_x)
                reg_output = self.reg_preds[i](reg_feat)
                ori_b, ori_c, ori_h, ori_w = reg_output.shape
                if self.use_dfl:
                    reg_output = reg_output.reshape([-1, 4, self.reg_max + 1, l]).permute(0, 2, 1, 3)
                    reg_output = self.proj_conv(F.softmax(reg_output, dim=1))
                # reg_output = reg_output.permute(0, 3, 1, 2)
                reg_output = reg_output.reshape([-1, 4, ori_h, ori_w])

                cls_output = torch.sigmoid(cls_output)
                # cls_score_list.append(cls_output.flatten(2).permute((0, 2, 1)))
                # reg_distri_list.append(reg_output.flatten(2).permute((0, 2, 1)))
                obj_output = torch.ones((b, 1, ori_h, ori_w), device=cls_output.device, dtype=cls_output.dtype)
                print(reg_output.shape, cls_output.shape)
                y = torch.cat([reg_output, obj_output, cls_output], 1)
                z.append(y)
            
            # cls_score_list = torch.cat(cls_score_list, axis=1)
            # reg_distri_list = torch.cat(reg_distri_list, axis=1)
            return z
        else:
            cls_score_list = []
            reg_distri_list = []
            cls_list = []
            reg_bbox_list = []
            anchor_points, stride_tensor = generate_anchors(
                x, self.stride, self.grid_cell_size, self.grid_cell_offset, device=x[0].device, is_eval=True)

            for i in range(self.nl):
                b, _, h, w = x[i].shape
                l = h * w
                x[i] = self.stems[i](x[i])
                cls_x = x[i]
                reg_x = x[i]
                cls_feat = self.cls_convs[i](cls_x)
                cls_output = self.cls_preds[i](cls_feat)
                reg_feat = self.reg_convs[i](reg_x)
                reg_output = self.reg_preds[i](reg_feat)
                cls_score_list.append(cls_output.flatten(2).permute((0, 2, 1)))
                reg_distri_list.append(reg_output.flatten(2).permute((0, 2, 1)))
                
                if self.use_dfl:
                    reg_output = reg_output.reshape([-1, 4, self.reg_max + 1, l]).permute(0, 2, 1, 3)
                    reg_output = self.proj_conv(F.softmax(reg_output, dim=1))
                
                cls_output = torch.sigmoid(cls_output)
                cls_list.append(cls_output.reshape([b, self.nc, l]))
                reg_bbox_list.append(reg_output.reshape([b, 4, l]))
            
            cls_score_list = torch.cat(cls_score_list, axis=1)
            reg_distri_list = torch.cat(reg_distri_list, axis=1)
            cls_list = torch.cat(cls_list, axis=-1).permute(0, 2, 1)
            reg_bbox_list = torch.cat(reg_bbox_list, axis=-1).permute(0, 2, 1)

            pred_bboxes = dist2bbox(reg_bbox_list, anchor_points, box_format='xywh')
            pred_bboxes *= stride_tensor
            feature = (x, cls_score_list, reg_distri_list)
            return (torch.cat(
                [
                    pred_bboxes,
                    torch.ones((b, pred_bboxes.shape[1], 1), device=pred_bboxes.device, dtype=pred_bboxes.dtype),
                    cls_list
                ], axis=-1), feature)
    
    def decode_outputs(self, outputs, dtype):
        grids = []
        strides = []
        for (hsize, wsize), stride in zip(self.hw, self.stride):
            yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
            grid = torch.stack((xv, yv), 2).view(1, -1, 2)
            grids.append(grid)
            shape = grid.shape[:2]
            strides.append(torch.full((*shape, 1), stride))

        grids = torch.cat(grids, dim=1).type(dtype)
        strides = torch.cat(strides, dim=1).type(dtype)

        outputs[..., :2] = (outputs[..., :2] + grids) * strides
        outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides

        return outputs

    def post_process(self, outputs):
        if self.training:
            h, w = outputs[0][0].shape[2:4]
            h *= self.stride[0]
            w *= self.stride[0]
            x_shifts = []
            y_shifts = []
            expanded_strides = []
            newouts = []
            for i, stride_this_level in enumerate(self.stride):
                reg_output, obj_output, cls_output = outputs[i]
                output = torch.cat([reg_output, obj_output, cls_output], 1)
                output, grid = self.get_output_and_grid(
                    output, i, stride_this_level, reg_output.type()
                )
                x_shifts.append(grid[:, :, 0])
                y_shifts.append(grid[:, :, 1])
                expanded_strides.append(
                    torch.zeros(1, grid.shape[1])
                        .fill_(stride_this_level)
                        .type_as(reg_output)
                )
                newouts.append(output)
            outputs = torch.cat(newouts, 1)
            x_shifts = torch.cat(x_shifts, 1)  # [1, n_anchors_all]
            y_shifts = torch.cat(y_shifts, 1)  # [1, n_anchors_all]
            expanded_strides = torch.cat(expanded_strides, 1)
            whwh = torch.Tensor([[w, h, w, h]]).type_as(outputs)
            return (outputs,None,x_shifts,y_shifts,expanded_strides,whwh)
        else:
            newouts = []
            for i, stride_this_level in enumerate(self.stride):
                reg_output, obj_output, cls_output = outputs[i]
                output = torch.cat(
                    [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
                )
                newouts.append(output)
            outputs = newouts
            self.hw = [out.shape[-2:] for out in outputs]
            outputs = torch.cat(
                [out.flatten(start_dim=2) for out in outputs], dim=2
            ).permute(0, 2, 1)
            outputs = self.decode_outputs(outputs, dtype=outputs.type())
            return (outputs,)

def tal_build_effidehead_layer(channels_list, num_anchors, num_classes, reg_max=16):
    head_layers = nn.Sequential(
        # stem0
        Conv(
            channels_list[0],
            channels_list[0],
            1,
            1
        ),
        # cls_conv0
        Conv(
            channels_list[0],
            channels_list[0],
            3,
            1
        ),
        # reg_conv0
        Conv(
            channels_list[0],
            channels_list[0],
            3,
            1
        ),
        # cls_pred0
        nn.Conv2d(
            in_channels=channels_list[0],
            out_channels=num_classes * num_anchors,
            kernel_size=1
        ),
        # reg_pred0
        nn.Conv2d(
            in_channels=channels_list[0],
            out_channels=4 * (reg_max + num_anchors),
            kernel_size=1
        ),
        # stem1
        Conv(
            channels_list[1],
            channels_list[1],
            1,
            1
        ),
        # cls_conv1
        Conv(
            channels_list[1],
            channels_list[1],
            3,
            1
        ),
        # reg_conv1
        Conv(
            channels_list[1],
            channels_list[1],
            3,
            1
        ),
        # cls_pred1
        nn.Conv2d(
            in_channels=channels_list[1],
            out_channels=num_classes * num_anchors,
            kernel_size=1
        ),
        # reg_pred1
        nn.Conv2d(
            in_channels=channels_list[1],
            out_channels=4 * (reg_max + num_anchors),
            kernel_size=1
        ),
        # stem2
        Conv(
            channels_list[2],
            channels_list[2],
            1,
            1
        ),
        # cls_conv2
        Conv(
            channels_list[2],
            channels_list[2],
            3,
            1
        ),
        # reg_conv2
        Conv(
            channels_list[2],
            channels_list[2],
            3,
            1
        ),
        # cls_pred2
        nn.Conv2d(
            in_channels=channels_list[2],
            out_channels=num_classes * num_anchors,
            kernel_size=1
        ),
        # reg_pred2
        nn.Conv2d(
            in_channels=channels_list[2],
            out_channels=4 * (reg_max + num_anchors),
            kernel_size=1
        )
    )
    return head_layers