py3: remove py2/3 branches

Remove `if-else`s with a condition like `sys.version_info >= (3, 0)`.
This commit is contained in:
Ana Maria Martinez Gomez
2021-03-11 17:59:40 +01:00
parent 407ecab162
commit aa4d6305af
8 changed files with 38 additions and 98 deletions

View File

@@ -27,10 +27,7 @@ VALID_ARCH = (ARCH_X32, ARCH_X64)
def bytes_to_str(b):
if sys.version_info[0] >= 3:
return str(codecs.encode(b, "hex").decode("utf-8"))
else:
return codecs.encode(b, "hex")
return str(codecs.encode(b, "hex").decode("utf-8"))
def hex_string(h):

View File

@@ -16,10 +16,7 @@ MIN_STACKSTRING_LEN = 8
def xor_static(data, i):
if sys.version_info >= (3, 0):
return bytes(c ^ i for c in data)
else:
return "".join(chr(ord(c) ^ i) for c in data)
return bytes(c ^ i for c in data)
def is_aw_function(symbol):

View File

@@ -34,10 +34,7 @@ def add_ea_int_cast(o):
this bit of skullduggery lets use cast viv-utils objects as ints.
the correct way of doing this is to update viv-utils (or subclass the objects here).
"""
if sys.version_info[0] >= 3:
setattr(o, "__int__", types.MethodType(get_ea, o))
else:
setattr(o, "__int__", types.MethodType(get_ea, o, type(o)))
setattr(o, "__int__", types.MethodType(get_ea, o))
return o

View File

@@ -39,18 +39,11 @@ def get_printable_len(op):
raise ValueError("Unhandled operand data type 0x%x." % op.dtype)
def is_printable_ascii(chars):
if sys.version_info[0] >= 3:
return all(c < 127 and chr(c) in string.printable for c in chars)
else:
return all(ord(c) < 127 and c in string.printable for c in chars)
return all(c < 127 and chr(c) in string.printable for c in chars)
def is_printable_utf16le(chars):
if sys.version_info[0] >= 3:
if all(c == 0x00 for c in chars[1::2]):
return is_printable_ascii(chars[::2])
else:
if all(c == "\x00" for c in chars[1::2]):
return is_printable_ascii(chars[::2])
if all(c == 0x00 for c in chars[1::2]):
return is_printable_ascii(chars[::2])
if is_printable_ascii(chars):
return idaapi.get_dtype_size(op.dtype)

View File

@@ -23,11 +23,7 @@ def find_byte_sequence(start, end, seq):
end: max virtual address
seq: bytes to search e.g. b"\x01\x03"
"""
if sys.version_info[0] >= 3:
seq = " ".join(["%02x" % b for b in seq])
else:
seq = " ".join(["%02x" % ord(b) for b in seq])
seq = " ".join(["%02x" % b for b in seq])
while True:
ea = idaapi.find_binary(start, end, seq, 0, idaapi.SEARCH_DOWN)
if ea == idaapi.BADADDR:

View File

@@ -264,15 +264,14 @@ def main(argv=None):
parser.add_argument(
"-f", "--format", choices=[f[0] for f in formats], default="auto", help="Select sample format, %s" % format_help
)
if sys.version_info >= (3, 0):
parser.add_argument(
"-b",
"--backend",
type=str,
help="select the backend to use",
choices=(capa.main.BACKEND_VIV, capa.main.BACKEND_SMDA),
default=capa.main.BACKEND_VIV,
)
parser.add_argument(
"-b",
"--backend",
type=str,
help="select the backend to use",
choices=(capa.main.BACKEND_VIV, capa.main.BACKEND_SMDA),
default=capa.main.BACKEND_VIV,
)
args = parser.parse_args(args=argv)
if args.quiet:
@@ -285,8 +284,7 @@ def main(argv=None):
logging.basicConfig(level=logging.INFO)
logging.getLogger().setLevel(logging.INFO)
backend = args.backend if sys.version_info > (3, 0) else capa.main.BACKEND_VIV
extractor = capa.main.get_extractor(args.sample, args.format, backend)
extractor = capa.main.get_extractor(args.sample, args.format, args.backend)
with open(args.output, "wb") as f:
f.write(dump(extractor))

View File

@@ -328,14 +328,10 @@ class CapaExplorerByteViewItem(CapaExplorerFeatureItem):
"""
byte_snap = idaapi.get_bytes(location, 32)
details = ""
if byte_snap:
byte_snap = codecs.encode(byte_snap, "hex").upper()
if sys.version_info >= (3, 0):
details = " ".join([byte_snap[i : i + 2].decode() for i in range(0, len(byte_snap), 2)])
else:
details = " ".join([byte_snap[i : i + 2] for i in range(0, len(byte_snap), 2)])
else:
details = ""
details = " ".join([byte_snap[i : i + 2].decode() for i in range(0, len(byte_snap), 2)])
super(CapaExplorerByteViewItem, self).__init__(parent, display, location=location, details=details)
self.ida_highlight = idc.get_color(location, idc.CIC_ITEM)

View File

@@ -288,26 +288,15 @@ def get_workspace(path, format, should_save=True):
return vw
def get_extractor_py2(path, format, disable_progress=False):
import capa.features.extractors.viv
with halo.Halo(text="analyzing program", spinner="simpleDots", stream=sys.stderr, enabled=not disable_progress):
vw = get_workspace(path, format, should_save=False)
try:
vw.saveWorkspace()
except IOError:
# see #168 for discussion around how to handle non-writable directories
logger.info("source directory is not writable, won't save intermediate workspace")
return capa.features.extractors.viv.VivisectFeatureExtractor(vw, path)
class UnsupportedRuntimeError(RuntimeError):
pass
def get_extractor_py3(path, format, backend, disable_progress=False):
def get_extractor(path, format, backend, disable_progress=False):
"""
raises:
UnsupportedFormatError:
"""
if backend == "smda":
from smda.SmdaConfig import SmdaConfig
from smda.Disassembler import Disassembler
@@ -337,17 +326,6 @@ def get_extractor_py3(path, format, backend, disable_progress=False):
return capa.features.extractors.viv.VivisectFeatureExtractor(vw, path)
def get_extractor(path, format, backend, disable_progress=False):
"""
raises:
UnsupportedFormatError:
"""
if sys.version_info >= (3, 0):
return get_extractor_py3(path, format, backend, disable_progress=disable_progress)
else:
return get_extractor_py2(path, format, disable_progress=disable_progress)
def is_nursery_rule_path(path):
"""
The nursery is a spot for rules that have not yet been fully polished.
@@ -498,22 +476,11 @@ def install_common_args(parser, wanted=None):
#
if "sample" in wanted:
if sys.version_info >= (3, 0):
parser.add_argument(
# Python 3 str handles non-ASCII arguments correctly
"sample",
type=str,
help="path to sample to analyze",
)
else:
parser.add_argument(
# in #328 we noticed that the sample path is not handled correctly if it contains non-ASCII characters
# https://stackoverflow.com/a/22947334/ offers a solution and decoding using getfilesystemencoding works
# in our testing, however other sources suggest `sys.stdin.encoding` (https://stackoverflow.com/q/4012571/)
"sample",
type=lambda s: s.decode(sys.getfilesystemencoding()),
help="path to sample to analyze",
)
parser.add_argument(
"sample",
type=str,
help="path to sample to analyze",
)
if "format" in wanted:
formats = [
@@ -532,15 +499,15 @@ def install_common_args(parser, wanted=None):
help="select sample format, %s" % format_help,
)
if "backend" in wanted and sys.version_info >= (3, 0):
parser.add_argument(
"-b",
"--backend",
type=str,
help="select the backend to use",
choices=(BACKEND_VIV, BACKEND_SMDA),
default=BACKEND_VIV,
)
if "backend" in wanted:
parser.add_argument(
"-b",
"--backend",
type=str,
help="select the backend to use",
choices=(BACKEND_VIV, BACKEND_SMDA),
default=BACKEND_VIV,
)
if "rules" in wanted:
parser.add_argument(
@@ -703,8 +670,7 @@ def main(argv=None):
else:
format = args.format
try:
backend = args.backend if sys.version_info > (3, 0) else BACKEND_VIV
extractor = get_extractor(args.sample, args.format, backend, disable_progress=args.quiet)
extractor = get_extractor(args.sample, args.format, args.backend, disable_progress=args.quiet)
except UnsupportedFormatError:
logger.error("-" * 80)
logger.error(" Input file does not appear to be a PE file.")