import numpy as np
from scipy.stats import qmc
from datetime import datetime
import os

class CarDatasetSampler:
    """
    Latin Hypercube Sampler for car dataset with continuous and binary variables.
    """
    
    def __init__(self):
        # Define continuous variables and their ranges
        self.continuous_vars = {
            'ride_height_mm': (50, 200),      # Ride height in mm
            'yaw_angle_deg': (0, 15),       # Yaw angle in degrees
            'velocity_kmh': (100, 150),        # Velocity in km/h
        }
        
        # Define binary components (customizable list)
        self.binary_components = [
            'tire_deflectors',
            'mirrors',
            'spoilers'
        ]
        
        # Additional metadata
        self.metadata = {
            'sampling_method': 'Latin Hypercube Sampling',
            'created_by': 'CarDatasetSampler',
            'version': '1.0'
        }
    
    def customize_components(self, component_list):
        """
        Customize the list of binary components to be sampled.
        
        Args:
            component_list (list): List of component names to sample
        """
        self.binary_components = component_list.copy()
    
    def customize_continuous_vars(self, var_dict):
        """
        Customize continuous variables and their ranges.
        
        Args:
            var_dict (dict): Dictionary with variable names as keys and (min, max) tuples as values
        """
        self.continuous_vars.update(var_dict)
    
    def generate_samples(self, n_samples, seed=None):
        """
        Generate Latin hypercube samples for the car dataset.
        
        Args:
            n_samples (int): Number of samples to generate
            seed (int, optional): Random seed for reproducibility
            
        Returns:
            dict: Dictionary containing samples and metadata
        """
        if seed is not None:
            np.random.seed(seed)
        
        # Total number of dimensions
        n_continuous = len(self.continuous_vars)
        n_binary = len(self.binary_components)
        n_dimensions = n_continuous + n_binary
        
        # Generate Latin hypercube samples
        sampler = qmc.LatinHypercube(d=n_dimensions, seed=seed)
        lhs_samples = sampler.random(n=n_samples)
        
        # Initialize results
        samples = []
        
        for i in range(n_samples):
            sample = {}
            sample['sample_id'] = i + 1
            
            # Process continuous variables
            for j, (var_name, (min_val, max_val)) in enumerate(self.continuous_vars.items()):
                # Scale from [0,1] to [min_val, max_val]
                scaled_value = min_val + lhs_samples[i, j] * (max_val - min_val)
                sample[var_name] = round(scaled_value, 3)
            
            # Process binary variables
            for j, component in enumerate(self.binary_components):
                # Convert to binary (0 or 1, then to False/True)
                binary_value = lhs_samples[i, n_continuous + j] > 0.5
                sample[component] = binary_value
            
            samples.append(sample)
        
        # Prepare results
        results = {
            'samples': samples,
            'metadata': {
                **self.metadata,
                'n_samples': n_samples,
                'n_continuous_vars': n_continuous,
                'n_binary_vars': n_binary,
                'continuous_variables': self.continuous_vars,
                'binary_components': self.binary_components,
                'generation_time': datetime.now().isoformat(),
                'seed': seed
            }
        }
        
        return results
    
    def save_to_file(self, results, filename='car_dataset_samples.txt'):
        """
        Save samples and metadata to a text file.
        
        Args:
            results (dict): Results from generate_samples()
            filename (str): Output filename
        """
        with open(filename, 'w') as f:
            # Write header and metadata
            f.write("=" * 80 + "\n")
            f.write("CAR DATASET LATIN HYPERCUBE SAMPLES\n")
            f.write("=" * 80 + "\n\n")
            
            # Write metadata
            f.write("METADATA:\n")
            f.write("-" * 40 + "\n")
            for key, value in results['metadata'].items():
                if key in ['continuous_variables', 'binary_components']:
                    f.write(f"{key}: {value}\n")
                else:
                    f.write(f"{key}: {value}\n")
            f.write("\n")
            
            # Write column headers
            f.write("SAMPLES:\n")
            f.write("-" * 40 + "\n")
            
            # Get all variable names for header
            if results['samples']:
                sample_keys = list(results['samples'][0].keys())
                header = "\t".join(sample_keys)
                f.write(header + "\n")
                f.write("-" * len(header) + "\n")
                
                # Write samples
                for sample in results['samples']:
                    row = []
                    for key in sample_keys:
                        if isinstance(sample[key], bool):
                            row.append("ON" if sample[key] else "OFF")
                        else:
                            row.append(str(sample[key]))
                    f.write("\t".join(row) + "\n")
            
            f.write("\n" + "=" * 80 + "\n")
            f.write(f"Total samples generated: {len(results['samples'])}\n")
            f.write(f"File saved: {filename}\n")
            f.write("=" * 80 + "\n")
    
    def print_sample_summary(self, results):
        """
        Print a summary of the generated samples.
        
        Args:
            results (dict): Results from generate_samples()
        """
        print("\n" + "=" * 60)
        print("SAMPLE GENERATION SUMMARY")
        print("=" * 60)
        print(f"Number of samples: {results['metadata']['n_samples']}")
        print(f"Continuous variables: {results['metadata']['n_continuous_vars']}")
        print(f"Binary variables: {results['metadata']['n_binary_vars']}")
        print(f"Generation time: {results['metadata']['generation_time']}")
        print(f"Seed: {results['metadata']['seed']}")
        
        print("\nContinuous variable ranges:")
        for var, (min_val, max_val) in results['metadata']['continuous_variables'].items():
            print(f"  {var}: [{min_val}, {max_val}]")
        
        print("\nBinary components:")
        for component in results['metadata']['binary_components']:
            print(f"  {component}")
        
        print("\nFirst 3 samples:")
        for i in range(min(3, len(results['samples']))):
            print(f"\nSample {i+1}:")
            for key, value in results['samples'][i].items():
                if key != 'sample_id':
                    if isinstance(value, bool):
                        print(f"  {key}: {'ON' if value else 'OFF'}")
                    else:
                        print(f"  {key}: {value}")
        print("=" * 60)


def main():
    """
    Example usage of the CarDatasetSampler
    """
    # Create sampler instance
    sampler = CarDatasetSampler()
    
    # Optional: Customize the components list
    custom_components = [
        'tire_deflectors',
        'mirrors',
        'spoilers'
    ]
    sampler.customize_components(custom_components)
    
    # Optional: Customize continuous variables
    custom_vars = {
        'ride_height_mm': (40, 180),
        'yaw_angle_deg': (0, 20)
    }
    sampler.customize_continuous_vars(custom_vars)
    
    # Generate samples
    n_samples = 1500
    results = sampler.generate_samples(n_samples=n_samples, seed=42)
    
    # Print summary
    sampler.print_sample_summary(results)
    
    # Save to file
    output_file = 'car_dataset_samples.txt'
    sampler.save_to_file(results, output_file)
    
    print(f"\nSamples saved to: {output_file}")
    print(f"Total samples generated: {len(results['samples'])}")


if __name__ == "__main__":
    main()