# Copyright (c) Cloud Linux Software, Inc
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENCE.TXT

import io
import json
import os
import shlex
import tarfile
import time
import uuid
from tempfile import NamedTemporaryFile

from kcarectl import auth, config, http_utils, ipv6_support, log_utils, utils
from kcarectl.process_utils import run_command

if False:  # pragma: no cover
    from typing import Any, Dict, List, Optional  # noqa: F401

    from typing_extensions import Self  # noqa: F401


def format_size(num_bytes):
    # type: (int) -> str
    """Render a byte count as a short human-readable string for the logs."""
    size = float(num_bytes)
    for unit in ('B', 'KiB', 'MiB', 'GiB'):
        if size < 1024.0:
            return '{0:.1f} {1}'.format(size, unit)
        size /= 1024.0
    return '{0:.1f} TiB'.format(size)


class DataPackage(object):
    """Generic archive package for the patch server uploads pipeline.

    Based on DataPackage from eportal (delivery_kit.py). Subclasses
    supply the manifest `data_type`, the `upload_uri` to PUT the
    archive to and the `max_size` payload limit.
    """

    data_type = ''  # type: str
    upload_uri = ''  # type: str

    def __init__(self):
        # type: () -> None
        self._tar = None  # type: Optional[tarfile.TarFile]
        self._errors_buffer = []  # type: List[str]
        self._total_payload_size = 0

    @property
    def max_size(self):
        # type: () -> int
        raise NotImplementedError  # pragma: no cover

    @property
    def archive_path(self):
        # type: () -> str
        tar = self._ensure_tar_created()
        return str(tar.name)

    def add_stdout(self, arcname, cmd):
        # type: (str, str) -> None
        stdout = None
        stderr = None

        try:
            _, stdout, stderr = run_command(shlex.split(cmd), catch_stdout=True, catch_stderr=True)
        except Exception as e:
            stderr = str(e)

        if stderr:
            self.log_error('failed to dump stdout of {0}:\n{1}'.format(cmd, stderr))

        if stdout is not None:
            self.add_file(arcname, data_bytes=utils.bstr(stdout, encoding='utf-8'))

    def add_file(self, arcname, src_path=None, data_bytes=None, skip_limit_check=False):
        # type: (str, Optional[str], Optional[bytes], bool) -> None
        if src_path is None and data_bytes is None:
            raise ValueError('No src_path or data_bytes provided')

        tar = self._ensure_tar_created()
        entry_size = 0
        if src_path is not None:
            if not os.path.exists(src_path):
                self.log_error('file not found: {0}'.format(src_path))
                return

            entry_size = os.path.getsize(src_path)
        else:
            entry_size = len(data_bytes)  # type: ignore[arg-type]

        if not self._check_required_space(entry_size):
            self.log_error('no available space to store: {0}'.format(arcname))
            return

        if not skip_limit_check and not self._check_total_payload_limit(entry_size):
            # a single artifact that would push the report past the size
            # budget is dropped, not the whole report: warn (so the omission
            # is visible in kcarectl.log and recorded in errors.log) and skip
            self.log_warning(
                'skipping {0}: {1} would exceed the {2} report size limit (already collected {3})'.format(
                    arcname,
                    format_size(entry_size),
                    format_size(self.max_size),
                    format_size(self._total_payload_size),
                )
            )
            return

        try:
            if src_path:
                tar.add(src_path, arcname=arcname)
            else:
                info = tarfile.TarInfo(arcname)
                info.size = entry_size
                tar.addfile(info, io.BytesIO(data_bytes))  # type: ignore[arg-type]

            if not skip_limit_check:
                # entries exempt from the limit (manifest.json, errors.log)
                # are not counted toward the payload size either
                self._total_payload_size += entry_size
                # per-artifact size accounting so an oversized report can be
                # traced to the item(s) that bloated it (kcarectl.log only)
                log_utils.loginfo(
                    'collected {0}: {1} (report total {2})'.format(
                        arcname,
                        format_size(entry_size),
                        format_size(self._total_payload_size),
                    ),
                    print_msg=False,
                )
        except Exception as e:
            self.log_error('failed to store {0}: {1}'.format(arcname, e))

    def add_json(self, arcname, data):
        # type: (str, Dict[str, Any]) -> None
        try:
            data_bytes = utils.bstr(json.dumps(data, indent=4), encoding='utf-8')
        except TypeError as e:
            self.log_error('failed to dump {0}:\n{1}'.format(arcname, e))
            return

        self.add_file(arcname, data_bytes=data_bytes)

    def _check_required_space(self, entry_size):
        # type: (int) -> bool
        # here we simplify the check and ignore that the compressed file size will be less
        statvfs = os.statvfs(self.archive_path)
        return statvfs.f_frsize * statvfs.f_bfree > entry_size

    def _check_total_payload_limit(self, entry_size):
        # type: (int) -> bool
        # here we simplify the check and ignore that the compressed file size will be less
        return self.max_size > self._total_payload_size + entry_size

    def make_manifest(self):
        # type: () -> Dict[str, Any]
        return {
            "schema_version": 1,
            "type": self.data_type,
            "time_created": int(time.time()),
        }

    def _add_manifest(self):
        # type: () -> None
        # manifest.json must always be present (the patch server dispatches
        # uploads by its `type`), so it bypasses the lenient add_file: write
        # failures propagate to __enter__ and abort the package creation
        # instead of being downgraded to errors.log; the entry is exempt
        # from the payload limit and not counted toward it
        data_bytes = utils.bstr(json.dumps(self.make_manifest(), indent=4), encoding='utf-8')
        tar = self._ensure_tar_created()
        info = tarfile.TarInfo('manifest.json')
        info.size = len(data_bytes)
        tar.addfile(info, io.BytesIO(data_bytes))

    def log_error(self, error_msg):
        # type: (str) -> None
        error_msg = error_msg.strip()
        log_utils.logerror(error_msg, print_msg=False)
        self._errors_buffer.append(error_msg)

    def log_warning(self, warning_msg):
        # type: (str) -> None
        # warn to kcarectl.log (print_msg=False: no console noise, like
        # log_error) and keep the note in the archived errors.log so the
        # uploaded report records what was dropped
        warning_msg = warning_msg.strip()
        log_utils.logwarn(warning_msg, print_msg=False)
        self._errors_buffer.append(warning_msg)

    def __enter__(self):
        # type: () -> Self
        for compression_mode in ('w:xz', 'w:bz2', 'w:gz'):  # pragma: no branch
            tmpfile = NamedTemporaryFile(suffix='.tar.{0}'.format(compression_mode[2:]), delete=False)
            tmpfile.close()
            try:
                log_utils.loginfo('Creating DataPackage: {0}'.format(tmpfile.name), print_msg=False)
                # dereference=True mirrors the kcdoctor.sh `dump` (`cat "$1"`):
                # a symlinked source added via add_file(src_path=...) -- e.g.
                # /etc/yum.conf -> dnf/dnf.conf on EL8+, /boot/grub2/grub.cfg
                # on EFI -- is archived by content, not as a dangling link.
                self._tar = tarfile.open(name=tmpfile.name, mode=compression_mode, dereference=True)
                self._add_manifest()
                return self
            except Exception as err:
                if self._tar is not None:
                    # the manifest write may fail after a successful open;
                    # don't leak the open handle
                    try:
                        self._tar.close()
                    except Exception:  # pragma: no cover
                        pass
                    self._tar = None

                if os.path.exists(tmpfile.name):  # pragma: no branch
                    os.unlink(tmpfile.name)

                if not isinstance(err, tarfile.CompressionError):
                    raise

        raise tarfile.CompressionError('No supported compression method found')  # pragma: no cover

    def __exit__(self, exc_type, exc_val, exc_tb):
        # type: (Optional[type[BaseException]], Optional[BaseException], Any) -> bool
        if self._errors_buffer:  # pragma: no branch
            errors = '\n'.join(self._errors_buffer) + '\n'
            self.add_file('errors.log', data_bytes=utils.bstr(errors), skip_limit_check=True)

        if self._tar:  # pragma: no branch
            self._tar.close()
            if exc_val:
                self.remove_archive()
                return False

        return True

    @utils.catch_errors(logger=log_utils.logwarn)
    def remove_archive(self):
        # type: () -> None
        if self._tar and os.path.exists(self.archive_path):  # pragma: no branch
            os.unlink(self.archive_path)

    def _ensure_tar_created(self):
        # type: () -> tarfile.TarFile
        if not self._tar:
            raise RuntimeError('DataPackage should be used as a context manager')
        return self._tar

    def send(self):
        # type: () -> str
        """Send the package archive to the patch server.

        Upload errors propagate to the caller (see the eportal
        precedent): wrap with utils.catch_errors where a silent
        failure is acceptable.

        :return: Upload name (package identifier)
        """

        # flush buffered tar data even when called inside the `with`
        # block (close() is a no-op on an already closed tar)
        self._ensure_tar_created().close()

        # Generate a unique package name
        # Use find('.') to get extension from first dot to preserve .tar.xz/.tar.bz2/.tar.gz
        basename = os.path.basename(self.archive_path)
        ext = basename[basename.find('.') :] if '.' in basename else ''
        upload_name = str(uuid.uuid4()) + ext
        upload_url = ipv6_support.get_patch_server() + self.upload_uri + upload_name
        http_utils.upload_file(
            self.archive_path,
            upload_url=upload_url,
            auth_string=auth.get_http_auth_string(),
        )

        return upload_name


class KernelAnomalyPackage(DataPackage):
    data_type = 'kernel-anomaly'
    upload_uri = '/upload/kernel-anomaly/'

    @property
    def max_size(self):
        # type: () -> int
        return config.KERNEL_ANOMALY_REPORT_MAX_SIZE_BYTES
