Current File : //opt/imunify360/venv/lib64/python3.11/site-packages/imav/malwarelib/plugins/detached_scan.py
"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import shutil
import time
from logging import getLogger
from typing import Dict, Optional, Union

from defence360agent.contracts.hook_events import HookEvent
from defence360agent.contracts.messages import MessageType
from defence360agent.contracts.plugins import (
    MessageSink,
    MessageSource,
    expect,
)
from defence360agent.utils import Scope
from imav.malwarelib.config import (
    MalwareScanResourceType,
    MalwareScanType,
)
from imav.malwarelib.model import MalwareScan as MalwareScanModel
from imav.malwarelib.scan import (
    ScanAlreadyCompleteError,
    ScanInfoError,
)
from imav.malwarelib.scan.ai_bolit.detached import (
    AiBolitDetachedScan,
)
from imav.malwarelib.scan.mds.detached import MDSDetachedScan
from imav.malwarelib.scan.queue_supervisor_sync import QueueSupervisorSync
from imav.malwarelib.scan.scan_result import aggregate_result
from imav.malwarelib.utils.user_list import fill_results_owner

logger = getLogger(__name__)


class DetachedScanPlugin(MessageSink, MessageSource):
    PROCESSING_ORDER = MessageSink.ProcessingOrder.PRE_PROCESS_MESSAGE
    SCOPE = Scope.AV
    loop, sink = None, None
    results_cache = {}  # type: Dict[str, dict]

    async def create_source(self, loop, sink):
        self.loop = loop
        self.sink = sink

    async def create_sink(self, loop):
        pass

    @expect(MessageType.MalwareScan, async_lock=True)
    async def complete_scan(self, message):
        message_type = MalwareScanMessageInfo(message)

        if not message_type.is_detached:
            total_malicious = await self._count_total_malicious(message)
            message["summary"]["total_malicious"] = total_malicious
            return message
        elif message_type.is_summary:
            return await self._handle_summary(message)

        # message_type.is_result
        return await self._handle_results(message)

    async def _handle_summary(self, message):
        scan_id = message["summary"]["scanid"]
        # If summary arrives after results, results are read from cache
        if scan_id in self.results_cache:
            message["summary"]["completed"] = time.time()
            message["results"] = self.results_cache.pop(scan_id)
            total_malicious = await self._count_total_malicious(message)
            message["summary"]["total_malicious"] = total_malicious
            queued_scan = QueueSupervisorSync.queue.find(
                scanid=message["summary"]["scanid"]
            )
            if queued_scan:
                QueueSupervisorSync.queue.remove(queued_scan)
            await self._call_scan_finished_hook(
                message["summary"], queued_scan.args if queued_scan else {}
            )
        return message

    async def _handle_results(self, message):
        message = await self.aggregate_result(message)
        message_type = MalwareScanMessageInfo(message)
        summary = message["summary"]
        logger.info("Scan stopped")
        queued_scan = QueueSupervisorSync.queue.find(scanid=summary["scanid"])

        if message_type.summary_from_db is None:
            if queued_scan:
                summary["file_patterns"] = queued_scan.args["file_patterns"]
                summary["exclude_patterns"] = queued_scan.args[
                    "exclude_patterns"
                ]
                QueueSupervisorSync.queue.remove(queued_scan)
            if summary.get("path") or summary.get("error"):
                # Scan failed
                summary["total_malicious"] = 0
                await self._call_scan_finished_hook(summary, scan_args={})
                return message

            # Summary is not in DB yet, save results to cache
            scan_id = message["summary"]["scanid"]
            self.results_cache[scan_id] = message["results"]
            # Report an error to Sentry if cache grows
            cache_size = len(self.results_cache)
            if cache_size > 1:
                logger.error("MalwareScan cache size is %d", cache_size)
            return

        scan = message_type.summary_from_db
        summary["scanid"] = scan.scanid
        summary["path"] = scan.path
        summary["started"] = scan.started
        summary["completed"] = time.time()
        if summary.get("total_files") is None:
            summary["total_files"] = scan.total_resources

        summary["type"] = scan.type
        summary["error"] = summary.get("error", None)
        message["summary"] = summary

        total_malicious = await self._count_total_malicious(message)
        message["summary"]["total_malicious"] = total_malicious
        if queued_scan:
            summary["file_patterns"] = queued_scan.args["file_patterns"]
            summary["exclude_patterns"] = queued_scan.args["exclude_patterns"]
            QueueSupervisorSync.queue.remove(queued_scan)
        await self._call_scan_finished_hook(
            summary, queued_scan.args if queued_scan else {}
        )
        return message

    @staticmethod
    async def _count_total_malicious(message) -> int:
        return len(
            [
                k
                for k, v in message["results"].items()
                if v["hits"][0]["suspicious"] is False
            ]
        )

    async def _call_scan_finished_hook(self, summary, scan_args) -> None:
        scan_finished = HookEvent.MalwareScanningFinished(
            scan_id=summary["scanid"],
            scan_type=summary["type"],
            path=summary["path"],
            started=summary["started"],
            total_files=summary["total_files"],
            total_malicious=summary["total_malicious"],
            error=summary.get("error"),
            status="failed" if summary.get("error") else "ok",
            scan_params=scan_args,
            stats={
                **{
                    key: value
                    for key, value in summary.items()
                    if key
                    in (  # performance-related metrics
                        "scan_time",
                        "scan_time_hs",
                        "scan_time_preg",
                        "smart_time_hs",
                        "smart_time_preg",
                        "finder_time",
                        "cas_time",
                        "deobfuscate_time",
                        "mem_peak",
                    )
                },
                **{"total_files": summary["total_files"]},
            },
        )
        await self.sink.process_message(scan_finished)
        await self._recheck_scan_queue()

    @staticmethod
    def _get_detached_scan(
        resource_type: Optional[Union[str, MalwareScanResourceType]], scan_id
    ):
        return AiBolitDetachedScan(scan_id)

    @expect(MessageType.MalwareScanComplete)
    async def complete_detached_scan(self, message):
        scan_id = message.get("scan_id")
        resource_type = message.get("resource_type")
        detached_scan = self._get_detached_scan(resource_type, scan_id)

        try:
            scan_message = await detached_scan.complete()
        except ScanAlreadyCompleteError as err:
            # This happens when AV is woken up by AiBolit. See DEF-11078.
            logger.warning(
                "Cannot complete scan %s, assuming it is already complete"
                ":\n%s",
                scan_id,
                err,
            )
            return
        except ScanInfoError as err:
            logger.error(
                "Cannot complete %s scan %s, assuming it was not started:\n%s",
                detached_scan.RESOURCE_TYPE.value,
                scan_id,
                err,
            )
            return
        finally:
            shutil.rmtree(str(detached_scan.detached_dir), ignore_errors=True)

        await self.sink.process_message(scan_message)

    @classmethod
    async def aggregate_result(cls, message):
        message["results"] = aggregate_result(message["results"])
        await fill_results_owner(message["results"])
        return message

    async def _recheck_scan_queue(self):
        await self.sink.process_message(MessageType.MalwareScanQueueRecheck())


class MalwareScanMessageInfo:
    """A helper class that allows to receive information about scan
    from MalwareScan message.
    """

    def __init__(self, message):
        self.message = message
        self._summary_from_db = None
        self.scan_id = self.message["summary"]["scanid"]

    @property
    def is_detached(self):
        summary = self.message["summary"]
        return summary.get("type") in (
            MalwareScanType.ON_DEMAND,
            MalwareScanType.BACKGROUND,
            MalwareScanType.USER,
            None,
        )

    @property
    def is_summary(self):
        return self.message["results"] is None

    @property
    def summary_from_db(self):
        if not self._summary_from_db:
            summary_from_db = (
                MalwareScanModel.select()
                .where(MalwareScanModel.scanid == self.scan_id)
                .limit(1)
            )
            if summary_from_db:
                self._summary_from_db = summary_from_db[0]
        return self._summary_from_db


class DetachedScanPluginIm360(DetachedScanPlugin):
    SCOPE = Scope.IM360

    @staticmethod
    def _get_detached_scan(
        resource_type: Optional[Union[str, MalwareScanResourceType]], scan_id
    ):
        if resource_type is not None and (
            MalwareScanResourceType(resource_type)
            is MalwareScanResourceType.DB
        ):
            return MDSDetachedScan(scan_id)
        return AiBolitDetachedScan(scan_id)

    @expect(MessageType.MalwareDatabaseScan)
    async def complete_scan_db(self, message):
        queued_scan = QueueSupervisorSync.queue.find(scanid=message["scan_id"])
        if queued_scan:
            QueueSupervisorSync.queue.remove(queued_scan)

            scan_finished_event = HookEvent.MalwareScanningFinished(
                scan_id=message["scan_id"],
                scan_type=message["type"],
                path=message["path"],
            )
            await self.sink.process_message(scan_finished_event)

            await self._recheck_scan_queue()