import os
import time
from typing import Optional

import dbt.exceptions
from dbt.adapters.setu.client import SetuClient
from dbt.events import AdapterLogger
from dbt.adapters.setu.constants import VALID_STATEMENT_KINDS
from dbt.adapters.setu.models import StatementKind, Output, StatementState, Statement
from dbt.adapters.setu.utils import (
    get_dataframe_from_json_output,
    polling_intervals,
    waiting_for_output,
    get_data_from_json_output,
)

logger = AdapterLogger("Spark")


class SetuStatementCursor:
    """
    Manage SETU statement and high-level interactions with it.
    :param client: setu client for managing statements
    :param session_id: setu session ID
    """

    def __init__(self, client: SetuClient, session_id: int):
        self.session_id: int = session_id
        self.client: SetuClient = client
        self.statement: Optional[Statement] = None

    def description(self):
        self.fetchall()
        pandas_df = get_dataframe_from_json_output(self.statement.output.json)
        columns = pandas_df.columns
        data = []
        for column in columns:
            data.append([column])
        return data

    def execute(self, code: str) -> Output:
        statement_kind: StatementKind = self.get_statement_kind(code)
        logger.info(f"statement_kind = {statement_kind} ")
        formatted_code: str = self.get_formatted_code(code)
        logger.info(f"formatted_code = {formatted_code} ")
        if statement_kind not in VALID_STATEMENT_KINDS:
            raise ValueError(
                f"{statement_kind} is not a valid statement kind for a SETU server of "
                f"(should be one of {VALID_STATEMENT_KINDS})"
            )
        self.statement = self.client.create_statement(
            self.session_id, formatted_code, statement_kind
        )
        intervals = polling_intervals([1, 2, 3, 5], 10)
        while waiting_for_output(self.statement):
            logger.info(
                " Setu statement progress {} : {}".format(
                    self.statement.statement_id, self.statement.progress
                )
            )
            time.sleep(next(intervals))
            self.statement = self.client.get_statement(
                self.statement.session_id, self.statement.statement_id
            )
        if self.statement.output is None:
            logger.error(f" Setu Statement {self.statement.statement_id} had no output ")
            raise dbt.exceptions.RuntimeException(
                f"Setu Statement {self.statement.statement_id} had no output"
            )
        logger.info(
            "Setu Statement {} state is : {}".format(
                self.statement.statement_id, self.statement.state
            )
        )
        self.statement.output.raise_for_status()
        if not self.statement.output.execution_success:
            logger.error(
                "Setu Statement {} output Error : {}".format(
                    self.statement.statement_id, self.statement.output
                )
            )
            raise dbt.exceptions.RuntimeException(
                f"Error during Setu Statement {self.statement.statement_id} execution : {self.statement.output.error}"
            )
        return self.statement.output

    def close(self):
        if self.statement is not None and self.statement.state in [
            StatementState.WAITING,
            StatementState.RUNNING,
        ]:
            try:
                logger.info("closing Setu Statement id : {} ".format(self.statement.statement_id))
                self.client.cancel_statement(
                    self.statement.session_id, self.statement.statement_id
                )
                logger.info("Setu Statement closed")
            except Exception:
                logger.info("Setu Statement already closed ")

    def fetchall(self):
        if self.statement is not None and self.statement.state in [
            StatementState.WAITING,
            StatementState.RUNNING,
        ]:
            intervals = polling_intervals([1, 2, 3, 5], 10)
            while waiting_for_output(self.statement):
                logger.info(
                    " Setu statement {} progress : {}".format(
                        self.statement.statement_id, self.statement.progress
                    )
                )
                time.sleep(next(intervals))
                self.statement = self.client.get_statement(
                    self.statement.session_id, self.statement.statement_id
                )
            if self.statement.output is None:
                logger.error(f"Setu Statement {self.statement.statement_id} had no output")
                raise dbt.exceptions.RuntimeException(
                    f"Setu Statement {self.statement.statement_id} had no output"
                )
            self.statement.output.raise_for_status()
            if self.statement.output.json is None:
                logger.error(f"Setu statement {self.statement.statement_id} had no JSON output")
                raise dbt.exceptions.RuntimeException(
                    f"Setu statement {self.statement.statement_id} had no JSON output"
                )
            return get_data_from_json_output(self.statement.output.json)
        elif self.statement is not None:
            self.statement.output.raise_for_status()
            return get_data_from_json_output(self.statement.output.json)
        else:
            raise dbt.exceptions.RuntimeException(
                "Setu statement response : {} ".format(self.statement)
            )

    def get_formatted_code(self, code: str) -> str:
        code_lines = []
        for line in code.splitlines():
            line = line.strip()
            # Ignore depends_on statements in model files
            if not line or line.startswith("-- depends_on:"):
                continue
            """
            StatementKind inference logic (sql/scala/pyspark)
            If Macro sql contains $$spark$$ in the beginning of the line, then spark
            Else If Macro sql contains $$pyspark$$ in the beginning of the line, then pyspark
            Else sql
            """
            if line.startswith("$$" + StatementKind.SPARK.value + "$$"):
                line = line.replace("$$" + StatementKind.SPARK.value + "$$", " ", 1)
            elif line.startswith("$$" + StatementKind.PYSPARK.value + "$$"):
                line = line.replace("$$" + StatementKind.PYSPARK.value + "$$", " ", 1)
            code_lines.append(" " + line)
        formatted_code = os.linesep.join([s for s in code_lines if s.strip()])
        return formatted_code

    def get_statement_kind(self, code: str) -> StatementKind:
        for line in code.splitlines():
            line = line.strip()
            # Ignore depends_on statements in model files
            if not line or line.startswith("-- depends_on:"):
                continue
            """
            StatementKind inference logic (sql/scala/pyspark)
            If Macro sql contains $$spark$$ in the beginning of the line, then spark
            Else If Macro sql contains $$pyspark$$ in the beginning of the line, then pyspark
            Else sql
            """
            if line.startswith("$$" + StatementKind.SPARK.value + "$$"):
                return StatementKind.SPARK
            elif line.startswith("$$" + StatementKind.PYSPARK.value + "$$"):
                return StatementKind.PYSPARK
            else:
                return StatementKind.SQL
