# Copyright 2020, Sophos Limited. All rights reserved.
#
# 'Sophos' and 'Sophos Anti-Virus' are registered trademarks of
# Sophos Limited and Sophos Group. All other product and company
# names mentioned are trademarks or registered trademarks of their
# respective owners.
import os

import numpy as np
import torch
from ember import PEFeatureExtractor
from secml.array import CArray
from secml.ml import CClassifier
from secml.settings import SECML_PYTORCH_USE_CUDA
from torch import nn

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA


class SorelNet(nn.Module):
	"""
	This is a simple network loosely based on the one used in ALOHA: Auxiliary Loss Optimization for Hypothesis Augmentation (https://arxiv.org/abs/1903.05700)
	Note that it uses fewer (and smaller) layers, as well as a single layer for all tag predictions, performance will suffer accordingly.
	"""

	def __init__(self, use_malware=True, use_counts=True, use_tags=True, n_tags=None, feature_dimension=2381,
				 layer_sizes=None):
		self.use_malware = use_malware
		self.use_counts = use_counts
		self.use_tags = use_tags
		self.n_tags = n_tags
		if self.use_tags and self.n_tags == None:
			raise ValueError("n_tags was None but we're trying to predict tags. Please include n_tags")
		super().__init__()
		p = 0.05
		layers = []
		if layer_sizes is None: layer_sizes = [512, 512, 128]
		for i, ls in enumerate(layer_sizes):
			if i == 0:
				layers.append(nn.Linear(feature_dimension, ls))
			else:
				layers.append(nn.Linear(layer_sizes[i - 1], ls))
			layers.append(nn.LayerNorm(ls))
			layers.append(nn.ELU())
			layers.append(nn.Dropout(p))
		self.model_base = nn.Sequential(*tuple(layers))
		self.malware_head = nn.Sequential(nn.Linear(layer_sizes[-1], 1),
										  nn.Sigmoid())
		self.count_head = nn.Linear(layer_sizes[-1], 1)
		self.sigmoid = nn.Sigmoid()
		self.tag_head = nn.Sequential(nn.Linear(layer_sizes[-1], 64),
									  nn.ELU(),
									  nn.Linear(64, 64),
									  nn.ELU(),
									  nn.Linear(64, n_tags),
									  nn.Sigmoid())
		self.model_base.eval()
		self.malware_head.eval()
		self.tag_head.eval()
		self.count_head.eval()

	def forward(self, data):
		rv = {}
		base_result = self.model_base.forward(data)
		if self.use_malware:
			rv['malware'] = self.malware_head(base_result)
		if self.use_counts:
			rv['count'] = self.count_head(base_result)
		if self.use_tags:
			rv['tags'] = self.tag_head(base_result)
		return rv


class CClassifierSorel(CClassifier):

	def __init__(self, model_path, use_counts=True, use_tags=True, n_tags=11,
				 feature_dimension=2381, layer_sizes=None):
		super().__init__()
		self._sorel = SorelNet(True, use_counts, use_tags, n_tags, feature_dimension, layer_sizes)
		if not os.path.isfile(model_path):
			raise FileNotFoundError(f'{model_path} not exists')
		self._use_tags = use_tags
		self._use_counts = use_counts
		self.load_model(model_path)

	def extract_features(self, x: CArray):
		extractor = PEFeatureExtractor(2, print_feature_warning=False)
		x_bytes = bytes(x.astype(np.int).tolist()[0])
		features = CArray([np.array(extractor.feature_vector(x_bytes), dtype=np.float32)])
		return features

	def load_model(self, model_path):
		state_dict = torch.load(model_path) if use_cuda else torch.load(model_path, map_location='cpu')
		self._sorel.load_state_dict(state_dict)
		if use_cuda:
			self._sorel.cuda()
		self._sorel.eval()
		self._classes = 2
		self._classes += 1 if self._use_counts else 0
		self._classes += self._sorel.n_tags if self._use_tags else 0
		self._n_features = 2381
		return self

	def _fit(self, dataset, **kwargs):
		raise NotImplementedError("Fit is not implemented.")

	def _backward(self, w):
		raise NotImplementedError("Backward is not implemented.")

	def _forward(self, x: CArray):
		x = torch.tensor(x.tondarray()).float()
		if use_cuda:
			x = x.cuda()
		rv = self._sorel.forward(x)
		confidence = []
		malware_score = rv['malware'].detach()
		goodware_score = 1 - malware_score
		confidence.append(goodware_score)
		confidence.append(malware_score)
		if 'count' in rv:
			count_score = rv['count'].detach()
			confidence.append(count_score)
		if 'tags' in rv:
			tags_score = rv['tags'].detach()
			confidence.append(tags_score)
		confidence = torch.hstack(confidence)
		confidence = CArray(confidence.numpy())
		return confidence
