"""
risk_management.py — pre-trade validation, daily PnL tracking, kill switch.

This module sits between strategy and execution. Every order flows through
RiskLayer.check(intent) before it's placed. If the bot is in DRY_RUN, the
intent is logged and dropped before reaching the exchange.

Persistence: SQLite at logs/state.db so PnL and dedup hashes survive restarts.
"""

from __future__ import annotations

import sqlite3
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional


@dataclass
class OrderIntent:
    """What the strategy wants to do. Inspected by the risk layer."""
    token_id: str
    condition_id: str
    side: str                    # BUY or SELL
    size_shares: float
    target_price: float
    order_type: str              # GTC, FOK, FAK, GTD
    reason: str                  # human-readable, used for logs
    signal_hash: str = ""        # for dedup
    market_volume_24h: float = 0.0
    market_resolves_at: Optional[int] = None  # unix seconds


@dataclass
class RiskDecision:
    approved: bool
    intent: Optional[OrderIntent]
    reject_reason: Optional[str] = None
    adjusted_size: Optional[float] = None  # if we clamped


class RiskLayer:
    """
    Centralised gate for all orders. Construct once and reuse.

    Hard rules (any failing → reject):
      - dry_run mode
      - max_trade_size_usd cap (clamps, not rejects)
      - max_daily_loss_usd kill switch
      - min_market_liquidity_usd floor
      - min_time_to_resolution
      - max_per_market_notional cap
      - signal dedup window
    """

    def __init__(
        self,
        *,
        dry_run: bool = True,
        max_trade_size_usd: float = 25.0,
        max_daily_loss_usd: float = 50.0,
        min_market_liquidity_usd: float = 5_000.0,
        min_time_to_resolution_s: int = 4 * 3600,
        max_per_market_notional: float = 100.0,
        dedup_window_s: int = 300,
        state_db_path: str = "logs/state.db",
    ):
        self.dry_run = dry_run
        self.max_trade_size_usd = max_trade_size_usd
        self.max_daily_loss_usd = max_daily_loss_usd
        self.min_market_liquidity_usd = min_market_liquidity_usd
        self.min_time_to_resolution_s = min_time_to_resolution_s
        self.max_per_market_notional = max_per_market_notional
        self.dedup_window_s = dedup_window_s

        Path(state_db_path).parent.mkdir(parents=True, exist_ok=True)
        self.db = sqlite3.connect(state_db_path, isolation_level=None)
        self._init_schema()

        self._kill_switch = False

    def _init_schema(self):
        self.db.executescript(
            """
            CREATE TABLE IF NOT EXISTS pnl_daily (
                day TEXT PRIMARY KEY,        -- YYYY-MM-DD UTC
                realized REAL DEFAULT 0,
                unrealized REAL DEFAULT 0,
                trades_count INTEGER DEFAULT 0,
                wins INTEGER DEFAULT 0,
                losses INTEGER DEFAULT 0
            );

            CREATE TABLE IF NOT EXISTS market_notional (
                condition_id TEXT PRIMARY KEY,
                notional_usd REAL DEFAULT 0,
                updated_ts INTEGER
            );

            CREATE TABLE IF NOT EXISTS signal_seen (
                signal_hash TEXT PRIMARY KEY,
                ts INTEGER
            );
            """
        )

    # --- API ---

    def check(self, intent: OrderIntent) -> RiskDecision:
        """The single entry point used by the orchestrator."""
        if self._kill_switch:
            return RiskDecision(False, intent, "kill_switch_active")

        # 1. dry-run is implicit approval but no execution
        if self.dry_run:
            return RiskDecision(True, intent, reject_reason="dry_run")

        # 2. daily loss kill switch
        if self.today_pnl() <= -self.max_daily_loss_usd:
            self._kill_switch = True
            return RiskDecision(False, intent, "daily_loss_exceeded")

        # 3. signal dedup
        if intent.signal_hash and self._signal_recently_seen(intent.signal_hash):
            return RiskDecision(False, intent, "signal_already_seen")

        # 4. liquidity floor
        if intent.market_volume_24h < self.min_market_liquidity_usd:
            return RiskDecision(False, intent, "below_liquidity_floor")

        # 5. time to resolution
        if intent.market_resolves_at is not None:
            ttr = intent.market_resolves_at - int(time.time())
            if ttr < self.min_time_to_resolution_s:
                return RiskDecision(False, intent, "too_close_to_resolution")

        # 6. per-trade size cap
        notional_usd = intent.size_shares * intent.target_price
        adjusted = None
        if notional_usd > self.max_trade_size_usd:
            scale = self.max_trade_size_usd / notional_usd
            adjusted = intent.size_shares * scale
            intent = OrderIntent(
                **{**intent.__dict__, "size_shares": adjusted}
            )

        # 7. per-market notional cap
        cur = self.market_notional(intent.condition_id)
        new_notional = intent.size_shares * intent.target_price
        if cur + new_notional > self.max_per_market_notional:
            return RiskDecision(False, intent, "per_market_cap_exceeded")

        # All checks passed
        if intent.signal_hash:
            self._record_signal(intent.signal_hash)

        return RiskDecision(True, intent, adjusted_size=adjusted)

    # --- accounting ---

    def record_fill(self, condition_id: str, notional_usd: float, pnl_delta: float = 0.0):
        """Call after a fill is confirmed."""
        day = time.strftime("%Y-%m-%d", time.gmtime())
        self.db.execute(
            """
            INSERT INTO pnl_daily (day, realized, trades_count, wins, losses)
            VALUES (?, ?, 1, ?, ?)
            ON CONFLICT(day) DO UPDATE SET
              realized = realized + excluded.realized,
              trades_count = trades_count + 1,
              wins = wins + excluded.wins,
              losses = losses + excluded.losses
            """,
            (day, pnl_delta, 1 if pnl_delta > 0 else 0, 1 if pnl_delta < 0 else 0),
        )
        self.db.execute(
            """
            INSERT INTO market_notional (condition_id, notional_usd, updated_ts)
            VALUES (?, ?, ?)
            ON CONFLICT(condition_id) DO UPDATE SET
              notional_usd = notional_usd + excluded.notional_usd,
              updated_ts = excluded.updated_ts
            """,
            (condition_id, notional_usd, int(time.time())),
        )

    def today_pnl(self) -> float:
        day = time.strftime("%Y-%m-%d", time.gmtime())
        row = self.db.execute(
            "SELECT realized FROM pnl_daily WHERE day = ?", (day,)
        ).fetchone()
        return row[0] if row else 0.0

    def market_notional(self, condition_id: str) -> float:
        row = self.db.execute(
            "SELECT notional_usd FROM market_notional WHERE condition_id = ?",
            (condition_id,),
        ).fetchone()
        return row[0] if row else 0.0

    # --- dedup ---

    def _signal_recently_seen(self, signal_hash: str) -> bool:
        cutoff = int(time.time()) - self.dedup_window_s
        row = self.db.execute(
            "SELECT ts FROM signal_seen WHERE signal_hash = ? AND ts > ?",
            (signal_hash, cutoff),
        ).fetchone()
        return row is not None

    def _record_signal(self, signal_hash: str):
        self.db.execute(
            """INSERT INTO signal_seen (signal_hash, ts) VALUES (?, ?)
               ON CONFLICT(signal_hash) DO UPDATE SET ts = excluded.ts""",
            (signal_hash, int(time.time())),
        )

    # --- kill switch (manual) ---

    def trip_kill_switch(self, reason: str = "manual"):
        self._kill_switch = True

    def reset_kill_switch(self):
        self._kill_switch = False

    @property
    def is_killed(self) -> bool:
        return self._kill_switch

    # --- summary for the dashboard ---

    def status(self) -> dict:
        day = time.strftime("%Y-%m-%d", time.gmtime())
        row = self.db.execute(
            "SELECT realized, trades_count, wins, losses FROM pnl_daily WHERE day=?",
            (day,)
        ).fetchone()
        if row:
            realized, trades, wins, losses = row
        else:
            realized, trades, wins, losses = 0.0, 0, 0, 0

        return {
            "dry_run": self.dry_run,
            "kill_switch": self._kill_switch,
            "today_pnl": realized,
            "today_trades": trades,
            "today_wins": wins,
            "today_losses": losses,
            "win_rate": wins / trades if trades else 0.0,
            "daily_loss_remaining": self.max_daily_loss_usd + realized,
        }
