#!/usr/bin/env python3
"""
Autopkgtest to compare binary packages built from this source against
NVIDIA's official archive packages.

This script validates that our repacked CUDA packages match NVIDIA's official
packages in terms of dependencies, file contents, and structure.
"""

import hashlib
import platform
import argparse
import glob
import gzip
import logging
import re
import shutil
import subprocess
import sys
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urljoin

import requests
from debian import deb822  # type: ignore[attr-defined]


def normalize_unicode(text: str) -> str:
    """Normalize common non-ASCII characters to ASCII equivalents for comparison."""
    import unicodedata

    replacements = {
        "\u00a0": " ",  # non-breaking space
        "\u2018": "'",  # left single quotation mark
        "\u2019": "'",  # right single quotation mark
        "\u201c": '"',  # left double quotation mark
        "\u201d": '"',  # right double quotation mark
        "\u2013": "-",  # en dash
        "\u2014": "--",  # em dash
        "\u2026": "...",  # ellipsis
    }
    for char, replacement in replacements.items():
        text = text.replace(char, replacement)
    return text


@dataclass
class ControlField:
    """Represents a package dependency."""

    name: str
    version: Optional[str] = None
    operator: Optional[str] = None  # e.g., ">=", "=", "<<"

    @classmethod
    def parse(cls, dep_string: str) -> "ControlField":
        """Parse a dependency string like 'package (>= 1.2.3)'."""
        match = re.match(r"([^\s(]+)\s*(?:\(([><=]+)\s*([^)]+)\))?", dep_string.strip())
        if match:
            name, operator, version = match.groups()
            return cls(name=name, version=version, operator=operator)
        return cls(name=dep_string.strip())

    def __str__(self) -> str:
        if self.version and self.operator:
            return f"{self.name} ({self.operator} {self.version})"
        return self.name


@dataclass
class FieldComparison:
    """Results of comparing dependencies between two packages."""

    missing_items: Set[str] = field(default_factory=set)
    extra_items: Set[str] = field(default_factory=set)
    version_mismatches: Dict[str, Tuple[str, str]] = field(default_factory=dict)

    def has_failures(self) -> bool:
        """Check if there are any critical failures."""
        return bool(self.missing_items or self.version_mismatches)

    def has_differences(self) -> bool:
        """Check if there are any differences at all."""
        return bool(self.missing_items or self.extra_items or self.version_mismatches)


@dataclass
class PackageInfo:
    """Information about a package."""

    name: str
    version: str
    deb_path: Path
    dependencies: List[ControlField] = field(default_factory=list)

    @property
    def upstream_version(self) -> str:
        """Extract upstream version (remove Debian revision)."""
        return self.version.rsplit("-", 1)[0]


@dataclass
class Config:
    """Configuration for the comparison test."""

    ubuntu_version: str = "24.04"
    work_dir: Path = Path.cwd() / "nvidia_comparison"
    nvidia_repo_base: str = "https://developer.download.nvidia.com/compute/cuda/repos"

    def __post_init__(self):
        self.repacked_dir = self.work_dir / "repacked"
        self.nvidia_dir = self.work_dir / "nvidia"
        self.extract_dir = self.work_dir / "extracted"
        self.diff_dir = self.work_dir / "diffs"

    def setup_directories(self):
        """Create all necessary working directories."""
        for dir_path in [
            self.repacked_dir,
            self.nvidia_dir,
            self.extract_dir,
            self.diff_dir,
        ]:
            dir_path.mkdir(parents=True, exist_ok=True)

    def cleanup(self):
        """Remove working directory."""
        if self.work_dir.exists():
            shutil.rmtree(self.work_dir)


class NvidiaRepository:
    """Handles interaction with NVIDIA's package repository."""

    UBUNTU_REPO_MAP = {
        "24.04": "ubuntu2404",
        "22.04": "ubuntu2204",
        "20.04": "ubuntu2004",
    }

    ARCH_MAP = {
        "x86_64": "x86_64",
        "arm64": "sbsa",
        "aarch64": "sbsa",
    }

    def __init__(self, config: Config, logger: logging.Logger):
        self.config = config
        self.logger = logger

        # package index contains `[pkg-name: [(package_version, package_filename), ...], ...]`
        self.package_index: Dict[str, List[Tuple[str, str]]] = {}

        repo_name = self.UBUNTU_REPO_MAP.get(self.config.ubuntu_version)
        if not repo_name:
            raise ValueError(
                f"Unsupported Ubuntu version: {self.config.ubuntu_version}"
            )

        nvidia_arch = self.ARCH_MAP.get(platform.machine())
        if not nvidia_arch:
            raise ValueError(f"Unsupported architecture: {platform.machine()}")

        self.repo_url = f"{self.config.nvidia_repo_base}/{repo_name}/{nvidia_arch}"

    def download_package_index(self) -> None:
        """Download and parse NVIDIA's package index."""
        self.logger.info(f"Downloading package index from: {self.repo_url}")

        packages_url = urljoin(self.repo_url + "/", "Packages.gz")
        response = requests.get(packages_url, timeout=30)
        response.raise_for_status()

        # Decompress and parse while keeping all versions per package.
        packages_content = gzip.decompress(response.content).decode("utf-8")
        self.package_index.clear()
        for paragraph in deb822.Packages.iter_paragraphs(
            packages_content, use_apt_pkg=False
        ):
            package_name = paragraph.get("Package")
            package_version = paragraph.get("Version")
            package_filename = paragraph.get("Filename")

            if not package_name or not package_version or not package_filename:
                continue

            self.package_index.setdefault(package_name, []).append(
                (package_version, package_filename)
            )

    def find_package(self, name: str, upstream_version: str) -> Optional[str]:
        """Find NVIDIA package by name and version."""
        candidates = self.package_index.get(name, [])
        if not candidates:
            self.logger.info(f"Package {name} not found in NVIDIA index")
            return None

        # Prefer exact upstream version match (e.g. 13.2.1-*), not first/last duplicate.
        version_prefix = f"{upstream_version}-"
        version_matches = sorted(
            [
                filename
                for version, filename in candidates
                if version == upstream_version or version.startswith(version_prefix)
            ]
        )
        if version_matches:
            if len(version_matches) > 1:
                self.logger.warning(
                    f"Multiple exact version matches found for {name} {upstream_version}; using first: {version_matches[0]}"
                )
            return version_matches[0]

        # if it wasn't found, log a few debugging helpers
        self.logger.info(
            f"Unable to find {name} with upstream version {upstream_version} in NVIDIA package index."
            "\nWill fail: no exact match found, checking for partial matches to help debugging:"
        )
        self.logger.info(
            f"Looking for package key: {name}. If it's nsight-*, you may need to update the actual package name with debian/nsight-*-version"
        )
        for version, filename in candidates:
            if name in filename:
                self.logger.info(
                    f"Found package with matching name/version candidate: {name} {version} -> key: {filename}"
                )

        return None

    def download_package(self, filename: str) -> Path:
        """Download a package from NVIDIA repository."""
        output_path = self.config.nvidia_dir / Path(filename).name

        if output_path.exists():
            self.logger.info(f"Package already downloaded: {output_path.name}")
            return output_path

        package_url = urljoin(self.repo_url + "/", filename)
        self.logger.info(f"Downloading: {output_path.name}")

        response = requests.get(package_url, timeout=60, stream=True)
        response.raise_for_status()

        with open(output_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)

        return output_path


class PackageManager:
    """Handles local package operations."""

    def __init__(self, logger: logging.Logger):
        self.logger = logger

    def get_installed_packages(self, control_file: Path) -> List[str]:
        """Get list of binary packages from debian/control."""
        result = subprocess.run(
            ["grep-dctrl", "-n", "-s", "Package", "", str(control_file)],
            capture_output=True,
            text=True,
            check=True,
        )
        return [pkg for pkg in result.stdout.strip().split("\n") if pkg]

    def get_package_info(self, package_name: str) -> PackageInfo:
        """Get information about an installed package."""
        version = self.get_package_field(package_name, "Version")
        dependencies = self.get_package_dependencies(package_name)

        return PackageInfo(
            name=package_name,
            version=version,
            deb_path=Path(),  # Will be set after repacking
            dependencies=dependencies,
        )

    def repack_package(self, package_name: str, output_dir: Path) -> Path:
        """Repack an installed package."""
        self.logger.info(f"Repacking installed package: {package_name}")

        # Check if package is installed
        result = subprocess.run(["dpkg", "-l", package_name], capture_output=True)
        if result.returncode != 0:
            raise ValueError(f"Package {package_name} is not installed")

        # Repack
        subprocess.run(
            ["fakeroot", "dpkg-repack", package_name],
            cwd=output_dir,
            capture_output=True,
            check=True,
        )

        # Find repacked file
        deb_files = list(output_dir.glob(f"{package_name}_*.deb"))
        if not deb_files:
            raise FileNotFoundError(f"Repacked .deb not found for {package_name}")

        return max(deb_files, key=lambda p: p.stat().st_mtime)

    def get_package_field(self, package_name: str, field: str) -> str:
        """Get a field value from an installed package."""
        result = subprocess.run(
            ["dpkg-query", "-W", f"-f=${{{field}}}", package_name],
            capture_output=True,
            text=True,
            check=True,
        )
        return result.stdout.strip()

    def get_package_dependencies(self, package_name: str) -> List[ControlField]:
        """Get dependencies of an installed package."""
        deps_str = self.get_package_field(package_name, "Depends")
        if not deps_str:
            return []

        deps = []
        for dep in deps_str.split(","):
            dep = dep.strip()
            if dep:
                deps.append(ControlField.parse(dep))

        return deps

    def get_package_provides(self, package_name: str) -> Set[str]:
        """Get Provides field of an installed package."""
        provides_str = self.get_package_field(package_name, "Provides")
        return set(p.strip() for p in provides_str.split(",") if p.strip())

    @staticmethod
    def get_deb_field(deb_path: Path, field: str) -> str:
        """Get a field value from a .deb file."""
        result = subprocess.run(
            ["dpkg-deb", "-f", str(deb_path), field], capture_output=True, text=True
        )
        return result.stdout.strip() if result.returncode == 0 else ""

    @staticmethod
    def extract_deb(deb_path: Path, output_dir: Path):
        """Extract a .deb file."""
        output_dir.mkdir(parents=True, exist_ok=True)
        subprocess.run(
            ["dpkg-deb", "-R", str(deb_path), str(output_dir)],
            check=True,
            capture_output=True,
        )


class PackageComparator:
    """Main class for comparing packages."""

    # Files that are expected to differ
    IGNORE_PATTERNS = [
        r"/usr/share/doc/.*/changelog\.Debian\.gz",
        r"/usr/share/doc/.*/copyright",
        r"/usr/share/lintian/overrides/",
        r"/opt/nvidia/nsight-.*/.*/documentation/",
        r"/opt/nvidia/nsight-.*/.*/docs/",
        r"/opt/nvidia/nsight-.*/.*/EULA",
        r"/opt/nvidia/nsight-.*/.*/LICENSE",
        r"/opt/nvidia/nsight-.*/.*/.*/python/lib/.*",
        r"nsight-systems-2025.5.2.*DEBIAN/postinst",
        r"nsight-systems-2025.5.2.*DEBIAN/prerm",
    ]

    def __init__(self, config: Config, logger: logging.Logger):
        self.config = config
        self.logger = logger
        self.nvidia_repo = NvidiaRepository(config, logger)
        self.package_manager = PackageManager(logger)
        # self.report_generator = ReportGenerator(logger)

    def parse_control_field(self, deps_str: str) -> List[ControlField]:
        """Parse dependency string into list of ControlField objects."""
        if not deps_str:
            return []

        deps = []
        for dep in deps_str.split(","):
            dep = dep.strip()
            if dep:
                deps.append(ControlField.parse(dep))

        return deps

    def is_extracted_debian_matching(self, nvidia_dir: Path, our_dir: Path) -> bool:
        """Compare control files of extracted packages."""
        # the paths are already 'DEBIAN' folders, so we can directly compare the files

        if not nvidia_dir.exists() or not our_dir.exists():
            logging.warning(
                f"DEBIAN folder not found for {nvidia_dir} or {our_dir}, failing comparison"
            )
            raise RuntimeError(f"DEBIAN folder not found for {nvidia_dir.name}")

        # List of maintainer script files to compare
        maintainer_scripts = ["preinst", "postinst", "prerm", "postrm"]

        for script in maintainer_scripts:
            nvidia_script = nvidia_dir / script
            our_script = our_dir / script

            _ = self.are_extracted_packages_matching(nvidia_script, our_script)

        return True

    def are_extracted_packages_matching(
        self,
        nvidia_dir: Path,
        our_dir: Path,
    ) -> bool:
        """Compare file contents and permissions of extracted packages."""
        import os

        # Run diff for content/structure
        result = subprocess.run(
            ["diff", "-rqN", "--no-dereference", str(nvidia_dir), str(our_dir)],
            capture_output=True,
            text=True,
        )
        try:
            result_large = subprocess.run(
                ["diff", "-ruN", "--no-dereference", str(nvidia_dir), str(our_dir)],
                capture_output=True,
                text=True,
            )
        except:
            self.logger.warning(
                "Large diff failed, likely due to too many differences, will proceed with short diff for filtering"
            )
            result_large = result  # if the diff fails due to too many differences, we can still use the short output for filtering

        diff_lines = result.stdout.split("\n")
        filtered_lines = self.filter_expected_diffs(diff_lines)

        filtered_lines = self.filter_cudla_exception(filtered_lines)

        filter_candidate, errors = (
            self.get_filter_candidates_from_expanded_diff_if_compressed(filtered_lines)
        )

        self.logger.debug("===> Filtered diff lines: ")
        self.logger.debug("\n".join(filtered_lines))
        self.logger.debug("===> Filter candidates from expanded diff: ")
        self.logger.debug("\n".join(filter_candidate))
        self.logger.debug("Errors from expanded diff: ")
        self.logger.debug("\n".join(errors))

        for line in filter_candidate:
            self.logger.info(f"Filtering candidate line: {line}")
            if line in filtered_lines:
                filtered_lines.remove(line)

        filtered_lines.extend(errors)

        permission_mismatches = []

        # Walk both trees and compare permissions for all files present in both
        for dirpath, _, filenames in os.walk(nvidia_dir):
            rel_dir = os.path.relpath(dirpath, nvidia_dir)
            our_dirpath = os.path.join(our_dir, rel_dir)
            for fname in filenames:
                nvidia_f = os.path.join(dirpath, fname)
                our_f = os.path.join(our_dirpath, fname)
                if os.path.exists(our_f):
                    nvidia_mode = (
                        os.stat(nvidia_f, follow_symlinks=False).st_mode & 0o7777
                    )
                    our_mode = os.stat(our_f, follow_symlinks=False).st_mode & 0o7777
                    if nvidia_mode != our_mode:
                        permission_mismatches.append(
                            f"Permission mismatch: {os.path.relpath(nvidia_f, nvidia_dir)} NVIDIA: {oct(nvidia_mode)} Ours: {oct(our_mode)}"
                        )

        if filtered_lines or permission_mismatches:
            logging.warning(
                f"!!!!!!!! File or permission differences found !!!!!!!!!!!!!!!"
            )
            for line in filtered_lines:
                logging.warning(f"  {line}")
            for line in permission_mismatches:
                logging.warning(f"  {line}")
            logging.warning(f"========== unfiltered diff ===============")
            logging.warning(result_large.stdout)
            logging.warning(f"==========================================")
            raise RuntimeError(f"File or permission differences found, fatal error.")

        return True

    def filter_cudla_exception(self, diff_lines: List[str]) -> List[str]:
        """
        Filter out cudla expected differences from diff output.

        cudla is packaged differently between the redistrib source and the actual
        debs from NVIDIA's repo. Hence the dates are different, hence the diff fails.
        We mitigate that by comparing checksums.
        """

        filtered = []
        cudla_pattern = r".*/libcudla.so.1.0.0"
        for line in diff_lines:
            if re.search(cudla_pattern, line):
                self.logger.info(f"Filtering cudla-related difference: {line}")
                # compare the md5sums of the two files, if they match, we can ignore this difference, as it's likely just a timestamp or metadata difference in the archive. if they don't match, we should keep the difference, as it might indicate a real issue with cudla.
                match = re.match(r"^Files\s+(.+?)\s+and\s+(.+?)\s+differ$", line)
                if not match:
                    self.logger.info(
                        f"Error, unable to extract file paths from diff line {line}"
                    )
                    raise RuntimeError(
                        f"Unable to extract file paths from diff line {line}"
                    )

                file_1 = match.group(1)
                file_2 = match.group(2)
                self.logger.debug(f"Extracted files: \n - {file_1}\n - {file_2}")
                md5_1 = hashlib.md5(open(file_1, "rb").read()).hexdigest()
                md5_2 = hashlib.md5(open(file_2, "rb").read()).hexdigest()
                if md5_1 == md5_2:
                    self.logger.info(
                        f"MD5 checksums match for {file_1} and {file_2}, ignoring this difference"
                    )
                else:
                    self.logger.warning(
                        f"MD5 checksums do NOT match for {file_1} and {file_2}, keeping this difference"
                    )
                    filtered.append(line)

            filtered.append(line)

        return filtered

    def filter_expected_diffs(self, diff_lines: List[str]) -> List[str]:
        """Filter out expected differences from diff output."""
        filtered = []

        for line in diff_lines:
            if not line.strip():
                continue

            # Check if line matches any ignore pattern
            if any(re.search(pattern, line) for pattern in self.IGNORE_PATTERNS):
                continue

            filtered.append(line)

        return filtered

    def uncompress_file_to_dir(
        self, file_path: Path, output_dir: Path, compression_type: str
    ):
        """Uncompress a file to a directory based on its compression type."""
        output_dir.mkdir(parents=True, exist_ok=True)
        if compression_type == "tar.gz":
            subprocess.run(
                ["tar", "xf", str(file_path), "-C", str(output_dir)],
                check=True,
                capture_output=True,
            )
        elif compression_type == "zip":
            subprocess.run(
                ["unzip", str(file_path), "-d", str(output_dir)],
                check=True,
                capture_output=True,
            )
        else:
            raise ValueError(f"Unsupported compression type: {compression_type}")

    def get_filter_candidates_from_expanded_diff_if_compressed(
        self, diff_lines: List[str]
    ) -> Tuple[List[str], List[str]]:
        """
        Filter out expected differences from diff output.
        This function compares compressed files to see if they are containing the same stuff
        """
        errors = []
        candidates = []

        for idx, line in enumerate(diff_lines):
            self.logger.info(f"Trying to filter even MORE, line: \n{line}")
            # structure is 'Files /path/to/file1 and /path/to/file2 differ'
            match = re.match(r"^Files\s+(.+?)\s+and\s+(.+?)\s+differ$", line)
            if not match:
                self.logger.info("Not a candidate, keeping line")
                continue

            file_1 = match.group(1)
            file_2 = match.group(2)
            self.logger.info(f"Extracted files: \n - {file_1}\n - {file_2}")

            if file_1.endswith(".tar.gz") or file_1.endswith(".zip"):
                self.logger.info("...")
                output_dir_1 = self.config.diff_dir / f"{Path(file_1).name}.{idx}.file1"
                output_dir_2 = self.config.diff_dir / f"{Path(file_2).name}.{idx}.file2"
                if file_1.endswith(".tar.gz"):
                    # uncompress both files and compare their contents, ignoring timestamps
                    # as usual put them in a specific folder
                    self.uncompress_file_to_dir(Path(file_1), output_dir_1, "tar.gz")
                    self.uncompress_file_to_dir(Path(file_2), output_dir_2, "tar.gz")

                elif file_1.endswith(".zip"):
                    self.uncompress_file_to_dir(Path(file_1), output_dir_1, "zip")
                    self.uncompress_file_to_dir(Path(file_2), output_dir_2, "zip")

                # only then run the diff
                result = subprocess.run(
                    [
                        "diff",
                        "-rqN",
                        "--no-dereference",
                        str(output_dir_1),
                        str(output_dir_2),
                    ],
                    capture_output=True,
                    text=True,
                )
                archive_diff_lines = result.stdout.splitlines()
                rc = result.returncode
                suberrors = []
                if archive_diff_lines:
                    self.logger.info(
                        f"Number of lines in tar diff: {len(archive_diff_lines)}"
                    )
                    for diff_line in archive_diff_lines:
                        if diff_line != "":
                            logging.info(f"Diff line: {diff_line}")
                            suberrors.append(diff_line)
                errors.extend(suberrors)
                if len(suberrors) == 0 and rc == 0:
                    candidates.append(line)
                # in case there is no diff lines but the diff still fails
                elif rc == 2:
                    self.logger.warning(
                        f"Diff of archive contents returned non-zero exit code: {rc}\n"
                        f"For {file_1} vs {file_2}, stderr:\n{result.stderr or ''}"
                    )
                    continue

            else:
                self.logger.info("Not a candidate, keeping line")

        return candidates, errors

    def prepare_package_for_comparison(self, package_name: str):
        """Compare a single package against NVIDIA's version."""
        self.logger.info(f"Processing: {package_name}")

        # Get package info
        pkg_info = self.package_manager.get_package_info(package_name)
        self.logger.info(f"  Version: {pkg_info.version}")
        self.logger.info(f"  Upstream: {pkg_info.upstream_version}")

        # Repack our package
        our_deb = self.package_manager.repack_package(
            package_name, self.config.repacked_dir
        )

        # Find NVIDIA package
        nvidia_filename = self.nvidia_repo.find_package(
            package_name, pkg_info.upstream_version
        )

        if not nvidia_filename:
            self.logger.warning(f"  NVIDIA package not found, skipping")
            raise RuntimeError(
                f"NVIDIA package not found for {package_name} version {pkg_info.upstream_version}"
            )

        # Download NVIDIA package
        nvidia_deb = self.nvidia_repo.download_package(nvidia_filename)

        self.logger.info(f"  NVIDIA: {nvidia_deb.name}")
        self.logger.info(f"  Ours:   {our_deb.name}")

        return our_deb, nvidia_deb

    def compare_package(
        self, package_name: str, our_deb: Path, nvidia_deb: Path
    ) -> bool:

        # Extract packages
        nvidia_extract = self.config.extract_dir / f"{package_name}_nvidia"
        our_extract = self.config.extract_dir / f"{package_name}_ours"

        self.package_manager.extract_deb(nvidia_deb, nvidia_extract)
        self.package_manager.extract_deb(our_deb, our_extract)

        nv_folder_list = glob.glob(str(nvidia_extract / "*"))
        ca_folder_list = glob.glob(str(our_extract / "*"))

        # extract basenames
        nv_folder_list = [Path(f).name for f in nv_folder_list]
        ca_folder_list = [Path(f).name for f in ca_folder_list]

        folders = set(nv_folder_list) | set(ca_folder_list)

        folders = folders - set(
            ["DEBIAN"]
        )  # ignore control folder, as it can contain expected differences

        for folder in folders:
            # Compare files
            _ = self.are_extracted_packages_matching(
                nvidia_extract / folder,
                our_extract / folder,
            )

        # Compare DEBIAN maintainer scripts specifically
        # _ = self.is_extracted_debian_matching(
        #     nvidia_extract / "DEBIAN",
        #     our_extract / "DEBIAN",
        # )
        # simply return true, because all false paths raise exceptions
        return True

    def compare_control_field(
        self,
        field: str,
        our_deb: Path,
        nvidia_deb: Path,
    ) -> FieldComparison:
        """Compare dependencies between NVIDIA and our package."""

        if field not in ["Depends", "Provides"]:
            raise ValueError(f"Unsupported field for comparison: {field}")

        nvidia_deps_str = PackageManager.get_deb_field(nvidia_deb, field)
        our_deps_str = PackageManager.get_deb_field(our_deb, field)

        logging.info(f"NVIDIA dependencies: {nvidia_deps_str}")
        logging.info(f"Our dependencies: {our_deps_str}")

        nvidia_deps = self.parse_control_field(nvidia_deps_str)
        our_deps = self.parse_control_field(our_deps_str)

        nvidia_dict = {d.name: d for d in nvidia_deps}
        our_dict = {d.name: d for d in our_deps}

        comparison = FieldComparison()

        # Find missing and extra dependencies
        nvidia_names = set(nvidia_dict.keys())
        our_names = set(our_dict.keys())

        comparison.missing_items = nvidia_names - our_names

        # special case for libgcc1. if it's missing but we have libgcc-s1, it's fine, as libgcc-s1 provides libgcc1
        if "libgcc1" in comparison.missing_items and "libgcc-s1" in our_names:
            logging.warning(
                f"libgcc1 is missing but libgcc-s1 is present, assuming compatibility"
            )
            comparison.missing_items.remove("libgcc1")

        comparison.extra_items = our_names - nvidia_names

        # Check version mismatches
        for name in nvidia_names & our_names:
            nvidia_dep = nvidia_dict[name]
            our_dep = our_dict[name]

            if nvidia_dep.version and our_dep.version:
                if nvidia_dep.version != our_dep.version:
                    comparison.version_mismatches[name] = (
                        nvidia_dep.version,
                        our_dep.version,
                    )

        return comparison

    def compare_license(self, package_name: str) -> bool:
        """Compare license files between NVIDIA and our package."""
        # Extract packages
        nvidia_dir = self.config.extract_dir / f"{package_name}_nvidia"
        our_dir = self.config.extract_dir / f"{package_name}_ours"

        if not nvidia_dir.exists() or not our_dir.exists():
            logging.warning(
                f"Extracted directories not found for {package_name}, failing license comparison"
            )
            return False

        # If NVIDIA's deb doesn't have a copyright file, we consider it a pass, as we can't compare
        nv_copyright = nvidia_dir / f"usr/share/doc/{package_name}/copyright"
        if not nv_copyright.exists():
            logging.warning(
                f"No license found for {package_name}, assuming this is on purpose"
            )
            return True
        ca_copyright = our_dir / f"usr/share/doc/{package_name}/copyright"

        self.logger.debug(
            f"Opening {nv_copyright} and {ca_copyright} for comparison..."
        )
        try:
            with open(nv_copyright) as f:
                nv_content = f.readlines()
        except UnicodeDecodeError as e:
            logging.warning(f"Failed to read NVIDIA copyright file: {e}")
            # using cp1252 because nsight-compute has weird characters in its copyright file.
            with open(nv_copyright, encoding="cp1252") as f:
                nv_content = f.readlines()

        with open(ca_copyright) as f:
            ca_content = f.readlines()

        # extract nv license stanza from ca_content
        idx = 0
        start_idx = 0
        end_idx = 0
        for idx, line in enumerate(ca_content):
            if line.startswith("License: NVIDIA-CUDA-Proprietary"):
                start_idx = idx
            if start_idx != 0 and line.strip() == "":
                end_idx = idx
                break

            # if end of file, set end_idx to length of content
            if start_idx != 0 and idx == len(ca_content) - 1:
                end_idx = len(ca_content)

        if start_idx == 0 or end_idx == 0:
            logging.warning(
                f"Unable to find the stanza in Canonical's copyright file, start_idx: {start_idx}, end_idx: {end_idx}"
            )
        ca_content = ca_content[start_idx + 1 : end_idx]

        # compare contents, ignoring whitespace, empty lines, and unicode variations
        nv_lines = [normalize_unicode(l.strip()) for l in nv_content if l.strip()]
        ca_lines = [normalize_unicode(l.strip()) for l in ca_content if l.strip()]
        # remove lines with just a dot, as they are used as separators in Debian copyright files
        ca_lines = [l for l in ca_lines if l != "."]

        diff = False
        if len(nv_lines) != len(ca_lines):
            logging.warning(
                f"License content line count mismatch for {package_name}, NVIDIA has {len(nv_lines)} lines, ours has {len(ca_lines)} lines"
            )
            diff = True
        else:
            for i in range(len(nv_lines)):
                if nv_lines[i] != ca_lines[i]:
                    logging.warning(
                        f"License content mismatch for {package_name} at line {i+1}:\n  NVIDIA: {nv_lines[i]}\n  Ours:   {ca_lines[i]}"
                    )
                    logging.warning(f"Diff L{i+1}: {nv_lines[i]} - {ca_lines[i]}")
                    diff = True

        if diff:
            logging.warning(
                f"License content differences found for {package_name}, please review the logs and fix the issues before proceeding with packaging."
            )
            # print both list
            logging.warning(f"NVIDIA license content:\n" + "\n".join(nv_lines))
            logging.warning(f"Our license content:\n" + "\n".join(ca_lines))
            return False

        return True


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description="Compare our CUDA packages against NVIDIA's official versions"
    )
    parser.add_argument(
        "--ubuntu-version",
        default="24.04",
        help="Ubuntu version to use for comparison (default: 24.04)",
    )

    parser.add_argument(
        "--keep-work-dir",
        action="store_true",
        help="Don't delete working directory after completion",
    )

    args = parser.parse_args()

    logging.basicConfig(
        level=logging.INFO, format="[%(asctime)s] %(message)s", stream=sys.stdout
    )
    logger = logging.getLogger(__name__)
    logger.info("Starting NVIDIA package comparison test")

    # Setup configuration
    config = Config(ubuntu_version=args.ubuntu_version)

    try:
        config.setup_directories()

        # Initialize comparator
        comparator = PackageComparator(config, logger)

        # Download NVIDIA package index
        comparator.nvidia_repo.download_package_index()

        # Get list of binary packages
        control_file = Path("debian/control")
        packages = comparator.package_manager.get_installed_packages(control_file)

        if not packages:
            logger.error("No binary packages found")
            return 1

        logger.info("Found binary packages:")
        for pkg in packages:
            logger.info(f"  - {pkg}")

        issues = {
            "depends": False,
            "provides": False,
            "license": False,
        }

        for package in packages:
            our_deb, nvidia_deb = comparator.prepare_package_for_comparison(package)
            if comparator.compare_package(package, our_deb, nvidia_deb):
                logger.info(
                    f"✓ {package} matches NVIDIA's version in terms of files integrity"
                )
            if comparator.compare_license(package):
                logger.info(f"✓ {package} license content matches NVIDIA's version")
            else:
                issues["license"] = True
                logger.warning(
                    f"✗ {package} license content does not match NVIDIA's version"
                )

            result = comparator.compare_control_field("Depends", our_deb, nvidia_deb)
            if result.has_failures():
                logger.warning(f"✗ {package} has critical dependency differences")
                logger.warning(f"✗ check the dependencies and versions again.")
                logger.warning(f"✗ version mismatchs: {result.version_mismatches}")
                logger.warning(f"✗ missing items: {result.missing_items}")
                issues["depends"] = True
            elif result.has_differences():
                logger.warning(f"⚠ {package} has some minor differences")
            else:
                logger.info(f"✓ {package} dependencies match NVIDIA's version")

            result = comparator.compare_control_field("Provides", our_deb, nvidia_deb)
            if result.has_differences():
                logger.warning(f"✗ {package} has critical provides differences")
                logger.warning(f"✗ check the provides and versions again.")
                logger.warning(f"✗ provides issue: {result.version_mismatches}")
                logger.warning(f"✗ missing items: {result.missing_items}")
                logger.warning(f"✗ extra items: {result.extra_items}")
                issues["provides"] = True
            else:
                logger.info(f"✓ {package} provides match NVIDIA's version")

        if any(issues.values()):
            err = "Critical differences found, please review the logs and fix the issues before proceeding with packaging."
            if issues["depends"]:
                err += "\n- Dependency differences found, check the logs for details."
            if issues["provides"]:
                err += "\n- Provides differences found, check the logs for details."
            if issues["license"]:
                err += (
                    "\n- License content differences found, check the logs for details."
                )

            raise RuntimeError(err)

    finally:
        if not args.keep_work_dir:
            config.cleanup()


if __name__ == "__main__":
    sys.exit(main())
