Current File : //proc/self/root/opt/imunify360/venv/lib/python3.11/site-packages/imav/malwarelib/model.py
"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
from __future__ import annotations

import asyncio
import itertools
import os
from dataclasses import dataclass
from operator import attrgetter
from pathlib import Path
from time import time
from typing import Dict, Iterable, List, Set, cast

from peewee import (
    BooleanField,
    Case,
    CharField,
    Check,
    Expression,
    FloatField,
    ForeignKeyField,
    IntegerField,
    PrimaryKeyField,
    SQL,
    TextField,
    fn,
)
from playhouse.shortcuts import model_to_dict

from defence360agent.contracts.config import UserType
from defence360agent.model import Model, instance
from defence360agent.model.simplification import (
    FilenameField,
    ScanPathField,
    apply_order_by,
)
from defence360agent.utils import (
    execute_iterable_expression,
    get_abspath_from_user_dir,
    get_results_iterable_expression,
    split_for_chunk,
)
from imav.malwarelib.config import (
    FAILED_TO_CLEANUP,
    MalwareHitStatus,
    MalwareScanResourceType,
    MalwareScanType,
)
from imav.malwarelib.scan.crontab import get_crontab


class MalwareScan(Model):
    """Represents a batch of files scanned for malware

    Usually a single AI-BOLIT execution.
    See :class:`.MalwareScanType` for possible kinds of scans.
    """

    class Meta:
        database = instance.db
        db_table = "malware_scans"

    #: An id of a scan, unique per server.
    scanid = CharField(primary_key=True)
    #: Scan start timestamp.
    started = IntegerField(null=False)
    #: Scan completion timestamp.
    completed = IntegerField(null=True)
    #: Scan type - reflects how and why the files were scanned.
    #: Must be one of :class:`.MalwareScanType`.
    type = CharField(
        null=False,
        constraints=[
            Check(
                "type in {}".format(
                    (
                        MalwareScanType.ON_DEMAND,
                        MalwareScanType.REALTIME,
                        MalwareScanType.MALWARE_RESPONSE,
                        MalwareScanType.BACKGROUND,
                        MalwareScanType.RESCAN,
                        MalwareScanType.USER,
                        MalwareScanType.RESCAN_OUTDATED,
                    )
                )
            )
        ],
    )
    #: The number of resources scanned.
    total_resources = IntegerField(null=False, default=0)
    #: For some types of scan - the directory or a file that was scanned.
    path = ScanPathField(null=True, default="")
    #: If not `null`, the scan did not finish successfully.
    #: Can be one of :class:`.ExitDetachedScanType` if scan was aborted or
    #: stopped by user, or an arbitrary error message for other kinds
    #: of issues.
    error = TextField(null=True, default=None)
    #: The number of malicious files found
    total_malicious = IntegerField(null=False, default=0)
    resource_type = CharField(
        null=False,
        constraints=[
            Check(
                "resource_type in {}".format(
                    (
                        MalwareScanResourceType.DB.value,
                        MalwareScanResourceType.FILE.value,
                    )
                )
            )
        ],
    )
    #: user who started the scan (None for root user)
    initiator = CharField(null=True)

    @classmethod
    def ondemand_list(
        cls,
        since,
        to,
        limit,
        offset,
        order_by=None,
        *,
        types=(
            MalwareScanType.ON_DEMAND,
            MalwareScanType.BACKGROUND,
            MalwareScanType.USER,
        ),
        paths=None,
    ):
        query = (
            cls.select(
                cls.total_resources,
                cls.path,
                cls.scanid,
                cls.started,
                cls.completed,
                cls.error,
                cls.total_malicious,
                cls.type.alias("scan_type"),
                cls.resource_type,
            )
            .where(cls.type.in_(types))
            .where(cls.started >= since)
            .where(cls.started <= to)
        )

        if paths:
            query = query.where(cls.path.in_(paths))

        query = (
            query.group_by(
                cls.total_resources, cls.path, cls.scanid, cls.started
            )
            .order_by(MalwareScan.started.desc())
            .limit(limit)
            .offset(offset)
        )

        if order_by is not None:
            query = apply_order_by(order_by, cls, query)

        return query.count(clear_limit=True), list(query.dicts())


class MalwareHit(Model):
    """Represents a malicious or suspicious file."""

    class Meta:
        database = instance.db
        db_table = "malware_hits"

    #: An id of a scan, unique per server.
    id = PrimaryKeyField()
    #: A reference to :class:`MalwareScan`.
    scanid = ForeignKeyField(
        MalwareScan, null=False, related_name="hits", on_delete="CASCADE"
    )
    #: The owner of the file.
    owner = CharField(null=False)
    #: The user a file belongs to (is in user's home but owned by another user)
    user = CharField(null=False)
    #: The original path to the file.
    orig_file = FilenameField(null=False)
    #: The type of infection (signature).
    type = CharField(null=False)
    #: Whether the file is malicious or just suspicious.
    #: Suspicious files are not displayed in UI but sent for analysis to MRS.
    malicious = BooleanField(null=False, default=False)
    #: The hash of the files as provided by AI-BOLIT.
    hash = CharField(null=True)
    #: The size of the file.
    size = CharField(null=True)
    #: The exact timestamp when AI-BOLIT has detected the file.
    #:
    #: FIXME: unused? It looks like it was intended to resolve some possible
    #: race conditions with parallel scans, but we don't actually use it
    #: from the DB - we only compare the value in scan report
    #: with :attr:`cleaned_at`.
    timestamp = FloatField(null=True)

    #: The current status of the file.
    #: Must be one of :class:`.MalwareHitStatus`.
    status = CharField(default=MalwareHitStatus.FOUND)
    #: Timestamp when the file was last cleaned.
    cleaned_at = FloatField(null=True)
    resource_type = CharField(
        null=False,
        constraints=[
            Check(
                "resource_type in {}".format(
                    (
                        MalwareScanResourceType.DB.value,
                        MalwareScanResourceType.FILE.value,
                    )
                )
            )
        ],
    )
    app_name = CharField(null=True)
    db_host = CharField(null=True)
    db_port = CharField(null=True)
    db_name = CharField(null=True)
    snippet = CharField(null=True)

    @property
    def orig_file_path(self):
        orig_file = cast(str, self.orig_file)
        return Path(orig_file)

    class OrderBy:
        @staticmethod
        def status():
            return (
                Case(
                    MalwareHit.status,
                    (
                        (MalwareHitStatus.CLEANUP_PENDING, 0),
                        (MalwareHitStatus.CLEANUP_STARTED, 1),
                        (MalwareHitStatus.FOUND, 2),
                        (MalwareHitStatus.CLEANUP_DONE, 4),
                        (MalwareHitStatus.CLEANUP_REMOVED, 5),
                    ),
                    100,
                ),
            )

    @classmethod
    def _hits_list(
        cls,
        clauses,
        since=0,
        to=None,
        limit=None,
        offset=None,
        search=None,
        by_scan_id=None,
        user=None,
        order_by=None,
        by_status=None,
        ids=None,
        **kwargs,
    ):
        hits = cls.select(cls, MalwareScan).join(MalwareScan)

        to = to or time()
        pattern = "%{}%".format(search)
        started = (MalwareScan.started >= since) & (MalwareScan.started <= to)
        full_clauses = clauses & started
        if search is not None:
            full_clauses &= SQL(
                "CAST(orig_file AS TEXT) LIKE ?", (pattern,)
            ) | (cls.user**pattern)
        if user is not None:
            full_clauses &= MalwareHit.user == user
        if by_scan_id is not None:
            full_clauses &= MalwareScan.scanid == by_scan_id
        if by_status is not None:
            full_clauses &= MalwareHit.status << by_status
        # `max_count` is used for pagination, must not include `ids`
        max_count_clauses = full_clauses
        if ids is not None:
            full_clauses &= MalwareHit.id.in_(ids)

        ordered = hits.where(full_clauses).limit(limit).offset(offset)

        if order_by is not None:
            ordered = apply_order_by(order_by, MalwareHit, ordered)

        max_count = cls._hits_num(max_count_clauses)
        result = [row.as_dict() for row in ordered]

        return max_count, result

    @classmethod
    def suspicious_list(cls, *args, **kwargs):
        return cls._hits_list(cls.is_suspicious(), *args, **kwargs)

    @classmethod
    def _hits_num(
        cls, clauses=None, since=None, to=None, user=None, order_by=None
    ):
        if since and to:
            clauses &= (MalwareScan.started >= since) & (
                MalwareScan.started <= to
            )
        if user is not None:
            clauses &= cls.user == user
        q = cls.select(fn.COUNT(cls.id)).join(MalwareScan).where(clauses)
        if order_by is not None:
            q = apply_order_by(order_by, MalwareHit, q)
        return q.scalar()

    @classmethod
    def malicious_num(cls, since, to, user=None):
        return cls._hits_num(
            (cls.status.not_in(MalwareHitStatus.CLEANUP) & cls.malicious),
            since,
            to,
            user,
        )

    @classmethod
    def malicious_list(cls, *args, ignore_cleaned=False, **kwargs):
        clauses = cls.malicious
        if ignore_cleaned:
            clauses &= cls.status.not_in(MalwareHitStatus.CLEANUP)
        return cls._hits_list(clauses, *args, **kwargs)

    @classmethod
    def set_status(cls, hits, status, cleaned_at=None):
        hits = [row.id for row in hits]

        def expression(ids, cls, status, cleaned_at):
            fields_to_update = {
                "status": status,
            }
            if cleaned_at is not None:
                fields_to_update["cleaned_at"] = cleaned_at

            return cls.update(**fields_to_update).where(cls.id.in_(ids))

        return execute_iterable_expression(
            expression, hits, cls, status, cleaned_at
        )

    @classmethod
    def delete_instances(cls, to_delete: list):
        to_delete = [row.id for row in to_delete]

        def expression(ids):
            return cls.delete().where(cls.id.in_(ids))

        return execute_iterable_expression(expression, to_delete)

    @classmethod
    def update_instances(cls, to_update: list):
        for data in to_update:
            for instance, new_fields_data in data.items():
                for field, value in new_fields_data.items():
                    setattr(instance, field, value)
                instance.save()

    @classmethod
    def is_infected(cls) -> Expression:
        clauses = (
            cls.status.in_(
                [
                    MalwareHitStatus.FOUND,
                ]
            )
            & cls.malicious
        )
        return clauses

    @classmethod
    def is_suspicious(cls):
        return ~cls.malicious

    @classmethod
    def malicious_select(
        cls, ids=None, user=None, cleanup=False, restore=False, **kwargs
    ):
        def expression(chunk_of_ids, cls, user):
            clauses = cls.malicious
            if chunk_of_ids is not None:
                clauses &= cls.id.in_(chunk_of_ids)
            elif cleanup:
                clauses &= cls.status.not_in(MalwareHitStatus.CLEANUP)
            elif restore:
                clauses &= cls.status.in_(MalwareHitStatus.RESTORABLE)
            if user is not None:
                if isinstance(user, str):
                    user = [user]
                clauses &= cls.user.in_(user)
            return cls.select().where(clauses)

        return list(
            get_results_iterable_expression(
                expression, ids, cls, user, exec_expr_with_empty_iter=True
            )
        )

    @classmethod
    def get_hits(cls, files, *, statuses=None):
        def expression(files):
            clauses = cls.orig_file.in_(files)
            if statuses:
                clauses &= cls.status.in_(statuses)
            return cls.select().where(clauses)

        return get_results_iterable_expression(expression, files)

    @classmethod
    def get_db_hits(cls, hits_info: Set):
        paths = [entry.path for entry in hits_info]
        apps = [entry.app_name for entry in hits_info]
        paths_apps = [(entry.path, entry.app_name) for entry in hits_info]
        hits = list(
            MalwareHit.select()
            .where(MalwareHit.orig_file.in_(paths))
            .where(MalwareHit.app_name.in_(apps))
        )
        hits = [
            hit for hit in hits if (hit.orig_file, hit.app_name) in paths_apps
        ]
        return hits

    @classmethod
    def delete_hits(cls, files):
        def expression(files):
            return cls.delete().where(cls.orig_file.in_(files))

        return execute_iterable_expression(expression, files)

    def refresh(self):
        return type(self).get(self._pk_expr())

    @classmethod
    def refresh_hits(cls, hits: Iterable[MalwareHit], include_scan_info=False):
        def expression(hits):
            query = cls.select()
            if include_scan_info:  # use a single query to get scan info
                query = cls.select(cls, MalwareScan).join(MalwareScan)
            return query.where(cls.id.in_([hit.id for hit in hits]))

        return list(get_results_iterable_expression(expression, hits))

    @classmethod
    def db_hits(cls):
        return cls.select().where(
            cls.resource_type == MalwareScanResourceType.DB.value
        )

    @classmethod
    def db_hits_pending_cleanup(cls) -> Expression:
        """Return db hits that are in queue for cleanup"""
        return cls.db_hits().where(
            cls.status == MalwareHitStatus.CLEANUP_PENDING,
        )

    @classmethod
    def db_hits_under_cleanup(cls) -> Expression:
        """Return db hits for which the cleanup is in progress"""
        return cls.db_hits().where(
            cls.status == MalwareHitStatus.CLEANUP_STARTED
        )

    @classmethod
    def db_hits_under_restoration(cls) -> Expression:
        """Return db hits for which the restore is in progress"""
        return cls.db_hits().where(
            cls.status == MalwareHitStatus.CLEANUP_RESTORE_STARTED
        )

    @classmethod
    def db_hits_under_cleanup_in(cls, hit_info_set):
        """
        Return db hits for which the cleanup is in progress
        specified by the provided set of MalwareDatabaseHitInfo
        """
        # FIXME: Use peewee.ValuesList when peewee is updated
        # to obtain all hits using one query without additional processing
        path_set = {hit_info.path for hit_info in hit_info_set}
        app_name_set = {hit_info.app_name for hit_info in hit_info_set}
        path_app_name_set = {
            (hit_info.path, hit_info.app_name) for hit_info in hit_info_set
        }
        query = (
            cls.db_hits_under_cleanup()
            .where(cls.orig_file.in_(path_set))
            .where(cls.app_name.in_(app_name_set))
        )
        return [
            hit
            for hit in query
            if (hit.orig_file, hit.app_name) in path_app_name_set
        ]

    @classmethod
    def db_hits_pending_cleanup_restore(cls):
        return cls.db_hits().where(
            cls.status == MalwareHitStatus.CLEANUP_RESTORE_PENDING
        )

    @classmethod
    def db_hits_under_cleanup_restore(cls):
        return cls.db_hits().where(
            cls.status == MalwareHitStatus.CLEANUP_RESTORE_STARTED
        )

    @staticmethod
    def group_by_attribute(
        *hit_list_list: List["MalwareHit"], attribute: str
    ) -> Dict[str, List["MalwareHit"]]:
        hit_list = sorted(
            (hit for hit in itertools.chain.from_iterable(hit_list_list)),
            key=attrgetter(attribute),
        )
        return {
            attr_value: list(hits)
            for attr_value, hits in itertools.groupby(
                hit_list,
                key=attrgetter(attribute),
            )
        }

    def as_dict(self):
        return {
            "id": self.id,
            "username": self.user,
            "file": self.orig_file,
            "created": self.scanid.started,
            "scan_id": self.scanid_id,
            "scan_type": self.scanid.type,
            "resource_type": self.resource_type,
            "type": self.type,
            "hash": self.hash,
            "size": self.size,
            "malicious": self.malicious,
            "status": self.status,
            "cleaned_at": self.cleaned_at,
            "extra_data": {},
            "db_name": self.db_name,
            "app_name": self.app_name,
            "db_host": self.db_host,
            "db_port": self.db_port,
            "snippet": self.snippet,
            "table_fields": (
                list(
                    MalwareHistory.select(
                        MalwareHistory.table_name,
                        MalwareHistory.table_field,
                        MalwareHistory.table_row_inf,
                    )
                    .where(
                        MalwareHistory.app_name == self.app_name,
                        MalwareHistory.db_host == self.db_host,
                        MalwareHistory.db_port == self.db_port,
                        MalwareHistory.db_name == self.db_name,
                        MalwareHistory.path == self.orig_file,
                        MalwareHistory.resource_type == self.resource_type,
                        MalwareHistory.scan_id == self.scanid,
                        MalwareHistory.table_name.is_null(False),
                        MalwareHistory.table_field.is_null(False),
                        MalwareHistory.table_row_inf.is_null(False),
                    )
                    .dicts()
                )
                if self.resource_type == MalwareScanResourceType.DB.value
                else []
            ),
        }

    def __repr__(self):
        if self.app_name:
            return "%s(orig_file=%r, app_name=%r)" % (
                self.__class__.__name__,
                self.orig_file,
                self.app_name,
            )
        return "%s(orig_file=%r)" % (self.__class__.__name__, self.orig_file)


@dataclass(frozen=True)
class MalwareHitAlternate:
    """
    Used as a replacement for MalwareHit for file hits only
    """

    scanid: str
    orig_file: str
    # app_name is always None for file hits
    app_name: None
    owner: str
    user: str
    size: int
    hash: str
    type: str
    timestamp: int
    malicious: bool

    @classmethod
    def create(cls, scanid, filename, data):
        return cls(
            scanid=scanid,
            orig_file=filename,
            app_name=None,
            owner=data["owner"],
            user=data["user"],
            size=data["size"],
            hash=data["hash"],
            type=data["hits"][0]["matches"],
            timestamp=data["hits"][0]["timestamp"],
            malicious=not data["hits"][0]["suspicious"],
        )

    @property
    def orig_file_path(self):
        return Path(os.fsdecode(self.orig_file))


class MalwareIgnorePath(Model):
    """A path that must be excluded from all scans"""

    class Meta:
        database = instance.db
        db_table = "malware_ignore_path"
        indexes = ((("path", "resource_type"), True),)  # True refers to unique

    CACHE = None

    id = PrimaryKeyField()
    #: The path itself. Wildcards or patterns are NOT supported.
    path = CharField()
    resource_type = CharField(
        null=False, constraints=[Check("resource_type in ('file','db')")]
    )
    #: Timestamp when it was added.
    added_date = IntegerField(null=False, default=lambda: int(time()))

    @classmethod
    def _update_cache(cls):
        items = list(cls.select().order_by(cls.path).dicts())
        cls.CACHE = items

    @classmethod
    def create(cls, **kwargs):
        cls.CACHE = None
        return super(MalwareIgnorePath, cls).create(**kwargs)

    @classmethod
    def delete(cls):
        cls.CACHE = None
        return super(MalwareIgnorePath, cls).delete()

    @classmethod
    def paths_count_and_list(
        cls,
        limit=None,
        offset=None,
        search=None,
        resource_type: str = None,
        user=None,
        since=None,
        to=None,
        order_by=None,
    ):
        q = cls.select().order_by(cls.path)
        if since is not None:
            q = q.where(cls.added_date >= since)
        if to is not None:
            q = q.where(cls.added_date <= to)
        if search is not None:
            q = q.where(cls.path.contains(search))
        if resource_type is not None:
            q = q.where(cls.resource_type == resource_type)
        if offset is not None:
            q = q.offset(offset)
        if limit is not None:
            q = q.limit(limit)
        if order_by is not None:
            q = apply_order_by(order_by, cls, q)
        if user is not None:
            user_home = get_abspath_from_user_dir(user)
            q = q.where(
                (cls.path.startswith(str(user_home) + "/"))
                | (cls.path == str(user_home))
                | (cls.path == str(get_crontab(user)))
            )

        max_count = q.count(clear_limit=True)
        return (
            max_count,
            [model_to_dict(row) for row in q],
        )

    @classmethod
    def path_list(cls, *args, **kwargs) -> List[str]:
        _, path_list = cls.paths_count_and_list(*args, **kwargs)
        return [row["path"] for row in path_list]

    @classmethod
    async def is_path_ignored(cls, check_path):
        """Checks whether path stored in MalwareIgnorePath cache or
        if it's belongs to path from cache or if it matches patters from cache

        :param str check_path: path to check
        :return: bool: is ignored according MalwareIgnorePath
        """
        if cls.CACHE is None:
            cls._update_cache()
        path = Path(check_path)
        for p in cls.CACHE:
            await asyncio.sleep(0)
            ignored_path = Path(p["path"])
            if (path == ignored_path) or (ignored_path in path.parents):
                return True
        return False


class MalwareHistory(Model):
    """Records every event related to :class:`MalwareHit` records"""

    class Meta:
        database = instance.db
        db_table = "malware_history"

    #: The path of the file.
    path = FilenameField(null=False)
    app_name = CharField(null=True)
    resource_type = CharField(
        null=False,
        constraints=[
            Check(
                "resource_type in {}".format(
                    (
                        MalwareScanResourceType.DB.value,
                        MalwareScanResourceType.FILE.value,
                    )
                )
            )
        ],
        default=MalwareScanResourceType.FILE.value,
    )
    #: What happened with the file. Should be one of :class:`.MalwareEvent`.
    event = CharField(null=False)
    #: What kind of scan has detected the file, or `manual` for manual actions.
    #: See :class:`.MalwareScanType`.
    cause = CharField(null=False)
    #: The name of the user who has triggered the event.
    initiator = CharField(null=False)
    #: A snapshot of :attr:`MalwareHit.owner`
    file_owner = CharField(null=False)
    #: A snapshot of :attr:`MalwareHit.user`
    file_user = CharField(null=False)
    #: Timestamp when the event took place.
    ctime = IntegerField(null=False, default=lambda: int(time()))
    #: Database host name (for db type scan).
    db_host = CharField(null=True)
    #: Database port (for db type scan).
    db_port = CharField(null=True)
    #: Database name (for db type scan).
    db_name = CharField(null=True)
    #: Infected table name (for db type scan)
    table_name = CharField(null=True)
    #: Infected field name (for db type scan)
    table_field = CharField(null=True)
    #: Infected table row id (for db type scan)
    table_row_inf = IntegerField(null=True)
    #: Scan ID reference (for generating `table_fields`)
    scan_id = CharField(null=True)

    @classmethod
    def get_history(
        cls, since, to, limit, offset, user=None, search=None, order_by=None
    ):
        clauses = (cls.ctime >= since) & (cls.ctime <= to)
        if search:
            clauses &= (cls.event.contains(search)) | (
                SQL("(INSTR(path, ?))", (search,))
            )

        if user:
            clauses &= cls.file_user == user
        query = cls.select().where(clauses).limit(limit).offset(offset).dicts()
        if order_by is not None:
            query = apply_order_by(order_by, MalwareHistory, query)

        list_result = list(query)

        return query.count(clear_limit=True), list_result

    @classmethod
    def save_event(cls, **kwargs):
        cls.insert(
            initiator=kwargs.pop("initiator", None) or UserType.ROOT,
            cause=kwargs.pop("cause", None) or MalwareScanType.MANUAL,
            resource_type=kwargs.pop("resource_type", None)
            or MalwareScanResourceType.FILE.value,
            **kwargs,
        ).execute()

    @classmethod
    def save_events(cls, hits: List[dict]):
        with instance.db.atomic():
            # The maximum number of inserts using insert_many is
            # SQLITE_LIMIT_VARIABLE_NUMBER / # of columns.
            # SQLITE_LIMIT_VARIABLE_NUMBER is set at SQLite compile time with
            # the default value of 999.
            for hits_chunk in split_for_chunk(
                hits, chunk_size=999 // len(cls._meta.columns)
            ):
                cls.insert_many(hits_chunk).execute()

    @classmethod
    def get_failed_cleanup_events_count(cls, paths: list, *, since: int):
        return (
            cls.select(cls.path, fn.COUNT())
            .where(
                cls.path.in_(paths)
                & (cls.event == FAILED_TO_CLEANUP)
                & (cls.ctime >= since)
            )
            .group_by(cls.path)
            .tuples()
        )