diff --git a/pysoot/lifter.py b/pysoot/lifter.py index 79c6352..e9de99a 100644 --- a/pysoot/lifter.py +++ b/pysoot/lifter.py @@ -88,11 +88,16 @@ def _get_ir(self): # delayed import from .soot_manager import SootManager - self.soot_wrapper = SootManager() + soot_wrapper = SootManager() log.info("Running Soot with the following config: " + repr(config)) - self.soot_wrapper.init(**config) - self.classes = self.soot_wrapper.get_classes() + soot_wrapper.init(**config) + self.classes = soot_wrapper.get_classes() + self._hierarchy = soot_wrapper.compute_hierarchy() + + def getSubclassesOf(self, class_name: str) -> list[str]: + """Return pre-computed subclasses of the given class name.""" + return self._hierarchy.get(class_name, []) def _get_java_home() -> str: diff --git a/pysoot/soot_manager.py b/pysoot/soot_manager.py index 441bd6f..6895502 100644 --- a/pysoot/soot_manager.py +++ b/pysoot/soot_manager.py @@ -89,8 +89,15 @@ def get_classes(self): classes[soot_class.name] = soot_class return classes - def getSubclassesOf(self, class_name: str) -> list[str]: - return [ - c.getName() - for c in self.hierarchy.getSubclassesOf(self.class_name_map[class_name]) - ] + def compute_hierarchy(self) -> dict[str, list[str]]: + """Pre-compute subclass relationships for all classes.""" + result = {} + for name, raw_class in self.class_name_map.items(): + try: + result[name] = [ + c.getName() for c in self.hierarchy.getSubclassesOf(raw_class) + ] + except Exception: + # Some classes (e.g. interfaces) may not support getSubclassesOf + pass + return result diff --git a/tests/test_pysoot.py b/tests/test_pysoot.py index 799d802..e8d39e0 100755 --- a/tests/test_pysoot.py +++ b/tests/test_pysoot.py @@ -76,7 +76,7 @@ def test_hierarchy(self): # Only check application classes — JDK classes (e.g. java.lang.System) # are phantom refs on modular JDKs (Java 9+) and won't appear in the hierarchy. test_subc = ["simple2.Class2", "simple2.Class1"] - subc = lifter.soot_wrapper.getSubclassesOf("java.lang.Object") + subc = lifter.getSubclassesOf("java.lang.Object") assert all([c in subc for c in test_subc]) def test_exceptions1(self): @@ -99,7 +99,7 @@ def test_exceptions1(self): def test_android1(self): apk = os.path.join(self.test_samples_folder, "android1.apk") lifter = Lifter(apk, input_format="apk", android_sdk=self.android_sdk_path) - subc = lifter.soot_wrapper.getSubclassesOf("java.lang.Object") + subc = lifter.getSubclassesOf("java.lang.Object") assert "com.example.antoniob.android1.MainActivity" in subc main_activity = lifter.classes["com.example.antoniob.android1.MainActivity"]