"""
AWS Glue Feature Engineering Pipeline for Airline Delay Analysis
This script processes airline delay data to create engineered features for ML models.
It reads from source tables, applies transformations, and writes to Iceberg feature tables.
Features Generated:
- Delay rates (total_delay_rate)
- Cancellation and diversion rates
- One-hot encoded categorical features (carrier, airport)
Author: Feature Engineering Team
Version: 1.0
Last Updated: January 2025
"""
import sys
import logging
import boto3
from awsglue.transforms import *
from awsglue.utils import getResolvedOptions
from awsglue.context import GlueContext
from awsglue.job import Job
from awsgluedq.transforms import EvaluateDataQuality
from pyspark.context import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lit, udf
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.ml.feature import StringIndexer, OneHotEncoder
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Parse job parameters
## @params: [JOB_NAME]
args = getResolvedOptions(sys.argv, ['JOB_NAME'])
logger.info(f"Starting Glue job: {args['JOB_NAME']}")
# Initialize Spark and Glue contexts
logger.info("Initializing Spark and Glue contexts...")
sc = SparkContext()
glueContext = GlueContext(sc)
spark = glueContext.spark_session
job = Job(glueContext)
job.init(args['JOB_NAME'], args)
logger.info("Contexts initialized successfully")

def get_aws_context():
    """Get current AWS account ID and region"""
    sts = boto3.client('sts')
    session = boto3.Session()
    account_id = sts.get_caller_identity()['Account']
    region = session.region_name
    return account_id, region

def configure_spark_with_iceberg(s3_tables_arn, region="us-east-1", app_name="glue-s3-tables-rest"):
    """
    Configure Spark session with Iceberg and S3 Tables integration.
    
    Args:
        s3_tables_arn (str): ARN of the S3 Tables bucket
        region (str): AWS region (default: us-east-1)
        app_name (str): Spark application name
        
    Returns:
        SparkSession: Configured Spark session with Iceberg support

    Usgae :
    # Replace the existing configuration with:
        s3_tables_arn = "arn:aws:s3tables:us-east-1:733892979501:bucket/airlines"
        region = "us-east-1"
        spark = configure_spark_with_iceberg(s3_tables_arn, region)

    """
    logger.info("Configuring Spark session with Iceberg support...")
    
    spark = SparkSession.builder.appName(app_name) \
        .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") \
        .config("spark.sql.defaultCatalog", "s3_rest_catalog") \
        .config("spark.sql.catalog.s3_rest_catalog", "org.apache.iceberg.spark.SparkCatalog") \
        .config("spark.sql.catalog.s3_rest_catalog.type", "rest") \
        .config("spark.sql.catalog.s3_rest_catalog.uri", f"https://s3tables.{region}.amazonaws.com/iceberg") \
        .config("spark.sql.catalog.s3_rest_catalog.warehouse", s3_tables_arn) \
        .config("spark.sql.catalog.s3_rest_catalog.rest.sigv4-enabled", "true") \
        .config("spark.sql.catalog.s3_rest_catalog.rest.signing-name", "s3tables") \
        .config("spark.sql.catalog.s3_rest_catalog.rest.signing-region", region) \
        .config('spark.sql.catalog.s3_rest_catalog.io-impl', 'org.apache.iceberg.aws.s3.S3FileIO') \
        .config('spark.sql.catalog.s3_rest_catalog.rest-metrics-reporting-enabled', 'false') \
        .config('spark.sql.parquet.enableVectorizedReader', 'false') \
        .getOrCreate()
    
    logger.info("Spark session configured with Iceberg catalog")
    return spark

def calculate_delay_rates(input_df):
    """
    Calculate delay rates for airline operations.
    
    This function computes the total delay rate as the ratio of delayed flights
    to total flights, handling division by zero cases.
    
    Args:
        input_df (DataFrame): Input DataFrame with arr_flights and arr_del15 columns
        
    Returns:
        DataFrame: DataFrame with added total_delay_rate column
        
    Business Logic:
        - total_delay_rate = arr_del15 / arr_flights
        - Returns 0 when arr_flights = 0 to avoid division by zero
        - Precision: decimal(10,4) for accurate rate calculations
    """
    logger.info("Calculating delay rates...")
    
    input_df = input_df.withColumn('total_delay_rate',
        when(col('arr_flights') == 0, lit(0))
        .otherwise((col('arr_del15') / col('arr_flights')).cast('decimal(10,4)')))
    
    logger.info("Delay rates calculated successfully")
    return input_df

def calculate_cancellation_diversion_rates(input_df):
    """
    Calculate cancellation and diversion rates for airline operations.
    
    This function computes operational disruption metrics as ratios of
    cancelled/diverted flights to total flights.
    
    Args:
        input_df (DataFrame): Input DataFrame with flight operation columns
        
    Returns:
        DataFrame: DataFrame with added cancellation_rate and diversion_rate columns
        
    Business Logic:
        - cancellation_rate = arr_cancelled / arr_flights (precision: 10,4)
        - diversion_rate = arr_diverted / arr_flights (precision: 10,6)
        - Higher precision for diversion_rate due to typically lower values
    """
    logger.info("Calculating cancellation and diversion rates...")
    
    input_df = input_df.withColumn('cancellation_rate', 
        (col('arr_cancelled') / col('arr_flights')).cast('decimal(10,4)'))
    input_df = input_df.withColumn('diversion_rate', 
        (col('arr_diverted') / col('arr_flights')).cast('decimal(10,6)'))
    
    logger.info("Cancellation and diversion rates calculated successfully")
    return input_df

def encode_categorical_features(input_df):
    """
    Encode categorical features using StringIndexer and OneHotEncoder.
    
    This function converts categorical variables (carrier, airport) into
    numerical representations suitable for ML algorithms.
    
    Args:
        input_df (DataFrame): Input DataFrame with carrier and airport columns
        
    Returns:
        DataFrame: DataFrame with encoded categorical features as vectors
        
    Process:
        1. StringIndexer: Maps string categories to numerical indices
        2. OneHotEncoder: Converts indices to binary vector representations
        3. Applied to both 'carrier' and 'airport' columns
        
    Output Columns:
        - carrier_indexed, carrier_vec
        - airport_indexed, airport_vec
    """
    logger.info("Starting categorical feature encoding...")
    
    # Encode carrier feature
    logger.info("Encoding carrier feature...")
    indexer_carrier = StringIndexer(inputCol="carrier", outputCol="carrier_indexed").fit(input_df)
    df_indexed = indexer_carrier.transform(input_df)
    encoder_carrier = OneHotEncoder(inputCol="carrier_indexed", outputCol="carrier_vec")
    df_encoded = encoder_carrier.fit(df_indexed).transform(df_indexed)
    
    # Encode airport feature
    logger.info("Encoding airport feature...")
    indexer_airport = StringIndexer(inputCol="airport", outputCol="airport_indexed").fit(df_encoded)
    df_indexed_encoded = indexer_airport.transform(df_encoded)
    encoder_airport = OneHotEncoder(inputCol="airport_indexed", outputCol="airport_vec")
    df_final = encoder_airport.fit(df_indexed_encoded).transform(df_indexed_encoded)
    
    logger.info("Categorical feature encoding completed successfully")
    return df_final

# ============================================================================
# MAIN FEATURE ENGINEERING PIPELINE
# ============================================================================

# Configuration parameters
# source
src_catalog = "spark_catalog"
src_database = "curated_db"
src_table = "airline_delay_cause"

# target
account_id, region = get_aws_context()
s3_tables_arn = f"arn:aws:s3tables:{region}:{account_id}:bucket/airlines"
tgt_catalog = "s3_rest_catalog"
tgt_database = s3_tables_arn.split(':')[-1].split('/')[-1] 
tgt_table_name = "fg_airline_features"

# Derived configurations
target_table = f"{tgt_catalog}.{tgt_database}.{tgt_table_name}"

logger.info(f"Starting feature engineering pipeline with S3 Tables ARN: {s3_tables_arn}")
logger.info(f"Target table: {target_table}")

# Configure Spark and read source data
spark = configure_spark_with_iceberg(s3_tables_arn, region)
df = spark.sql(f"SELECT * FROM `{src_catalog}`.`{src_database}`.`{src_table}`")
logger.info(f"Source data loaded. Record count: {df.count()}")

# Apply transformations
logger.info("Applying feature transformations...")
df = calculate_delay_rates(df)
df = calculate_cancellation_diversion_rates(df)
df = encode_categorical_features(df)

# Convert vectors to arrays and cast to string for Iceberg compatibility
vector_to_string = udf(lambda v: str(v.toArray().tolist()) if v else "[]")
df_final = df.withColumn("carrier_vec_array", vector_to_string("carrier_vec")) \
             .withColumn("airport_vec_array", vector_to_string("airport_vec")) \
             .drop("carrier_vec", "airport_vec")

# Write to target table
df_final.createOrReplaceTempView("temp_features")
spark.sql(f"INSERT OVERWRITE {target_table} SELECT * FROM temp_features")

# Log final statistics
final_count = spark.sql(f"SELECT COUNT(*) as count FROM {target_table}").collect()[0]['count']
logger.info(f"Pipeline completed. Final record count: {final_count}")

# ============================================================================
# JOB COMPLETION
# ============================================================================
logger.info("Feature engineering pipeline completed successfully")
logger.info("Committing Glue job...")
job.commit()
logger.info("Glue job committed successfully")

