#!/usr/bin/env python3

import datetime as dt
import numpy as np
import pandas as pd
import pytz
import os

import dukascopy_python
from dukascopy_python import instruments as dk_instruments

# --- INSTRUMENT (run list_instruments.py to find the right constant) ---
DAX_INSTRUMENT = 'E_DAAX'

if DAX_INSTRUMENT is None:
    for attr_name in sorted(dir(dk_instruments)):
        if not attr_name.startswith("INSTRUMENT_"):
            continue
        upper = attr_name.upper()
        if "IDX" in upper and ("DEU" in upper or "GER" in upper):
            DAX_INSTRUMENT = getattr(dk_instruments, attr_name)
            print(f"[BLOCK 1] Auto-discovered: {attr_name}")
            break

if DAX_INSTRUMENT is None:
    raise RuntimeError(
        "DAX instrument not found. Run list_instruments.py, find the "
        "DAX index constant, and set DAX_INSTRUMENT at the top of this file."
    )

# --- DATE RANGE ---
END_DATE   = dt.datetime.now(dt.timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
START_DATE = END_DATE - dt.timedelta(days=6 * 365)

# --- CACHE ---
CACHE_FILE = "dax_m1_cache.pkl"

if os.path.exists(CACHE_FILE):
    print(f"[BLOCK 1] Loading cached data from {CACHE_FILE}...")
    df_m1 = pd.read_pickle(CACHE_FILE)
    print(f"[BLOCK 1] Loaded {len(df_m1):,} rows from cache.")
else:
    print(f"[BLOCK 1] Requesting DAX M1: {START_DATE:%Y-%m-%d} → {END_DATE:%Y-%m-%d}")
    df_m1 = dukascopy_python.fetch(
        DAX_INSTRUMENT,
        dukascopy_python.INTERVAL_MIN_1,
        dukascopy_python.OFFER_SIDE_BID,
        START_DATE,
        END_DATE,
    )
    if df_m1.index.tz is None:
        df_m1.index = df_m1.index.tz_localize("UTC")
    df_m1.sort_index(inplace=True)
    df_m1.rename(columns={
        "open": "Open", "high": "High", "low": "Low",
        "close": "Close", "volume": "Volume",
    }, inplace=True)
    for col in ["Open", "High", "Low", "Close", "Volume"]:
        df_m1[col] = pd.to_numeric(df_m1[col], errors="coerce")
    df_m1.to_pickle(CACHE_FILE)
    print(f"[BLOCK 1] Saved {len(df_m1):,} rows to {CACHE_FILE}")

print(f"[BLOCK 1] M1 rows ingested : {len(df_m1):,}")
print(f"[BLOCK 1] Range            : {df_m1.index.min()} → {df_m1.index.max()}")

CET = pytz.timezone("Europe/Berlin")
cet_hours = df_m1.index.tz_convert(CET).hour
has_premarket = (cet_hours < 9).any()
earliest_cet = df_m1.index.tz_convert(CET).time.min()
print(f"[BLOCK 1] Earliest bar CET : {earliest_cet}")
if not has_premarket:
    print("[BLOCK 1] WARNING: No bars before 09:00 CET — check your instrument.")
else:
    print(f"[BLOCK 1] Pre-market data present (earliest: {earliest_cet})\n")


OHLCV_AGG = {"Open": "first", "High": "max", "Low": "min", "Close": "last", "Volume": "sum"}
df_m5  = df_m1.resample("5min").agg(OHLCV_AGG).dropna(subset=["Open"])
df_m15 = df_m1.resample("15min").agg(OHLCV_AGG).dropna(subset=["Open"])

print(f"[BLOCK 2A] M5  bars : {len(df_m5):,}")
print(f"[BLOCK 2A] M15 bars : {len(df_m15):,}")

# --- EMA PERIODS ---
EMA_FAST = 5
EMA_SLOW = 33

df_m5["ema_fast"] = df_m5["Close"].ewm(span=EMA_FAST, adjust=False).mean()
df_m5["ema_slow"] = df_m5["Close"].ewm(span=EMA_SLOW, adjust=False).mean()

rows_pre = len(df_m5)
df_m5.dropna(subset=["ema_fast", "ema_slow"], inplace=True)
print(f"[BLOCK 2B] EMA({EMA_FAST}/{EMA_SLOW}) — dropped {rows_pre - len(df_m5):,} warm-up rows")
print(f"[BLOCK 2B] Clean M5 rows : {len(df_m5):,}")

df_m5["cet_time"]  = df_m5.index.tz_convert(CET).time
df_m15["cet_time"] = df_m15.index.tz_convert(CET).time

# --- ENTRY / EXIT WINDOWS (CET) ---
ENTRY_START = dt.time(8, 30)
ENTRY_END   = dt.time(9, 0)
EXIT_AFTER  = dt.time(9, 0)

print(f"[BLOCK 2C] Entry window : {ENTRY_START}–{ENTRY_END} CET")
print(f"[BLOCK 2C] Exit window  : >= {EXIT_AFTER} CET\n")


fast = df_m5["ema_fast"].values
slow = df_m5["ema_slow"].values

fast_prev = np.empty_like(fast); fast_prev[0] = np.nan; fast_prev[1:] = fast[:-1]
slow_prev = np.empty_like(slow); slow_prev[0] = np.nan; slow_prev[1:] = slow[:-1]

cross_above = (fast_prev <= slow_prev) & (fast > slow)
cross_below = (fast_prev >= slow_prev) & (fast < slow)

in_entry_window = np.array(
    [(ENTRY_START <= t < ENTRY_END) for t in df_m5["cet_time"]], dtype=bool,
)

long_entries  = cross_above & in_entry_window
short_entries = cross_below & in_entry_window

print(f"[BLOCK 3A] Long  entry signals : {long_entries.sum():,}")
print(f"[BLOCK 3A] Short entry signals : {short_entries.sum():,}")

m15_bearish = (df_m15["Close"] < df_m15["Open"])
m15_bullish = (df_m15["Close"] > df_m15["Open"])

m15_after_exit = np.array([t >= EXIT_AFTER for t in df_m15["cet_time"]], dtype=bool)

m15_long_exit  = m15_bearish & m15_after_exit
m15_short_exit = m15_bullish & m15_after_exit

long_exits_ff  = m15_long_exit.reindex(df_m5.index, method="ffill").fillna(False)
short_exits_ff = m15_short_exit.reindex(df_m5.index, method="ffill").fillna(False)

long_exits  = (long_exits_ff & ~long_exits_ff.shift(1, fill_value=False)).values
short_exits = (short_exits_ff & ~short_exits_ff.shift(1, fill_value=False)).values

print(f"[BLOCK 3B] Long  exit signals  : {long_exits.sum():,}")
print(f"[BLOCK 3B] Short exit signals  : {short_exits.sum():,}")

# --- BACKTEST PARAMETERS ---
INIT_CAPITAL   = 100_000.0
RISK_PER_TRADE = 0.1       # fraction of equity per entry
LEVERAGE       = 20        # broker leverage multiplier
MAX_DD_HALT    = 0.5       # halt if drawdown exceeds this
MAX_LEVERAGE   = 3         # max stacked allocation (pre-leverage)
FEES           = 0.0001    # per side
SLIPPAGE       = 0.0001    # per side

cash           = INIT_CAPITAL
position_dir   = 0
layers         = []
current_pct    = RISK_PER_TRADE
peak_equity    = INIT_CAPITAL
halted         = False
halt_time      = None

n_long_executed  = 0
n_short_executed = 0
n_long_skipped   = 0
n_short_skipped  = 0
min_equity_seen  = INIT_CAPITAL
min_equity_time  = None

n_bars       = len(df_m5)
equity_curve = np.full(n_bars, np.nan)
trade_log    = []

close_arr = df_m5["Close"].values
cet_arr   = df_m5["cet_time"].values

print(f"\n[BLOCK 3C] Running forward-scan backtest on {n_bars:,} M5 bars...")

for i in range(n_bars):
    price = close_arr[i]
    cet_t = cet_arr[i]

    unrealized = 0.0
    if position_dir == 1:
        unrealized = sum(units * (price - ep) for ep, units in layers)
    elif position_dir == -1:
        unrealized = sum(units * (ep - price) for ep, units in layers)

    mtm_equity = cash + unrealized
    equity_curve[i] = mtm_equity

    if mtm_equity > peak_equity:
        peak_equity = mtm_equity
    if mtm_equity < min_equity_seen:
        min_equity_seen = mtm_equity
        min_equity_time = df_m5.index[i]

    if not halted:
        halt_reason = None

        if mtm_equity <= 0:
            halt_reason = f"ACCOUNT WIPED OUT (equity hit {mtm_equity:,.2f})"
        elif peak_equity > 0:
            dd = (peak_equity - mtm_equity) / peak_equity
            if dd >= MAX_DD_HALT:
                halt_reason = (
                    f"{MAX_DD_HALT*100:.0f}% DRAWDOWN HALT "
                    f"(equity {mtm_equity:,.2f}, peak {peak_equity:,.2f}, DD {dd*100:.1f}%)"
                )

        if halt_reason is not None:
            if position_dir != 0 and len(layers) > 0:
                forced_pnl = 0.0
                for ep, units in layers:
                    if position_dir == 1:
                        forced_pnl += units * (price - ep)
                    else:
                        forced_pnl += units * (ep - price)
                cash += forced_pnl
                trade_log.append({
                    "exit_time": df_m5.index[i],
                    "direction": "LONG" if position_dir == 1 else "SHORT",
                    "layers":    len(layers),
                    "gross_pnl": forced_pnl,
                    "fees":      0.0,
                    "net_pnl":   forced_pnl,
                })
                position_dir = 0
                layers       = []
                mtm_equity   = cash
                equity_curve[i] = mtm_equity

            halted    = True
            halt_time = df_m5.index[i].tz_convert(CET)
            print(f"\n{'='*68}")
            print(f"[BLOCK 3C] HALT TRIGGERED at {halt_time}")
            print(f"           {halt_reason}")
            print(f"{'='*68}\n")

    if halted:
        continue

    if ENTRY_START <= cet_t < ENTRY_END:

        if long_entries[i]:
            if position_dir == -1:
                n_long_skipped += 1
            else:
                if position_dir == 1:
                    current_pct = min(current_pct * 2.0, MAX_LEVERAGE)
                else:
                    current_pct  = RISK_PER_TRADE
                    position_dir = 1

                fill_price = price * (1 + SLIPPAGE)
                notional   = mtm_equity * current_pct * LEVERAGE
                units      = notional / fill_price
                cash      -= notional * FEES
                layers.append((fill_price, units))
                n_long_executed += 1

        elif short_entries[i]:
            if position_dir == 1:
                n_short_skipped += 1
            else:
                if position_dir == -1:
                    current_pct = min(current_pct * 2.0, MAX_LEVERAGE)
                else:
                    current_pct  = RISK_PER_TRADE
                    position_dir = -1

                fill_price = price * (1 - SLIPPAGE)
                notional   = mtm_equity * current_pct * LEVERAGE
                units      = notional / fill_price
                cash      -= notional * FEES
                layers.append((fill_price, units))
                n_short_executed += 1

    if cet_t >= EXIT_AFTER and position_dir != 0:
        should_exit = False

        if position_dir == 1 and long_exits[i]:
            fill_price  = price * (1 - SLIPPAGE)
            should_exit = True
        elif position_dir == -1 and short_exits[i]:
            fill_price  = price * (1 + SLIPPAGE)
            should_exit = True

        if should_exit:
            total_pnl          = 0.0
            total_exit_notional = 0.0
            for ep, units in layers:
                if position_dir == 1:
                    total_pnl += units * (fill_price - ep)
                else:
                    total_pnl += units * (ep - fill_price)
                total_exit_notional += units * fill_price

            exit_fee = total_exit_notional * FEES
            net_pnl  = total_pnl - exit_fee
            cash    += total_pnl - exit_fee

            trade_log.append({
                "exit_time": df_m5.index[i],
                "direction": "LONG" if position_dir == 1 else "SHORT",
                "layers":    len(layers),
                "gross_pnl": total_pnl,
                "fees":      exit_fee + sum(u * ep * FEES for ep, u in layers),
                "net_pnl":   net_pnl,
            })

            position_dir = 0
            layers       = []
            current_pct  = RISK_PER_TRADE


print("\n" + "─" * 68)
print("  EXECUTION DIAGNOSTICS")
print("─" * 68)
print(f"  Long  signals fired      : {long_entries.sum():>10,}")
print(f"  Long  entries executed   : {n_long_executed:>10,}")
print(f"  Long  entries skipped    : {n_long_skipped:>10,}  (blocked by open short)")
print(f"  Short signals fired      : {short_entries.sum():>10,}")
print(f"  Short entries executed   : {n_short_executed:>10,}")
print(f"  Short entries skipped    : {n_short_skipped:>10,}  (blocked by open long)")
print(f"  Closing trades logged    : {len(trade_log):>10,}")
print("─" * 68)
print(f"  Peak equity              : €{peak_equity:>14,.2f}")
print(f"  Min  equity              : €{min_equity_seen:>14,.2f}")
if min_equity_time is not None:
    print(f"  Min  equity timestamp    : {min_equity_time}")
print(f"  Final cash               : €{cash:>14,.2f}")
print("─" * 68)
if halted:
    print(f"  HALT TRIGGERED at        : {halt_time}")
else:
    print(f"  Halt NOT triggered (max DD stayed below {MAX_DD_HALT*100:.0f}%)")
print("─" * 68)


equity_series = pd.Series(equity_curve, index=df_m5.index).dropna()
returns       = equity_series.pct_change().dropna()

total_return  = (equity_series.iloc[-1] / INIT_CAPITAL) - 1
ann_vol       = returns.std() * np.sqrt(252 * 78)
running_max   = equity_series.cummax()
dd_series     = (equity_series - running_max) / running_max
max_drawdown  = abs(dd_series.min())
sharpe        = (returns.mean() / returns.std()) * np.sqrt(252 * 78) if returns.std() > 0 else 0.0
final_equity  = equity_series.iloc[-1]

n_trades = len(trade_log)
if n_trades > 0:
    pnl_arr      = np.array([t["net_pnl"] for t in trade_log])
    win_rate     = (pnl_arr > 0).sum() / n_trades
    gross_wins   = pnl_arr[pnl_arr > 0].sum()
    gross_losses = abs(pnl_arr[pnl_arr < 0].sum())
    profit_factor = gross_wins / gross_losses if gross_losses > 0 else float("inf")
    avg_pnl      = pnl_arr.mean()
else:
    win_rate = profit_factor = avg_pnl = 0.0

if not halted:
    print(f"\n[BLOCK 3D] Max DD ({max_drawdown*100:.2f}%) within {MAX_DD_HALT*100:.0f}% tolerance.")

print("\n" + "=" * 68)
print("                      STRATEGY TEAR SHEET")
print("=" * 68)
print(f"  Asset                : DAX Index")
print(f"  Period               : {df_m5.index.min()} → {df_m5.index.max()}")
print(f"  Timeframes           : M5 (entries) / M15 (exits)")
print(f"  Signal               : EMA({EMA_FAST}) x EMA({EMA_SLOW}) on M5")
print(f"  Entry Window         : 08:30–09:00 CET")
print(f"  Exit Logic           : 1st contrarian M15 candle >= 09:00 CET")
print("-" * 68)
print(f"  Initial Capital      : €{INIT_CAPITAL:>12,.0f}")
print(f"  Leverage             : {LEVERAGE:>12.1f}x")
print(f"  Final Equity         : €{final_equity:>12,.2f}")
print(f"  Fees (per side)      : {FEES * 10_000:.1f} bps")
print(f"  Slippage (per side)  : {SLIPPAGE * 10_000:.1f} bps")
print("-" * 68)
print(f"  Total Return         : {total_return * 100:>+10.2f}%")
print(f"  Annualised Volatility: {ann_vol * 100:>10.2f}%")
print(f"  Maximum Drawdown     : {max_drawdown * 100:>10.2f}%")
print(f"  Sharpe Ratio         : {sharpe:>10.3f}")
print("-" * 68)
print(f"  Total Trades         : {n_trades:>10,}")
print(f"  Win Rate             : {win_rate * 100:>10.1f}%")
print(f"  Profit Factor        : {profit_factor:>10.3f}")
print(f"  Avg Trade PnL        : €{avg_pnl:>10,.2f}")
print("=" * 68)

if n_trades > 0:
    trade_df = pd.DataFrame(trade_log)
    trade_df.to_csv("trade_log.csv", index=False)
    print(f"\n[OUTPUT] Trade log saved to trade_log.csv ({n_trades} trades)")

print("\n[PIPELINE COMPLETE]")
