diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml new file mode 100644 index 0000000..2ddc6a7 --- /dev/null +++ b/.github/actions/setup/action.yml @@ -0,0 +1,52 @@ +# .github/actions/shared-action/action.yml +name: Setup +description: Shared logic for workflows +runs: + using: "composite" + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + - uses: actions/setup-java@v4 + with: + distribution: 'zulu' # See 'Supported distributions' for available options + java-version: '21' + - uses: antoniovazquezblanco/setup-ghidra@v2.0.4 + + - name: Upgrade pip + shell: bash + run: python -m pip install --upgrade pip setuptools wheel + + - name: Install apt dependencies + shell: bash + run: sudo apt-get install -y pkg-config libsentencepiece-dev libprotobuf-dev + + - name: Install dependencies + shell: bash + run: pip install -r requirements.txt + + - name: Download Ghidrathon + uses: robinraju/release-downloader@v1 + with: + # The source repository path. + # Expected format {owner}/{repo} + # Default: ${{ github.repository }} + repository: mandiant/Ghidrathon + tag: v4.0.0 + fileName: '*.zip' + + - name: Install Ghidrathon + shell: bash + run: | + mkdir ghidrathon-tmp + unzip Ghidrathon*.zip -d ghidrathon-tmp + pip install -r ghidrathon-tmp/requirements.txt + python ghidrathon-tmp/ghidrathon_configure.py $GHIDRA_INSTALL_DIR + unzip ghidrathon-tmp/Ghidrathon*.zip -d $GHIDRA_INSTALL_DIR/Ghidra/Extensions + #$GHIDRA_INSTALL_DIR/support/analyzeHeadless projects TmpProject -import /bin/ls + + - name: Make projects directory + shell: bash + run: mkdir -p projects + diff --git a/.github/workflows/install.yml b/.github/workflows/install.yml deleted file mode 100644 index c548d5e..0000000 --- a/.github/workflows/install.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: Install - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - - workflow_dispatch: - -jobs: - install: - # The type of runner that the job will run on - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-20.04] - python-version: [3.6, 3.7, 3.8] - - steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Upgrade pip - run: python -m pip install --upgrade pip setuptools wheel - - name: Install dependencies - run: pip install -r requirements.txt \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 14eb4e9..390e78f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,48 +1,82 @@ -name: Test +name: Test DIRTY Ghidra on: push: - branches: [ main ] pull_request: - branches: [ main ] - workflow_dispatch: jobs: - install: - # The type of runner that the job will run on + test-train: runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - os: [ubuntu-20.04] - python-version: [3.6, 3.7, 3.8] + os: [ubuntu-22.04, ubuntu-24.04] + python-version: ["3.10", "3.11"] steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Upgrade pip - run: python -m pip install --upgrade pip setuptools wheel - - name: Install dependencies - run: pip install -r requirements.txt - - name: Download data and model - working-directory: ./dirty + - uses: actions/checkout@v4 + - name: Setup + uses: ./.github/actions/setup + + - name: Generate dataset + run: | + set -ex + mkdir -p $DATASET_DIR + cd $DATASET_DIR + # Create some dummy programs + for n in $(seq 20); do echo -e "#include \nint main(int argc, const char** argv) { printf(\"%d %d\\\\n\", $n, argc); }" > s$n.c; gcc -g s$n.c -o s$n; rm s$n.c; done + cd $GITHUB_WORKSPACE/dataset-gen-ghidra + python generate.py --verbose --ghidra $GHIDRA_INSTALL_DIR/support/analyzeHeadless -t 1 -b $DATASET_DIR -o $DATA_DIR/unprocessed + cd $DATA_DIR/unprocessed && python $GITHUB_WORKSPACE/dataset-gen-ghidra/gen_names.py $DATA_DIR/unprocessed + env: + DATASET_DIR: ${{ runner.temp }}/dataset + DATA_DIR: ${{ runner.temp }}/data + + - name: Preprocess dataset run: | - wget -q cmu-itl.s3.amazonaws.com/dirty/dirt.tar.gz -O dirt.tar.gz - tar -xzf dirt.tar.gz - mkdir exp_runs/ - wget -q cmu-itl.s3.amazonaws.com/dirty/dirty_mt.ckpt -O exp_runs/dirty_mt.ckpt + set -ex + cd $GITHUB_WORKSPACE/dirty + python -m utils.preprocess $DATA_DIR/unprocessed $DATA_DIR/unprocessed/files.txt $DATA_DIR/processed + ln -s $DATA_DIR/processed $(pwd)/data1 + python -m utils.vocab --size=164 --use-bpe "$DATA_DIR/processed/"'train-*.tar' "$DATA_DIR/processed/typelib.json" data1/vocab.bpe10000 + env: + DATA_DIR: ${{ runner.temp }}/data + - name: Train on dataset + run: | + set -ex + cd $GITHUB_WORKSPACE/dirty wandb offline - - name: Infer and evaluate - working-directory: ./dirty + python exp.py train multitask_test_ci.xfmr.jsonnet + find . -name '*.ckpt' + test-inference: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, ubuntu-24.04] + python-version: ["3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Setup + uses: ./.github/actions/setup + + - name: Install huggingface-cli + run: pip install huggingface_hub[cli] + + - name: Cache model files + uses: actions/cache@v4 + with: + path: ${{ runner.temp }}/model-dl + key: hf-model-dl + + - name: Download model files + run: huggingface-cli download --repo-type model ejschwartz/dirty-ghidra --local-dir $MODEL_DL_DIR && cp -R $MODEL_DL_DIR/data1 $GITHUB_WORKSPACE/dirty/data1 + env: + MODEL_DL_DIR: ${{ runner.temp }}/model-dl + + - name: Run DIRTY inference run: | - python exp.py train --expname=eval_dirty_mt multitask_test_ci.xfmr.jsonnet --eval-ckpt exp_runs/dirty_mt.ckpt - cat test_result.json - cat test_result.json | jq ".test_retype_acc" - cat test_result.json | jq ".test_rename_acc" - cat test_result.json | jq ".test_retype_acc" | awk '{if ($1 < 0.6) exit 1}' - cat test_result.json | jq ".test_rename_acc" | awk '{if ($1 < 0.5) exit 1}' + $GHIDRA_INSTALL_DIR/support/analyzeHeadless projects MyProject -import /bin/ls -postScript $GITHUB_WORKSPACE/scripts/DIRTY_infer.py $(pwd)/infer_success.txt + test -f infer_success.txt \ No newline at end of file diff --git a/.gitignore b/.gitignore index b85ed1a..ae95496 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,11 @@ cython_debug/ *.c exp_runs/ + +# wandb +dirty/wandb/ +dirty/multitask-greedy.xfmr.jsonnet +dirty/data1 +dirty/struct_files.txt +dirty/eval.pkl +dirty/forward.pkl diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6008d51 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +# docker build -t dirty-ghidra . +# docker run -d --name dirty-ghidra --gpus '"device=3,4"' -it -v /path/to/data:/data dirty-ghidra + +FROM blacktop/ghidra:latest + +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt-get -y update && apt-get -y install -y python3-pip python-is-python3 \ + git pkg-config libsentencepiece-dev libprotobuf-dev nano sudo unzip + +# Install Ghidrathon + +WORKDIR /tmp/ + +RUN wget https://github.com/mandiant/Ghidrathon/releases/download/v4.0.0/Ghidrathon-v4.0.0.zip +RUN unzip Ghidrathon-v4.0.0.zip -d ghidrathon +RUN --mount=type=cache,target=/root/.cache pip install --break-system-packages -r ghidrathon/requirements.txt +RUN python ghidrathon/ghidrathon_configure.py /ghidra +RUN unzip ghidrathon/Ghidrathon-v4.0.0.zip -d /ghidra/Ghidra/Extensions + +# Install DIRTY Ghidra + +WORKDIR / + +COPY . /DIRTY + +RUN --mount=type=cache,target=/root/.cache pip install --break-system-packages --upgrade -r /DIRTY/requirements.txt + +ENTRYPOINT ["/bin/sh", "-c"] +CMD ["tail", "-f", "/dev/null"] diff --git a/README.md b/README.md index 6009fe0..11f0900 100644 --- a/README.md +++ b/README.md @@ -8,17 +8,43 @@ While most of the model code remains identical, we add support for generating a The original README provides clear instructions on how to download and run their pre-trained DIRTY model, but the README's instructions are slightly unclear when describing how to train your own model. This README explicitly covers all the steps necessary to train a DIRTY model from scratch. +This is @edmcman's fork of the original DIRTY-Ghidra repository. It features a number of improvements and bug fixes, and also includes the ability to perform inference on new examples. + +## Getting Started with DIRTY-Ghidra Inference + +[![Test DIRTY Ghidra's inference ability](https://github.com/edmcman/DIRTY-Ghidra/actions/workflows/test.yml/badge.svg)](https://github.com/edmcman/DIRTY-Ghidra/actions/workflows/test.yml) + +Most people probably just want to use DIRTY-Ghidra to predict variable names and +types for their own binaries. If that is you, follow these instructions: + +1. Clone this repository to `DIRTY_DIR` +2. Optional but highly recommended: Create a virtual environment (venv) with `python -m venv /path/to/venv; source /path/to/venv/bin/activate`. This will prevent DIRTY from interfering with your system python packages. +3. Install the requirements via `pip install -r requirements.txt` +4. [Install Ghidra](https://ghidra-sre.org/InstallationGuide.html) +5. [Install Ghidrathon](https://github.com/mandiant/Ghidrathon/?tab=readme-ov-file#installing-ghidrathon). Make sure you configure Ghidrathon (`python + ghidrathon_configure.py`) using the venv from step 2. +6. Download the latest model from HF (`huggingface_hub[cli] && huggingface-cli download --repo-type model ejschwartz/dirty-ghidra --local-dir $DIRTY_DIR/dirty`) +7. Run `mkdir ~/ghidra_scripts && ln -s DIRTY_DIR/scripts/DIRTY_infer.py ~/ghidra_scripts/DIRTY_infer.py` if on Linux. +8. Open a function in Ghidra. Run the script `DIRTY_infer.py` in the script manager. +9. Optionally assign the script to a keyboard shortcut. + ## Requirements -- Linux with Python 3.6/3.7/3.8 +- Linux with Python 3.10+ - [PyTorch ≥ 1.5.1](https://pytorch.org/) -- [Ghidrathon 1.0.0](https://github.com/mandiant/Ghidrathon) +- [Ghidrathon >= 4.0.0](https://github.com/mandiant/Ghidrathon) - `pip install -r requirements.txt` +### Libraries + +A few libraries are required by the python packages. On ubuntu, you can install +these with: +- `apt install pkg-config libsentencepiece-dev libprotobuf-dev` + ## Training a DIRTY model ### Dataset Generation -The first step to train DIRTY is to obtain a unprocessed DIRT dataset. Instructions can be found in the `dataset-gen-ghidra` folder. +The first step to train DIRTY is to obtain a unprocessed DIRT dataset. Instructions can be found in the [dataset-gen-ghidra](dataset-gen-ghidra) folder. ### Dataset Preprocessing @@ -40,7 +66,7 @@ We also need to build a vocabulary of tokens that the model will understand ```bash # inside the `dirty` directory -python3 -m utils.vocab [-h] [options] TRAIN_FILES_TAR PATH_TO_TYPELIB_JSON TARGET_DIRECTORY/vocab.bpe10000 +python3 -m utils.vocab [-h] --use-bpe [options] TRAIN_FILES_TAR PATH_TO_TYPELIB_JSON TARGET_DIRECTORY/vocab.bpe10000 ``` This script generates vocabulary files located in `TARGET_DIRECTORY`. It is recommended to prefix the vocab files with `vocab.bpe10000` to match the expected vocabulary filenames in the model config files. diff --git a/binary/dire_types.py b/binary/dire_types.py index 6c29a63..7896b33 100644 --- a/binary/dire_types.py +++ b/binary/dire_types.py @@ -973,6 +973,29 @@ def __hash__(self) -> int: def __str__(self) -> str: return "void" +class TypeDef(TypeInfo): + + def __init__(self, name, size, other_type_name) -> None: + self.name = name + self.size = size + self.other_type_name = other_type_name + + @classmethod + def _from_json(cls, d: t.Dict[str, t.Any]) -> "TypeDef": + return cls(name=d["name"], size=d["size"], other_type_name=d["other_type_name"]) + + def _to_json(self) -> t.Dict[str, int]: + return {"T": 11, "name": self.name, "size": self.size, "other_type_name": self.other_type_name} + + def __eq__(self, other: t.Any) -> bool: + return isinstance(other, TypeDef) and self.name == other.name and self.size == other.size and self.other_type_name == other.other_type_name + + def __hash__(self) -> int: + return hash((self.name, self.size, self.other_type_name)) + + def __str__(self) -> str: + return self.name + class Disappear(TypeInfo): """Target type for variables that don't appear in the ground truth function""" size = 0 @@ -1061,6 +1084,7 @@ def read_metadata(d: t.Dict[str, t.Any]) -> "TypeLibCodec.CodecTypes": 8: Void, 9: FunctionPointer, 10: Disappear, + 11: TypeDef } return classes[d["T"]]._from_json(d) diff --git a/binary/ghidra_function.py b/binary/ghidra_function.py index a81524c..8a051ab 100644 --- a/binary/ghidra_function.py +++ b/binary/ghidra_function.py @@ -162,7 +162,7 @@ class CollectedFunction: """ def __init__(self, *, ea: int, debug: Function, decompiler: Function): - self.name: str = debug.name + self.name: str = debug.name if hasattr(debug, "name") else "unknown" self.ea = ea self.debug = debug self.decompiler = decompiler @@ -170,13 +170,13 @@ def __init__(self, *, ea: int, debug: Function, decompiler: Function): def to_json(self): return { "e": self.ea, - "b": self.debug.to_json(), + "b": self.debug.to_json() if hasattr(self.debug, "to_json") else None, "c": self.decompiler.to_json(), } @classmethod def from_json(cls, d): - debug = Function.from_json(d["b"]) + debug = Function.from_json(d["b"]) if d["b"] is not None else None decompiler = Function.from_json(d["c"]) return cls(ea=d["e"], debug=debug, decompiler=decompiler) diff --git a/binary/ghidra_types.py b/binary/ghidra_types.py index 2534251..b901775 100644 --- a/binary/ghidra_types.py +++ b/binary/ghidra_types.py @@ -17,14 +17,10 @@ try: import ghidra.program.model.listing as listing import ghidra.program.model.data as data - from ghidra.program.model.data import PointerDataType, ArrayDataType, StructureDataType, UnionDataType - from ghidra.program.database.data import PointerDB, ArrayDB, StructureDB, UnionDB + from ghidra.program.model.data import PointerDataType, ArrayDataType, StructureDataType, UnionDataType, TypedefDataType, FunctionDefinitionDataType + from ghidra.program.database.data import PointerDB, ArrayDB, StructureDB, UnionDB, TypedefDB, FunctionDefinitionDB except ImportError: pass -# try: -# import ida_typeinf # type: ignore -# except ImportError: -# print("Could not import ida_typeinf. Cannot parse IDA types.") class TypeLib: @@ -164,7 +160,7 @@ def __init__( self._data = data @staticmethod - def parse_ghidra_type(typ: "ghidra.program.model.data") -> "TypeInfo": + def parse_ghidra_type(typ: "ghidra.program.model.data.Datatype") -> "TypeInfo": """Parses an IDA tinfo_t object""" # if typ.is_void() # return Void() @@ -172,23 +168,28 @@ def parse_ghidra_type(typ: "ghidra.program.model.data") -> "TypeInfo": # return FunctionPointer(name=typ.dstr()) if typ is None: return - - if isinstance(typ, (PointerDataType.__pytype__, PointerDB.__pytype__)): - return Pointer(typ.getName()) - if isinstance(typ, (ArrayDataType.__pytype__, ArrayDB.__pytype__)): + + elif isinstance(typ, (FunctionDefinitionDataType.__pytype__, FunctionDefinitionDB.__pytype__)): + # Rather than actually encoding function definitions, we'll just + # typedef them to void, so function pointers will be of type void*. + return TypeDef(name=typ.getName(), size=typ.getLength(), other_type_name="void") + # TIL that PointerDataType do not have to have a subtype + elif isinstance(typ, (PointerDataType.__pytype__, PointerDB.__pytype__)) and typ.getDataType() is not None: + return Pointer(typ.getDataType().getName()) + elif isinstance(typ, (ArrayDataType.__pytype__, ArrayDB.__pytype__)): # To get array type info, first create an # array_type_data_t then call get_array_details to # populate it. Unions and structs follow a similar # pattern. nelements = typ.getNumElements() element_size = typ.getElementLength() - element_type = typ.getName() + element_type = typ.getDataType().getName() return Array( nelements=nelements, element_size=element_size, element_type=element_type, ) - if isinstance(typ, (StructureDataType.__pytype__, StructureDB.__pytype__)): + elif isinstance(typ, (StructureDataType.__pytype__, StructureDB.__pytype__)): name = typ.getName() size = typ.getLength() components = typ.getDefinedComponents() @@ -210,7 +211,7 @@ def parse_ghidra_type(typ: "ghidra.program.model.data") -> "TypeInfo": if end_padding > 0: layout.append(UDT.Padding(end_padding)) return Struct(name=name, layout=layout) - if isinstance(typ, (UnionDataType.__pytype__, UnionDB.__pytype__)): + elif isinstance(typ, (UnionDataType.__pytype__, UnionDB.__pytype__)): name = typ.getName() size = typ.getLength() components = typ.getDefinedComponents() @@ -227,7 +228,10 @@ def parse_ghidra_type(typ: "ghidra.program.model.data") -> "TypeInfo": return Union(name=name, members=members) return Union(name=name, members=members, padding=UDT.Padding(end_padding)) - return TypeInfo(name=typ.getName(), size=typ.getLength()) + elif isinstance(typ, (TypedefDataType.__pytype__, TypedefDB.__pytype__)): + return TypeDef(name=typ.getName(), size=typ.getLength(), other_type_name=typ.getDataType().getName()) + + return TypeInfo(name=typ.getName(), size=typ.getLength(), debug=typ.getDescription() + "\n" + typ.__class__.__name__) def add_ghidra_type( self, typ: "ghidra.program.model.data", worklist: t.Optional[t.Set[str]] = None @@ -241,12 +245,22 @@ def add_ghidra_type( worklist.add(typ.getName()) new_type: TypeInfo = self.parse_ghidra_type(typ) # If this type isn't a duplicate, break down the subtypes, if any exists - if not self._data[new_type.size].add(new_type) and isinstance(typ, (StructureDataType.__pytype__, UnionDataType.__pytype__, StructureDB.__pytype__, UnionDB.__pytype__)): - components = typ.getComponents() - for component in components: - self.add_ghidra_type(component.getDataType(), worklist) - # for i in range(num_components): - # self.add_ghidra_type(type.getComponent(i), worklist) + if self._data[new_type.size].add(new_type): + # Type already exists + pass + else: + subtypes = None + if isinstance(typ, (StructureDataType.__pytype__, StructureDB.__pytype__, UnionDataType.__pytype__, UnionDB.__pytype__)): + subtypes = [t.getDataType() for t in typ.getComponents()] + elif isinstance(typ, (PointerDataType.__pytype__, PointerDB.__pytype__, ArrayDataType.__pytype__, ArrayDB.__pytype__)): + if typ.getDataType() is not None: + subtypes = [typ.getDataType()] + elif isinstance(typ, (TypedefDataType.__pytype__, TypedefDB.__pytype__)): + subtypes = [typ.getDataType()] + + if subtypes is not None: + for subtype in subtypes: + self.add_ghidra_type(subtype, worklist) def add_entry_list(self, size: int, entries: "TypeLib.EntryList") -> None: """Add an entry list of items of size 'size'""" @@ -260,14 +274,18 @@ def add(self, typ): entry.add(typ) self.add_entry_list(typ.size, entry) - def add_json_file(self, json_file: str, *, threads: int = 1) -> None: - """Adds the info in a serialized (gzipped) JSON file to this TypeLib""" + def add_json_file(self, json_file: str, *, threads: int = 1, ungzip = False) -> None: + """Adds the info in a serialized JSON file to this TypeLib""" if not os.path.exists(json_file): return other: t.Optional[t.Any] = None - with open(json_file, "r") as other_file: - other = TypeLibCodec.decode(other_file.read()) + if ungzip: + with gzip.open(json_file, "rt") as other_file: + other = TypeLibCodec.decode(other_file.read()) + else: + with open(json_file, "r") as other_file: + other = TypeLibCodec.decode(other_file.read()) if other is not None and isinstance(other, TypeLib): for size, entries in other.items(): self.add_entry_list(size, entries) @@ -512,9 +530,10 @@ def prune(self, freq) -> None: class TypeInfo: """Stores information about a type""" - def __init__(self, *, name: t.Optional[str], size: int): + def __init__(self, *, name: t.Optional[str], size: int, debug: t.Optional[str] = None): self.name = name self.size = size + self.debug = debug def accessible_offsets(self) -> t.Tuple[int, ...]: """Offsets accessible in this type""" @@ -553,18 +572,18 @@ def displace(offsets: t.Tuple[int, ...]) -> t.Tuple[int, ...]: @classmethod def _from_json(cls, d: t.Dict[str, t.Any]) -> "TypeInfo": """Decodes from a dictionary""" - return cls(name=d["n"], size=d["s"]) + return cls(name=d["n"], size=d["s"], debug=d["debug"]) def _to_json(self) -> t.Dict[str, t.Any]: - return {"T": 1, "n": self.name, "s": self.size} + return {"T": 1, "n": self.name, "s": self.size, "debug": self.debug} def __eq__(self, other: t.Any) -> bool: if isinstance(other, TypeInfo): - return self.name == other.name and self.size == other.size + return self.name == other.name and self.size == other.size and self.debug == other.debug return False def __hash__(self) -> int: - return hash((self.name, self.size)) + return hash((self.name, self.size, self.debug)) def __str__(self) -> str: return f"{self.name}" @@ -846,10 +865,11 @@ def __hash__(self) -> int: return hash((self.name, self.layout)) def __str__(self) -> str: - if self.name is None: - ret = f"struct {{ " + + if self.name is not None: + return self.name else: - ret = f"struct {self.name} {{ " + ret = f"struct {{ " for l in self.layout: ret += f"{str(l)}; " ret += "}" @@ -927,10 +947,10 @@ def __hash__(self) -> int: return hash((self.name, self.members, self.padding)) def __str__(self) -> str: - if self.name is None: - ret = f"union {{ " + if self.name is not None: + return self.name else: - ret = f"union {self.name} {{ " + ret = f"union {{ " for m in self.members: ret += f"{str(m)}; " if self.padding is not None: @@ -963,6 +983,29 @@ def __hash__(self) -> int: def __str__(self) -> str: return "void" + +class TypeDef(TypeInfo): + + def __init__(self, name, size, other_type_name) -> None: + self.name = name + self.size = size + self.other_type_name = other_type_name + + @classmethod + def _from_json(cls, d: t.Dict[str, t.Any]) -> "TypeDef": + return cls(name=d["name"], size=d["size"], other_type_name=d["other_type_name"]) + + def _to_json(self) -> t.Dict[str, int]: + return {"T": 11, "name": self.name, "size": self.size, "other_type_name": self.other_type_name} + + def __eq__(self, other: t.Any) -> bool: + return isinstance(other, TypeDef) and self.name == other.name and self.size == other.size and self.other_type_name == other.other_type_name + + def __hash__(self) -> int: + return hash((self.name, self.size, self.other_type_name)) + + def __str__(self) -> str: + return self.name class Disappear(TypeInfo): """Target type for variables that don't appear in the ground truth function""" @@ -1031,15 +1074,7 @@ def decode(encoded: str) -> CodecTypes: @staticmethod def read_metadata(d: t.Dict[str, t.Any]) -> "TypeLibCodec.CodecTypes": - classes: t.Dict[ - t.Union[int, str], - t.Union[ - t.Type["TypeLib"], - t.Type["TypeLib.EntryList"], - t.Type["TypeInfo"], - t.Type["UDT.Member"], - ], - ] = { + classes = { "E": TypeLib.EntryList, 0: TypeLib, 1: TypeInfo, @@ -1052,6 +1087,7 @@ def read_metadata(d: t.Dict[str, t.Any]) -> "TypeLibCodec.CodecTypes": 8: Void, 9: FunctionPointer, 10: Disappear, + 11: TypeDef, } return classes[d["T"]]._from_json(d) diff --git a/binary/ghidra_variable.py b/binary/ghidra_variable.py index cf37fd5..06afe52 100644 --- a/binary/ghidra_variable.py +++ b/binary/ghidra_variable.py @@ -1,17 +1,13 @@ """Information about variables in a function""" from json import dumps -from typing import Any, Optional +from typing import Any # Huge hack to get importing to work with the decompiler try: from ghidra_types import TypeLibCodec, TypeInfo except ImportError: from .ghidra_types import TypeLibCodec, TypeInfo -# try: -# from dire_types import TypeLibCodec, TypeInfo -# except ImportError: -# from .dire_types import TypeLibCodec, TypeInfo class Location: """A variable location""" @@ -19,7 +15,6 @@ def json_key(self): """Returns a string suitable as a key in a JSON dict""" pass - class Register(Location): """A register @@ -30,13 +25,13 @@ def __init__(self, name: str): self.name = name def json_key(self): - return self.name + return f"r{self.name}" def __eq__(self, other: Any) -> bool: return isinstance(other, Register) and self.name == other.name def __hash__(self) -> int: - return hash(self.name) + return hash(self.json_key()) def __repr__(self) -> str: return f"Reg {self.name}" @@ -58,18 +53,42 @@ def __eq__(self, other: Any) -> bool: return isinstance(other, Stack) and self.offset == other.offset def __hash__(self) -> int: - return hash(self.offset) + return hash(self.json_key()) def __repr__(self) -> str: return f"Stk 0x{self.offset:x}" +class Unknown(Location): + """An unknown storage location + + str: optional string description of the location + """ + + def __init__(self, str: str = "unknown"): + self.str = str + + def json_key(self): + return f"u{self.str}" + + def __eq__(self, other: Any) -> bool: + return isinstance(other, Unknown) and self.str == other.str + + def __hash__(self) -> int: + return hash(self.json_key()) + + def __repr__(self) -> str: + return f"Unknown {self.str}" def location_from_json_key(key: str) -> "Location": """Hacky way to return a location from a JSON key""" if key.startswith("s"): return Stack(int(key[1:])) - else: + elif key.startswith("r"): return Register(key[1:]) + elif key.startswith("u"): + return Unknown(key[1:]) + else: + assert False class Variable: """A variable diff --git a/dataset-gen-ghidra/README.md b/dataset-gen-ghidra/README.md index 947587f..0c4aa3d 100644 --- a/dataset-gen-ghidra/README.md +++ b/dataset-gen-ghidra/README.md @@ -20,7 +20,7 @@ When writing our paper, the original DIRTY team were kind enough to provide the Use === To generate the dataset, run the [generate.py](generate.py) script: -`python3 generate.py --ghidra PATH_TO_GHIDRA -t NUM_THREADS -n [NUM_FILES|None] -b BINARIES_DIR -o OUTPUT_DIR` +`python3 generate.py --ghidra PATH_TO_GHIDRA_ANALYZEHEADLESS -t NUM_THREADS -n [NUM_FILES|None] -b BINARIES_DIR -o OUTPUT_DIR` This script creates a `bins/` and `types/` directory in `OUTPUT_DIR` and generates a `.jsonl` file in both directories for each binary in `BINARIES_DIR`. The file is in the [JSON Lines](http://jsonlines.org) format, and each entry corresponds to a diff --git a/dataset-gen-ghidra/decompiler/collect.py b/dataset-gen-ghidra/decompiler/collect.py index 7205f78..df8c031 100644 --- a/dataset-gen-ghidra/decompiler/collect.py +++ b/dataset-gen-ghidra/decompiler/collect.py @@ -1,14 +1,13 @@ import gzip -import pickle import os from collections import defaultdict -from typing import DefaultDict, Dict, Iterable, Optional, Set +from typing import DefaultDict, Iterable, Optional, Set -from ghidra_function import Function from ghidra_types import TypeInfo, TypeLib, TypeLibCodec -from ghidra_variable import Location, Stack, Register, Variable +from ghidra_variable import Location, Stack, Register, Unknown, Variable +import ghidra.program.model.symbol class Collector: """Generic class to collect information from a binary""" @@ -24,7 +23,7 @@ def __init__(self): with gzip.open(self.type_lib_file_name, "rt") as type_lib_file: self.type_lib = TypeLibCodec.decode(type_lib_file.read()) except Exception as e: - print(e) + # print(e) print("Could not find type library, creating a new one") self.type_lib = TypeLib() @@ -54,14 +53,23 @@ def collect_variables( loc: Optional[Location] = None storage = v.getStorage() + has_user_info = False + low_symbol = v.getSymbol() + if low_symbol is not None and low_symbol.getSource() in [ghidra.program.model.symbol.SourceType.IMPORTED, ghidra.program.model.symbol.SourceType.USER_DEFINED]: + has_user_info = True + if storage.isStackStorage(): loc = Stack(storage.getStackOffset()) - if storage.isRegisterStorage(): + elif storage.isRegisterStorage(): loc = Register(storage.getRegister().getName()) - if loc is not None: - collected_vars[loc].add( - Variable(typ=typ, name=v.getName(), user=False) - ) + else: + loc = Unknown(storage.toString()) + + assert loc is not None + + collected_vars[loc].add( + Variable(typ=typ, name=v.getName(), user=has_user_info) + ) return collected_vars def activate(self, ctx) -> int: diff --git a/dataset-gen-ghidra/decompiler/debug.py b/dataset-gen-ghidra/decompiler/debug.py index cb7a85c..0b2a53b 100644 --- a/dataset-gen-ghidra/decompiler/debug.py +++ b/dataset-gen-ghidra/decompiler/debug.py @@ -4,7 +4,9 @@ from ghidra.app.decompiler import DecompInterface from ghidra.program.model.data import Undefined -from ghidra.app.util.bin.format.dwarf4.next import DWARFImportOptions, DWARFProgram +#from ghidra.app.util.bin.format.dwarf4.next import DWARFImportOptions, DWARFProgram +# changed in ghidra 11.1 +from ghidra.app.util.bin.format.dwarf import DWARFImportOptions, DWARFProgram from ghidra.util.task import ConsoleTaskMonitor from collect import Collector @@ -16,6 +18,7 @@ class CollectDebug(Collector): """Class for collecting debug information""" def __init__(self): + print("Initializing CollectDebug") self.functions: Dict[int, Function] = dict() super().__init__() @@ -35,17 +38,21 @@ def activate(self, ctx) -> int: dwarf_options = DWARFImportOptions() dwarf_options.setOutputDIEInfo(True) monitor = ConsoleTaskMonitor() - dwarf_program = DWARFProgram(currentProgram, dwarf_options, monitor) + dwarf_program = DWARFProgram(currentProgram(), dwarf_options, monitor) decomp = DecompInterface() decomp.toggleSyntaxTree(False) decomp.openProgram(dwarf_program.getGhidraProgram()) # Ghidra separates Variables from their Data info, populate typelib first - # for data in currentProgram.getListing().getDefinedData(True): + # for data in currentProgram().getListing().getDefinedData(True): # self.type_lib.add_ghidra_type(data) - for f in currentProgram.getListing().getFunctions(True): + for f in currentProgram().getListing().getFunctions(True): + + if f.isThunk(): + continue + # Decompile decomp_results = decomp.decompileFunction(f, 30, None) @@ -66,8 +73,8 @@ def activate(self, ctx) -> int: self.type_lib.add_ghidra_type(func_return) return_type = TypeLib.parse_ghidra_type(func_return) - for symbol in symbols: - print(symbol.getDataType().getDescription()) + #for symbol in symbols: + # print(symbol.getDataType().getDescription()) arguments = self.collect_variables( f.getStackFrame().getFrameSize(), [v for v in symbols if v.isParameter()] # and v.getName() in all_var_names], diff --git a/dataset-gen-ghidra/decompiler/dump_trees.py b/dataset-gen-ghidra/decompiler/dump_trees.py index 5df595e..f12245d 100644 --- a/dataset-gen-ghidra/decompiler/dump_trees.py +++ b/dataset-gen-ghidra/decompiler/dump_trees.py @@ -7,7 +7,7 @@ from ghidra.app.decompiler import DecompInterface -from ghidra_ast import AST +#from ghidra_ast import AST from collect import Collector from ghidra_function import CollectedFunction, Function from ghidra_types import TypeLib @@ -22,7 +22,11 @@ def __init__(self): print("Loading functions") # Load the functions collected by CollectDebug with open(os.environ["FUNCTIONS"], "rb") as functions_fh: - self.debug_functions: Dict[int, Function] = pickle.load(functions_fh) + try: + self.debug_functions: Dict[int, Function] = pickle.load(functions_fh) + except: + print("Unable to load debug_functions") + self.debug_functions = dict() print("Done") self.functions: List[CollectedFunction] = list() self.output_file_name = os.path.join( @@ -45,9 +49,13 @@ def activate(self, ctx) -> int: decomp = DecompInterface() decomp.toggleSyntaxTree(False) - decomp.openProgram(currentProgram) + decomp.openProgram(currentProgram()) + + for f in currentProgram().getListing().getFunctions(True): + + if f.isThunk(): + continue - for f in currentProgram.getListing().getFunctions(True): # Decompile decomp_results = decomp.decompileFunction(f, 30, None) f = decomp_results.getFunction() @@ -89,7 +97,7 @@ def activate(self, ctx) -> int: self.functions.append( CollectedFunction( ea=f.getEntryPoint().toString(), - debug=self.debug_functions[f.getEntryPoint().toString()], + debug=self.debug_functions.get(f.getEntryPoint().toString(), None), decompiler=decompiler, ) ) diff --git a/dataset-gen-ghidra/gen_names.py b/dataset-gen-ghidra/gen_names.py index e3ef49e..37126f6 100644 --- a/dataset-gen-ghidra/gen_names.py +++ b/dataset-gen-ghidra/gen_names.py @@ -9,17 +9,19 @@ def main(argv): bins_set = set() for file in os.listdir(f"{argv[0]}/bins"): - bins_set.add(file[:file.index("_")]) + file = file.replace(".jsonl.gz", "") + bins_set.add(file) for file in os.listdir(f"{argv[0]}/types"): - types_set.add(file[:file.index("_")]) - + file = file.replace(".json.gz", "") + types_set.add(file) + bin_files = list(set.intersection(types_set, bins_set)) print(len(bin_files)) with open("files.txt", "w") as dataset_file: for file in bin_files: - dataset_file.write(f"{file}_{file}.jsonl.gz\n") + dataset_file.write(f"{file}.jsonl.gz\n") pass diff --git a/dataset-gen-ghidra/generate.py b/dataset-gen-ghidra/generate.py index 96958a2..af517a5 100644 --- a/dataset-gen-ghidra/generate.py +++ b/dataset-gen-ghidra/generate.py @@ -1,4 +1,4 @@ -# Specialzed version for ghidra compatibility +# Specialized version for ghidra compatibility import argparse import subprocess @@ -15,7 +15,6 @@ from typing import Iterable, Tuple from elftools.elf.elffile import ELFFile -from elftools.common.exceptions import ELFRelocationError class Runner(object): file_dir = os.path.dirname(os.path.abspath(__file__)) @@ -29,6 +28,7 @@ def __init__(self, args: argparse.Namespace): self._num_files = args.num_files self.verbose = args.verbose self.num_threads = args.num_threads + self.timeout = args.timeout self.env = os.environ.copy() @@ -107,22 +107,24 @@ def run_decompiler(self, env, path_to_dir, file_name, script, timeout=None): ghidracall = [self.ghidra, path_to_dir, temp_dir, '-import', file_name, '-postScript', script_name, file_name + ".p", "-scriptPath", script_dir, - '-max-cpu', "3", "-analysisTimeoutPerFile", str(timeout - 30), '-deleteProject'] - # idacall = [self.ida, "-B", f"-S{script}", file_name] - output = "" + '-max-cpu', "1", '-deleteProject'] + if self.verbose: + print(f"Running {ghidracall}") try: - p = subprocess.Popen(ghidracall, env=env, start_new_session=True) - p.wait(timeout=timeout) + p = subprocess.Popen(ghidracall, env=env, start_new_session=True, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = p.communicate(timeout=timeout) + if self.verbose: + print(stdout.decode()) + print(stderr.decode()) except subprocess.TimeoutExpired as e: + print(f"Timed out while running {ghidracall}") os.killpg(os.getpgid(p.pid), signal.SIGTERM) subprocess.run(f"rm -r {path_to_dir}/__*", shell=True) except subprocess.CalledProcessError as e: subprocess.run(f"rm -r {path_to_dir}/__*", shell=True) - pass - # output = e.output - # subprocess.call(["rm", "-f", f"{file_name}.i64"]) - # if self.verbose: - # print(output.decode("unicode_escape")) + if self.verbose: + print(e.output.decode()) def extract_dwarf_var_names(self, filepath:str) -> set: """ @@ -134,13 +136,16 @@ def extract_dwarf_var_names(self, filepath:str) -> set: with open(filepath, 'rb') as f: elffile = ELFFile(f) if not elffile.has_dwarf_info(): - return set() + if self.verbose: + print(f"No dwarf info in {filepath}") + return None # for some reason, this is throwing an exception, give it if it does so try: dwarfinfo = elffile.get_dwarf_info() except: - return set() + print(f"Error extracting dwarf info from {filepath}") + return None for CU in dwarfinfo.iter_CUs(): for DIE in CU.iter_DIEs(): @@ -150,10 +155,9 @@ def extract_dwarf_var_names(self, filepath:str) -> set: for attr in DIE.attributes.values(): if attr.name == "DW_AT_name": variable_names.add(attr.value.decode()) - print(variable_names) - #print(len(variable_names)) + if self.verbose: + print(f"Extracted variable names: {variable_names}") return variable_names - pass def run_one(self, args: Tuple[str, str]) -> None: path, binary = args @@ -174,7 +178,7 @@ def run_one(self, args: Tuple[str, str]) -> None: # Try stripping first, if it fails return subprocess.call(["cp", file_path, stripped.name]) try: - subprocess.call(["strip", "--strip-unneeded", stripped.name]) + subprocess.call(["strip", stripped.name]) except subprocess.CalledProcessError: if self.verbose: print(f"Could not strip {prefix}, skipping.") @@ -192,20 +196,26 @@ def run_one(self, args: Tuple[str, str]) -> None: print(f"{prefix} types already collected, skipping") else: # Collect from original + if self.verbose: + print(f"Collecting debug information") subprocess.check_output(["cp", file_path, orig.name]) - # Timeout after 30s for the collect run + var_set = self.extract_dwarf_var_names(os.path.join(path, orig.name)) - if not var_set: - return - pickle_file = os.path.join(path, orig.name) + ".p" - pickle.dump(var_set, open(pickle_file, 'wb')) - self.run_decompiler(new_env, path, os.path.join(path, orig.name), self.COLLECT, timeout=180) - os.remove(pickle_file) + if var_set: + pickle_file = os.path.join(path, orig.name) + ".p" + pickle.dump(var_set, open(pickle_file, 'wb')) + self.run_decompiler(new_env, path, os.path.join(path, orig.name), self.COLLECT, timeout=self.timeout) + os.remove(pickle_file) + else: + if self.verbose: + print(f"Unable to collect debug information for {prefix}") # Dump trees + if self.verbose: + print(f"Dumping trees") pickle_file = os.path.join(path, stripped.name) + ".p" pickle.dump(set(), open(pickle_file, 'wb')) self.run_decompiler( - new_env, path, os.path.join(path, stripped.name), self.DUMP_TREES, timeout=200 + new_env, path, os.path.join(path, stripped.name), self.DUMP_TREES, timeout=self.timeout ) os.remove(pickle_file) @@ -257,6 +267,13 @@ def main(): help='directory containing binaries', required=True ) + parser.add_argument( + "--timeout", + metavar="TIMEOUT", + help="timeout for each binary", + default=30*60, + type=int + ) parser.add_argument( "-o", "--output_dir", metavar="OUTPUT_DIR", help="output directory", required=True, ) diff --git a/dirty/exp.py b/dirty/exp.py index 4fe97b9..d870891 100644 --- a/dirty/exp.py +++ b/dirty/exp.py @@ -1,5 +1,5 @@ """ -Variable renaming +Experiment strip Usage: exp.py train [options] CONFIG_FILE @@ -7,7 +7,6 @@ Options: -h --help Show this screen - --cuda Use GPU --debug Debug mode --seed= Seed [default: 0] --expname= work dir [default: type] @@ -28,8 +27,10 @@ import torch import wandb from docopt import docopt +from pytorch_lightning import LightningDataModule from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, BatchSizeFinder, LearningRateMonitor +from pytorch_lightning.tuner import Tuner from torch.utils.data import DataLoader from model.model import TypeReconstructionModel @@ -46,72 +47,117 @@ def train(args): config = util.update(config, extra_config) # dataloaders - batch_size = config["train"]["batch_size"] + batch_size = config["test"]["batch_size"] if args["--eval-ckpt"] else config["train"]["batch_size"] train_set = Dataset( config["data"]["train_file"], config["data"], percent=float(args["--percent"]) ) + test_set = Dataset(config["data"]["test_file"], config["data"]) dev_set = Dataset(config["data"]["dev_file"], config["data"]) - train_loader = DataLoader( - train_set, - batch_size=batch_size, - collate_fn=Dataset.collate_fn, - num_workers=16, - pin_memory=True, - ) - val_loader = DataLoader( - dev_set, - batch_size=batch_size, - collate_fn=Dataset.collate_fn, - num_workers=8, - pin_memory=True, - ) + + print(f"Length of training dataset is {len(train_set)} examples") + + # Define DataModule for batch finding. + class LitDataModule(LightningDataModule): + def __init__(self, batch_size = batch_size): + super().__init__() + self.batch_size = batch_size + + def test_dataloader(self): + return DataLoader( + test_set, + batch_size=self.batch_size, + collate_fn=Dataset.collate_fn, + num_workers=8, + pin_memory=True, + ) + + def train_dataloader(self): + return DataLoader( + train_set, + batch_size=self.batch_size, + collate_fn=Dataset.collate_fn, + num_workers=16, + pin_memory=True, + ) + + def val_dataloader(self): + return DataLoader( + dev_set, + batch_size=self.batch_size, + collate_fn=Dataset.collate_fn, + num_workers=8, + pin_memory=True, + ) # model model = TypeReconstructionModel(config) - wandb_logger = WandbLogger(name=args["--expname"], project="dire", log_model=True) + if "torch_float32_matmul" in config["train"]: + torch.set_float32_matmul_precision(config["train"]["torch_float32_matmul"]) + + wandb_logger = WandbLogger(name=args["--expname"], project="dire", log_model="all") wandb_logger.log_hyperparams(config) + wandb_logger.watch(model, log="all", log_freq=10000) + monitor_var = "val_acc" resume_from_checkpoint = ( args["--eval-ckpt"] if args["--eval-ckpt"] else args["--resume"] ) if resume_from_checkpoint == "": resume_from_checkpoint = None + + # Adds a safety margin. For example, `safety_margin` of 0.1 indicates that + # the final batch_size will be reduced by 10% + class SafeBatchSizeFinder(BatchSizeFinder): + def __init__(self, safety_margin=0.1, *args, **kwargs): + super().__init__(*args, **kwargs) + assert safety_margin >= 0 and safety_margin <= 1.0 + self.safety_margin = safety_margin + + def scale_batch_size(self, trainer, *args, **kwargs): + super().scale_batch_size(trainer, *args, **kwargs) + original_batch_size = self.optimal_batch_size + new_batch_size = int(self.optimal_batch_size * (1.0 - self.safety_margin)) + print( + f"Found optimal batch size of {original_batch_size}, but with a safety margin of {self.safety_margin}, reducing it to {new_batch_size}" + ) + self.optimal_batch_size = new_batch_size + # This adjusts the data module batch_size. + pl.tuner.batch_size_scaling._adjust_batch_size(trainer, value=new_batch_size) + pl.tuner.batch_size_scaling._reset_dataloaders(trainer) + trainer._active_loop.reset() + trainer = pl.Trainer( + precision=config["train"].get("precision", 32), max_epochs=config["train"]["max_epoch"], logger=wandb_logger, - gpus=1 if args["--cuda"] else None, - auto_select_gpus=True, - gradient_clip_val=1, + gradient_clip_val=1.0, callbacks=[ EarlyStopping( - monitor="val_retype_acc" - if config["data"]["retype"] - else "val_rename_acc", + monitor=monitor_var, mode="max", patience=config["train"]["patience"], - ) + ), + # Save all checkpoints that improve accuracy + ModelCheckpoint( + monitor=monitor_var, + filename='{epoch}-{%s:.2f}' % monitor_var, + save_top_k=2, + mode="max"), + SafeBatchSizeFinder(safety_margin=0.1, init_val=batch_size, max_trials=30, steps_per_trial=3), + LearningRateMonitor(logging_interval='epoch') ], check_val_every_n_epoch=config["train"]["check_val_every_n_epoch"], - progress_bar_refresh_rate=10, accumulate_grad_batches=config["train"]["grad_accum_step"], - resume_from_checkpoint=resume_from_checkpoint, limit_test_batches=config["test"]["limit"] if "limit" in config["test"] else 1.0 ) + + datamodule = LitDataModule(batch_size=batch_size) + if args["--eval-ckpt"]: - # HACK: necessary to make pl test work for IterableDataset - Dataset.__len__ = lambda self: 1000000 - test_set = Dataset(config["data"]["test_file"], config["data"]) - test_loader = DataLoader( - test_set, - batch_size=config["test"]["batch_size"], - collate_fn=Dataset.collate_fn, - num_workers=8, - pin_memory=True, - ) - ret = trainer.test(model, test_dataloaders=test_loader, ckpt_path=args["--eval-ckpt"]) + ret = trainer.test(model, datamodule=datamodule, ckpt_path=args["--eval-ckpt"]) json.dump(ret[0], open("test_result.json", "w")) else: - trainer.fit(model, train_loader, val_loader) + trainer.fit(model, datamodule=datamodule, ckpt_path=resume_from_checkpoint) if __name__ == "__main__": @@ -123,9 +169,7 @@ def train(args): print(f"use random seed {seed}", file=sys.stderr) torch.manual_seed(seed) - use_cuda = cmd_args["--cuda"] - if use_cuda: - torch.cuda.manual_seed(seed) + torch.cuda.manual_seed(seed) np.random.seed(seed * 13 // 7) random.seed(seed * 17 // 7) diff --git a/dirty/model/model.py b/dirty/model/model.py index feedb28..c7482af 100644 --- a/dirty/model/model.py +++ b/dirty/model/model.py @@ -5,13 +5,18 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F -from pytorch_lightning.metrics.functional import accuracy +import torchmetrics.functional.classification # for multiclass_accuracy from utils.vocab import Vocab from utils.ghidra_types import TypeInfo, TypeLibCodec from model.encoder import Encoder from model.decoder import Decoder +# Wow, macro is the default. That is crazy. +def accuracy(preds, targets, average="micro", **kwargs): + if "num_classes" not in kwargs and average == "micro": + kwargs["num_classes"] = len(targets) # doesn't matter for micro + return torchmetrics.functional.classification.multiclass_accuracy(preds, targets, average=average, **kwargs) class RenamingDecodeModule(pl.LightningModule): def __init__(self, config): @@ -28,12 +33,12 @@ def __init__(self, config): def training_step(self, input_dict, context_encoding, target_dict): variable_name_logits = self.decoder(context_encoding, target_dict) if self.soft_mem_mask: - variable_name_logits = variable_name_logits[target_dict["target_mask"]] + variable_name_logits = variable_name_logits[target_dict["target_type_mask"]] mem_encoding = self.mem_encoder(input_dict) mem_logits = self.mem_decoder(mem_encoding, target_dict) loss = F.cross_entropy( variable_name_logits + mem_logits, - target_dict["target_name_id"][target_dict["target_mask"]], + target_dict["target_name_id"][target_dict["target_type_mask"]], reduction="none", ) else: @@ -43,18 +48,18 @@ def training_step(self, input_dict, context_encoding, target_dict): target_dict["target_name_id"], reduction="none", ) - loss = loss[target_dict["target_mask"]] + loss = loss[target_dict["target_type_mask"]] return loss.mean() def shared_eval_step(self, context_encoding, input_dict, target_dict, test=False): variable_name_logits = self.decoder(context_encoding, target_dict) if self.soft_mem_mask: - variable_name_logits = variable_name_logits[input_dict["target_mask"]] + variable_name_logits = variable_name_logits[input_dict["src_type_mask"]] mem_encoding = self.mem_encoder(input_dict) mem_logits = self.mem_decoder(mem_encoding, target_dict) loss = F.cross_entropy( variable_name_logits + mem_logits, - target_dict["target_name_id"][input_dict["target_mask"]], + target_dict["target_name_id"][input_dict["src_type_mask"]], reduction="none", ) else: @@ -63,8 +68,8 @@ def shared_eval_step(self, context_encoding, input_dict, target_dict, test=False target_dict["target_name_id"], reduction="none", ) - loss = loss[input_dict["target_mask"]] - targets = target_dict["target_name_id"][input_dict["target_mask"]] + loss = loss[input_dict["src_type_mask"]] + targets = target_dict["target_name_id"][input_dict["src_type_mask"]] preds = self.decoder.predict( context_encoding, input_dict, None, self.beam_size if test else 0 ) @@ -92,12 +97,12 @@ def __init__(self, config): def training_step(self, input_dict, context_encoding, target_dict): variable_type_logits = self.decoder(context_encoding, target_dict) if self.soft_mem_mask: - variable_type_logits = variable_type_logits[target_dict["target_mask"]] + variable_type_logits = variable_type_logits[target_dict["target_type_mask"]] mem_encoding = self.mem_encoder(input_dict) mem_type_logits = self.mem_decoder(mem_encoding, target_dict) loss = F.cross_entropy( variable_type_logits + mem_type_logits, - target_dict["target_type_id"][target_dict["target_mask"]], + target_dict["target_type_id"][target_dict["target_type_mask"]], reduction="none", ) else: @@ -111,7 +116,7 @@ def training_step(self, input_dict, context_encoding, target_dict): loss = loss[ target_dict["target_submask"] if self.subtype - else target_dict["target_mask"] + else target_dict["target_type_mask"] ] return loss.mean() @@ -119,13 +124,13 @@ def training_step(self, input_dict, context_encoding, target_dict): def shared_eval_step(self, context_encoding, input_dict, target_dict, test=False): variable_type_logits = self.decoder(context_encoding, target_dict) if self.soft_mem_mask: - variable_type_logits = variable_type_logits[input_dict["target_mask"]] + variable_type_logits = variable_type_logits[input_dict["src_type_mask"]] mem_encoding = self.mem_encoder(input_dict) mem_type_logits = self.mem_decoder(mem_encoding, target_dict) loss = F.cross_entropy( # cross_entropy requires num_classes at the second dimension variable_type_logits + mem_type_logits, - target_dict["target_type_id"][input_dict["target_mask"]], + target_dict["target_type_id"][input_dict["src_type_mask"]], reduction="none", ) else: @@ -140,9 +145,9 @@ def shared_eval_step(self, context_encoding, input_dict, target_dict, test=False loss = loss[ target_dict["target_submask"] if self.subtype - else target_dict["target_mask"] + else target_dict["target_type_mask"] ] - targets = target_dict["target_type_id"][input_dict["target_mask"]] + targets = target_dict["target_type_id"][input_dict["src_type_mask"]] preds = self.decoder.predict( context_encoding, input_dict, None, self.beam_size if test else 0 ) @@ -172,12 +177,12 @@ def training_step(self, input_dict, context_encoding, target_dict): ) # Retype if self.soft_mem_mask: - variable_type_logits = variable_type_logits[target_dict["target_mask"]] + variable_type_logits = variable_type_logits[target_dict["target_type_mask"]] mem_encoding = self.mem_encoder(input_dict) mem_type_logits = self.mem_decoder(mem_encoding, target_dict) retype_loss = F.cross_entropy( variable_type_logits + mem_type_logits, - target_dict["target_type_id"][target_dict["target_mask"]], + target_dict["target_type_id"][target_dict["target_type_mask"]], reduction="none", ) else: @@ -186,7 +191,7 @@ def training_step(self, input_dict, context_encoding, target_dict): target_dict["target_type_id"], reduction="none", ) - retype_loss = retype_loss[target_dict["target_mask"]] + retype_loss = retype_loss[target_dict["target_type_mask"]] retype_loss = retype_loss.mean() rename_loss = F.cross_entropy( @@ -195,13 +200,16 @@ def training_step(self, input_dict, context_encoding, target_dict): target_dict["target_name_id"], reduction="none", ) - rename_loss = rename_loss[target_dict["target_mask"]].mean() + rename_loss = rename_loss[target_dict["target_type_mask"]].mean() return retype_loss, rename_loss + def forward(self, context_encoding, input_dict, **kwargs): + return self.decoder.predict(context_encoding, input_dict, self.beam_size, **kwargs) + def get_unmasked_logits(self, context_encoding, input_dict, target_dict): variable_type_logits, _ = self.decoder(context_encoding, target_dict) - variable_type_logits = variable_type_logits[target_dict["target_mask"]] + variable_type_logits = variable_type_logits[target_dict["target_type_mask"]] # mem_encoding = self.mem_encoder(input_dict) # return self.mem_decoder(mem_encoding, target_dict) + variable_type_logits return variable_type_logits.argmax(dim=1) @@ -211,12 +219,12 @@ def shared_eval_step(self, context_encoding, input_dict, target_dict, test=False context_encoding, target_dict ) if self.soft_mem_mask: - variable_type_logits = variable_type_logits[input_dict["target_mask"]] + variable_type_logits = variable_type_logits[input_dict["src_type_mask"]] mem_encoding = self.mem_encoder(input_dict) mem_type_logits = self.mem_decoder(mem_encoding, target_dict) retype_loss = F.cross_entropy( variable_type_logits + mem_type_logits, - target_dict["target_type_id"][input_dict["target_mask"]], + target_dict["target_type_id"][input_dict["src_type_mask"]], reduction="none", ) else: @@ -225,27 +233,27 @@ def shared_eval_step(self, context_encoding, input_dict, target_dict, test=False target_dict["target_type_id"], reduction="none", ) - retype_loss = retype_loss[target_dict["target_mask"]] + retype_loss = retype_loss[target_dict["target_type_mask"]] rename_loss = F.cross_entropy( variable_name_logits.transpose(1, 2), target_dict["target_name_id"], reduction="none", ) - rename_loss = rename_loss[input_dict["target_mask"]] + rename_loss = rename_loss[input_dict["src_type_mask"]] ret = self.decoder.predict( - context_encoding, input_dict, None, self.beam_size if test else 0 + context_encoding, input_dict, self.beam_size if test else 0 ) retype_preds, rename_preds = ret[0], ret[1] return dict( retype_loss=retype_loss.detach().cpu(), - retype_targets=target_dict["target_type_id"][input_dict["target_mask"]] + retype_targets=target_dict["target_type_id"][input_dict["src_type_mask"]] .detach() .cpu(), retype_preds=retype_preds.detach().cpu(), rename_loss=rename_loss.detach().cpu(), - rename_targets=target_dict["target_name_id"][input_dict["target_mask"]] + rename_targets=target_dict["target_name_id"][input_dict["src_type_mask"]] .detach() .cpu(), rename_preds=rename_preds.detach().cpu(), @@ -257,6 +265,9 @@ def __init__(self, config, config_load=None): super().__init__() if config_load is not None: config = config_load + # Lame, we need to save our outputs now! + # https://github.com/Lightning-AI/pytorch-lightning/pull/16520 + self.eval_outputs = [] self.encoder = Encoder.build(config["encoder"]) self.retype = config["data"].get("retype", False) self.rename = config["data"].get("rename", False) @@ -314,7 +325,7 @@ def training_step( ) self.log("train_rename_loss", loss) total_loss += loss - self.log("train_loss", total_loss) + self.log("train_loss", total_loss, prog_bar=True) return total_loss def validation_step(self, batch, batch_idx): @@ -353,18 +364,73 @@ def _shared_eval_step( ) ret_dict = {**ret, **ret_dict} - return dict( + d = dict( **ret_dict, - targets_nums=input_dict["target_mask"].sum(dim=1), + # this is the number of variables per example, which is same in src + # and tgt. + targets_nums=input_dict["src_type_mask"].sum(dim=1), test_meta=target_dict["test_meta"], index=input_dict["index"], tgt_var_names=target_dict["tgt_var_names"], ) - def validation_epoch_end(self, outputs): + self.eval_outputs.append(d) + + return d + + def forward(self, batch, return_non_best=False): + input_dict = batch + context_encoding = self.encoder(input_dict) + if self.interleave: + ret = self.interleave_module(context_encoding, input_dict, return_non_best=return_non_best) + else: + if self.retype: + ret = self.retyping_module(context_encoding, input_dict) + elif self.rename: + ret = self.renaming_module(context_encoding, input_dict) + else: + assert False + + if return_non_best: + retype_preds, rename_preds, all_retype_preds, all_rename_preds = ret + else: + retype_preds, rename_preds = ret + + retype_preds_name = [self.vocab.types.id2word[x.item()] for x in retype_preds] + + rename_preds_name = [self.vocab.names.id2word[x.item()] for x in rename_preds] + + ret = { + "retype_preds": retype_preds_name, + "rename_preds": rename_preds_name, + } + + if return_non_best: + + all_retype_preds_names = [ + [self.vocab.types.id2word[prediction.item()] for prediction in predictions] + for predictions in all_retype_preds + ] + + all_rename_preds_names = [ + [self.vocab.names.id2word[prediction.item()] for prediction in predictions] + for predictions in all_rename_preds + ] + + ret.update({ + "all_retype_preds": all_retype_preds_names, + "all_rename_preds": all_rename_preds_names, + }) + + return ret + + def on_validation_epoch_end(self): + outputs = self.eval_outputs self._shared_epoch_end(outputs, "val") + self.eval_outputs.clear() - def test_epoch_end(self, outputs): + def on_test_epoch_end(self): + outputs = self.eval_outputs final_ret = self._shared_epoch_end(outputs, "test") if "pred_file" in self.config["test"]: results = {} @@ -386,16 +452,25 @@ def test_epoch_end(self, outputs): ) pred_file = self.config["test"]["pred_file"] json.dump(results, open(pred_file, "w")) + self.eval_outputs.clear() def _shared_epoch_end(self, outputs, prefix): final_ret = {} if self.retype: ret = self._shared_epoch_end_task(outputs, prefix, "retype") final_ret = {**final_ret, **ret} + acc = final_ret["retype_acc"] + loss = final_ret["retype_loss"] if self.rename: ret = self._shared_epoch_end_task(outputs, prefix, "rename") final_ret = {**final_ret, **ret} + acc = final_ret["rename_acc"] + loss = final_ret["rename_loss"] if self.retype and self.rename: + acc = final_ret["retype_acc"] + final_ret["rename_acc"] + loss = final_ret["retype_loss"] + final_ret["rename_loss"] + self.log(f"{prefix}_acc", acc, sync_dist=True) + self.log(f"{prefix}_loss", loss, sync_dist=True) # Evaluate rename accuracy on correctedly retyped samples retype_preds = torch.cat([x[f"retype_preds"] for x in outputs]) retype_targets = torch.cat([x[f"retype_targets"] for x in outputs]) @@ -406,8 +481,9 @@ def _shared_epoch_end(self, outputs, prefix): f"{prefix}_rename_on_correct_retype_acc", accuracy( rename_preds[retype_preds == retype_targets], - rename_targets[retype_preds == retype_targets], + rename_targets[retype_preds == retype_targets] ), + sync_dist=True ) return final_ret @@ -418,16 +494,18 @@ def _shared_epoch_end_task(self, outputs, prefix, task): preds = torch.cat([x[f"{task}_preds"] for x in outputs]) targets = torch.cat([x[f"{task}_targets"] for x in outputs]) loss = torch.cat([x[f"{task}_loss"] for x in outputs]).mean() - self.log(f"{prefix}_{task}_loss", loss) - self.log(f"{prefix}_{task}_acc", accuracy(preds, targets)) + self.log(f"{prefix}_{task}_loss", loss, sync_dist=True) + acc = accuracy(preds, targets) + self.log(f"{prefix}_{task}_acc", acc, sync_dist=True) self.log( f"{prefix}_{task}_acc_macro", accuracy( preds, targets, - num_classes=len(self.vocab.types), - class_reduction="macro", + num_classes=len(self.vocab.types if task == "retype" else self.vocab.names), + average="macro", ), + sync_dist=True ) # func acc num_correct, num_funcs, pos = 0, 0, 0 @@ -439,7 +517,7 @@ def _shared_epoch_end_task(self, outputs, prefix, task): for num, test_meta in zip(target_num.tolist(), test_metas): num_correct += all(preds[pos : pos + num] == targets[pos : pos + num]) pos += num - body_in_train_mask += [test_meta["function_body_sin_train"]] * num + body_in_train_mask += [test_meta["function_body_in_train"]] * num name_in_train_mask += [test_meta["function_name_in_train"]] * num num_funcs += len(target_num) body_in_train_mask = torch.tensor(body_in_train_mask) @@ -448,20 +526,23 @@ def _shared_epoch_end_task(self, outputs, prefix, task): # HACK for data parallel body_in_train_mask = body_in_train_mask[:, 0] name_in_train_mask = name_in_train_mask[:, 0] - self.log( - f"{prefix}_{task}_body_in_train_acc", - accuracy(preds[body_in_train_mask], targets[body_in_train_mask]), - ) + if body_in_train_mask.sum() > 0: + self.log( + f"{prefix}_{task}_body_in_train_acc", + accuracy(preds[body_in_train_mask], targets[body_in_train_mask]), + sync_dist=True + ) if (~body_in_train_mask).sum() > 0: self.log( f"{prefix}_{task}_body_not_in_train_acc", accuracy(preds[~body_in_train_mask], targets[~body_in_train_mask]), + sync_dist=True ) assert pos == sum(x["targets_nums"].sum() for x in outputs), ( pos, sum(x["targets_nums"].sum() for x in outputs), ) - self.log(f"{prefix}_{task}_func_acc", num_correct / num_funcs) + self.log(f"{prefix}_{task}_func_acc", num_correct / num_funcs, sync_dist=True) struc_mask = torch.zeros(len(targets), dtype=torch.bool) for idx, target in enumerate(targets): @@ -472,6 +553,7 @@ def _shared_epoch_end_task(self, outputs, prefix, task): self.log( f"{prefix}{task_str}_struc_acc", accuracy(preds[struc_mask], targets[struc_mask]), + sync_dist=True ) # adjust for the number of classes self.log( @@ -479,27 +561,30 @@ def _shared_epoch_end_task(self, outputs, prefix, task): accuracy( preds[struc_mask], targets[struc_mask], - num_classes=len(self.vocab.types), - class_reduction="macro", + num_classes=len(self.vocab.types if task == "retype" else self.vocab.names), + average="macro", ) * len(self.vocab.types) / len(self.vocab.types.struct_set), + sync_dist=True ) if (struc_mask & body_in_train_mask).sum() > 0: self.log( f"{prefix}{task_str}_body_in_train_struc_acc", accuracy( preds[struc_mask & body_in_train_mask], - targets[struc_mask & body_in_train_mask], + targets[struc_mask & body_in_train_mask] ), + sync_dist=True ) if (~body_in_train_mask & struc_mask).sum() > 0: self.log( f"{prefix}{task_str}_body_not_in_train_struc_acc", accuracy( preds[~body_in_train_mask & struc_mask], - targets[~body_in_train_mask & struc_mask], + targets[~body_in_train_mask & struc_mask] ), + sync_dist=True ) return { "indexes": indexes, @@ -507,8 +592,19 @@ def _shared_epoch_end_task(self, outputs, prefix, task): f"{task}_preds": preds, f"{task}_targets": preds, "body_in_train_mask": body_in_train_mask, + f"{task}_acc": acc, + f"{task}_loss": loss } def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.config["train"]["lr"]) - return optimizer + optimizer = torch.optim.Adam(self.parameters(), lr=self.config + ["train"]["lr"]) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=0) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + "frequency": 1 # epoch + } + } \ No newline at end of file diff --git a/dirty/model/xfmr_decoder.py b/dirty/model/xfmr_decoder.py index eea06f5..621bdc5 100644 --- a/dirty/model/xfmr_decoder.py +++ b/dirty/model/xfmr_decoder.py @@ -118,6 +118,7 @@ def predict( input_dict: Dict[str, torch.Tensor], variable_type_logits: torch.Tensor, beam_size: int = 0, + return_non_best: bool = False, ): if beam_size == 0: return self.greedy_decode( @@ -125,7 +126,8 @@ def predict( ) else: return self.beam_decode( - context_encoding, input_dict, variable_type_logits, beam_size + context_encoding, input_dict, variable_type_logits, beam_size, + return_non_best=return_non_best ) def greedy_decode( @@ -420,7 +422,7 @@ def forward( tgt_mask = XfmrDecoder.generate_square_subsequent_mask(tgt.shape[1], tgt.device) # TransformerModels have batch_first=False tgt_padding_mask = XfmrInterleaveDecoder.interleave_2d( - target_dict["target_mask"], target_dict["target_mask"] + target_dict["target_type_mask"], target_dict["target_type_mask"] ) hidden = self.decoder( tgt.transpose(0, 1), @@ -440,23 +442,25 @@ def predict( self, context_encoding: Dict[str, torch.Tensor], input_dict: Dict[str, torch.Tensor], - variable_type_logits: torch.Tensor, + #variable_type_logits: torch.Tensor, beam_size: int = 0, + return_non_best: bool = False ): + if beam_size == 0: return self.greedy_decode( - context_encoding, input_dict, variable_type_logits + context_encoding, input_dict ) else: return self.beam_decode( - context_encoding, input_dict, variable_type_logits, beam_size + context_encoding, input_dict, beam_size, return_non_best=return_non_best ) def greedy_decode( self, context_encoding: Dict[str, torch.Tensor], input_dict: Dict[str, torch.Tensor], - variable_type_logits: torch.Tensor, + #variable_type_logits: torch.Tensor, ): """Greedy decoding""" @@ -464,7 +468,7 @@ def greedy_decode( context_encoding["variable_encoding"], context_encoding["variable_encoding"] ) tgt_padding_mask = XfmrInterleaveDecoder.interleave_2d( - input_dict["target_mask"], input_dict["target_mask"] + input_dict["src_type_mask"], input_dict["src_type_mask"] ) batch_size, max_time_step, _ = variable_encoding.shape tgt = torch.zeros(batch_size, 1, self.config["target_embedding_size"]).to( @@ -480,7 +484,7 @@ def greedy_decode( mem_logits_list = [] idx = 0 for b in range(batch_size): - nvar = input_dict["target_mask"][b].sum().item() + nvar = input_dict["src_type_mask"][b].sum().item() mem_logits_list.append(mem_logits[idx : idx + nvar]) idx += nvar assert idx == mem_logits.shape[0] @@ -550,25 +554,26 @@ def greedy_decode( type_preds = torch.stack(type_preds_list).transpose(0, 1) name_preds = torch.stack(name_preds_list).transpose(0, 1) return ( - type_preds[input_dict["target_mask"]], - name_preds[input_dict["target_mask"]], + type_preds[input_dict["src_type_mask"]], + name_preds[input_dict["src_type_mask"]], ) def beam_decode( self, context_encoding: Dict[str, torch.Tensor], input_dict: Dict[str, torch.Tensor], - variable_type_logits: torch.Tensor, + #variable_type_logits: torch.Tensor, beam_size: int = 5, length_norm: bool = True, + return_non_best: bool = False ): """Beam search decoding""" variable_encoding = XfmrInterleaveDecoder.interleave_3d( context_encoding["variable_encoding"], context_encoding["variable_encoding"] ) - tgt_padding_mask = XfmrInterleaveDecoder.interleave_2d( - input_dict["target_mask"], input_dict["target_mask"] + padding_mask = XfmrInterleaveDecoder.interleave_2d( + input_dict["src_type_mask"], input_dict["src_type_mask"] ) batch_size, max_time_step, _ = variable_encoding.shape tgt = torch.zeros(batch_size, 1, self.config["target_embedding_size"]).to( @@ -584,7 +589,7 @@ def beam_decode( mem_logits_list = [] idx = 0 for b in range(batch_size): - nvar = input_dict["target_mask"][b].sum().item() + nvar = input_dict["src_type_mask"][b].sum().item() mem_logits_list.append(mem_logits[idx : idx + nvar]) idx += nvar assert idx == mem_logits.shape[0] @@ -608,7 +613,7 @@ def beam_decode( # context_encoding["variable_encoding"]: batch, max_time, hidden tgt = tile(tgt, beam_size, dim=0) - tiled_target_mask = tile(tgt_padding_mask, beam_size, dim=0) + tiled_target_mask = tile(padding_mask, beam_size, dim=0) code_token_encoding = tile( context_encoding["code_token_encoding"], beam_size, dim=0 ) @@ -628,14 +633,14 @@ def beam_decode( scores = logits[:, 0].view(batch_size, beam_size, -1) select_indices_array = [] for b, bm in enumerate(beams): - if not tgt_padding_mask[b, idx]: + if not padding_mask[b, idx]: select_indices_array.append( torch.arange(beam_size).to(tgt.device) + b * beam_size ) continue if idx % 2 == 0: s = scores[b, :, : self.retype_vocab_size] - if self.mem_mask == "soft" and tgt_padding_mask[b, idx]: + if self.mem_mask == "soft" and padding_mask[b, idx]: s += mem_logits_list[b][idx // 2] else: s = scores[b, :, self.retype_vocab_size :] @@ -647,7 +652,7 @@ def beam_decode( torch.stack( [ bm.getCurrentState() - if tgt_padding_mask[b, idx] + if padding_mask[b, idx] else torch.zeros(beam_size, dtype=torch.long).to(tgt.device) for b, bm in enumerate(beams) ] @@ -674,14 +679,36 @@ def beam_decode( tgt = torch.cat([tgt, tgt_step], dim=1) all_type_hyps, all_name_hyps, all_scores = [], [], [] + # include non-best hypotheses + all_nonbest_type_hyps, all_nonbest_name_hyps = [], [] for j in range(batch_size): b = beams[j] scores, ks = b.sortFinished(minimum=beam_size) - times, k = ks[0] - hyp = b.getHyp(times, k) - hyp = torch.tensor(hyp).view(-1, 2).t() + + def get(i): + times, k = ks[i] + hyp = b.getHyp(times, k) + hyp = torch.tensor(hyp).view(-1, 2).t() + return hyp + + # The batch is by example, and the time steps of the hypothesis are + # variable/type predictions. + hyp = get(0) + all_type_hyps.append(hyp[0]) all_name_hyps.append(hyp[1]) all_scores.append(scores[0]) - return torch.cat(all_type_hyps), torch.cat(all_name_hyps) + if return_non_best: + all_hyps = (get(i) for i in range(beam_size)) + all_hyps = ((tup[0], tup[1]) for tup in all_hyps) + all_hyps = zip(*all_hyps) + all_hyps = tuple(torch.stack(x) for x in all_hyps) + + all_nonbest_type_hyps.append(all_hyps[0]) + all_nonbest_name_hyps.append(all_hyps[1]) + + if return_non_best: + return torch.cat(all_type_hyps), torch.cat(all_name_hyps), torch.cat(all_nonbest_type_hyps), torch.cat(all_nonbest_name_hyps) + else: + return torch.cat(all_type_hyps), torch.cat(all_name_hyps) diff --git a/dirty/model/xfmr_mem_encoder.py b/dirty/model/xfmr_mem_encoder.py index 191089e..eac6a85 100644 --- a/dirty/model/xfmr_mem_encoder.py +++ b/dirty/model/xfmr_mem_encoder.py @@ -104,7 +104,7 @@ def forward( :rtype: Dict[str, torch.Tensor] """ mem_encoding, mem_mask = self.encode_sequence( - tensor_dict["target_type_src_mems"][tensor_dict["target_mask"]] + tensor_dict["src_var_locs"][tensor_dict["src_type_mask"]] ) # TODO: ignore the padding when averaging diff --git a/dirty/multitask.xfmr.jsonnet b/dirty/multitask.xfmr.jsonnet index 7a7b61d..e04c6ec 100644 --- a/dirty/multitask.xfmr.jsonnet +++ b/dirty/multitask.xfmr.jsonnet @@ -46,9 +46,12 @@ "hidden_size": $['mem_encoder'].hidden_size, }, "train": { + "torch_float32_matmul": "medium", # high, highest + # 16-mixed/AMP can lead to NaN errors + "precision": "32", #bit "batch_size": 16, "grad_accum_step": 4, - "max_epoch": 15, + "max_epoch": 25, "lr": 1e-4, "patience": 10, "check_val_every_n_epoch": 1, diff --git a/dirty/multitask_test_ci.xfmr.jsonnet b/dirty/multitask_test_ci.xfmr.jsonnet index bbb7fb4..cbfc8e1 100644 --- a/dirty/multitask_test_ci.xfmr.jsonnet +++ b/dirty/multitask_test_ci.xfmr.jsonnet @@ -1,7 +1,7 @@ { "data": { "train_file": "data1/train-shard-*.tar", - "dev_file": "data1/dev-*.tar", + "dev_file": "data1/dev*.tar", "test_file": "data1/test.tar", "vocab_file": "data1/vocab.bpe10000", "typelib_file": "data1/typelib.json", @@ -57,6 +57,5 @@ "pred_file": "pred_mt_ci.json", "batch_size": 16, "beam_size": 0, - "limit": 2, } } diff --git a/dirty/utils/case_study.py b/dirty/utils/case_study.py index 9878f07..849f5a7 100644 --- a/dirty/utils/case_study.py +++ b/dirty/utils/case_study.py @@ -5,7 +5,7 @@ import torch from tqdm import tqdm -from evaluate import add_options, load_data +from utils.evaluate import add_options, load_data def view(c, cc, dcc): diff --git a/dirty/utils/dataset.py b/dirty/utils/dataset.py index 103cd15..69674d4 100644 --- a/dirty/utils/dataset.py +++ b/dirty/utils/dataset.py @@ -1,20 +1,19 @@ import glob import json -from collections import defaultdict from typing import Dict, List, Mapping, Optional, Set, Tuple, Union +from collections import defaultdict import _jsonnet -import numpy as np import torch import webdataset as wds +from torch.utils.data import IterableDataset from torch.nn.utils.rnn import pad_sequence -from tqdm import tqdm from utils.code_processing import tokenize_raw_code -from utils.ghidra_function import CollectedFunction, Function -from utils.ghidra_variable import Location, Variable, location_from_json_key, Register, Stack -from utils.ghidra_types import Struct, TypeLibCodec, TypeLib, UDT, TypeInfo, Disappear - +from utils.ghidra_function import CollectedFunction +from utils.ghidra_variable import Location, Variable, Unknown, location_from_json_key, Register, Stack +from utils.ghidra_types import TypeLibCodec, Disappear +from utils.vocab import Vocab class Example: def __init__( @@ -28,6 +27,7 @@ def __init__( raw_code: str = "", test_meta: Dict[str, Dict[str, bool]] = None, binary: str = None, + other_info = None, ): self.name = name self.code_tokens = code_tokens @@ -38,17 +38,23 @@ def __init__( self.raw_code = raw_code self.test_meta = test_meta self.binary = binary + self.other_info = other_info @classmethod def from_json(cls, d: Dict): source = { - location_from_json_key(loc): Variable.from_json(var) - for loc, var in d["source"].items() + location_from_json_key(loc): [Variable.from_json(var) for var in varlist] + for loc, varlist in d["source"].items() } target = { - location_from_json_key(loc): Variable.from_json(var) - for loc, var in d["target"].items() + location_from_json_key(loc): [Variable.from_json(var) for var in varlist] + for loc, varlist in d["target"].items() } + + # It seems like other code assumes the number of source and target + # variables are the same. + assert len(source) == len(target), "Source and target have different lengths" + return cls( d["name"], d["code_tokens"], @@ -60,8 +66,8 @@ def from_json(cls, d: Dict): def to_json(self): assert self._is_valid - source = {loc.json_key(): var.to_json() for loc, var in self.source.items()} - target = {loc.json_key(): var.to_json() for loc, var in self.target.items()} + source = {loc.json_key(): [var.to_json() for var in varlist] for loc, varlist in self.source.items()} + target = {loc.json_key(): [var.to_json() for var in varlist] for loc, varlist in self.target.items()} return { "name": self.name, "code_tokens": self.code_tokens, @@ -70,44 +76,79 @@ def to_json(self): } @classmethod - def from_cf(cls, cf: CollectedFunction, **kwargs): - """Convert from a decoded CollectedFunction""" + def from_cf(cls, cf: CollectedFunction, prediction=False, **kwargs): + """Convert from a decoded CollectedFunction. + """ + use_disappear = prediction + filter_dups = not prediction name = cf.decompiler.name raw_code = cf.decompiler.raw_code code_tokens = tokenize_raw_code(raw_code) - source = {**cf.decompiler.local_vars, **cf.decompiler.arguments} - target = {**cf.debug.local_vars, **cf.debug.arguments} + source = {**cf.decompiler.local_vars} + + # Actually merge these correctly! + for k, v in cf.decompiler.arguments.items(): + # v is a set + if k in source: + source[k].update(v) + else: + source[k] = v + if hasattr(cf.debug, "local_vars"): + target = {**cf.debug.local_vars} + for k, v in cf.debug.arguments.items(): + # v is a set + if k in target: + target[k].update(v) + else: + target[k] = v + else: + target = {} # Remove variables that overlap on memory or don't appear in the code tokens source_code_tokens_set = set(code_tokens[code_tokens.index("{"):]) - #target_code_tokens_set = set(tokenize_raw_code(cf.debug.raw_code)) - source = Example.filter(source, source_code_tokens_set) - # target = Example.filter(target, target_code_tokens_set, set(source.keys())) - target = Example.filter(target, None, set(source.keys())) + source, source_filtered_out = Example.filter(source, source_code_tokens_set, filter_out_duplicate_locations=filter_dups) + target, target_filtered_out = Example.filter(target, None, set(source.keys()), filter_non_user_names=True, filter_out_duplicate_locations=filter_dups) + + # Optionally assign type "Disappear" to variables not existing in the + # ground truth. EJS thinks this may be harmful since the model learns + # to overzealously predict disappear. + + # Note: Need to copy source.keys() so we don't change the list while + # iterating. + for loc in list(source.keys()): + if use_disappear: + if loc not in target.keys(): + target[loc] = [Variable(Disappear(), "disappear", False)] * len(source[loc]) + else: + if loc in source.keys() and loc not in target.keys(): + del source[loc] - # Assign type "Disappear" to variables not existing in the ground truth varnames = set() - for loc in source.keys(): - if loc not in target.keys(): - target[loc] = Variable(Disappear(), "", False) - # Add special tokens to variables to prevnt being sub-tokenized in BPE - for var in source.values(): - varname = var.name - varnames.add(varname) + # Add special tokens to variable names + for varlist in source.values(): + for var in varlist: + varname = var.name + varnames.add(varname) for idx in range(len(code_tokens)): if code_tokens[idx] in varnames: code_tokens[idx] = f"@@{code_tokens[idx]}@@" + other_info = { + 'source_filtered': source_filtered_out, + 'target_filtered': target_filtered_out, + } + return cls( name, code_tokens, source, target, kwargs["binary_file"], - valid=name == cf.debug.name and source and "halt_baddata" not in source_code_tokens_set, + valid=source and "halt_baddata" not in source_code_tokens_set, raw_code=raw_code, + other_info=other_info ) @staticmethod @@ -115,24 +156,36 @@ def filter( mapping: Mapping[Location, Set[Variable]], code_tokens: Optional[Set[str]] = None, locations: Optional[Set[Location]] = None, - ) -> Mapping[Location, Variable]: + filter_non_user_names: bool = False, + filter_out_duplicate_locations: bool = True + ) -> Mapping[Location, Set[Variable]]: """Discard and leave these for future work: Multiple variables sharing a memory location (no way to determine ground truth); Variables not appearing in code (no way to get representation); Target variables not appearing in source (useless ground truth); """ - ret: Mapping[Location, Set[Variable]] = {} + ret: Mapping[Location, List[Variable]] = defaultdict(list) + + filtered = set() + for location, variable_set in mapping.items(): - if len(variable_set) > 1: - continue - var = list(variable_set)[0] - if code_tokens is not None and not var.name in code_tokens: - continue - if locations is not None and not location in locations: + for v in variable_set: + filtered.add((location, v)) + if len(variable_set) > 1 and filter_out_duplicate_locations: + print(f"Warning: Ignoring location {location} with multiple variables {variable_set}") continue - ret[location] = var - return ret + + for var in variable_set: + if code_tokens is not None and not var.name in code_tokens: + continue + if locations is not None and not location in locations: + continue + if filter_non_user_names and not var.user: + continue + filtered.remove((location, var)) + ret[location].append(var) + return ret, {x.name: loc.json_key() for loc, x in filtered} @property def is_valid_example(self): @@ -147,8 +200,7 @@ def identity(x): def get_src_len(e): return e.source_seq_length - -class Dataset(wds.Dataset): +class Dataset(IterableDataset): SHUFFLE_BUFFER = 5000 SORT_BUFFER = 512 @@ -157,14 +209,13 @@ def __init__(self, url: str, config: Optional[Dict] = None, percent: float = 1.0 # support wildcards urls = sorted(glob.glob(url)) urls = urls[: int(percent * len(urls))] - super().__init__(urls) + if config: # annotate example for training - from utils.vocab import Vocab - print(config["vocab_file"]) self.vocab = Vocab.load(config["vocab_file"]) - with open(config["typelib_file"]) as type_f: - self.typelib = TypeLibCodec.decode(type_f.read()) + # do we need this? + #with open(config["typelib_file"]) as type_f: + # self.typelib = TypeLibCodec.decode(type_f.read()) self.max_src_tokens_len = config["max_src_tokens_len"] self.max_num_var = config["max_num_var"] annotate = self._annotate @@ -175,13 +226,51 @@ def __init__(self, url: str, config: Optional[Dict] = None, percent: float = 1.0 # for creating the vocab annotate = identity sort = identity - self = ( - self.pipe(Dataset._file_iter_to_line_iter) - .map(Example.from_json) - .map(annotate) - .shuffle(Dataset.SHUFFLE_BUFFER) - .pipe(sort) - ) + + if not urls: + + # Dummy dataset for utils.infer + + self.len = None + self.wds = None + + else: + + self.wds = ( + wds.WebDataset(urls, empty_check=False, shardshuffle=True) + .compose(Dataset._file_iter_to_line_iter) + .map(Example.from_json) + .map(annotate) + .shuffle(Dataset.SHUFFLE_BUFFER) + .compose(sort) + ) + + # Estimate size of dataset + # XXX: Limit number of files we read or use a timer? Right + # now we use all of them. + try: + line_dataset = ( + wds.WebDataset(urls, shardshuffle=False) + .compose(Dataset._file_iter_to_line_iter) + ) + #print(f"URLs: {urls} dataset: {line_dataset}") + #print("Estimating size of dataset...") + self.len = sum(1 for _line in line_dataset) + except: + # This might fail if we create a dummy dataset, such as in + # utils.infer. + self.len = None + + def __len__(self): + return self.len + + # We need this for IterableDataset + def __iter__(self): + return iter(self.wds) + + # Should we forward to webdataset? + #def __getattr__(self, name): + # return getattr(self.wds, name) @staticmethod def _sort(example_iter): @@ -241,60 +330,69 @@ def _annotate(self, example: Example): tgt_var_subtypes = [] tgt_var_type_sizes = [] tgt_var_type_objs = [] - tgt_var_src_mems = [] + src_var_locs_encoded = [] tgt_names = [] - # variables on registers first, followed by those on stack - locs = sorted( - example.source, - key=lambda x: sub_tokens.index(f"@@{example.source[x].name}@@") - if f"@@{example.source[x].name}@@" in sub_tokens - else self.max_src_tokens_len, - ) - stack_pos = [x.offset for x in example.source if isinstance(x, Stack)] + + locs = sorted(example.source.keys(), key=lambda loc: repr(loc)) + + stack_pos = [x.offset for x in example.source.keys() if isinstance(x, Stack)] stack_start_pos = max(stack_pos) if stack_pos else None - for loc in locs[: self.max_num_var]: - src_var = example.source[loc] - tgt_var = example.target[loc] + + def var_loc_in_func(loc): + # TODO: fix the magic number (1030) for computing vocabulary idx + # TODO: add vocabulary for unknown locations? + if isinstance(loc, Register): + return 1030 + self.vocab.regs[loc.name] + elif isinstance(loc, Unknown): + return 2 # unknown + else: + from utils.vocab import VocabEntry + + return ( + 3 + stack_start_pos - loc.offset + if stack_start_pos - loc.offset < VocabEntry.MAX_STACK_SIZE + else 2 + ) + + def for_src_var(loc, src_var): + nonlocal src_var_names, src_var_types_id, src_var_types_str, src_var_locs_encoded src_var_names.append(f"@@{src_var.name}@@") - tgt_var_names.append(f"@@{tgt_var.name}@@") src_var_types_id.append(types_model.lookup_decomp(str(src_var.typ))) src_var_types_str.append(str(src_var.typ)) - tgt_var_types_id.append(types_model[str(tgt_var.typ)]) - tgt_var_types_str.append(str(tgt_var.typ)) - if types_model[str(tgt_var.typ)] == types_model.unk_id: - subtypes = [subtypes_model.unk_id, subtypes_model[""]] - else: - subtypes = [subtypes_model[subtyp] for subtyp in tgt_var.typ.tokenize()] - tgt_var_type_sizes.append(len(subtypes)) - tgt_var_subtypes += subtypes - tgt_var_type_objs.append(tgt_var.typ) # Memory # 0: absolute location of the variable in the function, e.g., # for registers: Reg 56 # for stack: relative position to the first variable # 1: size of the type # 2, 3, ...: start offset of fields in the type - def var_loc_in_func(loc): - # TODO: fix the magic number for computing vocabulary idx - if isinstance(loc, Register): - return 1030 + self.vocab.regs[loc.name] - else: - from utils.vocab import VocabEntry - return ( - 3 + stack_start_pos - loc.offset - if stack_start_pos - loc.offset < VocabEntry.MAX_STACK_SIZE - else 2 - ) - - tgt_var_src_mems.append( + src_var_locs_encoded.append( [var_loc_in_func(loc)] + types_model.encode_memory( (src_var.typ.size,) + src_var.typ.start_offsets() ) ) + + def for_tgt_var(loc, tgt_var): + nonlocal tgt_var_names, tgt_var_types_id, tgt_var_types_str, tgt_var_subtypes, tgt_var_type_sizes, tgt_var_type_objs, tgt_names + tgt_var_names.append(f"@@{tgt_var.name}@@") + tgt_var_types_id.append(types_model[str(tgt_var.typ)]) + tgt_var_types_str.append(str(tgt_var.typ)) + if types_model[str(tgt_var.typ)] == types_model.unk_id: + subtypes = [subtypes_model.unk_id, subtypes_model[""]] + else: + subtypes = [subtypes_model[subtyp] for subtyp in tgt_var.typ.tokenize()] + tgt_var_type_sizes.append(len(subtypes)) + tgt_var_subtypes += subtypes + tgt_var_type_objs.append(tgt_var.typ) tgt_names.append(tgt_var.name) + for loc in locs[: self.max_num_var]: + for src_var in example.source[loc]: + for_src_var(loc, src_var) + for tgt_var in example.target[loc]: + for_tgt_var(loc, tgt_var) + setattr(example, "src_var_names", src_var_names) setattr(example, "tgt_var_names", tgt_var_names) if self.rename: @@ -305,11 +403,11 @@ def var_loc_in_func(loc): ) setattr(example, "src_var_types", src_var_types_id) setattr(example, "src_var_types_str", src_var_types_str) + setattr(example, "src_var_locs", src_var_locs_encoded) setattr(example, "tgt_var_types", tgt_var_types_id) setattr(example, "tgt_var_types_str", tgt_var_types_str) setattr(example, "tgt_var_subtypes", tgt_var_subtypes) setattr(example, "tgt_var_type_sizes", tgt_var_type_sizes) - setattr(example, "tgt_var_src_mems", tgt_var_src_mems) return example @@ -348,9 +446,9 @@ def collate_fn( torch.tensor(e.src_var_types, dtype=torch.long) for e in examples ] src_type_id = pad_sequence(src_type_ids, batch_first=True) - type_ids = [torch.tensor(e.tgt_var_types, dtype=torch.long) for e in examples] - target_type_id = pad_sequence(type_ids, batch_first=True) - assert target_type_id.shape == variable_mention_num.shape + tgt_type_ids = [torch.tensor(e.tgt_var_types, dtype=torch.long) for e in examples] + target_type_id = pad_sequence(tgt_type_ids, batch_first=True) + assert target_type_id.shape == variable_mention_num.shape, f"{target_type_id.shape} != {variable_mention_num.shape}" subtype_ids = [ torch.tensor(e.tgt_var_subtypes, dtype=torch.long) for e in examples @@ -361,18 +459,20 @@ def collate_fn( ] target_type_sizes = pad_sequence(type_sizes, batch_first=True) - target_mask = src_type_id > 0 - target_type_src_mems = [ + src_type_mask = src_type_id > 0 + tgt_type_mask = target_type_id > 0 + + src_var_locs = [ torch.tensor(mems, dtype=torch.long) for e in examples - for mems in e.tgt_var_src_mems + for mems in e.src_var_locs ] - target_type_src_mems = pad_sequence(target_type_src_mems, batch_first=True) - target_type_src_mems_unflattened = torch.zeros( - *target_mask.shape, target_type_src_mems.size(-1), dtype=torch.long + src_var_locs = pad_sequence(src_var_locs, batch_first=True) + src_var_locs_unflattened = torch.zeros( + *src_type_mask.shape, src_var_locs.size(-1), dtype=torch.long ) - target_type_src_mems_unflattened[target_mask] = target_type_src_mems - target_type_src_mems = target_type_src_mems_unflattened + src_var_locs_unflattened[src_type_mask] = src_var_locs + src_var_locs = src_var_locs_unflattened # renaming task if hasattr(examples[0], "tgt_var_name_ids"): @@ -397,18 +497,19 @@ def collate_fn( variable_mention_mask=variable_mention_mask, variable_mention_num=variable_mention_num, variable_encoding_mask=variable_encoding_mask, - target_type_src_mems=target_type_src_mems, + #target_type_src_mems=target_type_src_mems, src_type_id=src_type_id, - target_mask=target_mask, - target_submask=target_subtype_id > 0, - target_type_sizes=target_type_sizes, + src_type_mask=src_type_mask, + src_var_locs=src_var_locs, + #target_submask=target_subtype_id > 0, + #target_type_sizes=target_type_sizes, ), dict( tgt_var_names=sum([e.tgt_var_names for e in examples], []), target_type_id=target_type_id, target_name_id=target_name_id, target_subtype_id=target_subtype_id, - target_mask=target_mask, + target_type_mask=tgt_type_mask, test_meta=[e.test_meta for e in examples], ), ) diff --git a/dirty/utils/evaluate.py b/dirty/utils/evaluate.py index c9f8fe1..86ce10c 100644 --- a/dirty/utils/evaluate.py +++ b/dirty/utils/evaluate.py @@ -18,6 +18,7 @@ def add_options(parser): def load_data(config_file): config = json.loads(_jsonnet.evaluate_file(config_file))["data"] + # This is a reason why accuracy using this module is slightly different config["max_num_var"] = 1 << 30 dataset = Dataset(config["test_file"], config) return dataset @@ -188,7 +189,7 @@ def func_no_disappear_body_not_in_train_acc(preds, results, test_metas): } NAME_METRICS = { - "accuracy": acc, + "acc": acc, "body_in_train_acc": body_in_train_acc, "body_not_in_train_acc": body_not_in_train_acc, "no_disappear_acc": no_disappear_acc, @@ -204,14 +205,15 @@ def evaluate(dataset, results, type_metrics, name_metrics): pred_names, ref_names, pred_types, ref_types = [], [], [], [] test_meta_types, test_meta_names = [], [] examples_w_structs = [] - num_functions, num_all_disappear, num_no_disappear = 0, 0, 0 + num_functions, num_all_disappear, num_no_disappear, num_filtered_variables = 0, 0, 0, 0 + for example in tqdm(dataset): # one example is one function: check if all variables are disappear or if all variables are actual variables all_disappear = True no_disappear = True - for tgt_type in example.tgt_var_types_str: - if dataset.dataset.vocab.types.id2word[dataset.dataset.vocab.types[tgt_type]] == "disappear": + for tgt_type_name in example.tgt_var_types_str: + if tgt_type_name == "disappear": no_disappear = False else: all_disappear = False @@ -222,31 +224,31 @@ def evaluate(dataset, results, type_metrics, name_metrics): if no_disappear: num_no_disappear += 1 - for src_name, tgt_name, tgt_type in zip( + for src_name, tgt_name, tgt_type_name in zip( example.src_var_names, example.tgt_var_names, example.tgt_var_types_str ): - pred_type, _ = ( + pred_type_name, _ = ( results.get(example.binary, {}) .get(example.name, {}) .get(src_name[2:-2], ("", "")) ) - pred_types.append(pred_type) - ref_types.append(tgt_type) + pred_types.append(pred_type_name) + ref_types.append(tgt_type_name) test_meta = example.test_meta.copy() - test_meta["is_struct"] = dataset.dataset.vocab.types.id2word[ - dataset.dataset.vocab.types[tgt_type] - ].startswith("struct ") + test_meta["is_struct"] = tgt_type_name.startswith("struct ") if test_meta["is_struct"]: examples_w_structs.append(example.binary) - test_meta["is_disappear"] = dataset.dataset.vocab.types.id2word[ - dataset.dataset.vocab.types[tgt_type] - ].startswith("disappear") + test_meta["is_disappear"] = tgt_type_name.startswith("disappear") test_meta["func_all_disappear"] = all_disappear test_meta["func_no_disappear"] = no_disappear test_meta_types.append(test_meta) - if src_name != tgt_name and tgt_name != "@@@@": + + # Note: This is why accuracy metrics slightly differ from the + # prediction step. + + if src_name != tgt_name and tgt_name != "@@@@" and tgt_name != "@@disappear@@": # only report need_rename _, pred_name = ( results.get(example.binary, {}) @@ -256,6 +258,9 @@ def evaluate(dataset, results, type_metrics, name_metrics): pred_names.append(pred_name) ref_names.append(tgt_name[2:-2]) test_meta_names.append(test_meta) + else: + num_filtered_variables += 1 + #print(f"Warning: Skipping {src_name} {tgt_name}") pred_types = np.array(pred_types, dtype=object) ref_types = np.array(ref_types, dtype=object) @@ -271,7 +276,9 @@ def evaluate(dataset, results, type_metrics, name_metrics): wandb.log( { - "total variables": len(test_meta_types), + "total variables with types": len(test_meta_types), + "total variables with names": len(test_meta_names), + "filtered variables names": num_filtered_variables, "num structs": struct_counter, "num disappear": disappear_counter } diff --git a/dirty/utils/gen-pred.py b/dirty/utils/gen-pred.py new file mode 100644 index 0000000..2e1215c --- /dev/null +++ b/dirty/utils/gen-pred.py @@ -0,0 +1,35 @@ +# This is a script to generate the pred-mt-ref.json file used by prepare_vis.py +# script for the DIRTY explorer web interface. + +import argparse + +from collections import defaultdict +from .evaluate import load_data + +import json +import tqdm + +def add_options(parser): + parser.add_argument("--config-file", type=str, required=True) + parser.add_argument("--output-file", type=str, default="pred_mt_ref.json") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_options(parser) + args = parser.parse_args() + + dataset = load_data(args.config_file) + + d = defaultdict(dict) + + for e in tqdm.tqdm(dataset): + binary = e.binary + func = e.name + body_in_train = e.test_meta['function_body_in_train'] + d[binary][func] = {} + + for srcname, tgtname, tgttyp in zip(e.src_var_names, e.tgt_var_names, e.tgt_var_types_str): + d[binary][func][srcname[2:-2]] = (tgtname, tgttyp, body_in_train) + + open(args.output_file, "w").write(json.dumps(d, indent=2)) + diff --git a/dirty/utils/infer.py b/dirty/utils/infer.py new file mode 100644 index 0000000..cfea919 --- /dev/null +++ b/dirty/utils/infer.py @@ -0,0 +1,186 @@ +""" +Usage: + infer.py [options] CONFIG_FILE INPUT_JSON MODEL_CHECKPOINT + +Options: + -h --help Show this screen. +""" + +from collections import defaultdict + +from typing import Optional + +from utils.ghidra_function import Function, CollectedFunction +from utils.ghidra_types import TypeLib, TypeInfo +from utils.ghidra_variable import Location, Stack, Register, Unknown, Variable +from utils.dataset import Example, Dataset +from utils.code_processing import canonicalize_code + +from model.model import TypeReconstructionModel + +import _jsonnet + +from docopt import docopt +import json + +import torch + +# Specialized version of dataset-gen-ghidra/decompiler/dump_trees.py + +def ghidra_obtain_cf(ghidra_func): + + from ghidra.app.decompiler import DecompInterface + + # Specialized version from collect.py + + def collect_variables(variables): + + collected_vars = defaultdict(set) + + for v in variables: + if v.getName() == "": + continue + + typ: TypeInfo = TypeLib.parse_ghidra_type(v.getDataType()) + + loc: Optional[Location] = None + storage = v.getStorage() + + if storage.isStackStorage(): + loc = Stack(storage.getStackOffset()) + elif storage.isRegisterStorage(): + loc = Register(storage.getRegister().getName()) + else: + loc = Unknown(storage.toString()) + + assert loc is not None + + collected_vars[loc].add(Variable(typ=typ, name=v.getName(), user=False)) + + return collected_vars + + decomp = DecompInterface() + decomp.toggleSyntaxTree(False) + decomp.openProgram(currentProgram()) + + decomp_results = decomp.decompileFunction(ghidra_func, 30, None) + + if not decomp_results.decompileCompleted(): + raise RuntimeError("Failed to decompile") + + if decomp_results.getErrorMessage() != "": + raise RuntimeError("Failed to decompile") + + high_func = decomp_results.getHighFunction() + lsm = high_func.getLocalSymbolMap() + symbols = [v for v in lsm.getSymbols()] + func_return = high_func.getFunctionPrototype().getReturnType() + + name: str = ghidra_func.getName() + + return_type = TypeLib.parse_ghidra_type(func_return) + + arguments = collect_variables( + [v for v in symbols if v.isParameter()], + ) + local_vars = collect_variables( + [v for v in symbols if not v.isParameter()], + ) + + raw_code = decomp_results.getCCodeMarkup().toString() + + decompiler = Function( + ast=None, + name=name, + return_type=return_type, + arguments=arguments, + local_vars=local_vars, + raw_code=raw_code, + ) + + cf = CollectedFunction( + ea=ghidra_func.getEntryPoint().toString(), + debug=None, + decompiler=decompiler, + ) + + return cf + +def infer(config, model, cf, binary_file=None): + + example = Example.from_cf( + cf, prediction=True, binary_file=binary_file, max_stack_length=1024, max_type_size=1024 + ) + #print(example) + + assert example.is_valid_example, "Not a valid example, it probably has no variables" + + canonical_code = canonicalize_code(example.raw_code) + example.canonical_code = canonical_code + #print(example.canonical_code) + + # Create a dummy Dataset so we can call .annotate + dataset = Dataset(config["data"]["test_file"], config["data"]) + + #print(f"example src: {example.source}") + #print(f"example target: {example.target}") + + example = dataset._annotate(example) + + collated_example = dataset.collate_fn([example]) + collated_example, _garbage = collated_example + #print(collated_example) + + #tensor = torch.tensor([collated_example]) + #print(tensor) + + #single_example_loader = DataLoader([collated_example], batch_size=1) + + #trainer = pl.Trainer() + #wat = trainer.predict(model, single_example_loader) + #print(wat) + + with torch.no_grad(): + output = model(collated_example, return_non_best=True) + + var_names = [x[2:-2] for x in example.src_var_names] + pred_names = output['rename_preds'] + pred_types = output['retype_preds'] + + all_pred_names = output['all_rename_preds'] + all_pred_types = output['all_retype_preds'] + + def make_model_output(var_names, pred_names, pred_types): + return {oldname: (newtype, newname) for (oldname, newname, newtype) in zip(var_names, pred_names, pred_types)} + + model_output = make_model_output(var_names, pred_names, pred_types) + #{oldname: (newtype, newname) for (oldname, newname, newtype) in zip(var_names, pred_names, pred_types)} + + # Multi-predictions from beam search + # This is currently a list of mappings. But maybe it should be a mapping to lists? + model_output_multi = [make_model_output(var_names, pred_names, pred_types) for pred_names, pred_types in zip(all_pred_names, all_pred_types)] + + other_outputs = {k:v for k,v in output.items() if k not in ["rename_preds", "retype_preds"]} + + return model_output, model_output_multi, example.other_info, other_outputs + +def main(args): + + config = json.loads(_jsonnet.evaluate_file(args["CONFIG_FILE"])) + + json_dict = json.load(open(args["INPUT_JSON"], "r")) + # print(json_dict) + cf = CollectedFunction.from_json(json_dict) + #print(cf) + + model = TypeReconstructionModel.load_from_checkpoint(checkpoint_path=args["MODEL_CHECKPOINT"], config=config) + model.eval() + + model_output = infer(config, model, cf, binary_file=args['INPUT_JSON']) + + print(f"The model output is: {model_output[0]}") + + +if __name__ == "__main__": + args = docopt(__doc__) + main(args) diff --git a/dirty/utils/prepare_vis.py b/dirty/utils/prepare_vis.py new file mode 100644 index 0000000..247b0d6 --- /dev/null +++ b/dirty/utils/prepare_vis.py @@ -0,0 +1,182 @@ +import os +import html +import argparse +import re +import json +import random +import gzip +import pickle as pkl +from subprocess import Popen, PIPE +from multiprocessing import Pool +from tqdm import tqdm +from lexer import Lexer, Token +from collections import defaultdict + +def add_options(parser): + parser.add_argument("--pred", type=str, required=True) + parser.add_argument("--ref", type=str, required=True) + parser.add_argument("--bin-mapping", type=str, required=True) + parser.add_argument("--bins-path", type=str, required=True) + parser.add_argument("--ida-output-path", type=str, required=True) + parser.add_argument("--preprocessed-path", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument("--not-train", action="store_true") + parser.add_argument("--struct", action="store_true") + +def get_all_funcs(pred, ref): + all_funcs = set() + #assert pred.keys() == ref.keys() + missing = set(pred.keys()) - set(ref.keys()) + assert len(missing) == 0 + for binary in pred: + missing = set(pred[binary].keys()) - set(ref[binary].keys()) + assert len(missing) == 0 + #assert pred[binary].keys() == ref[binary].keys() + for func_name in pred[binary]: + all_funcs.add((binary, func_name)) + + return all_funcs + +def eval(pred, ref, funcs): + count = 0 + correct = 0 + for binary, func_name in funcs: + for (src_name, src_type), (tgt_name, tgt_type, body_in_train) in zip(pred[binary][func_name], ref[binary][func_name]): + count += 1 + correct += src_type == tgt_type + return correct / count + +def get_binary_info(func, meta, bins_path): + binary, func_name = func + # get disassembler results + if "path" not in meta: + return meta + bin_path = os.path.join(bins_path, meta["path"]) + all_dump = os.popen(f"objdump -d {bin_path}").read().split("\n\n") + func_dump = [f for f in all_dump if f"<{func_name}>" in f.split("\n")[0]] + if len(func_dump) >= 1: + func_dump = func_dump[0] + meta["objdump"] = func_dump + else: + meta["objdump"] = "" + + return meta + +def format_code(code): + # Requires clang-format + p = Popen("clang-format-13", stdout=PIPE, stdin=PIPE) + ret = p.communicate(input=code.encode("utf-8"))[0] + return ret.decode() + +# var = re.compile() + +def prepare_highlight_var(code): + return re.sub(r"@@(\w+)@@", r"e6ff4de4\g<1>4ed4ff6e", code) + +# xvar = re.compile() +def highlight_var(code): + return re.sub(r"e6ff4de4(\w+)4ed4ff6e", lambda m: '' + m.group(1) + '', code) + +def tokenize_raw_code(raw_code): + lexer = Lexer(raw_code) + tokens = [] + for token_type, token in lexer.get_tokens(): + if token_type in Token.Literal: + token = str(token_type).split('.')[2] + + tokens.append(token) + + return tokens + +def get_preprocessed_code(func, pred, ref, preprocessed_path, only_struct=False): + binary, func_name = func + with open(os.path.join(preprocessed_path, f"{binary}_{binary}.jsonl")) as f: + for line in f: + json_line = json.loads(line) + if json_line["name"] == func_name: + varnames = set(name for name in pred[binary][func_name] if not only_struct or ref[binary][func_name][name][1].startswith("struc")) + code = json_line["code_tokens"] + code = map(lambda x: x[2:-2] if x.startswith("@@") and x[2:-2] not in varnames else x, code) + code = " ".join(code) + return highlight_var(format_code(prepare_highlight_var(html.escape(code)))) + return "" + +def get_debug_code(func, ref, ida_output_path, only_struct=False): + binary, func_name = func + with gzip.open(os.path.join(ida_output_path, f"{binary}_{binary}.jsonl.gz"), "rt") as f: + for line in f: + json_line = json.loads(line) + if json_line["b"]["n"] == func_name: + code = tokenize_raw_code(json_line["b"]["c"]) + varnames = set(name[2:-2] for name, typ, _ in ref[binary][func_name].values() if not only_struct or typ.startswith("struc")) + code = map(lambda x: f"@@{x}@@" if x in varnames else x, code) + code = " ".join(code) + return highlight_var(format_code(prepare_highlight_var(html.escape(code)))) + return "" + + +def main(args): + func, meta, bins_path, ida_output_path, preprocessed_path, pred, ref, only_struct = args + info = get_binary_info(func, meta, bins_path) + info["code_s"] = get_preprocessed_code(func, pred, ref, preprocessed_path, only_struct) + info["code_t"] = get_debug_code(func, ref, ida_output_path, only_struct) + info["var"] = [] + binary, func_name = func + for src_name in pred[binary][func_name]: + print(f"{binary} {func_name} {src_name}") + assert src_name in pred[binary][func_name], "pred" + assert src_name in ref[binary][func_name], f"Unable to find src_name {src_name} in ref {binary} {func_name}. ref keys: {ref[binary][func_name].keys()} pred keys: {pred[binary][func_name].keys()}" + (src_type, src_name_pred), (tgt_name, tgt_type, body_in_train) = pred[binary][func_name][src_name], ref[binary][func_name][src_name] + info["body_in_train"] = body_in_train + if tgt_type.startswith("struc") or not only_struct: + + d = {"name": src_name.replace("@", ""), "type": src_type.replace("", "__unk__"), "pred_name": src_name_pred.replace("", "__unk__"), "ref_name": tgt_name.replace("@", ""), "ref_type": tgt_type.replace("", "__unk__")} + for k in d.keys(): + d[k] = html.escape(d[k]) + info["var"].append(d) + + return info + +def sample(all_funcs, num, pred, ref, only_not_in_train=False, only_struct=False): + if not only_not_in_train and not only_struct: + return random.sample(all_funcs, num) + ret = [] + while len(ret) < num: + binary, func_name = random.sample(all_funcs, 1)[0] + valid = True + has_struc = False + for src_name in pred[binary][func_name]: + (src_type, src_name_pred), (tgt_name, tgt_type, body_in_train) = pred[binary][func_name][src_name], ref[binary][func_name][src_name] + if only_not_in_train and body_in_train: + valid = False + if tgt_type.startswith("struc"): + has_struc = True + if only_struct and not has_struc: + valid = False + valid = os.path.exists(os.path.join("/home/jlacomis/direoutput-new/bins", f"{binary}_{binary}.jsonl.gz")) + if valid: + ret.append((binary, func_name)) + else: + print(f"missing {binary} {func_name}") + return ret + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_options(parser) + args = parser.parse_args() + + pred = json.load(open(args.pred)) + ref = json.load(open(args.ref)) + all_funcs = get_all_funcs(pred, ref) + sampled_funcs = sample(all_funcs, 100, pred, ref, args.not_train, args.struct) + + bin_mapping = pkl.load(open(args.bin_mapping, "rb")) + + with Pool(processes=1) as pool: + ret = pool.map( + main, + ((func, bin_mapping.get(func[0], defaultdict(str)), args.bins_path, args.ida_output_path, args.preprocessed_path, pred, ref, args.struct) for func in sampled_funcs), + chunksize=4, + ) + json.dump(ret, open(args.output, "w")) diff --git a/dirty/utils/preprocess.py b/dirty/utils/preprocess.py index ceac773..64b9b39 100644 --- a/dirty/utils/preprocess.py +++ b/dirty/utils/preprocess.py @@ -8,7 +8,6 @@ --max= max dataset size [default: -1] --shard-size= shard size [default: 5000] --test-file= test file - --no-filtering do not filter files """ import glob @@ -20,8 +19,6 @@ import sys import tarfile from json import dumps -from multiprocessing import Process -from typing import Tuple import numpy as np import ujson as json @@ -29,7 +26,7 @@ from tqdm import tqdm from utils.dataset import Example -from utils.ghidra_types import TypeInfo, TypeLib, TypeLibCodec +from utils.ghidra_types import TypeLib, TypeLibCodec from utils.ghidra_function import CollectedFunction from utils.code_processing import canonicalize_code @@ -48,9 +45,13 @@ def example_generator(json_str_list): except ValueError: print(json_str, file=sys.stderr) continue - + example = Example.from_cf( - cf, binary_file=meta, max_stack_length=1024, max_type_size=1024 + cf, + prediction=False, + binary_file=meta, + max_stack_length=1024, + max_type_size=1024, ) if example.is_valid_example: @@ -81,22 +82,6 @@ def json_line_reader(args): return func_json_list -def type_dumper(args): - tgt_folder, fname = args - typelib = TypeLib() - with open(fname, "r") as f: - for line in f: - e = Example.from_json(json.loads(line)) - for var in e.target.values(): - typelib.add(var.typ) - typelib.sort() - with open( - os.path.join(tgt_folder, "types", fname.split("/")[-1]), "w" - ) as type_lib_file: - encoded = TypeLibCodec.encode(typelib) - type_lib_file.write(encoded) - - def main(args): np.random.seed(1234) random.seed(1992) @@ -129,15 +114,15 @@ def main(args): print("loading examples") with multiprocessing.Pool(num_workers) as pool: - json_iter = pool.imap( + json_iter = pool.imap_unordered( json_line_reader, ((input_folder, fname) for fname in input_fnames), chunksize=64, ) - example_iter = pool.imap(example_generator, json_iter, chunksize=64) + example_iter = pool.imap_unordered(example_generator, json_iter, chunksize=64) - for examples in tqdm(example_iter): + for examples in tqdm(example_iter, desc="Writing output", total=len(input_fnames)): if not examples: continue json_file_name = examples[0].binary_file["file_name"].split("/")[-1] @@ -148,6 +133,15 @@ def main(args): example.name ] = example.canonical_code + # Symlink the type file from the unprocessed folder to the preprocessed folder. + + base_file_name = os.path.splitext(json_file_name)[0] + type_file_name = base_file_name + ".json.gz" + + input_type_file = os.path.join(input_folder, "types", type_file_name) + + os.symlink(input_type_file, os.path.join(tgt_folder, "types", type_file_name)) + valid_example_count += len(examples) print("valid examples: ", valid_example_count) @@ -181,34 +175,36 @@ def main(args): test_files_set = set(test_files) train_files = [fname for fname in all_files if fname not in test_files_set] - if dev_file_num == 0: + if dev_file_num == 0 and not test_file: dev_file_num = int(len(train_files) * 0.1) np.random.shuffle(train_files) - dev_files = train_files[-dev_file_num:] - train_files = train_files[:-dev_file_num] + dev_files = train_files[-dev_file_num:] if dev_file_num > 0 else [] + train_files = train_files[:-dev_file_num] if dev_file_num > 0 else train_files # Create types from filtered training set - with multiprocessing.Pool(num_workers) as pool: - pool.map( - type_dumper, - ((tgt_folder, fname) for fname in train_files), - chunksize=64, - ) print("reading typelib") typelib = TypeLib() for fname in tqdm(train_files): fname = os.path.basename(fname) - fname = fname[: fname.index(".")] + ".jsonl" - typelib.add_json_file(os.path.join(tgt_folder, "types", fname)) - typelib.prune(5) + fname = fname[: fname.index(".")] + ".json.gz" + typelib.add_json_file(os.path.join(tgt_folder, "types", fname), ungzip=True) + typelib.sort() print("dumping typelib") - with open(os.path.join(tgt_folder, "typelib.json"), "w") as type_lib_file: + with open(os.path.join(tgt_folder, "typelib_complete.json"), "w") as type_lib_file: encoded = TypeLibCodec.encode(typelib) type_lib_file.write(encoded) + # Prune the type library. This may remove subtypes of course. + if not test_file: typelib.prune(5) + + print("dumping pruned typelib") + with open(os.path.join(tgt_folder, "typelib.json"), "w") as pruned_type_lib_file: + encoded = TypeLibCodec.encode(typelib) + pruned_type_lib_file.write(encoded) + train_functions = dict() for train_file in train_files: file_name = train_file.split("/")[-1] diff --git a/dirty/utils/sample_test.py b/dirty/utils/sample_test.py new file mode 100644 index 0000000..1929b71 --- /dev/null +++ b/dirty/utils/sample_test.py @@ -0,0 +1,193 @@ +# Sample a few test cases from the prediction json file. + +import json +import hjson +import argparse +import multiprocessing +import random +import json +import tempfile +import os +import shutil +import subprocess +import traceback +import tqdm +from utils.vocab import Vocab, UNKNOWN_ID + +BIN_DIR = "/home/ed/Projects/DIRTY/dirt-binaries/all-binaries" + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Sample a few test cases from the dataset.") + parser.add_argument("-n", "--num_samples", type=int, default=50, help="Number of samples to be generated") + + parser.add_argument("json_file", help="Path to the JSON file") + parser.add_argument("-o", "--output_file", help="Path to the output JSON file") + parser.add_argument("--vocab", help="Path to the vocab file", default="data1/vocab.bpe10000") + args = parser.parse_args() + json_file = args.json_file + + with open(json_file, "r") as f: + data = json.load(f) + + sampled_bins = random.sample(list(data.keys()), args.num_samples) + + sampled_functions = list((bin,random.choice(list(data[bin].items()))) for bin in sampled_bins) + + vocab = Vocab.load(args.vocab) + + # if args.output_file: + # with open(args.output_file, "w") as f: + # json.dump(dict(sampled_functions), f, indent=4) + # else: + # print(json.dumps(dict(sampled_functions), indent=4), end="\n") + + def worker(input): + bin, (func, jsondata) = input + assert func.startswith("FUN_"), f"Function name {func} does not start with FUN_" + funcaddr = func.replace("FUN_", "0x") + #print(f"Bin: {bin}, Function: {func}") + + temp_dir = tempfile.mkdtemp() + #print(f"Temporary directory: {temp_dir}") + + temp_json_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False) + temp_json_file_name = temp_json_file.name + temp_json_file.close() + + # XXX Capture logs + + symbol_log = subprocess.check_output([ + "%s/support/analyzeHeadless" % os.environ["GHIDRA_INSTALL_DIR"], + temp_dir, + "DummyProject", + "-readOnly", + "-import", + os.path.join(BIN_DIR, bin), + "-postScript", + "/home/ed/Projects/DIRTY/DIRTY-ghidra/scripts/DIRTY_eval.py", + "/home/ed/Projects/DIRTY/DIRTY-ghidra/dirty/pred_mt.json", + bin, + funcaddr, + temp_json_file_name + ], stderr=subprocess.STDOUT) + + try: + with open(temp_json_file_name, "r") as temp_json_file: + symbol_data = json.loads(temp_json_file.read()) + except Exception as e: + return {"exception": str(e), "log": symbol_log.decode("utf-8")} + + # Nothing to rewrite + del symbol_data['rewritten_decompilation'] + + #print("temp data: %s", str(temp_data)) + + # Copy the binary to a temporary file name + temp_bin_file = tempfile.NamedTemporaryFile(delete=False) + temp_bin_file_name = temp_bin_file.name + temp_bin_file.close() + shutil.copyfile(os.path.join(BIN_DIR, bin), temp_bin_file_name) + + # Strip the binary + subprocess.check_call(["strip", temp_bin_file_name]) + + # print(f"Stripped binary: {temp_bin_file_name}") + + # Delete the json so we don't accidentally use the old results + os.truncate(temp_json_file_name, 0) + + # Run Ghidra on the stripped version + # XXX Capture logs + strip_log = subprocess.check_output([ + "%s/support/analyzeHeadless" % os.environ["GHIDRA_INSTALL_DIR"], + temp_dir, + "DummyProject2", + "-readOnly", + "-import", + temp_bin_file_name, + "-postScript", + "/home/ed/Projects/DIRTY/DIRTY-ghidra/scripts/DIRTY_eval.py", + "/home/ed/Projects/DIRTY/DIRTY-ghidra/dirty/pred_mt.json", + bin, + funcaddr, + temp_json_file_name + ], stderr=subprocess.STDOUT) + + try: + with open(temp_json_file_name, "r") as temp_json_file: + strip_data = json.loads(temp_json_file.read()) + except Exception as e: + return {"exception": str(e), "log": strip_log.decode("utf-8")} + + # Delete the temporary binary file + os.remove(temp_bin_file_name) + + # Delete the temporary file + os.remove(temp_json_file_name) + + d = {"strip": strip_data, "symbol": symbol_data} + + d["bin"] = bin + d["func"] = func + + d["strip"]["aligned_vars"] = list(set(d["strip"]["vars"]) & set(jsondata.keys())) + + d["strip"]["aligned_frac"] = len(d["strip"]["aligned_vars"]) / len(d["strip"]["vars"]) + + d["symbol"]["unknown_var_names"] = [v for v in d["symbol"]["vars"] if vocab.names[v] == UNKNOWN_ID] + + # d["symbol"]["origvars"] = [v[1] for v in jsondata.values()] + + # d["symbol"]["aligned_vars"] = list(set(d["symbol"]["vars"]) & {v[1] for v in jsondata.values()}) + # d["symbol"]["aligned_frac"] = len(d["symbol"]["aligned_vars"]) / len(d["symbol"]["vars"]) + + if False: + d["strip"]["log"] = strip_log.decode("utf-8") + d["symbol"]["log"] = symbol_log.decode("utf-8") + + d["strip"]["aligned_frac"] = len(d["strip"]["aligned_vars"]) / len(d["strip"]["vars"]) + + d["predictions"] = jsondata + + d["name_predictions"] = {k: v for k, v in jsondata.items() if v[1] != "" and k != v[1]} + d["type_predictions"] = {k: v for k, v in jsondata.items() if v[0] != ""} + + return d + + def worker_catch(args): + try: + return worker(args) + except Exception as e: + traceback.print_exc() + + return json.dumps({"exception": str(e)}) + + with multiprocessing.Pool(4) as pool: + results = list(tqdm.tqdm(pool.imap_unordered(worker_catch, list(sampled_functions)), total=args.num_samples)) + + total_variables_wo_symbols = sum(len(ex['strip']['vars']) for ex in results if "exception" not in ex) + 0.01 + total_variables_w_symbols = sum(len(ex['symbol']['vars']) for ex in results if "exception" not in ex) + 0.01 + #total_predictions = sum(len(ex['predictions']) for ex in results if "exception" not in ex) + total_name_predictions = sum(len(ex['name_predictions']) for ex in results if "exception" not in ex) + total_type_predictions = sum(len(ex['type_predictions']) for ex in results if "exception" not in ex) + + results = {"examples": results} + results["total_variables_wo_symbols"] = total_variables_wo_symbols + results["total_variables_w_symbols"] = total_variables_w_symbols + + #results["total_predictions"] = total_predictions + results["total_name_predictions"] = total_name_predictions + results["total_type_predictions"] = total_type_predictions + + #results["total_prediction_frac"] = total_predictions / total_variables_wo_symbols + results["total_name_prediction_frac"] = total_name_predictions / total_variables_wo_symbols + results["total_type_prediction_frac"] = total_type_predictions / total_variables_wo_symbols + + results["total_valid_ex"] = len([ex for ex in results["examples"] if "exception" not in ex]) + + if args.output_file: + with open(args.output_file, "w") as f: + hjson.dump(results, f, indent=2) + else: + print(hjson.dumps(results, indent=2), end="\n") diff --git a/dirty/utils/vocab.py b/dirty/utils/vocab.py index 13cf79c..1643b3f 100644 --- a/dirty/utils/vocab.py +++ b/dirty/utils/vocab.py @@ -4,19 +4,20 @@ vocab.py [options] TRAIN_FILE TYPE_FILE VOCAB_FILE Options: - -h --help Show this screen. - --use-bpe Use bpe - --size= vocab size [default: 10000] - --freq-cutoff= frequency cutoff [default: 5] + -h --help Show this screen. + --use-bpe Use bpe + --size= vocab size [default: 15000] + --character-coverage= character coverage [default: 0.9995] + --freq-cutoff= frequency cutoff [default: 5] + --make-tokens-for-ids Make tokens for ids (only use for IDA) """ from collections import Counter from itertools import chain -import torch -import pickle from docopt import docopt import json +import os import sentencepiece as spm from tqdm import tqdm @@ -27,7 +28,7 @@ SAME_VARIABLE_TOKEN = "" END_OF_VARIABLE_TOKEN = "" PAD_ID = 0 -assert PAD_ID == 0 +UNKNOWN_ID = 3 class VocabEntry: @@ -36,7 +37,7 @@ def __init__(self, subtoken_model_path=None): self.subtoken_model_path = subtoken_model_path if subtoken_model_path: - print(subtoken_model_path) + # print(subtoken_model_path) self.subtoken_model = spm.SentencePieceProcessor() self.subtoken_model.load(subtoken_model_path) @@ -91,7 +92,11 @@ def params(self): params = dict( unk_id=self.unk_id, word2id=self.word2id, - subtoken_model_path=self.subtoken_model_path, + subtoken_model_path=( + os.path.basename(self.subtoken_model_path) + if self.subtoken_model_path is not None + else None + ), ) if hasattr(self, "word_freq"): params["word_freq"] = self.word_freq @@ -102,15 +107,20 @@ def save(self, path): json.dump(self.params, open(path, "w"), indent=2) @classmethod - def load(cls, path=None, params=None): + def load(cls, path=None, dir=None, params=None): if path: - print(path) params = json.load(open(path, "r")) else: assert params, "Params must be given when path is None!" - if "subtoken_model_path" in params: - subtoken_model_path = params["subtoken_model_path"] + if params.get("subtoken_model_path", None) is not None: + assert dir is not None + # temporary: for backward compatibility + # the submodel path is always in the same dir as the main vocab. + # but older versions don't do the basename in advance, so we'll do + # it here + params["subtoken_model_path"] = os.path.basename(params["subtoken_model_path"]) + subtoken_model_path = os.path.join(dir, params["subtoken_model_path"]) else: subtoken_model_path = None @@ -213,13 +223,14 @@ def save(self, path): @classmethod def load(cls, path): + dir = os.path.dirname(os.path.realpath(path)) params = json.load(open(path, "r")) entries = dict() for key, val in params.items(): # if key in ('grammar', ): # entry = Grammar.load(val) # else: - entry = VocabEntry.load(params=val) + entry = VocabEntry.load(params=val, dir=dir) entries[key] = entry return cls(**entries) @@ -229,6 +240,8 @@ def load(cls, path): args = docopt(__doc__) vocab_size = int(args["--size"]) + character_coverage = float(args["--character-coverage"]) + make_tokens_for_ids = args["--make-tokens-for-ids"] vocab_file = args["VOCAB_FILE"] type_file = args["TYPE_FILE"] train_set = Dataset(args["TRAIN_FILE"]) @@ -257,7 +270,7 @@ def load(cls, path): ) src_code_tokens_file = vocab_file + ".src_code_tokens.txt" - preserved_tokens = set() + identifier_names = set() name_counter = Counter() reg_counter = Counter() with open(src_code_tokens_file, "w") as f_src_token: @@ -270,10 +283,11 @@ def load(cls, path): filter(lambda x: isinstance(x, Register), example.target.keys()), ) ) - name_counter.update(map(lambda x: x.name, example.target.values())) + for varlist in example.target.values(): + name_counter.update(map(lambda x: x.name, varlist)) for token in code_tokens: if token.startswith("@@") and token.endswith("@@"): - preserved_tokens.add(token) + identifier_names.add(token) f_src_token.write(" ".join(code_tokens) + "\n") name_vocab_entry = VocabEntry.from_counter( name_counter, size=len(name_counter), freq_cutoff=int(args["--freq-cutoff"]) @@ -282,16 +296,31 @@ def load(cls, path): reg_counter, size=len(reg_counter), freq_cutoff=int(args["--freq-cutoff"]) ) - assert args["--use-bpe"] - print("use bpe") + assert args["--use-bpe"], "Please use BPE for tokenization" print("building source code tokens vocabulary") # train subtoken models - preserved_tokens = ",".join(preserved_tokens) + + if args["--make-tokens-for-ids"]: + identifier_names = ",".join(identifier_names) + else: + identifier_names = "" + + # DIRTY for Hex-Rays made each source variable identifier a preserved token. + # Since the identifiers were mainly a1, a2, v1, v2, etc., this was fine. + # But Ghidra likes to preface addresses and offsets to automatic variable + # names, e.g., local_8110, so we SHOULD use subwords for these. If we have + # more than 1000 preserved tokens, its a sign that something has gone wrong. + print(f"There are {len(identifier_names)} preserved tokens (variable names)") + assert ( + len(identifier_names) < 1000 + ), "There are too many preserved tokens. If you used --make-tokens-for-ids, turn it off. If you did not, please file a bug report." + spm.SentencePieceTrainer.Train( f"--add_dummy_prefix=false --pad_id={PAD_ID} --bos_id=1 --eos_id=2 --unk_id=3 " - f"--user_defined_symbols={preserved_tokens} " + f"--user_defined_symbols={identifier_names} " f"--vocab_size={vocab_size} " + f"--character_coverage={character_coverage} " f"--model_prefix={vocab_file}.src_code_tokens --model_type=bpe " f"--input={src_code_tokens_file}" ) diff --git a/requirements.txt b/requirements.txt index c940cf6..42d021d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,14 @@ -docopt==0.6.2 -editdistance==0.5.3 -jsonnet==0.16.0 -pygments==2.6.1 -pytorch-lightning==1.0.8 -sentencepiece==0.1.91 -ujson==3.2.0 -wandb==0.10.12 -webdataset==0.1.40 -pyelftools==0.29 -pandas -openpyxl +docopt~=0.6.2 +# only needed for dire eval? +#editdistance~=0.5.3 +jsonnet~=0.16.0 +pygments~=2.6.1 +pytorch-lightning +torchmetrics +sentencepiece~=0.1.91 +ujson~=3.2.0 +wandb +webdataset~=0.2.100 +pyelftools~=0.29 +jsonlines +numpy>=1.20,<2.0 diff --git a/scripts/DIRTY_eval.py b/scripts/DIRTY_eval.py new file mode 100644 index 0000000..5135474 --- /dev/null +++ b/scripts/DIRTY_eval.py @@ -0,0 +1,289 @@ +# Evaluation helper script for DIRTY Ghidra +from functools import lru_cache +from ghidra.app.decompiler import DecompInterface +from ghidra.util.task import ConsoleTaskMonitor +from ghidra.program.model.pcode import HighFunctionDBUtil +from ghidra.program.model.symbol import SourceType +from ghidra.program.model.data import PointerDataType, ArrayDataType, StructureDataType, UnionDataType, TypedefDataType +from ghidra.program.model.data import BuiltInDataTypeManager +from ghidra.app.services import DataTypeManagerService +from ghidra.app.plugin.core.analysis import AutoAnalysisManager +import json +import random + +import ghidra_types + +debug = False + +def abort(s): + raise Exception(s) + +TYPELIB_PATH = "/home/ed/Downloads/typelib.json" + +codec = ghidra_types.TypeLibCodec() +typelib = codec.decode(open(TYPELIB_PATH, "r").read()) + +name_to_type = {} + +def all_typenames(): + for _size, typelist in typelib.items(): + for _freq, typeentry in typelist: + yield str(typeentry) + +def find_types_by_name(name): + for size, typelist in typelib.items(): + for _freq, typeentry in typelist: + typename = str(typeentry) + #if size == 160000: # or "double" in typename: + #if debug and name in typename: + # print(f"Hmm: {name} ==? {typename}") + if typename == name: + yield typeentry + return + +def find_type_by_name(name): + try: + return next(find_types_by_name(name)) + except StopIteration: + print(f"Unable to find type {name} in typelib. Hopefully it is a built-in!") + return ghidra_types.TypeInfo(name=name, size=0) + +def find_type_in_ghidra_typemanager(name, dtm): + if dtm is None: + dtm = currentProgram().getDataTypeManager() + + al = dtm.getAllDataTypes() + output_list = [dt for dt in al if dt.getName() == name] + + if len(output_list) > 0: + return output_list[0] + else: + return None + +def find_type_in_any_ghidra_typemanager(name): + #dtms = [AutoAnalysisManager.getAnalysisManager(currentProgram()).getDataTypeManagerService()] + dtms = [currentProgram().getDataTypeManager(), BuiltInDataTypeManager.getDataTypeManager()] + #dtms = state().getTool().getService(DataTypeManagerService).getDataTypeManagers() + for dtm in dtms: + output = find_type_in_ghidra_typemanager(name, dtm) + if output is not None: + return output + return None + +@lru_cache(maxsize=20000) +def build_ghidra_type(typelib_type): + #print(f"build_ghidra_type {typelib_type} {typelib_type.__dict__} {type(typelib_type)}") + + # First check our cache. This is important for self-referential types + # (e.g., linked lists). + if str(typelib_type) in name_to_type: + return name_to_type[str(typelib_type)] + + out = find_type_in_any_ghidra_typemanager(str(typelib_type)) + if out is not None: + return out + + if type(typelib_type) == ghidra_types.TypeInfo: + + print(f"WARNING: {typelib_type.name} is a TypeInfo type: {typelib_type.debug}") + print(f"WARNING: Unable to find type {typelib_type.name} in Ghidra.") + + match typelib_type.size: + case 1: + return find_type_in_any_ghidra_typemanager("byte") + case 4: + return find_type_in_any_ghidra_typemanager("uint32_t") + case 8: + return find_type_in_any_ghidra_typemanager("uint64_t") + case _: + abort(f"Unknown type with unusual size: {typelib_type.size} {typelib_type.name}") + + elif type(typelib_type) == ghidra_types.Array: + element_type = build_ghidra_type(find_type_by_name(typelib_type.element_type)) + return ArrayDataType(element_type, typelib_type.nelements, typelib_type.element_size) + elif type(typelib_type) == ghidra_types.Pointer: + target_type = build_ghidra_type(find_type_by_name(typelib_type.target_type_name)) + return PointerDataType(target_type) + # Make type. + elif type(typelib_type) == ghidra_types.Struct or type(typelib_type) == ghidra_types.Union: + new_struct = StructureDataType(typelib_type.name, typelib_type.size) if type(typelib_type) == ghidra_types.Struct else UnionDataType(typelib_type.name) + # We need to immediately make this available in case we have a self-referential type. + name_to_type[str(typelib_type)] = new_struct + offset = 0 + for member in (typelib_type.layout if type(typelib_type) == ghidra_types.Struct else typelib_type.members): + if type(member) == ghidra_types.UDT.Padding: + # Don't do anything? + pass + # new_struct.insertAtOffset(offset, VoidDataType(), member.size) + elif type(member) == ghidra_types.UDT.Field: + member_type = build_ghidra_type(find_type_by_name(member.type_name)) + if type(typelib_type) == ghidra_types.Struct: + new_struct.insertAtOffset(offset, member_type, member.size, member.name, "") + elif type(typelib_type) == ghidra_types.Union: + new_struct.add(member_type, member.size, member.name, "") + else: + abort("Unknown member type: " + str(type(typelib_type))) + else: + abort("Unknown member type: " + str(type(member))) + + offset = offset + member.size + + #field_type = build_ghidra_type(find_type_by_name(field.type)) + #new_struct.add(field_type, field.name, field.comment) + return new_struct + elif type(typelib_type) == ghidra_types.TypeDef: + other_type = build_ghidra_type(find_type_by_name(typelib_type.other_type_name)) + return TypedefDataType(typelib_type.name, other_type) + else: + abort(f"Unknown type: {type(typelib_type)} {typelib_type}") + +def test_types(): + total = 0 + succ = 0 + + l = list(all_typenames()) + random.shuffle(l) + + #l = ["longlong[20][10]"] + #l = ["xen_string_string_map"] + #debug = True + + for typename in l: + print(f"Trying to build type {typename}") + total += 1 + + if monitor().isCancelled(): + break + + monitor().setMessage(f"Building type {typename}") + + try: + ti = find_type_by_name(typename) + except Exception as e: + print(f"Failed to find type {typename} exception: {e}") + continue + + try: + gtype = build_ghidra_type(ti) + assert gtype is not None, "build_ghidra_type returned None." + print(f"Successfully built type {typename} in Ghidra: {gtype}") + except Exception as e: + print(f"Failed to build ghidra type {typename} exception: {e}") + #break + continue + + succ += 1 + + print(f"Successfully built {succ}/{total} {float(succ)/total} types.") + + exit(0) + +print ([x for x in getScriptArgs()]) + +#exeName = currentProgram().getName() + +jsonFile = getScriptArgs()[0] +exeName = getScriptArgs()[1] +targetFunction = getScriptArgs()[2] +targetFunctionAddr = int(targetFunction, 16) +outputJsonFile = getScriptArgs()[3] + +#print("Using JSON file: " + jsonFile) +#print("Target function: " + targetFunctionAddr) + +#jsonFile = askFile("Select JSON file", "Open") +print(f"Parsing JSON for {exeName}") +jsonObj = json.load(open(jsonFile)) + +if exeName in jsonObj: + jsonObj = jsonObj[exeName] +elif len(jsonObj) == 1: + jsonObj = jsonObj[list(jsonObj.keys())[0]] +else: + abort(f"Unable to find the executable {exeName} in the JSON file.") + +#current_location = currentLocation() + +#for f in currentProgram().getFunctionManager().getFunctions(True): +# print(f.getName()) + +func = next(f for f in currentProgram().getFunctionManager().getFunctions(True) if f.getEntryPoint().getOffset() == targetFunctionAddr) + +if func is None: + abort("Unable to find function " + targetFunction) + +print("Found target function: " + func.getName()) + +assert func is not None + +funcName = f"FUN_%08x" % targetFunctionAddr +if funcName in jsonObj: + jsonObj = jsonObj[funcName] +else: + abort(f"Unable to find function %s in the JSON file." % funcName) + +dtm = currentProgram().getDataTypeManager() + + # Set up the decompiler +decompiler = DecompInterface() +decompiler.openProgram(func.getProgram()) + +# Decompile the current function +def decompile(func): + print("Decompiling function " + func.getName() + "...") + results = decompiler.decompileFunction(func, 0, ConsoleTaskMonitor()) + if not results.decompileCompleted(): + abort("Decompilation failed.") + + # Get the high-level representation of the function + high_function = results.getHighFunction() + if not high_function: + abort("Failed to get high-level function representation.") + + return results, high_function + +results, high_function = decompile(func) + +output = {} + +output['decompilation'] = results.getDecompiledFunction().getC() +output['vars'] = [v.getName() for v in high_function.getLocalSymbolMap().getSymbols()] +output['params'] = [v.getName() for v in high_function.getLocalSymbolMap().getSymbols() if v.isParameter()] + +# Rename variables (if no symbols) +for var in high_function.getLocalSymbolMap().getSymbols(): + + original_name = var.getName() + + if original_name in jsonObj: + new_type_name, new_name = jsonObj[original_name] + if new_type_name != "disappear" and new_name != "" and new_type_name != "": + print("Renaming " + original_name + " to " + new_name + ".") + + new_type = None + + print(f"Attempting to retype {original_name}/{new_name} to {new_type_name}") + + try: + ti = find_type_by_name(new_type_name) + new_type = build_ghidra_type(ti) + print(f"Changing type of {original_name}/{new_name} to {new_type_name}: {new_type}") + except Exception as e: + print(f"Failed to find or build type {new_type_name} exception: {e}") + + try: + HighFunctionDBUtil.updateDBVariable(var, new_name, new_type, SourceType.USER_DEFINED) + except Exception as e: + print(f"Failed to rename/retype {original_name} to {new_name}/{new_type_name} exception: {e}") + + else: + print("Skipping disappear/unknown variable " + original_name + ".") + else: + print("No new name for " + original_name + " in JSON file.") + +# Re-decompile +results, high_function = decompile(func) + +output['rewritten_decompilation'] = results.getDecompiledFunction().getC() + +json.dump(output, open(outputJsonFile, "w"), indent=4) diff --git a/scripts/DIRTY_import.py b/scripts/DIRTY_import.py new file mode 100644 index 0000000..30f7509 --- /dev/null +++ b/scripts/DIRTY_import.py @@ -0,0 +1,248 @@ +# Import results from DIRTY json file into Ghidra. +from functools import lru_cache +from ghidra.app.decompiler import DecompInterface +from ghidra.util.task import ConsoleTaskMonitor +from ghidra.program.model.pcode import HighFunctionDBUtil +from ghidra.program.model.symbol import SourceType +from ghidra.program.model.data import PointerDataType, ArrayDataType, StructureDataType, UnionDataType, TypedefDataType +from ghidra.app.services import DataTypeManagerService +import json +import random + +import ghidra_types + +debug = False + +def abort(s): + raise Exception(s) + +TYPELIB_PATH = "/home/ed/Downloads/typelib.json" + +codec = ghidra_types.TypeLibCodec() +typelib = codec.decode(open(TYPELIB_PATH, "r").read()) + +name_to_type = {} + +def all_typenames(): + for _size, typelist in typelib.items(): + for _freq, typeentry in typelist: + yield str(typeentry) + +def find_types_by_name(name): + for size, typelist in typelib.items(): + for _freq, typeentry in typelist: + typename = str(typeentry) + #if size == 160000: # or "double" in typename: + #if debug and name in typename: + # print(f"Hmm: {name} ==? {typename}") + if typename == name: + yield typeentry + return + +def find_type_by_name(name): + try: + return next(find_types_by_name(name)) + except StopIteration: + print(f"Unable to find type {name} in typelib. Hopefully it is a built-in!") + return ghidra_types.TypeInfo(name=name, size=0) + +def find_type_in_ghidra_typemanager(name, dtm): + if dtm is None: + dtm = currentProgram().getDataTypeManager() + + al = dtm.getAllDataTypes() + output_list = [dt for dt in al if dt.getName() == name] + + if len(output_list) > 0: + return output_list[0] + else: + return None + +def find_type_in_any_ghidra_typemanager(name): + dtms = state().getTool().getService(DataTypeManagerService).getDataTypeManagers() + for dtm in dtms: + output = find_type_in_ghidra_typemanager(name, dtm) + if output is not None: + return output + return None + +@lru_cache(maxsize=20000) +def build_ghidra_type(typelib_type): + #print(f"build_ghidra_type {typelib_type} {typelib_type.__dict__} {type(typelib_type)}") + + # First check our cache. This is important for self-referential types + # (e.g., linked lists). + if str(typelib_type) in name_to_type: + return name_to_type[str(typelib_type)] + + out = find_type_in_any_ghidra_typemanager(str(typelib_type)) + if out is not None: + return out + + if type(typelib_type) == ghidra_types.TypeInfo: + + print(f"WARNING: {typelib_type.name} is a TypeInfo type: {typelib_type.debug}") + print(f"WARNING: Unable to find type {typelib_type.name} in Ghidra.") + + match typelib_type.size: + case 1: + return find_type_in_any_ghidra_typemanager("byte") + case 4: + return find_type_in_any_ghidra_typemanager("uint32_t") + case 8: + return find_type_in_any_ghidra_typemanager("uint64_t") + case _: + abort(f"Unknown type with unusual size: {typelib_type.size} {typelib_type.name}") + + elif type(typelib_type) == ghidra_types.Array: + element_type = build_ghidra_type(find_type_by_name(typelib_type.element_type)) + return ArrayDataType(element_type, typelib_type.nelements, typelib_type.element_size) + elif type(typelib_type) == ghidra_types.Pointer: + target_type = build_ghidra_type(find_type_by_name(typelib_type.target_type_name)) + return PointerDataType(target_type) + # Make type. + elif type(typelib_type) == ghidra_types.Struct or type(typelib_type) == ghidra_types.Union: + new_struct = StructureDataType(typelib_type.name, typelib_type.size) if type(typelib_type) == ghidra_types.Struct else UnionDataType(typelib_type.name) + # We need to immediately make this available in case we have a self-referential type. + name_to_type[str(typelib_type)] = new_struct + offset = 0 + for member in (typelib_type.layout if type(typelib_type) == ghidra_types.Struct else typelib_type.members): + if type(member) == ghidra_types.UDT.Padding: + # Don't do anything? + pass + # new_struct.insertAtOffset(offset, VoidDataType(), member.size) + elif type(member) == ghidra_types.UDT.Field: + member_type = build_ghidra_type(find_type_by_name(member.type_name)) + if type(typelib_type) == ghidra_types.Struct: + new_struct.insertAtOffset(offset, member_type, member.size, member.name, "") + elif type(typelib_type) == ghidra_types.Union: + new_struct.add(member_type, member.size, member.name, "") + else: + abort("Unknown member type: " + str(type(typelib_type))) + else: + abort("Unknown member type: " + str(type(member))) + + offset = offset + member.size + + #field_type = build_ghidra_type(find_type_by_name(field.type)) + #new_struct.add(field_type, field.name, field.comment) + return new_struct + elif type(typelib_type) == ghidra_types.TypeDef: + other_type = build_ghidra_type(find_type_by_name(typelib_type.other_type_name)) + return TypedefDataType(typelib_type.name, other_type) + else: + abort(f"Unknown type: {type(typelib_type)} {typelib_type}") + +def test_types(): + total = 0 + succ = 0 + + l = list(all_typenames()) + random.shuffle(l) + + #l = ["longlong[20][10]"] + #l = ["xen_string_string_map"] + #debug = True + + for typename in l: + print(f"Trying to build type {typename}") + total += 1 + + if monitor().isCancelled(): + break + + monitor().setMessage(f"Building type {typename}") + + try: + ti = find_type_by_name(typename) + except Exception as e: + print(f"Failed to find type {typename} exception: {e}") + continue + + try: + gtype = build_ghidra_type(ti) + assert gtype is not None, "build_ghidra_type returned None." + print(f"Successfully built type {typename} in Ghidra: {gtype}") + except Exception as e: + print(f"Failed to build ghidra type {typename} exception: {e}") + #break + continue + + succ += 1 + + print(f"Successfully built {succ}/{total} {float(succ)/total} types.") + + exit(0) + +exeName = currentProgram().getName() + +jsonFile = askFile("Select JSON file", "Open") +print("Parsing JSON") +jsonObj = json.load(open(jsonFile.getAbsolutePath())) + +if exeName in jsonObj: + jsonObj = jsonObj[exeName] +elif len(jsonObj) == 1: + jsonObj = jsonObj[list(jsonObj.keys())[0]] +else: + abort("Unable to find the executable in the JSON file.") + +current_location = currentLocation() + +# Get the function containing this location. +current_function = getFunctionContaining(current_location.getAddress()) + +assert current_function is not None + +funcName = current_function.getName() +if funcName in jsonObj: + jsonObj = jsonObj[funcName] +else: + abort("Unable to find function in the JSON file.") + +print(jsonObj) + + # Set up the decompiler +decompiler = DecompInterface() +decompiler.openProgram(current_function.getProgram()) + +# Decompile the current function +print("Decompiling function " + current_function.getName() + "...") +results = decompiler.decompileFunction(current_function, 0, ConsoleTaskMonitor()) +if not results.decompileCompleted(): + abort("Decompilation failed.") + +# Get the high-level representation of the function +high_function = results.getHighFunction() +if not high_function: + abort("Failed to get high-level function representation.") + +# Example: rename a specific variable (change the criteria as needed) +for var in high_function.getLocalSymbolMap().getSymbols(): + + original_name = var.getName() + + if original_name in jsonObj: + new_type_name, new_name = jsonObj[original_name] + if new_type_name != "disappear": + print("Renaming " + original_name + " to " + new_name + ".") + + new_type = None + + + print(f"Attempting to retype {original_name}/{new_name} to {new_type_name}") + + try: + ti = find_type_by_name(new_type_name) + new_type = build_ghidra_type(ti) + print(f"Changing type of {original_name}/{new_name} to {new_type_name}: {new_type}") + except Exception as e: + print(f"Failed to find or build type {new_type_name} exception: {e}") + + HighFunctionDBUtil.updateDBVariable(var, new_name, new_type, SourceType.USER_DEFINED) + + + else: + print("Skipping disappear variable " + original_name + ".") + else: + print("No new name for " + original_name + " in JSON file.") diff --git a/scripts/DIRTY_infer.py b/scripts/DIRTY_infer.py new file mode 100644 index 0000000..01cbc54 --- /dev/null +++ b/scripts/DIRTY_infer.py @@ -0,0 +1,375 @@ +# Run inference on current function using DIRTY +from functools import lru_cache +from ghidra.app.decompiler import DecompInterface +from ghidra.util.task import ConsoleTaskMonitor +from ghidra.program.model.pcode import HighFunctionDBUtil +from ghidra.program.model.symbol import SourceType +from ghidra.program.model.data import ( + PointerDataType, + ArrayDataType, + StructureDataType, + UnionDataType, + TypedefDataType, +) +from ghidra.app.services import DataTypeManagerService +from ghidra.program.model.address import Address +import json +import sys +import os +import _jsonnet +import pathlib +import tqdm +import traceback + +DIRTY_PATH = pathlib.Path(os.path.realpath(__file__)).parent.parent.resolve() + +TYPELIB_PATH = os.path.join(DIRTY_PATH, "dirty", "data1", "typelib_complete.json") + +DIRTY_CONFIG = os.path.join(DIRTY_PATH, "dirty", "multitask.xfmr.jsonnet") + +MODEL_CHECKPOINT = os.path.join(DIRTY_PATH, "dirty", "data1", "model.ckpt") + +# Allow loading from the dirty directories. + +sys.path.append(os.path.join(DIRTY_PATH, "dirty")) + +# Load dirty modules + +import utils.ghidra_types +from utils.ghidra_function import Function, CollectedFunction +from utils.ghidra_types import TypeLib, TypeInfo +from utils.ghidra_variable import Location, Register, Stack, Variable + +from model.model import TypeReconstructionModel + +import utils.infer + +debug = False + + +def abort(s): + raise Exception(s) + +codec = utils.ghidra_types.TypeLibCodec() +typelib = codec.decode(open(TYPELIB_PATH, "r").read()) + +name_to_type = {} + + +def all_typenames(): + for _size, typelist in typelib.items(): + for _freq, typeentry in typelist: + yield str(typeentry) + + +def find_types_by_name(name): + for size, typelist in typelib.items(): + for _freq, typeentry in typelist: + typename = str(typeentry) + if typename == name: + yield typeentry + return + + +def find_type_by_name(name): + try: + return next(find_types_by_name(name)) + except StopIteration: + print(f"Unable to find type {name} in typelib. Hopefully it is a built-in!") + return utils.ghidra_types.TypeInfo(name=name, size=0) + + +def find_type_in_ghidra_typemanager(name, dtm): + if dtm is None: + dtm = currentProgram().getDataTypeManager() + + al = dtm.getAllDataTypes() + output_list = [dt for dt in al if dt.getName() == name] + + if len(output_list) > 0: + return output_list[0] + else: + return None + + +def find_type_in_any_ghidra_typemanager(name): + tool = state().getTool() + if tool is not None: + dtms = ( + state().getTool().getService(DataTypeManagerService).getDataTypeManagers() + ) + else: + dtms = [currentProgram().getDataTypeManager()] + for dtm in dtms: + output = find_type_in_ghidra_typemanager(name, dtm) + if output is not None: + return output + return None + + +@lru_cache(maxsize=20000) +def build_ghidra_type(typelib_type): + + # First check our cache. This is important for self-referential types + # (e.g., linked lists). + if str(typelib_type) in name_to_type: + return name_to_type[str(typelib_type)] + + out = find_type_in_any_ghidra_typemanager(str(typelib_type)) + if out is not None: + return out + + if type(typelib_type) == utils.ghidra_types.TypeInfo: + + print(f"WARNING: {typelib_type.name} is a TypeInfo type: {typelib_type.debug}") + print(f"WARNING: Unable to find type {typelib_type.name} in Ghidra.") + + match typelib_type.size: + case 1: + return find_type_in_any_ghidra_typemanager("byte") + case 4: + return find_type_in_any_ghidra_typemanager("uint32_t") + case 8: + return find_type_in_any_ghidra_typemanager("uint64_t") + case _: + abort( + f"Unknown type with unusual size: {typelib_type.size} {typelib_type.name}" + ) + + elif type(typelib_type) == utils.ghidra_types.Array: + element_type = build_ghidra_type(find_type_by_name(typelib_type.element_type)) + return ArrayDataType( + element_type, typelib_type.nelements, typelib_type.element_size + ) + elif type(typelib_type) == utils.ghidra_types.Pointer: + target_type = build_ghidra_type( + find_type_by_name(typelib_type.target_type_name) + ) + return PointerDataType(target_type) + # Make type. + elif ( + type(typelib_type) == utils.ghidra_types.Struct + or type(typelib_type) == utils.ghidra_types.Union + ): + new_struct = ( + StructureDataType(typelib_type.name, typelib_type.size) + if type(typelib_type) == utils.ghidra_types.Struct + else UnionDataType(typelib_type.name) + ) + # We need to immediately make this available in case we have a self-referential type. + name_to_type[str(typelib_type)] = new_struct + offset = 0 + for member in ( + typelib_type.layout + if type(typelib_type) == utils.ghidra_types.Struct + else typelib_type.members + ): + if type(member) == utils.ghidra_types.UDT.Padding: + # Don't do anything? + pass + # new_struct.insertAtOffset(offset, VoidDataType(), member.size) + elif type(member) == utils.ghidra_types.UDT.Field: + member_type = build_ghidra_type(find_type_by_name(member.type_name)) + if type(typelib_type) == utils.ghidra_types.Struct: + new_struct.insertAtOffset( + offset, member_type, member.size, member.name, "" + ) + elif type(typelib_type) == utils.ghidra_types.Union: + new_struct.add(member_type, member.size, member.name, "") + else: + abort("Unknown member type: " + str(type(typelib_type))) + else: + abort("Unknown member type: " + str(type(member))) + + offset = offset + member.size + + # field_type = build_ghidra_type(find_type_by_name(field.type)) + # new_struct.add(field_type, field.name, field.comment) + return new_struct + elif type(typelib_type) == utils.ghidra_types.TypeDef: + other_type = build_ghidra_type(find_type_by_name(typelib_type.other_type_name)) + return TypedefDataType(typelib_type.name, other_type) + else: + abort(f"Unknown type: {type(typelib_type)} {typelib_type}") + + +def do_infer(cf, ghidra_function, redecompile=False): + + output = {} + + config = json.loads(_jsonnet.evaluate_file(DIRTY_CONFIG)) + + # Set wd so the model can find data1/vocab.bpe10000 + + os.chdir(os.path.join(DIRTY_PATH, "dirty")) + + model = TypeReconstructionModel.load_from_checkpoint( + checkpoint_path=MODEL_CHECKPOINT, config=config + ) + model.eval() + + model_output, model_output_multi, example_info, other_outputs = utils.infer.infer(config, model, cf) + + output['model_output'] = model_output + output['model_output_multi'] = model_output_multi + output['other_info'] = {'example_info': example_info, 'other_outputs': other_outputs} + + print(f"Model output: {model_output}") + + # Set up the decompiler + decompiler = DecompInterface() + decompiler.openProgram(ghidra_function.getProgram()) + + # Decompile the current function + print("Decompiling function " + ghidra_function.getName() + "...") + results = decompiler.decompileFunction(ghidra_function, 0, ConsoleTaskMonitor()) + if not results.decompileCompleted(): + abort("Decompilation failed.") + + output["original_decompile"] = results.getDecompiledFunction().getC() + + # Get the high-level representation of the function + high_function = results.getHighFunction() + if not high_function: + abort("Failed to get high-level function representation.") + + # Example: rename a specific variable (change the criteria as needed) + for var in high_function.getLocalSymbolMap().getSymbols(): + + original_name = var.getName() + + if original_name in model_output: + new_type_name, new_name = model_output[original_name] + if new_type_name != "disappear": + + if new_name in ["", "", "disappear"]: + new_name = original_name + + if new_name != original_name: + print("Renaming " + original_name + " to " + new_name + ".") + + new_type = None + + if new_type_name != "": + print( + f"Attempting to retype {original_name}/{new_name} to {new_type_name}" + ) + + try: + ti = find_type_by_name(new_type_name) + new_type = build_ghidra_type(ti) + print( + f"Changing type of {original_name}/{new_name} to {new_type_name}: {new_type}" + ) + except Exception as e: + print( + f"Failed to find or build type {new_type_name} exception: {e}" + ) + + try: + HighFunctionDBUtil.updateDBVariable( + var, new_name, new_type, SourceType.USER_DEFINED + ) + except Exception as e: + print(f"Failed to update variable {original_name} exception: {e}") + + else: + print("Skipping disappear variable " + original_name + ".") + else: + print("No new name/type for " + original_name + " in prediction.") + + if redecompile: + + addrSet = ghidra_function.getBody() + codeUnits = currentProgram().getListing().getCodeUnits(addrSet, True) + asm = "" + for codeUnit in codeUnits: + asm += f"{hex(codeUnit.getAddress().getOffset())}: {codeUnit.toString()}\n" + output["disassembly"] = asm + + + results = decompiler.decompileFunction(ghidra_function, 0, ConsoleTaskMonitor()) + if not results.decompileCompleted(): + abort("Re-decompilation failed.") + output["decompile"] = results.getDecompiledFunction().getC() + + return output + + +if sys.version_info.major < 3: + abort( + "You are not running Python 3. This is probably a sign that you did not correctly configure Ghidrathon." + ) + +if not isRunningHeadless(): + + current_location = currentLocation() + + # Get the function containing this location. + ghidra_function = getFunctionContaining(current_location.getAddress()) + + assert ghidra_function is not None + + cf = utils.infer.ghidra_obtain_cf(ghidra_function) + do_infer(cf, ghidra_function) + +else: + + print("We are in headless mode.") + + args = getScriptArgs() + outfile = args[0] if len(args) > 0 else "infer_success.txt" + + # Argument 0 is the output file for infer_success.txt. This is used by the + # CI. Argument 1 is the target function to infer, if present. This is used + # by the huggingface space. + targetFunAddr = hex(int(args[1])) if len(args) >= 2 else None + + function_manager = currentProgram().getFunctionManager() + + if targetFunAddr is not None: # Huggingface space + + try: + print(f"HF mode: {targetFunAddr}") + addr = currentProgram().getAddressFactory().getAddress(targetFunAddr) + print(f"Address: {addr}") + fun = function_manager.getFunctionAt(addr) + assert fun is not None, f"Unable to find function {targetFunAddr}" + + cf = utils.infer.ghidra_obtain_cf(fun) + infer_out = do_infer(cf, fun, redecompile=True) + + json_output = {**infer_out} + + json.dump(json_output, open(outfile, "w")) + except Exception as e: + json_output = {"exception": str(e)} + print( + f"{targetFunAddr} failed because {e.__class__.__name__}: {str(e)}" + ) + json.dump(json_output, open(outfile, "w")) + + else: # CI mode + print("CI mode") + # Get all functions as an iterator + function_iter = function_manager.getFunctions(True) + + # Keep trying functions until we find one that works! This is needed + # because small/trivial functions will fail. + for ghidra_function in tqdm.tqdm(function_iter): + if ghidra_function.isThunk() or ghidra_function.isExternal(): + continue + print(f"Trying {ghidra_function}") + try: + cf = utils.infer.ghidra_obtain_cf(ghidra_function) + do_infer(cf, ghidra_function) + print("Success!") + + open(outfile, "w").write("success") + # break + except Exception as e: + print( + f"{ghidra_function} failed because {e.__class__.__name__}: {str(e)}, trying next function" + ) + traceback.print_exc() + continue