"""
Ride Service
Business logic for ride booking, driver matching, and fare calculation
"""
import secrets
import string
from datetime import datetime, timedelta
from typing import Optional, List, Tuple
import logging

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func
from sqlalchemy.orm import selectinload

from app.config import settings
from app.models.ride import Ride, RideStatus, RideType
from app.models.fare_bid import FareBid, BidStatus
from app.models.driver import Driver, DriverStatus
from app.models.vehicle import VehicleCategory
from app.models.promo import PromoCode
from app.schemas.ride import RideCreate, FareEstimateRequest, FareEstimateResponse, FareBidCreate
from app.services.map_provider import MapProvider

logger = logging.getLogger(__name__)


class RideService:
    """Ride booking and management service"""
    
    @staticmethod
    def generate_ride_code(length: int = 8) -> str:
        """Generate unique ride code"""
        chars = string.ascii_uppercase + string.digits
        return "".join(secrets.choice(chars) for _ in range(length))
    
    @staticmethod
    def _effective_passenger_count(passenger_count: int, children: int = 0) -> int:
        return max(1, int(passenger_count or 0) + int(children or 0))

    @classmethod
    async def estimate_fare(
        cls,
        db: AsyncSession,
        map_provider: MapProvider,
        data: FareEstimateRequest
    ) -> FareEstimateResponse:
        """Calculate fare estimate for a ride"""
        effective_duration_minutes = data.duration_minutes
        if data.ride_type == RideType.PER_HOUR and not effective_duration_minutes:
            # Keep per-hour booking usable when client omits duration.
            effective_duration_minutes = 60

        # Get vehicle category
        query = select(VehicleCategory).where(
            VehicleCategory.id == data.vehicle_category_id,
            VehicleCategory.is_active == True
        )
        result = await db.execute(query)
        category = result.scalar_one_or_none()
        
        if not category:
            raise ValueError("Invalid vehicle category")
        
        # Calculate distance and duration using map provider
        distance_result = await map_provider.calculate_distance(
            data.pickup.latitude,
            data.pickup.longitude,
            data.dropoff.latitude,
            data.dropoff.longitude
        )
        
        if not distance_result:
            # Fallback to straight-line distance
            distance_km = map_provider.haversine_distance(
                data.pickup.latitude,
                data.pickup.longitude,
                data.dropoff.latitude,
                data.dropoff.longitude
            )
            duration_minutes = int(distance_km * 3)  # Rough estimate
        else:
            distance_km = distance_result.distance_km
            duration_minutes = distance_result.duration_minutes

        if data.ride_type == RideType.PER_HOUR and effective_duration_minutes:
            duration_minutes = effective_duration_minutes
        
        # Calculate fare components
        base_fare = category.base_fare
        distance_fare = int(distance_km * category.per_km_rate)
        time_fare = int(duration_minutes * category.per_minute_rate)
        
        # Apply surge multiplier
        surge_multiplier = category.surge_multiplier
        
        # Calculate subtotal
        subtotal = base_fare + distance_fare + time_fare
        subtotal = int(subtotal * surge_multiplier)
        
        # Apply promo discount
        promo_discount = 0
        if data.promo_code:
            promo = await cls._validate_promo(db, data.promo_code)
            if promo:
                promo_discount = promo.calculate_discount(subtotal)
        
        # Final fare
        estimated_fare = max(subtotal - promo_discount, category.minimum_fare)
        
        return FareEstimateResponse(
            vehicle_category_id=category.id,
            vehicle_category_name=category.display_name,
            estimated_distance_km=round(distance_km, 2),
            estimated_duration_minutes=duration_minutes,
            base_fare=base_fare,
            distance_fare=distance_fare,
            time_fare=time_fare,
            surge_multiplier=surge_multiplier,
            promo_discount=promo_discount,
            estimated_fare=estimated_fare
        )
    
    @classmethod
    async def create_ride(
        cls,
        db: AsyncSession,
        map_provider: MapProvider,
        passenger_id: int,
        data: RideCreate
    ) -> Ride:
        """Create a new ride booking"""
        effective_duration_minutes = data.duration_minutes
        if data.ride_type == RideType.PER_HOUR and not effective_duration_minutes:
            # Default for missing per-hour duration to keep booking flow resilient.
            effective_duration_minutes = 60

        effective_passenger_count = cls._effective_passenger_count(
            data.passenger_count,
            data.children,
        )

        # Estimate fare first
        estimate_request = FareEstimateRequest(
            pickup=data.pickup,
            dropoff=data.dropoff,
            ride_type=data.ride_type,
            vehicle_category_id=data.vehicle_category_id,
            passenger_count=data.passenger_count,
            children=data.children,
            promo_code=data.promo_code,
            duration_minutes=effective_duration_minutes,
        )
        estimate = await cls.estimate_fare(db, map_provider, estimate_request)
        
        # Get route for storing
        route = await map_provider.get_route(
            data.pickup.latitude,
            data.pickup.longitude,
            data.dropoff.latitude,
            data.dropoff.longitude
        )
        
        # Get promo code ID if valid
        promo_code_id = None
        if data.promo_code:
            promo = await cls._validate_promo(db, data.promo_code)
            if promo:
                promo_code_id = promo.id
        
        # Create ride
        ride = Ride(
            ride_code=cls.generate_ride_code(),
            passenger_id=passenger_id,
            vehicle_category_id=data.vehicle_category_id,
            ride_type=data.ride_type,
            status=RideStatus.REQUESTED,
            
            # Pickup
            pickup_address=data.pickup.address or "",
            pickup_latitude=data.pickup.latitude,
            pickup_longitude=data.pickup.longitude,
            pickup_place_id=data.pickup.place_id,
            
            # Dropoff
            dropoff_address=data.dropoff.address or "",
            dropoff_latitude=data.dropoff.latitude,
            dropoff_longitude=data.dropoff.longitude,
            dropoff_place_id=data.dropoff.place_id,
            
            # Return (if round trip)
            return_address=data.return_location.address if data.return_location else None,
            return_latitude=data.return_location.latitude if data.return_location else None,
            return_longitude=data.return_location.longitude if data.return_location else None,
            
            # Scheduling
            is_scheduled=data.is_scheduled,
            scheduled_at=data.scheduled_at,
            
            # Distance & Duration
            estimated_distance_km=estimate.estimated_distance_km,
            estimated_duration_minutes=effective_duration_minutes or estimate.estimated_duration_minutes,
            
            # Fare
            estimated_fare=estimate.estimated_fare,
            base_fare=estimate.base_fare,
            distance_fare=estimate.distance_fare,
            time_fare=estimate.time_fare,
            surge_multiplier=estimate.surge_multiplier,
            promo_discount=estimate.promo_discount,
            promo_code_id=promo_code_id,
            
            # Payment
            payment_method=data.payment_method,
            
            # Commission
            commission_percentage=settings.default_commission_percentage,
            
            # Passenger
            passenger_count=effective_passenger_count,
            passenger_notes=data.passenger_notes,
            
            # Route
            route_polyline=route.polyline if route else None,
            route_data={
                **(route.to_dict() if route else {}),
                "booking_meta": {
                    "children": data.children,
                    "return_scheduled_at": data.return_scheduled_at.isoformat() if data.return_scheduled_at else None,
                    "duration_minutes": effective_duration_minutes,
                    "end_scheduled_at": data.end_scheduled_at.isoformat() if data.end_scheduled_at else None,
                },
            },
        )
        
        db.add(ride)
        await db.commit()
        await db.refresh(ride)
        
        return ride
    
    @classmethod
    async def find_nearby_drivers(
        cls,
        db: AsyncSession,
        latitude: float,
        longitude: float,
        category_id: int,
        radius_km: float = None
    ) -> List[Driver]:
        """Find available drivers near a location"""
        if radius_km is None:
            radius_km = settings.driver_search_radius_km
        
        # Query for available drivers
        query = select(Driver).where(
            and_(
                Driver.status == DriverStatus.APPROVED,
                Driver.is_online == True,
                Driver.is_on_ride == False,
                Driver.current_latitude.isnot(None),
                Driver.current_longitude.isnot(None)
            )
        ).options(
            selectinload(Driver.user),
            selectinload(Driver.current_vehicle)
        )
        
        result = await db.execute(query)
        drivers = result.scalars().all()
        
        # Filter by distance and category
        nearby_drivers = []
        for driver in drivers:
            # Check if driver has vehicle in requested category
            if driver.current_vehicle and driver.current_vehicle.category_id == category_id:
                # Calculate distance
                from app.services.map_provider.base_provider import MapProvider
                distance = MapProvider.haversine_distance(
                    latitude, longitude,
                    driver.current_latitude, driver.current_longitude
                )
                
                if distance <= radius_km:
                    driver._distance_km = distance
                    nearby_drivers.append(driver)
        
        # Sort by distance
        nearby_drivers.sort(key=lambda d: d._distance_km)
        
        return nearby_drivers
    
    @classmethod
    async def submit_bid(
        cls,
        db: AsyncSession,
        driver_id: int,
        data: FareBidCreate
    ) -> FareBid:
        """Submit a fare bid from driver"""
        # Get ride
        query = select(Ride).where(
            Ride.id == data.ride_id,
            Ride.status.in_([RideStatus.REQUESTED, RideStatus.SEARCHING, RideStatus.BIDDING])
        )
        result = await db.execute(query)
        ride = result.scalar_one_or_none()
        
        if not ride:
            raise ValueError("Ride not available for bidding")
        
        # Check if driver already bid
        existing_query = select(FareBid).where(
            FareBid.ride_id == data.ride_id,
            FareBid.driver_id == driver_id,
            FareBid.status == BidStatus.PENDING
        )
        existing_result = await db.execute(existing_query)
        if existing_result.scalar_one_or_none():
            raise ValueError("You already have a pending bid for this ride")
        
        # Create bid
        bid = FareBid(
            ride_id=ride.id,
            driver_id=driver_id,
            bid_amount=data.bid_amount,
            original_fare=ride.estimated_fare,
            bid_percentage=(data.bid_amount / ride.estimated_fare) * 100 if ride.estimated_fare > 0 else 100,
            message=data.message,
            estimated_arrival_minutes=data.estimated_arrival_minutes,
            expires_at=datetime.utcnow() + timedelta(seconds=settings.bid_expiry_seconds)
        )
        
        db.add(bid)
        
        # Update ride status to bidding
        if ride.status == RideStatus.REQUESTED:
            ride.status = RideStatus.BIDDING
        
        await db.commit()
        await db.refresh(bid)
        
        return bid
    
    @classmethod
    async def accept_bid(
        cls,
        db: AsyncSession,
        ride_id: int,
        bid_id: int,
        passenger_id: int
    ) -> Ride:
        """Accept a fare bid"""
        # Get ride
        query = select(Ride).where(
            Ride.id == ride_id,
            Ride.passenger_id == passenger_id,
            Ride.status.in_([RideStatus.BIDDING, RideStatus.REQUESTED])
        )
        result = await db.execute(query)
        ride = result.scalar_one_or_none()
        
        if not ride:
            raise ValueError("Ride not found or not in bidding state")
        
        # Get bid
        bid_query = select(FareBid).where(
            FareBid.id == bid_id,
            FareBid.ride_id == ride_id,
            FareBid.status == BidStatus.PENDING
        ).options(selectinload(FareBid.driver))
        bid_result = await db.execute(bid_query)
        bid = bid_result.scalar_one_or_none()
        
        if not bid:
            raise ValueError("Bid not found or expired")
        
        # Check expiry
        if bid.is_expired:
            bid.status = BidStatus.EXPIRED
            await db.commit()
            raise ValueError("Bid has expired")
        
        # Accept bid
        bid.status = BidStatus.ACCEPTED
        bid.responded_at = datetime.utcnow()
        
        # Update ride
        ride.driver_id = bid.driver_id
        ride.accepted_bid_amount = bid.bid_amount
        ride.status = RideStatus.ACCEPTED
        ride.accepted_at = datetime.utcnow()
        
        # Calculate commission
        ride.commission_amount = int(bid.bid_amount * (ride.commission_percentage / 100))
        ride.driver_earnings = bid.bid_amount - ride.commission_amount
        
        # Mark driver as on ride
        driver = bid.driver
        driver.is_on_ride = True
        
        # Reject other bids
        reject_query = select(FareBid).where(
            FareBid.ride_id == ride_id,
            FareBid.id != bid_id,
            FareBid.status == BidStatus.PENDING
        )
        reject_result = await db.execute(reject_query)
        for other_bid in reject_result.scalars():
            other_bid.status = BidStatus.REJECTED
            other_bid.responded_at = datetime.utcnow()
        
        await db.commit()
        await db.refresh(ride)
        
        return ride
    
    @classmethod
    async def update_ride_status(
        cls,
        db: AsyncSession,
        ride_id: int,
        new_status: RideStatus,
        user_id: int,
        is_driver: bool = False
    ) -> Ride:
        """Update ride status"""
        query = select(Ride).where(Ride.id == ride_id).options(
            selectinload(Ride.driver)
        )
        result = await db.execute(query)
        ride = result.scalar_one_or_none()
        
        if not ride:
            raise ValueError("Ride not found")
        
        # Validate status transition
        valid_transitions = {
            RideStatus.ACCEPTED: [RideStatus.DRIVER_ARRIVED, RideStatus.CANCELLED],
            RideStatus.DRIVER_ARRIVED: [RideStatus.STARTED, RideStatus.CANCELLED],
            RideStatus.STARTED: [RideStatus.COMPLETED, RideStatus.CANCELLED],
        }
        
        if new_status not in valid_transitions.get(ride.status, []):
            raise ValueError(f"Cannot transition from {ride.status} to {new_status}")
        
        # Update status with timestamps
        ride.status = new_status
        
        if new_status == RideStatus.DRIVER_ARRIVED:
            ride.driver_arrived_at = datetime.utcnow()
        elif new_status == RideStatus.STARTED:
            ride.started_at = datetime.utcnow()
        elif new_status == RideStatus.COMPLETED:
            ride.completed_at = datetime.utcnow()
            ride.final_fare = ride.accepted_bid_amount or ride.estimated_fare
            
            # Update driver stats
            if ride.driver:
                ride.driver.is_on_ride = False
                ride.driver.total_rides += 1
                ride.driver.completed_rides += 1
                ride.driver.total_earnings += ride.driver_earnings
                ride.driver.current_balance += ride.driver_earnings
        
        await db.commit()
        await db.refresh(ride)
        
        return ride
    
    @classmethod
    async def get_ride_by_id(
        cls,
        db: AsyncSession,
        ride_id: int,
        user_id: int
    ) -> Optional[Ride]:
        """Get ride by ID with permission check"""
        query = select(Ride).where(Ride.id == ride_id).options(
            selectinload(Ride.passenger),
            selectinload(Ride.driver).selectinload(Driver.user),
            selectinload(Ride.driver).selectinload(Driver.current_vehicle),
            selectinload(Ride.vehicle_category),
            selectinload(Ride.rating)
        )
        result = await db.execute(query)
        ride = result.scalar_one_or_none()
        
        if not ride:
            return None
        
        # Check permission (passenger, driver, or admin)
        # For now, just return the ride
        return ride
    
    @classmethod
    async def _validate_promo(
        cls,
        db: AsyncSession,
        code: str
    ) -> Optional[PromoCode]:
        """Validate promo code"""
        query = select(PromoCode).where(
            PromoCode.code == code,
            PromoCode.is_active == True
        )
        result = await db.execute(query)
        promo = result.scalar_one_or_none()
        
        if promo and promo.is_valid:
            return promo
        return None
