#!/usr/bin/env python3
"""
Project 8 (v3b): Sidereal Feature Importance
============================================
Analyze which planets in the Sidereal Zodiac contributed most 
to the 26.7% accuracy (beating Tropical's 20.9%).

Method:
1. Train Random Forest on the Sidereal Dataset (Full Chart).
2. Extract Feature Importances.
3. Aggregate the "One-Hot" importances back to their parent Planet.
   (e.g., Importance(Moon) = Sum(Importance(Moon_Aries), Importance(Moon_Taurus)...))
"""

import numpy as np
import pandas as pd
import swisseph as swe
from pathlib import Path
import sys
from datetime import datetime
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer

# Import Data from Project 6
sys.path.append(str(Path(__file__).parent.parent / "06-harmonic-analysis-aspects"))
try:
    from celebrity_data import CELEBRITY_DATA
except ImportError:
    print("Error: Could not import celebrity_data.py")
    sys.exit(1)

swe.set_ephe_path(None)

# Bodies for Sign Features (Traditional 9 Grahas)
SIGN_BODIES = {
    swe.SUN: 'Sun', swe.MOON: 'Moon', swe.MERCURY: 'Mercury',
    swe.VENUS: 'Venus', swe.MARS: 'Mars', swe.JUPITER: 'Jupiter',
    swe.SATURN: 'Saturn', swe.MEAN_NODE: 'Rahu'
}

# --- Copied Utility Functions from v3 ---
def get_sign(lon):
    return int(lon / 30) % 12

def process_chart(entry):
    s_date = f"{entry['date']} {entry['time']}"
    try:
        dt = datetime.strptime(s_date, "%Y-%m-%d %H:%M")
    except ValueError:
        return None

    hour = dt.hour + dt.minute/60.0
    jd = swe.julday(dt.year, dt.month, dt.day, hour)

    # Sidereal Mode (Lahiri)
    swe.set_sid_mode(swe.SIDM_LAHIRI)
    sid_features = {}

    for pid, name in SIGN_BODIES.items():
        lon = swe.calc_ut(jd, pid, swe.FLG_SIDEREAL)[0][0]
        sid_features[f"{name}_Sid"] = get_sign(lon)

        if name == 'Rahu':
            ketu_lon = (lon + 180) % 360
            sid_features['Ketu_Sid'] = get_sign(ketu_lon)

    # Add dummy harmonics to match shape if needed, or just focus on Signs
    # Let's include harmonics to be consistent with the winning model
    # Need tropical positions for harmonics
    swe.set_sid_mode(0)
    trop_positions = [swe.calc_ut(jd, pid, 0)[0][0] for pid in [swe.SUN, swe.MOON, swe.MERCURY, swe.VENUS, swe.MARS, swe.JUPITER, swe.SATURN, swe.URANUS, swe.NEPTUNE, swe.PLUTO]]

    # Simple Harmonic Calc
    def h_strength(h):
        angles = []
        for i in range(len(trop_positions)):
            for j in range(i+1, len(trop_positions)):
                diff = abs(trop_positions[i] - trop_positions[j])
                if diff > 180: diff = 360 - diff
                angles.append(diff)
        rads = np.deg2rad(np.array(angles) * h)
        return np.sqrt(np.mean(np.cos(rads))**2 + np.mean(np.sin(rads))**2)

    return {
        'Category': entry['category'],
        'H4': h_strength(4), 
        'H5': h_strength(5), 
        'H7': h_strength(7),
        **sid_features
    }

def analyze_importance():
    print("Generating Sidereal Dataset...")
    data = []
    for entry in CELEBRITY_DATA:
        if 'category' not in entry: continue
        row = process_chart(entry)
        if row: data.append(row)

    df = pd.DataFrame(data)
    X = df.drop('Category', axis=1)
    y = df['Category']

    # Features
    harm_cols = ['H4', 'H5', 'H7']
    sid_cols = [c for c in df.columns if "_Sid" in c]

    # Preprocessor
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', 'passthrough', harm_cols),
            ('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), sid_cols)
        ])

    # Fit Model
    print("Training Model to extract feature importance...")
    clf = RandomForestClassifier(n_estimators=500, random_state=42)
    X_processed = preprocessor.fit_transform(X, y)
    clf.fit(X_processed, y)

    # Get Feature Names
    # Num features are first
    feature_names = harm_cols.copy()

    # Cat features from OneHot
    ohe = preprocessor.named_transformers_['cat']
    cat_names = ohe.get_feature_names_out(sid_cols)
    feature_names.extend(cat_names)

    importances = clf.feature_importances_

    # Aggregate by Parent Feature
    # Map "Sun_Sid_1", "Sun_Sid_2" -> "Sun"

    aggregated = {}

    # Initialize keys
    all_parents = harm_cols + ['Sun', 'Moon', 'Mercury', 'Venus', 'Mars', 'Jupiter', 'Saturn', 'Rahu', 'Ketu']
    for p in all_parents:
        aggregated[p] = 0.0

    for name, imp in zip(feature_names, importances):
        if name in harm_cols:
            aggregated[name] += imp
        else:
            # name looks like "Moon_Sid_3.0"
            # Split to get "Moon"
            # original col name was "Moon_Sid"
            # OHE name is "Moon_Sid_3"
            parts = name.split('_') 
            # Parts: ['Moon', 'Sid', '3']
            planet = parts[0] # Moon
            if planet in aggregated:
                aggregated[planet] += imp

    # Sort and Print
    print("\n--- SIDEREAL FEATURE IMPORTANCE ---")
    print("(Contribution to Profession Classification)")
    sorted_feats = sorted(aggregated.items(), key=lambda x: x[1], reverse=True)

    rank = 1
    for name, imp in sorted_feats:
        print(f"{rank}. {name:<10}: {imp:.4f}")
        rank += 1

if __name__ == "__main__":
    analyze_importance()