#!/usr/bin/env python3
"""
Generate Kaplan-Meier Survival Curves for Clinical Decision Support Documents

This script creates publication-quality survival curves with:
- Kaplan-Meier survival estimates
- 95% confidence intervals
- Log-rank test statistics
- Hazard ratios with confidence intervals
- Number at risk tables
- Median survival annotations

Dependencies: lifelines, matplotlib, pandas, numpy
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test, multivariate_logrank_test
from lifelines import CoxPHFitter
import argparse
from pathlib import Path


def load_survival_data(filepath):
    """
    Load survival data from CSV file.
    
    Expected columns:
    - patient_id: Unique patient identifier
    - time: Survival time (months or days)
    - event: Event indicator (1=event occurred, 0=censored)
    - group: Stratification variable (e.g., 'Biomarker+', 'Biomarker-')
    - Optional: Additional covariates for Cox regression
    
    Returns:
        pandas.DataFrame
    """
    df = pd.read_csv(filepath)
    
    # Validate required columns
    required_cols = ['patient_id', 'time', 'event', 'group']
    missing = [col for col in required_cols if col not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")
    
    # Convert event to boolean if needed
    df['event'] = df['event'].astype(bool)
    
    return df


def calculate_median_survival(kmf):
    """Calculate median survival with 95% CI."""
    median = kmf.median_survival_time_
    ci = kmf.confidence_interval_survival_function_
    
    # Find time when survival crosses 0.5
    if median == np.inf:
        return None, None, None
    
    # Get CI at median
    idx = np.argmin(np.abs(kmf.survival_function_.index - median))
    lower_ci = ci.iloc[idx]['KM_estimate_lower_0.95']
    upper_ci = ci.iloc[idx]['KM_estimate_upper_0.95']
    
    return median, lower_ci, upper_ci


def generate_kaplan_meier_plot(data, time_col='time', event_col='event', 
                               group_col='group', output_path='survival_curve.pdf',
                               title='Kaplan-Meier Survival Curve',
                               xlabel='Time (months)', ylabel='Survival Probability'):
    """
    Generate Kaplan-Meier survival curve comparing groups.
    
    Parameters:
        data: DataFrame with survival data
        time_col: Column name for survival time
        event_col: Column name for event indicator
        group_col: Column name for stratification
        output_path: Path to save figure
        title: Plot title
        xlabel: X-axis label (specify units)
        ylabel: Y-axis label
    """
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Get unique groups
    groups = data[group_col].unique()
    
    # Colors for groups (colorblind-friendly)
    colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', '#CA9161']
    
    kmf_models = {}
    median_survivals = {}
    
    # Plot each group
    for i, group in enumerate(groups):
        group_data = data[data[group_col] == group]
        
        # Fit Kaplan-Meier
        kmf = KaplanMeierFitter()
        kmf.fit(group_data[time_col], group_data[event_col], label=str(group))
        
        # Plot survival curve
        kmf.plot_survival_function(ax=ax, ci_show=True, color=colors[i % len(colors)],
                                   linewidth=2, alpha=0.8)
        
        # Store model
        kmf_models[group] = kmf
        
        # Calculate median survival
        median, lower, upper = calculate_median_survival(kmf)
        median_survivals[group] = (median, lower, upper)
    
    # Log-rank test
    if len(groups) == 2:
        group1_data = data[data[group_col] == groups[0]]
        group2_data = data[data[group_col] == groups[1]]
        
        results = logrank_test(
            group1_data[time_col], group2_data[time_col],
            group1_data[event_col], group2_data[event_col]
        )
        
        p_value = results.p_value
        test_statistic = results.test_statistic
        
        # Add log-rank test result to plot
        ax.text(0.02, 0.15, f'Log-rank test:\np = {p_value:.4f}',
               transform=ax.transAxes, fontsize=10,
               verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    else:
        # Multivariate log-rank for >2 groups
        results = multivariate_logrank_test(data[time_col], data[group_col], data[event_col])
        p_value = results.p_value
        test_statistic = results.test_statistic
        
        ax.text(0.02, 0.15, f'Log-rank test:\np = {p_value:.4f}\n({len(groups)} groups)',
               transform=ax.transAxes, fontsize=10,
               verticalalignment='top',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # Add median survival annotations
    y_pos = 0.95
    for group, (median, lower, upper) in median_survivals.items():
        if median is not None:
            ax.text(0.98, y_pos, f'{group}: {median:.1f} months (95% CI {lower:.1f}-{upper:.1f})',
                   transform=ax.transAxes, fontsize=9, ha='right',
                   verticalalignment='top')
        else:
            ax.text(0.98, y_pos, f'{group}: Not reached',
                   transform=ax.transAxes, fontsize=9, ha='right',
                   verticalalignment='top')
        y_pos -= 0.05
    
    # Formatting
    ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
    ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
    ax.set_title(title, fontsize=14, fontweight='bold', pad=15)
    ax.legend(loc='lower left', frameon=True, fontsize=10)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_ylim([0, 1.05])
    
    plt.tight_layout()
    
    # Save figure
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Survival curve saved to: {output_path}")
    
    # Also save as PNG for easy viewing
    png_path = Path(output_path).with_suffix('.png')
    plt.savefig(png_path, dpi=300, bbox_inches='tight')
    print(f"PNG version saved to: {png_path}")
    
    plt.close()
    
    return kmf_models, p_value


def generate_number_at_risk_table(data, time_col='time', event_col='event',
                                  group_col='group', time_points=None):
    """
    Generate number at risk table for survival analysis.
    
    Parameters:
        data: DataFrame with survival data
        time_points: List of time points for risk table (if None, auto-generate)
    
    Returns:
        DataFrame with number at risk at each time point
    """
    
    if time_points is None:
        # Auto-generate time points (every 6 months up to max time)
        max_time = data[time_col].max()
        time_points = np.arange(0, max_time + 6, 6)
    
    groups = data[group_col].unique()
    risk_table = pd.DataFrame(index=time_points, columns=groups)
    
    for group in groups:
        group_data = data[data[group_col] == group]
        
        for t in time_points:
            # Number at risk = patients who haven't had event and haven't been censored before time t
            at_risk = len(group_data[group_data[time_col] >= t])
            risk_table.loc[t, group] = at_risk
    
    return risk_table


def calculate_hazard_ratio(data, time_col='time', event_col='event', group_col='group',
                          reference_group=None):
    """
    Calculate hazard ratio using Cox proportional hazards regression.
    
    Parameters:
        data: DataFrame
        reference_group: Reference group for comparison (if None, uses first group)
    
    Returns:
        Hazard ratio, 95% CI, p-value
    """
    
    # Encode group as binary for Cox regression
    groups = data[group_col].unique()
    if len(groups) != 2:
        print("Warning: Cox HR calculation assumes 2 groups. Using first 2 groups.")
        groups = groups[:2]
    
    if reference_group is None:
        reference_group = groups[0]
    
    # Create binary indicator (1 for comparison group, 0 for reference)
    data_cox = data.copy()
    data_cox['group_binary'] = (data_cox[group_col] != reference_group).astype(int)
    
    # Fit Cox model
    cph = CoxPHFitter()
    cph.fit(data_cox[[time_col, event_col, 'group_binary']], 
            duration_col=time_col, event_col=event_col)
    
    # Extract results
    hr = np.exp(cph.params_['group_binary'])
    ci = np.exp(cph.confidence_intervals_.loc['group_binary'].values)
    p_value = cph.summary.loc['group_binary', 'p']
    
    return hr, ci[0], ci[1], p_value


def generate_report(data, output_dir, prefix='survival'):
    """
    Generate comprehensive survival analysis report.
    
    Creates:
    - Kaplan-Meier curves (PDF and PNG)
    - Number at risk table (CSV)
    - Statistical summary (TXT)
    - LaTeX table code (TEX)
    """
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate survival curve
    kmf_models, logrank_p = generate_kaplan_meier_plot(
        data,
        output_path=output_dir / f'{prefix}_kaplan_meier.pdf',
        title='Survival Analysis by Group'
    )
    
    # Number at risk table
    risk_table = generate_number_at_risk_table(data)
    risk_table.to_csv(output_dir / f'{prefix}_number_at_risk.csv')
    
    # Calculate hazard ratio
    hr, ci_lower, ci_upper, hr_p = calculate_hazard_ratio(data)
    
    # Generate statistical summary
    with open(output_dir / f'{prefix}_statistics.txt', 'w') as f:
        f.write("SURVIVAL ANALYSIS STATISTICAL SUMMARY\n")
        f.write("=" * 60 + "\n\n")
        
        groups = data['group'].unique()
        for group in groups:
            kmf = kmf_models[group]
            median = kmf.median_survival_time_
            
            # Calculate survival rates at common time points
            try:
                surv_12m = kmf.survival_function_at_times(12).values[0]
                surv_24m = kmf.survival_function_at_times(24).values[0] if data['time'].max() >= 24 else None
            except:
                surv_12m = None
                surv_24m = None
            
            f.write(f"Group: {group}\n")
            f.write(f"  N = {len(data[data['group'] == group])}\n")
            f.write(f"  Events = {data[data['group'] == group]['event'].sum()}\n")
            f.write(f"  Median survival: {median:.1f} months\n" if median != np.inf else "  Median survival: Not reached\n")
            if surv_12m is not None:
                f.write(f"  12-month survival rate: {surv_12m*100:.1f}%\n")
            if surv_24m is not None:
                f.write(f"  24-month survival rate: {surv_24m*100:.1f}%\n")
            f.write("\n")
        
        f.write(f"Log-Rank Test:\n")
        f.write(f"  p-value = {logrank_p:.4f}\n")
        f.write(f"  Interpretation: {'Significant' if logrank_p < 0.05 else 'Not significant'} difference in survival\n\n")
        
        if len(groups) == 2:
            f.write(f"Hazard Ratio ({groups[1]} vs {groups[0]}):\n")
            f.write(f"  HR = {hr:.2f} (95% CI {ci_lower:.2f}-{ci_upper:.2f})\n")
            f.write(f"  p-value = {hr_p:.4f}\n")
            f.write(f"  Interpretation: {groups[1]} has {((1-hr)*100):.0f}% {'reduction' if hr < 1 else 'increase'} in risk\n")
    
    # Generate LaTeX table code
    with open(output_dir / f'{prefix}_latex_table.tex', 'w') as f:
        f.write("% LaTeX table code for survival outcomes\n")
        f.write("\\begin{table}[H]\n")
        f.write("\\centering\n")
        f.write("\\small\n")
        f.write("\\begin{tabular}{lcccc}\n")
        f.write("\\toprule\n")
        f.write("\\textbf{Endpoint} & \\textbf{Group A} & \\textbf{Group B} & \\textbf{HR (95\\% CI)} & \\textbf{p-value} \\\\\n")
        f.write("\\midrule\n")
        
        # Add median survival row
        for i, group in enumerate(groups):
            kmf = kmf_models[group]
            median = kmf.median_survival_time_
            if i == 0:
                f.write(f"Median survival, months (95\\% CI) & ")
                if median != np.inf:
                    f.write(f"{median:.1f} & ")
                else:
                    f.write("NR & ")
            else:
                if median != np.inf:
                    f.write(f"{median:.1f} & ")
                else:
                    f.write("NR & ")
        
        f.write(f"{hr:.2f} ({ci_lower:.2f}-{ci_upper:.2f}) & {hr_p:.3f} \\\\\n")
        
        # Add 12-month survival rate
        f.write("12-month survival rate (\\%) & ")
        for group in groups:
            kmf = kmf_models[group]
            try:
                surv_12m = kmf.survival_function_at_times(12).values[0]
                f.write(f"{surv_12m*100:.0f}\\% & ")
            except:
                f.write("-- & ")
        f.write("-- & -- \\\\\n")
        
        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")
        f.write(f"\\caption{{Survival outcomes by group (log-rank p={logrank_p:.3f})}}\n")
        f.write("\\end{table}\n")
    
    print(f"\nAnalysis complete! Files saved to {output_dir}/")
    print(f"  - Survival curves: {prefix}_kaplan_meier.pdf/png")
    print(f"  - Statistics: {prefix}_statistics.txt")
    print(f"  - LaTeX table: {prefix}_latex_table.tex")
    print(f"  - Risk table: {prefix}_number_at_risk.csv")


def main():
    parser = argparse.ArgumentParser(description='Generate Kaplan-Meier survival curves')
    parser.add_argument('input_file', type=str, help='CSV file with survival data')
    parser.add_argument('-o', '--output', type=str, default='survival_output',
                       help='Output directory (default: survival_output)')
    parser.add_argument('-t', '--title', type=str, default='Kaplan-Meier Survival Curve',
                       help='Plot title')
    parser.add_argument('-x', '--xlabel', type=str, default='Time (months)',
                       help='X-axis label')
    parser.add_argument('-y', '--ylabel', type=str, default='Survival Probability',
                       help='Y-axis label')
    parser.add_argument('--time-col', type=str, default='time',
                       help='Column name for time variable')
    parser.add_argument('--event-col', type=str, default='event',
                       help='Column name for event indicator')
    parser.add_argument('--group-col', type=str, default='group',
                       help='Column name for grouping variable')
    
    args = parser.parse_args()
    
    # Load data
    print(f"Loading data from {args.input_file}...")
    data = load_survival_data(args.input_file)
    print(f"Loaded {len(data)} patients")
    print(f"Groups: {data[args.group_col].value_counts().to_dict()}")
    
    # Generate analysis
    generate_report(
        data,
        output_dir=args.output,
        prefix='survival'
    )


if __name__ == '__main__':
    main()


# Example usage:
# python generate_survival_analysis.py survival_data.csv -o figures/ -t "PFS by PD-L1 Status"
#
# Input CSV format:
# patient_id,time,event,group
# PT001,12.3,1,PD-L1+
# PT002,8.5,1,PD-L1-
# PT003,18.2,0,PD-L1+
# ...

