loading com db only once

avoid loading db multiple times by caching it.
This commit is contained in:
Aayush Goel
2023-10-17 19:48:06 +05:30
parent 412d296d6b
commit 884b714be2

View File

@@ -16,6 +16,7 @@ import logging
import binascii
import collections
from enum import Enum
from typing import Literal
from pathlib import Path
from capa.helpers import assert_never
@@ -202,23 +203,37 @@ VALID_COM_TYPES = {
"interface": {"db_path": "assets/interfaces.json.gz", "prefix": "IID_"},
}
com_db_cache: Dict[str, Dict[str, List[str]]] = {}
def translate_com_feature(com_name: str, com_type: str) -> ceng.Or:
def load_com_database(com_type: Literal["class", "interface"]) -> Dict[str, List[str]]:
com_db_path = capa.main.get_default_root() / VALID_COM_TYPES[com_type]["db_path"]
if com_type in com_db_cache:
# If the com database is already in the cache, return it
return com_db_cache[com_type]
if not com_db_path.exists():
raise IOError(f"COM database path '{com_db_path}' does not exist or cannot be accessed")
try:
with gzip.open(com_db_path, "rb") as gzfile:
com_db: Dict[str, List[str]] = json.loads(gzfile.read().decode("utf-8"))
com_db_cache[com_type] = com_db # Cache the loaded database
return com_db
except Exception as e:
raise IOError(f"Error loading COM database from '{com_db_path}'") from e
def translate_com_feature(com_name: str, com_type: Literal["class", "interface"]) -> ceng.Or:
if com_type not in VALID_COM_TYPES:
raise InvalidRule(f"Invalid COM type present {com_type}")
CD = Path(__file__).resolve().parent.parent.parent
com_db_path = CD / VALID_COM_TYPES[com_type]["db_path"]
if not com_db_path.exists():
logger.error("Using COM %s database '%s', but it doesn't exist", com_type, com_db_path)
raise IOError(f"COM database path '{com_db_path}' does not exist or cannot be accessed")
with gzip.open(com_db_path, "rb") as gzfile:
com_db: Dict[str, List[str]] = json.loads(gzfile.read().decode("utf-8"))
guid_strings: Optional[List[str]] = com_db.get(com_name)
if guid_strings is None or len(guid_strings) == 0:
logger.error(" %s doesn't exist in COM %s database", com_name, com_type)
raise InvalidRule(f"'{com_name}' doesn't exist in COM {com_type} database")
com_db = load_com_database(com_type)
guid_strings: Optional[List[str]] = com_db.get(com_name)
if guid_strings is None or len(guid_strings) == 0:
logger.error(" %s doesn't exist in COM %s database", com_name, com_type)
raise InvalidRule(f"'{com_name}' doesn't exist in COM {com_type} database")
com_features: List = []
for guid_string in guid_strings: