#!/usr/bin/env python3
"""
filter_eu_networks.py

Reads two CSV files:
  1. A country reference file (countries.csv) mapping numeric IDs to country info
  2. An IP network file (networks.csv) where each row references country IDs

Outputs a new CSV listing only networks whose "assigned country" is in the EU,
along with the human-readable country names for both assigned and registration countries.

Usage:
    python3 filter_eu_networks.py countries.csv networks.csv output.csv
"""

import csv       # Built-in module for reading/writing CSV files
import sys       # Built-in module for reading command-line arguments and exiting


def load_countries(filepath):
    """
    Read the country reference CSV and build two data structures:

      1. country_by_id   : dict mapping numeric ID (str) -> country info dict
      2. eu_ids          : set of numeric IDs (str) that belong to the EU

    The country CSV columns are (no header row):
        0: numeric_id
        1: language code   (e.g. "en")
        2: continent code  (e.g. "EU", "AS")
        3: continent name  (e.g. "Europe")
        4: country code    (e.g. "FI")
        5: country name    (e.g. "Finland")
        6: is_in_eu flag   (1 = EU member, 0 = not)

    Returns:
        country_by_id (dict), eu_ids (set)
    """

    country_by_id = {}  # Will hold ALL countries keyed by their numeric ID
    eu_ids = set()      # Will hold only the IDs of EU countries

    # Open the file for reading; encoding='utf-8' handles special characters
    with open(filepath, newline='', encoding='utf-8') as f:

        # csv.reader turns each line into a list of strings split by comma
        reader = csv.reader(f)

        for row in reader:
            # Skip empty lines that might exist at the end of the file
            if not row:
                continue

            # Unpack the columns we care about by index
            numeric_id     = row[0]   # e.g. "660013"
            continent_code = row[2]   # e.g. "EU"
            country_name   = row[5]   # e.g. "Finland"

            # Store all country info in the lookup dict
            # The key is the numeric ID as a string (matches what networks.csv uses)
            country_by_id[numeric_id] = {
                'continent': continent_code,
                'name':      country_name,
            }

            # If the continent is EU, remember this ID for fast lookup later
            if continent_code == 'EU':
                eu_ids.add(numeric_id)

    return country_by_id, eu_ids


def process_networks(networks_filepath, country_by_id, eu_ids, output_filepath):
    """
    Read the networks CSV and write matching rows to the output CSV.

    The networks CSV columns are (no header row):
        0: IP network in CIDR notation  (e.g. "185.39.184.0/22")
        1: assigned country ID          (e.g. "660013")  <- main filter field
        2: registration country ID      (e.g. "660013")  <- may differ
        3-6: other fields we don't need

    Only rows where column 1 (assigned country) is in eu_ids are written out.

    Output CSV columns:
        network, assigned_country_name, registration_country_name
    """

    # Keep a counter so we can report how many networks were found
    found_count = 0

    # Open the input file for reading and the output file for writing simultaneously
    with open(networks_filepath, newline='', encoding='utf-8') as infile, \
         open(output_filepath, 'w', newline='', encoding='utf-8') as outfile:

        reader = csv.reader(infile)

        # csv.writer will handle quoting automatically (e.g. for names with commas)
        writer = csv.writer(outfile)

        # Write a header row so the output is self-explanatory
        writer.writerow(['network', 'assigned_country', 'registration_country'])

        for row in reader:
            # Guard against malformed / short rows
            if len(row) < 3:
                continue

            network             = row[0]   # e.g. "185.39.184.0/22"
            assigned_id         = row[1]   # numeric ID of the country this block is assigned to
            registration_id     = row[2]   # numeric ID of the registering country (may differ)

            # --- The key filter: is the assigned country in the EU? ---
            # The 'in' operator checks set membership, which is O(1) (very fast)
            if assigned_id not in eu_ids:
                continue   # Skip non-EU rows and move to the next iteration

            # Look up the human-readable name for the assigned country
            # dict.get(key, default) returns the default if the key is missing
            assigned_info = country_by_id.get(assigned_id)
            assigned_name = assigned_info['name'] if assigned_info else f"Unknown({assigned_id})"

            # Look up the human-readable name for the registration country
            # The registration country might not be in our reference file, so we
            # fall back to a descriptive placeholder instead of crashing
            reg_info = country_by_id.get(registration_id)
            reg_name = reg_info['name'] if reg_info else f"Unknown({registration_id})"

            # Write one output row: network, assigned country name, registration country name
            writer.writerow([network, assigned_name, reg_name])
            found_count += 1   # Increment our counter

    # Report results to the terminal (stderr is conventional for status messages)
    print(f"Done! Found {found_count} EU network(s). Output written to: {output_filepath}",
          file=sys.stderr)


def main():
    """
    Entry point: validate command-line arguments and kick off processing.

    Expected usage:
        python3 filter_eu_networks.py <countries_csv> <networks_csv> <output_csv>
    """

    # sys.argv is a list where:
    #   sys.argv[0] = the script name itself
    #   sys.argv[1] = first argument, etc.
    if len(sys.argv) != 4:
        print("Usage: python3 filter_eu_networks.py countries.csv networks.csv output.csv",
              file=sys.stderr)
        sys.exit(1)   # Exit with a non-zero code to signal an error

    countries_file = sys.argv[1]
    networks_file  = sys.argv[2]
    output_file    = sys.argv[3]

    print(f"Loading country data from: {countries_file}", file=sys.stderr)
    country_by_id, eu_ids = load_countries(countries_file)
    print(f"  Loaded {len(country_by_id)} countries, {len(eu_ids)} are in EU", file=sys.stderr)

    print(f"Processing networks from: {networks_file}", file=sys.stderr)
    process_networks(networks_file, country_by_id, eu_ids, output_file)


# This block ensures main() only runs when the script is executed directly,
# not when it's imported as a module by another script.
if __name__ == '__main__':
    main()