diff --git a/sync/SyncTalk/.DS_Store b/sync/SyncTalk/.DS_Store
new file mode 100644
index 00000000..8950e42c
Binary files /dev/null and b/sync/SyncTalk/.DS_Store differ
diff --git a/sync/SyncTalk/.ipynb_checkpoints/Dockerfile-checkpoint b/sync/SyncTalk/.ipynb_checkpoints/Dockerfile-checkpoint
new file mode 100644
index 00000000..19f3fbec
--- /dev/null
+++ b/sync/SyncTalk/.ipynb_checkpoints/Dockerfile-checkpoint
@@ -0,0 +1,61 @@
+# 使用 Ubuntu 18.04 作为基础镜像
+FROM ubuntu:18.04
+
+# 设置环境变量,防止交互式安装时出现提示
+ENV DEBIAN_FRONTEND=noninteractive
+ENV PATH="/opt/conda/bin:$PATH"
+ENV CUDA_HOME="/usr/local/cuda-11.3"
+ENV LD_LIBRARY_PATH="/usr/local/cuda-11.3/lib64:$LD_LIBRARY_PATH"
+
+# 更新系统并安装必要工具
+RUN apt-get update && apt-get install -y \
+ wget \
+ curl \
+ git \
+ build-essential \
+ software-properties-common \
+ portaudio19-dev \
+ && apt-get clean
+
+# 安装 CUDA 11.3 和 cuDNN 8.2
+RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb && \
+ dpkg -i cuda-keyring_1.0-1_all.deb && \
+ apt-get update && apt-get install -y \
+ cuda-11-3 \
+ libcudnn8=8.2.1.*-1+cuda11.3 \
+ libcudnn8-dev=8.2.1.*-1+cuda11.3 \
+ && apt-get clean
+
+# 安装 Miniconda
+RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \
+ bash /tmp/miniconda.sh -b -p /opt/conda && \
+ rm /tmp/miniconda.sh && \
+ /opt/conda/bin/conda clean -tipsy
+
+# 创建 Conda 环境并安装 Python 3.8.8
+RUN conda create -n synctalk python=3.8.8 -y && \
+ conda clean -a -y
+
+# 激活环境并安装 PyTorch 和依赖
+RUN /bin/bash -c "source activate synctalk && \
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 && \
+ pip install -r /app/requirements.txt && \
+ pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1121/download.html && \
+ pip install tensorflow-gpu==2.8.1"
+
+# 复制项目文件到容器
+WORKDIR /app
+COPY . /app
+
+# 激活环境并安装本地包
+RUN /bin/bash -c "source activate synctalk && \
+ pip install ./freqencoder && \
+ pip install ./shencoder && \
+ pip install ./gridencoder && \
+ pip install ./raymarching"
+
+# 确保脚本可执行
+RUN chmod +x install.sh inference.sh main.sh
+
+# 默认执行 main.sh
+CMD ["/bin/bash"]
diff --git a/sync/SyncTalk/.ipynb_checkpoints/inference-checkpoint.sh b/sync/SyncTalk/.ipynb_checkpoints/inference-checkpoint.sh
new file mode 100644
index 00000000..80390460
--- /dev/null
+++ b/sync/SyncTalk/.ipynb_checkpoints/inference-checkpoint.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Ensure the script exits on any error
+set -e
+
+# Define variables for the arguments
+DATA_PATH="data/May"
+WORKSPACE="model/trial_May"
+ASR_MODEL="ave"
+AUDIO_PATH="data/May/aud.wav"
+
+# Run the Python script with the specified arguments
+python main.py "$DATA_PATH" \
+ --workspace "$WORKSPACE" \
+ -O \
+ --test \
+ --test_train \
+ --asr_model "$ASR_MODEL" \
+ --portrait \
+ --aud "$AUDIO_PATH"
diff --git a/sync/SyncTalk/.ipynb_checkpoints/requirements-checkpoint.txt b/sync/SyncTalk/.ipynb_checkpoints/requirements-checkpoint.txt
new file mode 100644
index 00000000..0e92ee12
--- /dev/null
+++ b/sync/SyncTalk/.ipynb_checkpoints/requirements-checkpoint.txt
@@ -0,0 +1,31 @@
+torch-ema
+ninja
+trimesh
+opencv-python
+tensorboardX
+numpy==1.24.4
+pandas==2.0.3
+tqdm
+matplotlib
+PyMCubes==0.1.4
+rich
+dearpygui
+packaging
+scipy
+scikit-learn
+transformers==4.36.0
+face_alignment==1.4.1
+python_speech_features
+numba
+resampy
+pyaudio
+soundfile
+einops
+configargparse
+mediapipe
+lpips
+imageio-ffmpeg
+onnxruntime-gpu
+librosa
+fvcore
+iopath
\ No newline at end of file
diff --git a/sync/SyncTalk/Dockerfile b/sync/SyncTalk/Dockerfile
new file mode 100644
index 00000000..19f3fbec
--- /dev/null
+++ b/sync/SyncTalk/Dockerfile
@@ -0,0 +1,61 @@
+# 使用 Ubuntu 18.04 作为基础镜像
+FROM ubuntu:18.04
+
+# 设置环境变量,防止交互式安装时出现提示
+ENV DEBIAN_FRONTEND=noninteractive
+ENV PATH="/opt/conda/bin:$PATH"
+ENV CUDA_HOME="/usr/local/cuda-11.3"
+ENV LD_LIBRARY_PATH="/usr/local/cuda-11.3/lib64:$LD_LIBRARY_PATH"
+
+# 更新系统并安装必要工具
+RUN apt-get update && apt-get install -y \
+ wget \
+ curl \
+ git \
+ build-essential \
+ software-properties-common \
+ portaudio19-dev \
+ && apt-get clean
+
+# 安装 CUDA 11.3 和 cuDNN 8.2
+RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb && \
+ dpkg -i cuda-keyring_1.0-1_all.deb && \
+ apt-get update && apt-get install -y \
+ cuda-11-3 \
+ libcudnn8=8.2.1.*-1+cuda11.3 \
+ libcudnn8-dev=8.2.1.*-1+cuda11.3 \
+ && apt-get clean
+
+# 安装 Miniconda
+RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \
+ bash /tmp/miniconda.sh -b -p /opt/conda && \
+ rm /tmp/miniconda.sh && \
+ /opt/conda/bin/conda clean -tipsy
+
+# 创建 Conda 环境并安装 Python 3.8.8
+RUN conda create -n synctalk python=3.8.8 -y && \
+ conda clean -a -y
+
+# 激活环境并安装 PyTorch 和依赖
+RUN /bin/bash -c "source activate synctalk && \
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 && \
+ pip install -r /app/requirements.txt && \
+ pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1121/download.html && \
+ pip install tensorflow-gpu==2.8.1"
+
+# 复制项目文件到容器
+WORKDIR /app
+COPY . /app
+
+# 激活环境并安装本地包
+RUN /bin/bash -c "source activate synctalk && \
+ pip install ./freqencoder && \
+ pip install ./shencoder && \
+ pip install ./gridencoder && \
+ pip install ./raymarching"
+
+# 确保脚本可执行
+RUN chmod +x install.sh inference.sh main.sh
+
+# 默认执行 main.sh
+CMD ["/bin/bash"]
diff --git a/sync/SyncTalk/LICENSE b/sync/SyncTalk/LICENSE
new file mode 100644
index 00000000..94f88094
--- /dev/null
+++ b/sync/SyncTalk/LICENSE
@@ -0,0 +1,13 @@
+Copyright (c) 2024 Peng Ziqiao
+
+This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0). To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, and distribute the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+1. Attribution — You must give appropriate credit, provide a link to the license, and indicate if changes were made. You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
+
+2. NonCommercial — You may not use the material for commercial purposes.
+
+3. No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/sync/SyncTalk/README.md b/sync/SyncTalk/README.md
new file mode 100644
index 00000000..6bbe0be8
--- /dev/null
+++ b/sync/SyncTalk/README.md
@@ -0,0 +1,115 @@
+# SyncTalk: The Devil😈 is in the Synchronization for Talking Head Synthesis [CVPR 2024]
+The official repository of the paper [SyncTalk: The Devil is in the Synchronization for Talking Head Synthesis](https://arxiv.org/abs/2311.17590)
+
+
+
+ Paper
+ |
+ Project Page
+ |
+ Code
+
+
+
+Colab notebook demonstration: [](https://colab.research.google.com/drive/1Egq0_ZK5sJAAawShxC0y4JRZQuVS2X-Z?usp=sharing)
+
+
+
+
+
+ The proposed **SyncTalk** synthesizes synchronized talking head videos, employing tri-plane hash representations to maintain subject identity. It can generate synchronized lip movements, facial expressions, and stable head poses, and restores hair details to create high-resolution videos.
+
+## Installation
+
+Tested on Ubuntu 18.04, Pytorch 1.12.1 and CUDA 11.3.
+```bash
+git clone https://github.com/ZiqiaoPeng/SyncTalk.git
+cd SyncTalk
+```
+### Install dependency
+
+```bash
+conda create -n synctalk python==3.8.8
+conda activate synctalk
+pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
+pip install -r requirements.txt
+pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1121/download.html
+pip install ./freqencoder
+pip install ./shencoder
+pip install ./gridencoder
+pip install ./raymarching
+```
+If you encounter problems installing PyTorch3D, you can use the following command to install it:
+```bash
+python ./scripts/install_pytorch3d.py
+```
+
+## Data Preparation
+Please place the [May.zip](https://drive.google.com/file/d/18Q2H612CAReFxBd9kxr-i1dD8U1AUfsV/view?usp=sharing) in the **data** folder, the [trial_may.zip](https://drive.google.com/file/d/1C2639qi9jvhRygYHwPZDGs8pun3po3W7/view?usp=sharing) in the **model** folder, and then unzip them.
+
+## Quick Start
+### Run the evaluation code
+```bash
+python main.py data/May --workspace model/trial_may -O --test --asr_model ave
+
+python main.py data/May --workspace model/trial_may -O --test --asr_model ave --portrait
+```
+“ave” refers to our Audio Visual Encoder, “portrait” signifies pasting the generated face back onto the original image, representing higher quality.
+If it runs correctly, you will get the following results.
+
+| Setting | PSNR | LPIPS | LMD |
+|--------------------------|--------|--------|-------|
+| SyncTalk (w/o Portrait) | 32.201 | 0.0394 | 2.822 |
+| SyncTalk (Portrait) | 37.644 | 0.0117 | 2.825 |
+
+This is for a single subject; the paper reports the average results for multiple subjects.
+
+### Inference with target audio
+```bash
+python main.py data/May --workspace model/trial_may -O --test --test_train --asr_model ave --portrait --aud ./demo/test.wav
+```
+Please use files with the “.wav” extension for inference, and the inference results will be saved in “model/trial_may/results/”.
+## Train
+```bash
+# by default, we load data from disk on the fly.
+# we can also preload all data to CPU/GPU for faster training, but this is very memory-hungry for large datasets.
+# `--preload 0`: load from disk (default, slower).
+# `--preload 1`: load to CPU (slightly slower)
+# `--preload 2`: load to GPU (fast)
+python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model ave
+python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model ave
+
+# or you can use the script to train
+sh ./scripts/train_may.sh
+```
+
+## Test
+```bash
+python main.py data/May --workspace model/trial_may -O --test --asr_model ave --portrait
+```
+
+
+## TODO
+- [x] **Release Training Code.**
+- [x] **Release Pre-trained Model.**
+- [x] **Release Google Colab.**
+- [ ] Release Preprocessing Code.
+
+
+
+## Citation
+
+```
+@InProceedings{peng2023synctalk,
+ title = {SyncTalk: The Devil is in the Synchronization for Talking Head Synthesis},
+ author = {Ziqiao Peng and Wentao Hu and Yue Shi and Xiangyu Zhu and Xiaomei Zhang and Jun He and Hongyan Liu and Zhaoxin Fan},
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month = {June},
+ year = {2024},
+}
+```
+
+## Acknowledgement
+This code is developed heavily relying on [ER-NeRF](https://github.com/Fictionarry/ER-NeRF), and also [RAD-NeRF](https://github.com/ashawkey/RAD-NeRF), [GeneFace](https://github.com/yerfor/GeneFace), [DFRF](https://github.com/sstzal/DFRF), [AD-NeRF](https://github.com/YudongGuo/AD-NeRF), and [Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch).
+
+Thanks for these great projects.
diff --git a/sync/SyncTalk/assets/image/synctalk.png b/sync/SyncTalk/assets/image/synctalk.png
new file mode 100644
index 00000000..7a67737c
Binary files /dev/null and b/sync/SyncTalk/assets/image/synctalk.png differ
diff --git a/sync/SyncTalk/demo/.DS_Store b/sync/SyncTalk/demo/.DS_Store
new file mode 100644
index 00000000..5008ddfc
Binary files /dev/null and b/sync/SyncTalk/demo/.DS_Store differ
diff --git a/sync/SyncTalk/demo/test.wav b/sync/SyncTalk/demo/test.wav
new file mode 100644
index 00000000..24d48941
Binary files /dev/null and b/sync/SyncTalk/demo/test.wav differ
diff --git a/sync/SyncTalk/freqencoder/__init__.py b/sync/SyncTalk/freqencoder/__init__.py
new file mode 100644
index 00000000..69ec49cf
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/__init__.py
@@ -0,0 +1 @@
+from .freq import FreqEncoder
\ No newline at end of file
diff --git a/sync/SyncTalk/freqencoder/backend.py b/sync/SyncTalk/freqencoder/backend.py
new file mode 100644
index 00000000..3bd9131a
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/backend.py
@@ -0,0 +1,41 @@
+import os
+from torch.utils.cpp_extension import load
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+ '-use_fast_math'
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+_backend = load(name='_freqencoder',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'freqencoder.cu',
+ 'bindings.cpp',
+ ]],
+ )
+
+__all__ = ['_backend']
\ No newline at end of file
diff --git a/sync/SyncTalk/freqencoder/build/lib.linux-x86_64-cpython-38/_freqencoder.cpython-38-x86_64-linux-gnu.so b/sync/SyncTalk/freqencoder/build/lib.linux-x86_64-cpython-38/_freqencoder.cpython-38-x86_64-linux-gnu.so
new file mode 100644
index 00000000..bba41f65
Binary files /dev/null and b/sync/SyncTalk/freqencoder/build/lib.linux-x86_64-cpython-38/_freqencoder.cpython-38-x86_64-linux-gnu.so differ
diff --git a/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/.ninja_deps b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/.ninja_deps
new file mode 100644
index 00000000..a4698d23
Binary files /dev/null and b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/.ninja_deps differ
diff --git a/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/.ninja_log b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/.ninja_log
new file mode 100644
index 00000000..abef13da
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/.ninja_log
@@ -0,0 +1,3 @@
+# ninja log v5
+0 26782 1734802515297543445 /home/pod/shared-nvme/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/bindings.o 44f86bd32c0d33f9
+1 50546 1734802539069452910 /home/pod/shared-nvme/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/freqencoder.o 5409e7e776086ceb
diff --git a/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/build.ninja b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/build.ninja
new file mode 100644
index 00000000..835ec1b0
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/build.ninja
@@ -0,0 +1,29 @@
+ninja_required_version = 1.3
+cxx = c++
+nvcc = /usr/local/cuda/bin/nvcc
+
+cflags = -pthread -B /home/pod/shared-nvme/conda/envs/synctalk/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/TH -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/pod/shared-nvme/conda/envs/synctalk/include/python3.8 -c
+post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_freqencoder -D_GLIBCXX_USE_CXX11_ABI=0
+cuda_cflags = -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/TH -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/pod/shared-nvme/conda/envs/synctalk/include/python3.8 -c
+cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -use_fast_math -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_freqencoder -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86
+ldflags =
+
+rule compile
+ command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
+ depfile = $out.d
+ deps = gcc
+
+rule cuda_compile
+ depfile = $out.d
+ deps = gcc
+ command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
+
+
+
+build /home/pod/shared-nvme/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/bindings.o: compile /home/pod/shared-nvme/SyncTalk/freqencoder/src/bindings.cpp
+build /home/pod/shared-nvme/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/freqencoder.o: cuda_compile /home/pod/shared-nvme/SyncTalk/freqencoder/src/freqencoder.cu
+
+
+
+
+
diff --git a/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/bindings.o b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/bindings.o
new file mode 100644
index 00000000..0eec69b1
Binary files /dev/null and b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/bindings.o differ
diff --git a/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/freqencoder.o b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/freqencoder.o
new file mode 100644
index 00000000..8bfe4352
Binary files /dev/null and b/sync/SyncTalk/freqencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/freqencoder/src/freqencoder.o differ
diff --git a/sync/SyncTalk/freqencoder/freq.py b/sync/SyncTalk/freqencoder/freq.py
new file mode 100644
index 00000000..5cba1e66
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/freq.py
@@ -0,0 +1,77 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+try:
+ import _freqencoder as _backend
+except ImportError:
+ from .backend import _backend
+
+
+class _freq_encoder(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
+ def forward(ctx, inputs, degree, output_dim):
+ # inputs: [B, input_dim], float
+ # RETURN: [B, F], float
+
+ if not inputs.is_cuda: inputs = inputs.cuda()
+ inputs = inputs.contiguous()
+
+ B, input_dim = inputs.shape # batch size, coord dim
+
+ outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
+
+ _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
+
+ ctx.save_for_backward(inputs, outputs)
+ ctx.dims = [B, input_dim, degree, output_dim]
+
+ return outputs
+
+ @staticmethod
+ #@once_differentiable
+ @custom_bwd
+ def backward(ctx, grad):
+ # grad: [B, C * C]
+
+ grad = grad.contiguous()
+ inputs, outputs = ctx.saved_tensors
+ B, input_dim, degree, output_dim = ctx.dims
+
+ grad_inputs = torch.zeros_like(inputs)
+ _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
+
+ return grad_inputs, None, None
+
+
+freq_encode = _freq_encoder.apply
+
+
+class FreqEncoder(nn.Module):
+ def __init__(self, input_dim=3, degree=4):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.degree = degree
+ self.output_dim = input_dim + input_dim * 2 * degree
+
+ def __repr__(self):
+ return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
+
+ def forward(self, inputs, **kwargs):
+ # inputs: [..., input_dim]
+ # return: [..., ]
+
+ prefix_shape = list(inputs.shape[:-1])
+ inputs = inputs.reshape(-1, self.input_dim)
+
+ outputs = freq_encode(inputs, self.degree, self.output_dim)
+
+ outputs = outputs.reshape(prefix_shape + [self.output_dim])
+
+ return outputs
\ No newline at end of file
diff --git a/sync/SyncTalk/freqencoder/freqencoder.egg-info/PKG-INFO b/sync/SyncTalk/freqencoder/freqencoder.egg-info/PKG-INFO
new file mode 100644
index 00000000..c280ab19
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/freqencoder.egg-info/PKG-INFO
@@ -0,0 +1,3 @@
+Metadata-Version: 2.1
+Name: freqencoder
+Version: 0.0.0
diff --git a/sync/SyncTalk/freqencoder/freqencoder.egg-info/SOURCES.txt b/sync/SyncTalk/freqencoder/freqencoder.egg-info/SOURCES.txt
new file mode 100644
index 00000000..33340296
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/freqencoder.egg-info/SOURCES.txt
@@ -0,0 +1,7 @@
+setup.py
+/home/pod/shared-nvme/SyncTalk/freqencoder/src/bindings.cpp
+/home/pod/shared-nvme/SyncTalk/freqencoder/src/freqencoder.cu
+freqencoder.egg-info/PKG-INFO
+freqencoder.egg-info/SOURCES.txt
+freqencoder.egg-info/dependency_links.txt
+freqencoder.egg-info/top_level.txt
\ No newline at end of file
diff --git a/sync/SyncTalk/freqencoder/freqencoder.egg-info/dependency_links.txt b/sync/SyncTalk/freqencoder/freqencoder.egg-info/dependency_links.txt
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/freqencoder.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/sync/SyncTalk/freqencoder/freqencoder.egg-info/top_level.txt b/sync/SyncTalk/freqencoder/freqencoder.egg-info/top_level.txt
new file mode 100644
index 00000000..85f88eba
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/freqencoder.egg-info/top_level.txt
@@ -0,0 +1 @@
+_freqencoder
diff --git a/sync/SyncTalk/freqencoder/setup.py b/sync/SyncTalk/freqencoder/setup.py
new file mode 100644
index 00000000..3eb4af77
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/setup.py
@@ -0,0 +1,51 @@
+import os
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+ '-use_fast_math'
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+setup(
+ name='freqencoder', # package name, import this to use python API
+ ext_modules=[
+ CUDAExtension(
+ name='_freqencoder', # extension name, import this to use CUDA API
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'freqencoder.cu',
+ 'bindings.cpp',
+ ]],
+ extra_compile_args={
+ 'cxx': c_flags,
+ 'nvcc': nvcc_flags,
+ }
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
\ No newline at end of file
diff --git a/sync/SyncTalk/freqencoder/src/bindings.cpp b/sync/SyncTalk/freqencoder/src/bindings.cpp
new file mode 100644
index 00000000..bb5f285a
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/src/bindings.cpp
@@ -0,0 +1,8 @@
+#include
+
+#include "freqencoder.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
+ m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
+}
\ No newline at end of file
diff --git a/sync/SyncTalk/freqencoder/src/freqencoder.cu b/sync/SyncTalk/freqencoder/src/freqencoder.cu
new file mode 100644
index 00000000..de378840
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/src/freqencoder.cu
@@ -0,0 +1,129 @@
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+inline constexpr __device__ float PI() { return 3.141592653589793f; }
+
+template
+__host__ __device__ T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+// inputs: [B, D]
+// outputs: [B, C], C = D + D * deg * 2
+__global__ void kernel_freq(
+ const float * __restrict__ inputs,
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
+ float * outputs
+) {
+ // parallel on per-element
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * C) return;
+
+ // get index
+ const uint32_t b = t / C;
+ const uint32_t c = t - b * C; // t % C;
+
+ // locate
+ inputs += b * D;
+ outputs += t;
+
+ // write self
+ if (c < D) {
+ outputs[0] = inputs[c];
+ // write freq
+ } else {
+ const uint32_t col = c / D - 1;
+ const uint32_t d = c % D;
+ const uint32_t freq = col / 2;
+ const float phase_shift = (col % 2) * (PI() / 2);
+ outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
+ }
+}
+
+// grad: [B, C], C = D + D * deg * 2
+// outputs: [B, C]
+// grad_inputs: [B, D]
+__global__ void kernel_freq_backward(
+ const float * __restrict__ grad,
+ const float * __restrict__ outputs,
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
+ float * grad_inputs
+) {
+ // parallel on per-element
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * D) return;
+
+ const uint32_t b = t / D;
+ const uint32_t d = t - b * D; // t % D;
+
+ // locate
+ grad += b * C;
+ outputs += b * C;
+ grad_inputs += t;
+
+ // register
+ float result = grad[d];
+ grad += D;
+ outputs += D;
+
+ for (uint32_t f = 0; f < deg; f++) {
+ result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
+ grad += 2 * D;
+ outputs += 2 * D;
+ }
+
+ // write
+ grad_inputs[0] = result;
+}
+
+
+void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(outputs);
+
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(outputs);
+
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(outputs);
+
+ static constexpr uint32_t N_THREADS = 128;
+
+ kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr());
+}
+
+
+void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
+ CHECK_CUDA(grad);
+ CHECK_CUDA(outputs);
+ CHECK_CUDA(grad_inputs);
+
+ CHECK_CONTIGUOUS(grad);
+ CHECK_CONTIGUOUS(outputs);
+ CHECK_CONTIGUOUS(grad_inputs);
+
+ CHECK_IS_FLOATING(grad);
+ CHECK_IS_FLOATING(outputs);
+ CHECK_IS_FLOATING(grad_inputs);
+
+ static constexpr uint32_t N_THREADS = 128;
+
+ kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr());
+}
\ No newline at end of file
diff --git a/sync/SyncTalk/freqencoder/src/freqencoder.h b/sync/SyncTalk/freqencoder/src/freqencoder.h
new file mode 100644
index 00000000..34f28c79
--- /dev/null
+++ b/sync/SyncTalk/freqencoder/src/freqencoder.h
@@ -0,0 +1,10 @@
+# pragma once
+
+#include
+#include
+
+// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
+void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
+
+// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
+void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
\ No newline at end of file
diff --git a/sync/SyncTalk/gridencoder/__init__.py b/sync/SyncTalk/gridencoder/__init__.py
new file mode 100644
index 00000000..f1476cef
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/__init__.py
@@ -0,0 +1 @@
+from .grid import GridEncoder
\ No newline at end of file
diff --git a/sync/SyncTalk/gridencoder/__pycache__/__init__.cpython-38.pyc b/sync/SyncTalk/gridencoder/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 00000000..3b204bd8
Binary files /dev/null and b/sync/SyncTalk/gridencoder/__pycache__/__init__.cpython-38.pyc differ
diff --git a/sync/SyncTalk/gridencoder/__pycache__/__init__.cpython-39.pyc b/sync/SyncTalk/gridencoder/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 00000000..05c12bc6
Binary files /dev/null and b/sync/SyncTalk/gridencoder/__pycache__/__init__.cpython-39.pyc differ
diff --git a/sync/SyncTalk/gridencoder/__pycache__/grid.cpython-38.pyc b/sync/SyncTalk/gridencoder/__pycache__/grid.cpython-38.pyc
new file mode 100644
index 00000000..2e263bfc
Binary files /dev/null and b/sync/SyncTalk/gridencoder/__pycache__/grid.cpython-38.pyc differ
diff --git a/sync/SyncTalk/gridencoder/__pycache__/grid.cpython-39.pyc b/sync/SyncTalk/gridencoder/__pycache__/grid.cpython-39.pyc
new file mode 100644
index 00000000..08e111b2
Binary files /dev/null and b/sync/SyncTalk/gridencoder/__pycache__/grid.cpython-39.pyc differ
diff --git a/sync/SyncTalk/gridencoder/backend.py b/sync/SyncTalk/gridencoder/backend.py
new file mode 100644
index 00000000..d4aa494c
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/backend.py
@@ -0,0 +1,40 @@
+import os
+from torch.utils.cpp_extension import load
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14', '-finput-charset=UTF-8']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+_backend = load(name='_grid_encoder',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'gridencoder.cu',
+ 'bindings.cpp',
+ ]],
+ )
+
+__all__ = ['_backend']
\ No newline at end of file
diff --git a/sync/SyncTalk/gridencoder/build/lib.linux-x86_64-cpython-38/_gridencoder.cpython-38-x86_64-linux-gnu.so b/sync/SyncTalk/gridencoder/build/lib.linux-x86_64-cpython-38/_gridencoder.cpython-38-x86_64-linux-gnu.so
new file mode 100644
index 00000000..aaeaa15a
Binary files /dev/null and b/sync/SyncTalk/gridencoder/build/lib.linux-x86_64-cpython-38/_gridencoder.cpython-38-x86_64-linux-gnu.so differ
diff --git a/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/.ninja_deps b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/.ninja_deps
new file mode 100644
index 00000000..74e0686f
Binary files /dev/null and b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/.ninja_deps differ
diff --git a/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/.ninja_log b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/.ninja_log
new file mode 100644
index 00000000..74fa510e
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/.ninja_log
@@ -0,0 +1,3 @@
+# ninja log v5
+2 18144 1734802621217143428 /home/pod/shared-nvme/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/bindings.o 1c462e532a2ac6e4
+2 90419 1734802693496874982 /home/pod/shared-nvme/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/gridencoder.o d0d103819090c37b
diff --git a/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/build.ninja b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/build.ninja
new file mode 100644
index 00000000..70d74e9d
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/build.ninja
@@ -0,0 +1,29 @@
+ninja_required_version = 1.3
+cxx = c++
+nvcc = /usr/local/cuda/bin/nvcc
+
+cflags = -pthread -B /home/pod/shared-nvme/conda/envs/synctalk/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/TH -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/pod/shared-nvme/conda/envs/synctalk/include/python3.8 -c
+post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_gridencoder -D_GLIBCXX_USE_CXX11_ABI=0
+cuda_cflags = -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/TH -I/home/pod/shared-nvme/conda/envs/synctalk/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/pod/shared-nvme/conda/envs/synctalk/include/python3.8 -c
+cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_gridencoder -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86
+ldflags =
+
+rule compile
+ command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
+ depfile = $out.d
+ deps = gcc
+
+rule cuda_compile
+ depfile = $out.d
+ deps = gcc
+ command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
+
+
+
+build /home/pod/shared-nvme/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/bindings.o: compile /home/pod/shared-nvme/SyncTalk/gridencoder/src/bindings.cpp
+build /home/pod/shared-nvme/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/gridencoder.o: cuda_compile /home/pod/shared-nvme/SyncTalk/gridencoder/src/gridencoder.cu
+
+
+
+
+
diff --git a/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/bindings.o b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/bindings.o
new file mode 100644
index 00000000..7dcc220e
Binary files /dev/null and b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/bindings.o differ
diff --git a/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/gridencoder.o b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/gridencoder.o
new file mode 100644
index 00000000..f68a21a4
Binary files /dev/null and b/sync/SyncTalk/gridencoder/build/temp.linux-x86_64-cpython-38/home/pod/shared-nvme/SyncTalk/gridencoder/src/gridencoder.o differ
diff --git a/sync/SyncTalk/gridencoder/grid.py b/sync/SyncTalk/gridencoder/grid.py
new file mode 100644
index 00000000..3bb4d80a
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/grid.py
@@ -0,0 +1,155 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+try:
+ import _gridencoder as _backend
+except ImportError:
+ from .backend import _backend
+
+_gridtype_to_id = {
+ 'hash': 0,
+ 'tiled': 1,
+}
+
+class _grid_encode(Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
+ # inputs: [B, D], float in [0, 1]
+ # embeddings: [sO, C], float
+ # offsets: [L + 1], int
+ # RETURN: [B, F], float
+
+ inputs = inputs.float().contiguous()
+
+ B, D = inputs.shape # batch size, coord dim
+ L = offsets.shape[0] - 1 # level
+ C = embeddings.shape[1] # embedding dim for each level
+ S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
+ H = base_resolution # base resolution
+
+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
+ # if C % 2 != 0, force float, since half for atomicAdd is very slow.
+ if torch.is_autocast_enabled() and C % 2 == 0:
+ embeddings = embeddings.to(torch.half)
+
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
+
+ if calc_grad_inputs:
+ dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
+ else:
+ dy_dx = None
+
+ _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
+
+ # permute back to [B, L * C]
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
+
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
+ ctx.dims = [B, D, C, L, S, H, gridtype]
+ ctx.align_corners = align_corners
+
+ return outputs
+
+ @staticmethod
+ #@once_differentiable
+ @custom_bwd
+ def backward(ctx, grad):
+
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
+ B, D, C, L, S, H, gridtype = ctx.dims
+ align_corners = ctx.align_corners
+
+ # grad: [B, L * C] --> [L, B, C]
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
+
+ grad_embeddings = torch.zeros_like(embeddings)
+
+ if dy_dx is not None:
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
+ else:
+ grad_inputs = None
+
+ _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
+
+ if dy_dx is not None:
+ grad_inputs = grad_inputs.to(inputs.dtype)
+
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None
+
+
+
+grid_encode = _grid_encode.apply
+
+
+class GridEncoder(nn.Module):
+ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
+ super().__init__()
+
+ # the finest resolution desired at the last level, if provided, overridee per_level_scale
+ if desired_resolution is not None:
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
+
+ self.input_dim = input_dim # coord dims, 2 or 3
+ self.num_levels = num_levels # num levels, each level multiply resolution by 2
+ self.level_dim = level_dim # encode channels per level
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
+ self.log2_hashmap_size = log2_hashmap_size
+ self.base_resolution = base_resolution
+ self.output_dim = num_levels * level_dim
+ self.gridtype = gridtype
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
+ self.align_corners = align_corners
+
+ # allocate parameters
+ offsets = []
+ offset = 0
+ self.max_params = 2 ** log2_hashmap_size
+ for i in range(num_levels):
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
+ params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
+ offsets.append(offset)
+ offset += params_in_level
+ # print(resolution, params_in_level)
+ offsets.append(offset)
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
+ self.register_buffer('offsets', offsets)
+
+ self.n_params = offsets[-1] * level_dim
+
+ # parameters
+ self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ std = 1e-4
+ self.embeddings.data.uniform_(-std, std)
+
+ def __repr__(self):
+ return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
+
+ def forward(self, inputs, bound=1):
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
+ # return: [..., num_levels * level_dim]
+
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
+
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
+
+ prefix_shape = list(inputs.shape[:-1])
+ inputs = inputs.view(-1, self.input_dim)
+
+ outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
+ outputs = outputs.view(prefix_shape + [self.output_dim])
+
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
+
+ return outputs
\ No newline at end of file
diff --git a/sync/SyncTalk/gridencoder/gridencoder.egg-info/PKG-INFO b/sync/SyncTalk/gridencoder/gridencoder.egg-info/PKG-INFO
new file mode 100644
index 00000000..c89f13a5
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/gridencoder.egg-info/PKG-INFO
@@ -0,0 +1,3 @@
+Metadata-Version: 2.1
+Name: gridencoder
+Version: 0.0.0
diff --git a/sync/SyncTalk/gridencoder/gridencoder.egg-info/SOURCES.txt b/sync/SyncTalk/gridencoder/gridencoder.egg-info/SOURCES.txt
new file mode 100644
index 00000000..458c615a
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/gridencoder.egg-info/SOURCES.txt
@@ -0,0 +1,7 @@
+setup.py
+/home/pod/shared-nvme/SyncTalk/gridencoder/src/bindings.cpp
+/home/pod/shared-nvme/SyncTalk/gridencoder/src/gridencoder.cu
+gridencoder.egg-info/PKG-INFO
+gridencoder.egg-info/SOURCES.txt
+gridencoder.egg-info/dependency_links.txt
+gridencoder.egg-info/top_level.txt
\ No newline at end of file
diff --git a/sync/SyncTalk/gridencoder/gridencoder.egg-info/dependency_links.txt b/sync/SyncTalk/gridencoder/gridencoder.egg-info/dependency_links.txt
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/gridencoder.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/sync/SyncTalk/gridencoder/gridencoder.egg-info/top_level.txt b/sync/SyncTalk/gridencoder/gridencoder.egg-info/top_level.txt
new file mode 100644
index 00000000..0ab3d303
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/gridencoder.egg-info/top_level.txt
@@ -0,0 +1 @@
+_gridencoder
diff --git a/sync/SyncTalk/gridencoder/setup.py b/sync/SyncTalk/gridencoder/setup.py
new file mode 100644
index 00000000..714bf1ca
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/setup.py
@@ -0,0 +1,50 @@
+import os
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+setup(
+ name='gridencoder', # package name, import this to use python API
+ ext_modules=[
+ CUDAExtension(
+ name='_gridencoder', # extension name, import this to use CUDA API
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'gridencoder.cu',
+ 'bindings.cpp',
+ ]],
+ extra_compile_args={
+ 'cxx': c_flags,
+ 'nvcc': nvcc_flags,
+ }
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
\ No newline at end of file
diff --git a/sync/SyncTalk/gridencoder/src/bindings.cpp b/sync/SyncTalk/gridencoder/src/bindings.cpp
new file mode 100644
index 00000000..afa6f64f
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/src/bindings.cpp
@@ -0,0 +1,8 @@
+#include
+
+#include "gridencoder.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
+ m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
+}
\ No newline at end of file
diff --git a/sync/SyncTalk/gridencoder/src/gridencoder.cu b/sync/SyncTalk/gridencoder/src/gridencoder.cu
new file mode 100644
index 00000000..d410d7f0
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/src/gridencoder.cu
@@ -0,0 +1,479 @@
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+
+// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
+static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
+ // requires CUDA >= 10 and ARCH >= 70
+ // this is very slow compared to float or __half2, and never used.
+ //return atomicAdd(reinterpret_cast<__half*>(address), val);
+}
+
+
+template
+static inline __host__ __device__ T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+
+template
+__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
+ static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
+
+ // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
+ // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
+ // coordinates.
+ constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
+
+ uint32_t result = 0;
+ #pragma unroll
+ for (uint32_t i = 0; i < D; ++i) {
+ result ^= pos_grid[i] * primes[i];
+ }
+
+ return result;
+}
+
+
+template
+__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
+ uint32_t stride = 1;
+ uint32_t index = 0;
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
+ index += pos_grid[d] * stride;
+ stride *= align_corners ? resolution: (resolution + 1);
+ }
+
+ // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
+ // gridtype: 0 == hash, 1 == tiled
+ if (gridtype == 0 && stride > hashmap_size) {
+ index = fast_hash(pos_grid);
+ }
+
+ return (index % hashmap_size) * C + ch;
+}
+
+
+template
+__global__ void kernel_grid(
+ const float * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ const int * __restrict__ offsets,
+ scalar_t * __restrict__ outputs,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ scalar_t * __restrict__ dy_dx,
+ const uint32_t gridtype,
+ const bool align_corners
+) {
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+
+ // locate
+ grid += (uint32_t)offsets[level] * C;
+ inputs += b * D;
+ outputs += level * B * C + b * C;
+
+ // check input range (should be in [0, 1])
+ bool flag_oob = false;
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ flag_oob = true;
+ }
+ }
+ // if input out of bound, just set output to 0
+ if (flag_oob) {
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ outputs[ch] = 0;
+ }
+ if (dy_dx) {
+ dy_dx += b * D * L * C + level * D * C; // B L D C
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ dy_dx[d * C + ch] = 0;
+ }
+ }
+ }
+ return;
+ }
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const float scale = exp2f(level * S) * H - 1.0f;
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
+
+ // calculate coordinate
+ float pos[D];
+ uint32_t pos_grid[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
+ pos_grid[d] = floorf(pos[d]);
+ pos[d] -= (float)pos_grid[d];
+ }
+
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
+
+ // interpolate
+ scalar_t results[C] = {0}; // temp results in register
+
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
+ float w = 1;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if ((idx & (1 << d)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
+
+ // writing to register (fast)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ results[ch] += w * grid[index + ch];
+ }
+
+ //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
+ }
+
+ // writing to global memory (slow)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ outputs[ch] = results[ch];
+ }
+
+ // prepare dy_dx
+ // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
+ if (dy_dx) {
+
+ dy_dx += b * D * L * C + level * D * C; // B L D C
+
+ #pragma unroll
+ for (uint32_t gd = 0; gd < D; gd++) {
+
+ scalar_t results_grad[C] = {0};
+
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
+ float w = scale;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
+
+ if ((idx & (1 << nd)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ pos_grid_local[gd] = pos_grid[gd];
+ uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
+ pos_grid_local[gd] = pos_grid[gd] + 1;
+ uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
+ }
+ }
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ dy_dx[gd * C + ch] = results_grad[ch];
+ }
+ }
+ }
+}
+
+
+template
+__global__ void kernel_grid_backward(
+ const scalar_t * __restrict__ grad,
+ const float * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ const int * __restrict__ offsets,
+ scalar_t * __restrict__ grad_grid,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ const uint32_t gridtype,
+ const bool align_corners
+) {
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
+
+ // locate
+ grad_grid += offsets[level] * C;
+ inputs += b * D;
+ grad += level * B * C + b * C + ch; // L, B, C
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const float scale = exp2f(level * S) * H - 1.0f;
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
+
+ // check input range (should be in [0, 1])
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ return; // grad is init as 0, so we simply return.
+ }
+ }
+
+ // calculate coordinate
+ float pos[D];
+ uint32_t pos_grid[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
+ pos_grid[d] = floorf(pos[d]);
+ pos[d] -= (float)pos_grid[d];
+ }
+
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c++) {
+ grad_cur[c] = grad[c];
+ }
+
+ // interpolate
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
+ float w = 1;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if ((idx & (1 << d)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
+
+ // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
+ // TODO: use float which is better than __half, if N_C % 2 != 0
+ if (std::is_same::value && N_C % 2 == 0) {
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c += 2) {
+ // process two __half at once (by interpreting as a __half2)
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
+ atomicAdd((__half2*)&grad_grid[index + c], v);
+ }
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
+ } else {
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c++) {
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
+ }
+ }
+ }
+}
+
+
+template
+__global__ void kernel_input_backward(
+ const scalar_t * __restrict__ grad,
+ const scalar_t * __restrict__ dy_dx,
+ scalar_t * __restrict__ grad_inputs,
+ uint32_t B, uint32_t L
+) {
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * D) return;
+
+ const uint32_t b = t / D;
+ const uint32_t d = t - b * D;
+
+ dy_dx += b * L * D * C;
+
+ scalar_t result = 0;
+
+ # pragma unroll
+ for (int l = 0; l < L; l++) {
+ # pragma unroll
+ for (int ch = 0; ch < C; ch++) {
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
+ }
+ }
+
+ grad_inputs[t] = result;
+}
+
+
+template
+void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
+ static constexpr uint32_t N_THREAD = 512;
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
+ switch (C) {
+ case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
+ case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
+ case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
+ case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
+// H: base resolution
+// dy_dx: [B, L * D * C]
+template
+void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
+ switch (D) {
+ case 1: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
+ case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
+ case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
+ case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
+ case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5"};
+ }
+
+}
+
+template
+void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
+ static constexpr uint32_t N_THREAD = 256;
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
+ const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
+ switch (C) {
+ case 1:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 2:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 4:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 8:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+
+// grad: [L, B, C], float
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// grad_embeddings: [sO, C]
+// H: base resolution
+template
+void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
+ switch (D) {
+ case 1: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
+ case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
+ case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
+ case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
+ case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5"};
+ }
+}
+
+
+
+void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners) {
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(embeddings);
+ CHECK_CUDA(offsets);
+ CHECK_CUDA(outputs);
+ // CHECK_CUDA(dy_dx);
+
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(embeddings);
+ CHECK_CONTIGUOUS(offsets);
+ CHECK_CONTIGUOUS(outputs);
+ // CHECK_CONTIGUOUS(dy_dx);
+
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(embeddings);
+ CHECK_IS_INT(offsets);
+ CHECK_IS_FLOATING(outputs);
+ // CHECK_IS_FLOATING(dy_dx);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
+ grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners);
+ }));
+}
+
+void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners) {
+ CHECK_CUDA(grad);
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(embeddings);
+ CHECK_CUDA(offsets);
+ CHECK_CUDA(grad_embeddings);
+ // CHECK_CUDA(dy_dx);
+ // CHECK_CUDA(grad_inputs);
+
+ CHECK_CONTIGUOUS(grad);
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(embeddings);
+ CHECK_CONTIGUOUS(offsets);
+ CHECK_CONTIGUOUS(grad_embeddings);
+ // CHECK_CONTIGUOUS(dy_dx);
+ // CHECK_CONTIGUOUS(grad_inputs);
+
+ CHECK_IS_FLOATING(grad);
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(embeddings);
+ CHECK_IS_INT(offsets);
+ CHECK_IS_FLOATING(grad_embeddings);
+ // CHECK_IS_FLOATING(dy_dx);
+ // CHECK_IS_FLOATING(grad_inputs);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad.scalar_type(), "grid_encode_backward", ([&] {
+ grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners);
+ }));
+
+}
diff --git a/sync/SyncTalk/gridencoder/src/gridencoder.h b/sync/SyncTalk/gridencoder/src/gridencoder.h
new file mode 100644
index 00000000..0415b59f
--- /dev/null
+++ b/sync/SyncTalk/gridencoder/src/gridencoder.h
@@ -0,0 +1,15 @@
+#ifndef _HASH_ENCODE_H
+#define _HASH_ENCODE_H
+
+#include
+#include
+
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// outputs: [B, L * C], float
+// H: base resolution
+void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners);
+void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners);
+
+#endif
\ No newline at end of file
diff --git a/sync/SyncTalk/inference.sh b/sync/SyncTalk/inference.sh
new file mode 100644
index 00000000..80390460
--- /dev/null
+++ b/sync/SyncTalk/inference.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Ensure the script exits on any error
+set -e
+
+# Define variables for the arguments
+DATA_PATH="data/May"
+WORKSPACE="model/trial_May"
+ASR_MODEL="ave"
+AUDIO_PATH="data/May/aud.wav"
+
+# Run the Python script with the specified arguments
+python main.py "$DATA_PATH" \
+ --workspace "$WORKSPACE" \
+ -O \
+ --test \
+ --test_train \
+ --asr_model "$ASR_MODEL" \
+ --portrait \
+ --aud "$AUDIO_PATH"
diff --git a/sync/SyncTalk/main.py b/sync/SyncTalk/main.py
new file mode 100644
index 00000000..5428ad13
--- /dev/null
+++ b/sync/SyncTalk/main.py
@@ -0,0 +1,261 @@
+import argparse
+
+from nerf_triplane.provider import NeRFDataset
+from nerf_triplane.utils import *
+from nerf_triplane.network import NeRFNetwork
+
+# torch.autograd.set_detect_anomaly(True)
+# Close tf32 features. Fix low numerical accuracy on rtx30xx gpu.
+try:
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = False
+except AttributeError as e:
+ print('Info. This pytorch version is not support with tf32.')
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('path', type=str)
+ parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye")
+ parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)")
+ parser.add_argument('--test_train', action='store_true', help="test mode (load model and train dataset)")
+ parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use")
+ parser.add_argument('--workspace', type=str, default='workspace')
+ parser.add_argument('--seed', type=int, default=0)
+
+ ### training options
+ parser.add_argument('--iters', type=int, default=200000, help="training iters")
+ parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
+ parser.add_argument('--lr_net', type=float, default=1e-3, help="initial learning rate")
+ parser.add_argument('--ckpt', type=str, default='latest')
+ parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step")
+ parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
+ parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)")
+ parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)")
+ parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
+ parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
+ parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
+
+ ### loss set
+ parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps")
+ parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss")
+ parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss")
+ parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss")
+ parser.add_argument('--lambda_amb', type=float, default=1e-1, help="lambda for ambient loss")
+ parser.add_argument('--pyramid_loss', type=int, default=0, help="use perceptual loss")
+
+ ### network backbone options
+ parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
+
+ parser.add_argument('--bg_img', type=str, default='', help="background image")
+ parser.add_argument('--fbg', action='store_true', help="frame-wise bg")
+ parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes")
+ parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye")
+ parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence")
+ parser.add_argument('--bs_area', type=str, default="upper", help="upper or eye")
+
+ parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform")
+
+ ### dataset options
+ parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)")
+ parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.")
+ # (the default value is for the fox dataset)
+ parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.")
+ parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3")
+ parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
+ parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
+ parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera")
+ parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)")
+ parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)")
+ parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")
+
+ parser.add_argument('--init_lips', action='store_true', help="init lips region")
+ parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region")
+ parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...")
+
+ parser.add_argument('--torso', action='store_true', help="fix head and train torso")
+ parser.add_argument('--head_ckpt', type=str, default='', help="head model")
+
+ ### GUI options
+ parser.add_argument('--gui', action='store_true', help="start a GUI")
+ parser.add_argument('--W', type=int, default=450, help="GUI width")
+ parser.add_argument('--H', type=int, default=450, help="GUI height")
+ parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center")
+ parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy")
+ parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
+
+ ### else
+ parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)")
+ parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)")
+ parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits")
+ parser.add_argument('--portrait', action='store_true', help="only render face")
+ parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off")
+ parser.add_argument('--ind_num', type=int, default=20000, help="number of individual codes, should be larger than training dataset size")
+
+ parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off")
+
+ parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension")
+ parser.add_argument('--part', action='store_true', help="use partial training data (1/10)")
+ parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)")
+
+ parser.add_argument('--train_camera', action='store_true', help="optimize camera pose")
+ parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size")
+ parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size")
+
+ # asr
+ parser.add_argument('--asr', action='store_true', help="load asr for real-time app")
+ parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input")
+ parser.add_argument('--asr_play', action='store_true', help="play out the audio")
+
+ parser.add_argument('--asr_model', type=str, default='deepspeech')
+
+ parser.add_argument('--asr_save_feats', action='store_true')
+ # audio FPS
+ parser.add_argument('--fps', type=int, default=50)
+ # sliding window left-middle-right length (unit: 20ms)
+ parser.add_argument('-l', type=int, default=10)
+ parser.add_argument('-m', type=int, default=50)
+ parser.add_argument('-r', type=int, default=10)
+
+ opt = parser.parse_args()
+
+ if opt.O:
+ opt.fp16 = True
+ opt.exp_eye = True
+
+ if opt.test and False:
+ opt.smooth_path = True
+ opt.smooth_eye = True
+ opt.smooth_lips = True
+
+ opt.cuda_ray = True
+ # assert opt.cuda_ray, "Only support CUDA ray mode."
+
+ if opt.patch_size > 1:
+ # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
+ assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
+
+ # if opt.finetune_lips:
+ # # do not update density grid in finetune stage
+ # opt.update_extra_interval = 1e9
+
+ print(opt)
+
+ seed_everything(opt.seed)
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ model = NeRFNetwork(opt)
+
+ # manually load state dict for head
+ if opt.torso and opt.head_ckpt != '':
+
+ model_dict = torch.load(opt.head_ckpt, map_location='cpu')['model']
+
+ missing_keys, unexpected_keys = model.load_state_dict(model_dict, strict=False)
+
+ if len(missing_keys) > 0:
+ print(f"[WARN] missing keys: {missing_keys}")
+ if len(unexpected_keys) > 0:
+ print(f"[WARN] unexpected keys: {unexpected_keys}")
+
+ # freeze these keys
+ for k, v in model.named_parameters():
+ if k in model_dict:
+ print(f'[INFO] freeze {k}, {v.shape}')
+ v.requires_grad = False
+
+
+ # print(model)
+
+ # criterion = torch.nn.MSELoss(reduction='none')
+ criterion = torch.nn.L1Loss(reduction='none')
+
+
+ if opt.test:
+
+ if opt.gui:
+ metrics = [] # use no metric in GUI for faster initialization...
+ else:
+ # metrics = [PSNRMeter(), LPIPSMeter(device=device)]
+ metrics = [PSNRMeter(), LPIPSMeter(device=device), LMDMeter(backend='fan')]
+
+ trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
+
+ if opt.test_train:
+ test_set = NeRFDataset(opt, device=device, type='train')
+ # a manual fix to test on the training dataset
+ test_set.training = False
+ test_set.num_rays = -1
+ test_loader = test_set.dataloader()
+ else:
+ test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
+
+
+ # temp fix: for update_extra_states
+ model.aud_features = test_loader._data.auds
+ model.eye_areas = test_loader._data.eye_area
+
+ if opt.gui:
+ from nerf_triplane.gui import NeRFGUI
+ # we still need test_loader to provide audio features for testing.
+ with NeRFGUI(opt, trainer, test_loader) as gui:
+ gui.render()
+
+ else:
+ ### test and save video (fast)
+ trainer.test(test_loader)
+
+ ### evaluate metrics (slow)
+ if test_loader.has_gt:
+ trainer.evaluate(test_loader)
+
+
+
+ else:
+
+ optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr, opt.lr_net), betas=(0, 0.99), eps=1e-8)
+
+ train_loader = NeRFDataset(opt, device=device, type='train').dataloader()
+
+ assert len(train_loader) < opt.ind_num, f"[ERROR] dataset too many frames: {len(train_loader)}, please increase --ind_num to this number!"
+
+ # temp fix: for update_extra_states
+ model.aud_features = train_loader._data.auds
+ model.eye_area = train_loader._data.eye_area
+ model.poses = train_loader._data.poses
+
+ # decay to 0.1 * init_lr at last iter step
+ if opt.finetune_lips:
+ scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.05 ** (iter / opt.iters))
+ else:
+ scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.5 ** (iter / opt.iters))
+
+ metrics = [PSNRMeter(), LPIPSMeter(device=device),LMDMeter(backend='fan')]
+
+ eval_interval = max(1, int(5000 / len(train_loader)))
+ trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=eval_interval)
+ with open(os.path.join(opt.workspace, 'opt.txt'), 'a') as f:
+ f.write(str(opt))
+ if opt.gui:
+ with NeRFGUI(opt, trainer, train_loader) as gui:
+ gui.render()
+
+ else:
+ valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader()
+
+ max_epochs = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
+ print(f'[INFO] max_epoch = {max_epochs}')
+ trainer.train(train_loader, valid_loader, max_epochs)
+
+ # free some mem
+ del train_loader, valid_loader
+ torch.cuda.empty_cache()
+
+ # also test
+ test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
+
+ if test_loader.has_gt:
+ trainer.evaluate(test_loader) # blender has gt, so evaluate it.
+
+ trainer.test(test_loader)
\ No newline at end of file
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/encoding.cpython-38.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/encoding.cpython-38.pyc
new file mode 100644
index 00000000..98ab5c79
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/encoding.cpython-38.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/encoding.cpython-39.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/encoding.cpython-39.pyc
new file mode 100644
index 00000000..3d74d663
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/encoding.cpython-39.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/network.cpython-38.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/network.cpython-38.pyc
new file mode 100644
index 00000000..75817f03
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/network.cpython-38.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/network.cpython-39.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/network.cpython-39.pyc
new file mode 100644
index 00000000..b1204b1d
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/network.cpython-39.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/provider.cpython-38.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/provider.cpython-38.pyc
new file mode 100644
index 00000000..b37a7f71
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/provider.cpython-38.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/provider.cpython-39.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/provider.cpython-39.pyc
new file mode 100644
index 00000000..1667286a
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/provider.cpython-39.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/renderer.cpython-38.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/renderer.cpython-38.pyc
new file mode 100644
index 00000000..513961a5
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/renderer.cpython-38.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/renderer.cpython-39.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/renderer.cpython-39.pyc
new file mode 100644
index 00000000..3ad90012
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/renderer.cpython-39.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/utils.cpython-38.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 00000000..2c67af84
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/utils.cpython-38.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/__pycache__/utils.cpython-39.pyc b/sync/SyncTalk/nerf_triplane/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 00000000..7dcd69a6
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/__pycache__/utils.cpython-39.pyc differ
diff --git a/sync/SyncTalk/nerf_triplane/asr.py b/sync/SyncTalk/nerf_triplane/asr.py
new file mode 100644
index 00000000..b0bd0476
--- /dev/null
+++ b/sync/SyncTalk/nerf_triplane/asr.py
@@ -0,0 +1,419 @@
+import time
+import numpy as np
+import torch
+import torch.nn.functional as F
+from transformers import AutoModelForCTC, AutoProcessor
+
+import pyaudio
+import soundfile as sf
+import resampy
+
+from queue import Queue
+from threading import Thread, Event
+
+
+def _read_frame(stream, exit_event, queue, chunk):
+
+ while True:
+ if exit_event.is_set():
+ print(f'[INFO] read frame thread ends')
+ break
+ frame = stream.read(chunk, exception_on_overflow=False)
+ frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk]
+ queue.put(frame)
+
+def _play_frame(stream, exit_event, queue, chunk):
+
+ while True:
+ if exit_event.is_set():
+ print(f'[INFO] play frame thread ends')
+ break
+ frame = queue.get()
+ frame = (frame * 32767).astype(np.int16).tobytes()
+ stream.write(frame, chunk)
+
+class ASR:
+ def __init__(self, opt):
+
+ self.opt = opt
+
+ self.play = opt.asr_play
+
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.fps = opt.fps # 20 ms per frame
+ self.sample_rate = 16000
+ self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
+ self.mode = 'live' if opt.asr_wav == '' else 'file'
+
+ if 'esperanto' in self.opt.asr_model:
+ self.audio_dim = 44
+ elif 'deepspeech' in self.opt.asr_model:
+ self.audio_dim = 29
+ else:
+ self.audio_dim = 32
+
+ # prepare context cache
+ # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms
+ self.context_size = opt.m
+ self.stride_left_size = opt.l
+ self.stride_right_size = opt.r
+ self.text = '[START]\n'
+ self.terminated = False
+ self.frames = []
+
+ # pad left frames
+ if self.stride_left_size > 0:
+ self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
+
+
+ self.exit_event = Event()
+ self.audio_instance = pyaudio.PyAudio()
+
+ # create input stream
+ if self.mode == 'file':
+ self.file_stream = self.create_file_stream()
+ else:
+ # start a background process to read frames
+ self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
+ self.queue = Queue()
+ self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
+
+ # play out the audio too...?
+ if self.play:
+ self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk)
+ self.output_queue = Queue()
+ self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk))
+
+ # current location of audio
+ self.idx = 0
+
+ # create wav2vec model
+ print(f'[INFO] loading ASR model {self.opt.asr_model}...')
+ self.processor = AutoProcessor.from_pretrained(opt.asr_model)
+ self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
+
+ # prepare to save logits
+ if self.opt.asr_save_feats:
+ self.all_feats = []
+
+ # the extracted features
+ # use a loop queue to efficiently record endless features: [f--t---][-------][-------]
+ self.feat_buffer_size = 4
+ self.feat_buffer_idx = 0
+ self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device)
+
+ # TODO: hard coded 16 and 8 window size...
+ self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
+ self.tail = 8
+ # attention window...
+ self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
+
+ # warm up steps needed: mid + right + window_size + attention_size
+ self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
+
+ self.listening = False
+ self.playing = False
+
+ def listen(self):
+ # start
+ if self.mode == 'live' and not self.listening:
+ print(f'[INFO] starting read frame thread...')
+ self.process_read_frame.start()
+ self.listening = True
+
+ if self.play and not self.playing:
+ print(f'[INFO] starting play frame thread...')
+ self.process_play_frame.start()
+ self.playing = True
+
+ def stop(self):
+
+ self.exit_event.set()
+
+ if self.play:
+ self.output_stream.stop_stream()
+ self.output_stream.close()
+ if self.playing:
+ self.process_play_frame.join()
+ self.playing = False
+
+ if self.mode == 'live':
+ self.input_stream.stop_stream()
+ self.input_stream.close()
+ if self.listening:
+ self.process_read_frame.join()
+ self.listening = False
+
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+
+ self.stop()
+
+ if self.mode == 'live':
+ # live mode: also print the result text.
+ self.text += '\n[END]'
+ print(self.text)
+
+ def get_next_feat(self):
+ # return a [1/8, 16] window, for the next input to nerf side.
+
+ while len(self.att_feats) < 8:
+ # [------f+++t-----]
+ if self.front < self.tail:
+ feat = self.feat_queue[self.front:self.tail]
+ # [++t-----------f+]
+ else:
+ feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
+
+ self.front = (self.front + 2) % self.feat_queue.shape[0]
+ self.tail = (self.tail + 2) % self.feat_queue.shape[0]
+
+ # print(self.front, self.tail, feat.shape)
+
+ self.att_feats.append(feat.permute(1, 0))
+
+ att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
+
+ # discard old
+ self.att_feats = self.att_feats[1:]
+
+ return att_feat
+
+ def run_step(self):
+
+ if self.terminated:
+ return
+
+ # get a frame of audio
+ frame = self.get_audio_frame()
+
+ # the last frame
+ if frame is None:
+ # terminate, but always run the network for the left frames
+ self.terminated = True
+ else:
+ self.frames.append(frame)
+ # put to output
+ if self.play:
+ self.output_queue.put(frame)
+ # context not enough, do not run network.
+ if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
+ return
+
+ inputs = np.concatenate(self.frames) # [N * chunk]
+
+ # discard the old part to save memory
+ if not self.terminated:
+ self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
+
+ logits, labels, text = self.frame_to_text(inputs)
+ feats = logits # better lips-sync than labels
+
+ # save feats
+ if self.opt.asr_save_feats:
+ self.all_feats.append(feats)
+
+ # record the feats efficiently.. (no concat, constant memory)
+ start = self.feat_buffer_idx * self.context_size
+ end = start + feats.shape[0]
+ self.feat_queue[start:end] = feats
+ self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size
+
+ # very naive, just concat the text output.
+ if text != '':
+ self.text = self.text + ' ' + text
+
+ # will only run once at ternimation
+ if self.terminated:
+ self.text += '\n[END]'
+ print(self.text)
+ if self.opt.asr_save_feats:
+ print(f'[INFO] save all feats for training purpose... ')
+ feats = torch.cat(self.all_feats, dim=0) # [N, C]
+ # print('[INFO] before unfold', feats.shape)
+ window_size = 16
+ padding = window_size // 2
+ feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M]
+ feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1]
+ unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1]
+ unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C]
+ # print('[INFO] after unfold', unfold_feats.shape)
+ # save to a npy file
+ if 'esperanto' in self.opt.asr_model:
+ output_path = self.opt.asr_wav.replace('.wav', '_eo.npy')
+ else:
+ output_path = self.opt.asr_wav.replace('.wav', '.npy')
+ np.save(output_path, unfold_feats.cpu().numpy())
+ print(f"[INFO] saved logits to {output_path}")
+
+ def create_file_stream(self):
+
+ stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
+ stream = stream.astype(np.float32)
+
+ if stream.ndim > 1:
+ print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
+ stream = stream[:, 0]
+
+ if sample_rate != self.sample_rate:
+ print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
+ stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
+
+ print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
+
+ return stream
+
+
+ def create_pyaudio_stream(self):
+
+ import pyaudio
+
+ print(f'[INFO] creating live audio stream ...')
+
+ audio = pyaudio.PyAudio()
+
+ # get devices
+ info = audio.get_host_api_info_by_index(0)
+ n_devices = info.get('deviceCount')
+
+ for i in range(0, n_devices):
+ if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
+ name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
+ print(f'[INFO] choose audio device {name}, id {i}')
+ break
+
+ # get stream
+ stream = audio.open(input_device_index=i,
+ format=pyaudio.paInt16,
+ channels=1,
+ rate=self.sample_rate,
+ input=True,
+ frames_per_buffer=self.chunk)
+
+ return audio, stream
+
+
+ def get_audio_frame(self):
+
+ if self.mode == 'file':
+
+ if self.idx < self.file_stream.shape[0]:
+ frame = self.file_stream[self.idx: self.idx + self.chunk]
+ self.idx = self.idx + self.chunk
+ return frame
+ else:
+ return None
+
+ else:
+
+ frame = self.queue.get()
+ # print(f'[INFO] get frame {frame.shape}')
+
+ self.idx = self.idx + self.chunk
+
+ return frame
+
+
+ def frame_to_text(self, frame):
+ # frame: [N * 320], N = (context_size + 2 * stride_size)
+
+ inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
+
+ with torch.no_grad():
+ result = self.model(inputs.input_values.to(self.device))
+ logits = result.logits # [1, N - 1, 32]
+
+ # cut off stride
+ left = max(0, self.stride_left_size)
+ right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
+
+ # do not cut right if terminated.
+ if self.terminated:
+ right = logits.shape[1]
+
+ logits = logits[:, left:right]
+
+ # print(frame.shape, inputs.input_values.shape, logits.shape)
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ transcription = self.processor.batch_decode(predicted_ids)[0].lower()
+
+
+ # for esperanto
+ # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]'])
+
+ # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z'])
+ # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()]))
+ # print(predicted_ids[0])
+ # print(transcription)
+
+ return logits[0], predicted_ids[0], transcription # [N,]
+
+
+ def run(self):
+
+ self.listen()
+
+ while not self.terminated:
+ self.run_step()
+
+ def clear_queue(self):
+ # clear the queue, to reduce potential latency...
+ print(f'[INFO] clear queue')
+ if self.mode == 'live':
+ self.queue.queue.clear()
+ if self.play:
+ self.output_queue.queue.clear()
+
+ def warm_up(self):
+
+ self.listen()
+
+ print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
+ t = time.time()
+ for _ in range(self.warm_up_steps):
+ self.run_step()
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ t = time.time() - t
+ print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
+
+ self.clear_queue()
+
+
+
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--wav', type=str, default='')
+ parser.add_argument('--play', action='store_true', help="play out the audio")
+
+ parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto')
+ # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self')
+
+ parser.add_argument('--save_feats', action='store_true')
+ # audio FPS
+ parser.add_argument('--fps', type=int, default=50)
+ # sliding window left-middle-right length.
+ parser.add_argument('-l', type=int, default=10)
+ parser.add_argument('-m', type=int, default=50)
+ parser.add_argument('-r', type=int, default=10)
+
+ opt = parser.parse_args()
+
+ # fix
+ opt.asr_wav = opt.wav
+ opt.asr_play = opt.play
+ opt.asr_model = opt.model
+ opt.asr_save_feats = opt.save_feats
+
+ if 'deepspeech' in opt.asr_model:
+ raise ValueError("DeepSpeech features should not use this code to extract...")
+
+ with ASR(opt) as asr:
+ asr.run()
\ No newline at end of file
diff --git a/sync/SyncTalk/nerf_triplane/checkpoints/audio_visual_encoder.pth b/sync/SyncTalk/nerf_triplane/checkpoints/audio_visual_encoder.pth
new file mode 100644
index 00000000..9fe9b655
Binary files /dev/null and b/sync/SyncTalk/nerf_triplane/checkpoints/audio_visual_encoder.pth differ
diff --git a/sync/SyncTalk/nerf_triplane/encoding.py b/sync/SyncTalk/nerf_triplane/encoding.py
new file mode 100644
index 00000000..a4c69d3d
--- /dev/null
+++ b/sync/SyncTalk/nerf_triplane/encoding.py
@@ -0,0 +1,33 @@
+def get_encoder(encoding, input_dim=3,
+ multires=6,
+ degree=4,
+ num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
+ **kwargs):
+
+ if encoding == 'None':
+ return lambda x, **kwargs: x, input_dim
+
+ elif encoding == 'frequency':
+ from freqencoder import FreqEncoder
+ encoder = FreqEncoder(input_dim=input_dim, degree=multires)
+
+ elif encoding == 'spherical_harmonics':
+ from shencoder import SHEncoder
+ encoder = SHEncoder(input_dim=input_dim, degree=degree)
+
+ elif encoding == 'hashgrid':
+ from gridencoder import GridEncoder
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
+
+ elif encoding == 'tiledgrid':
+ from gridencoder import GridEncoder
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
+
+ elif encoding == 'ash':
+ from ashencoder import AshEncoder
+ encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution)
+
+ else:
+ raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]')
+
+ return encoder, encoder.output_dim
\ No newline at end of file
diff --git a/sync/SyncTalk/nerf_triplane/gui.py b/sync/SyncTalk/nerf_triplane/gui.py
new file mode 100644
index 00000000..b59fd004
--- /dev/null
+++ b/sync/SyncTalk/nerf_triplane/gui.py
@@ -0,0 +1,562 @@
+import dearpygui.dearpygui as dpg
+from scipy.spatial.transform import Rotation as R
+
+from .utils import *
+
+from .asr import ASR
+
+
+class OrbitCamera:
+ def __init__(self, W, H, r=2, fovy=60):
+ self.W = W
+ self.H = H
+ self.radius = r # camera distance from center
+ self.fovy = fovy # in degree
+ self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
+ self.rot = R.from_matrix([[0, -1, 0], [0, 0, -1], [1, 0, 0]]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention)
+ self.up = np.array([1, 0, 0], dtype=np.float32) # need to be normalized!
+
+ # pose
+ @property
+ def pose(self):
+ # first move camera to radius
+ res = np.eye(4, dtype=np.float32)
+ res[2, 3] -= self.radius
+ # rotate
+ rot = np.eye(4, dtype=np.float32)
+ rot[:3, :3] = self.rot.as_matrix()
+ res = rot @ res
+ # translate
+ res[:3, 3] -= self.center
+ return res
+
+ def update_pose(self, pose):
+ # pose: [4, 4] numpy array
+ # assert self.center is 0
+ self.radius = np.linalg.norm(pose[:3, 3])
+ T = np.eye(4)
+ T[2, 3] = -self.radius
+ rot = pose @ np.linalg.inv(T)
+ self.rot = R.from_matrix(rot[:3, :3])
+
+ def update_intrinsics(self, intrinsics):
+ fl_x, fl_y, cx, cy = intrinsics
+ self.W = int(cx * 2)
+ self.H = int(cy * 2)
+ self.fovy = np.rad2deg(2 * np.arctan2(self.H, 2 * fl_y))
+
+ # intrinsics
+ @property
+ def intrinsics(self):
+ focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
+ return np.array([focal, focal, self.W // 2, self.H // 2])
+
+ def orbit(self, dx, dy):
+ # rotate along camera up/side axis!
+ side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
+ rotvec_x = self.up * np.radians(-0.01 * dx)
+ rotvec_y = side * np.radians(-0.01 * dy)
+ self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
+
+ def scale(self, delta):
+ self.radius *= 1.1 ** (-delta)
+
+ def pan(self, dx, dy, dz=0):
+ # pan in camera coordinate system (careful on the sensitivity!)
+ self.center += 0.0001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])
+
+
+class NeRFGUI:
+ def __init__(self, opt, trainer, data_loader, debug=True):
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
+ self.W = opt.W
+ self.H = opt.H
+ self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
+ self.debug = debug
+ self.training = False
+ self.step = 0 # training step
+
+ self.trainer = trainer
+ self.data_loader = data_loader
+
+ # override with dataloader's intrinsics
+ self.W = data_loader._data.W
+ self.H = data_loader._data.H
+ self.cam.update_intrinsics(data_loader._data.intrinsics)
+
+ # use dataloader's pose
+ pose_init = data_loader._data.poses[0]
+ self.cam.update_pose(pose_init.detach().cpu().numpy())
+
+ # use dataloader's bg
+ bg_img = data_loader._data.bg_img #.view(1, -1, 3)
+ if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]:
+ bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous()
+ self.bg_color = bg_img.view(1, -1, 3)
+
+ # audio features (from dataloader, only used in non-playing mode)
+ self.audio_features = data_loader._data.auds # [N, 29, 16]
+ self.audio_idx = 0
+
+ # control eye
+ self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item()
+
+ # playing seq from dataloader, or pause.
+ self.playing = False
+ self.loader = iter(data_loader)
+
+ self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
+ self.need_update = True # camera moved, should reset accumulation
+ self.spp = 1 # sample per pixel
+ self.mode = 'image' # choose from ['image', 'depth']
+
+ self.dynamic_resolution = False # assert False!
+ self.downscale = 1
+ self.train_steps = 16
+
+ self.ind_index = 0
+ self.ind_num = trainer.model.individual_codes.shape[0]
+
+ # build asr
+ if self.opt.asr:
+ self.asr = ASR(opt)
+
+ dpg.create_context()
+ self.register_dpg()
+ self.test_step()
+
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if self.opt.asr:
+ self.asr.stop()
+ dpg.destroy_context()
+
+ def train_step(self):
+
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
+ starter.record()
+
+ outputs = self.trainer.train_gui(self.data_loader, step=self.train_steps)
+
+ ender.record()
+ torch.cuda.synchronize()
+ t = starter.elapsed_time(ender)
+
+ self.step += self.train_steps
+ self.need_update = True
+
+ dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
+ dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
+
+ # dynamic train steps
+ # max allowed train time per-frame is 500 ms
+ full_t = t / self.train_steps * 16
+ train_steps = min(16, max(4, int(16 * 500 / full_t)))
+ if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
+ self.train_steps = train_steps
+
+ def prepare_buffer(self, outputs):
+ if self.mode == 'image':
+ return outputs['image']
+ else:
+ return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
+
+ def test_step(self):
+
+ if self.need_update or self.spp < self.opt.max_spp:
+
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
+ starter.record()
+
+ if self.playing:
+ try:
+ data = next(self.loader)
+ except StopIteration:
+ self.loader = iter(self.data_loader)
+ data = next(self.loader)
+
+ if self.opt.asr:
+ # use the live audio stream
+ data['auds'] = self.asr.get_next_feat()
+
+ outputs = self.trainer.test_gui_with_data(data, self.W, self.H)
+
+ # sync local camera pose
+ self.cam.update_pose(data['poses_matrix'][0].detach().cpu().numpy())
+
+ else:
+ if self.audio_features is not None:
+ auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx)
+ else:
+ auds = None
+ outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale)
+
+ ender.record()
+ torch.cuda.synchronize()
+ t = starter.elapsed_time(ender)
+
+ # update dynamic resolution
+ if self.dynamic_resolution:
+ # max allowed infer time per-frame is 200 ms
+ full_t = t / (self.downscale ** 2)
+ downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
+ if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
+ self.downscale = downscale
+
+ if self.need_update:
+ self.render_buffer = self.prepare_buffer(outputs)
+ self.spp = 1
+ self.need_update = False
+ else:
+ self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
+ self.spp += 1
+
+ if self.playing:
+ self.need_update = True
+
+ dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
+ dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
+ dpg.set_value("_log_spp", self.spp)
+ dpg.set_value("_texture", self.render_buffer)
+
+
+ def register_dpg(self):
+
+ ### register texture
+
+ with dpg.texture_registry(show=False):
+ dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
+
+ ### register window
+
+ # the rendered image, as the primary window
+ with dpg.window(tag="_primary_window", width=self.W, height=self.H):
+
+ # add the texture
+ dpg.add_image("_texture")
+
+ # dpg.set_primary_window("_primary_window", True)
+
+ dpg.show_tool(dpg.mvTool_Metrics)
+
+ # control window
+ with dpg.window(label="Control", tag="_control_window", width=400, height=300):
+
+ # button theme
+ with dpg.theme() as theme_button:
+ with dpg.theme_component(dpg.mvButton):
+ dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
+ dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
+
+ # time
+ if not self.opt.test:
+ with dpg.group(horizontal=True):
+ dpg.add_text("Train time: ")
+ dpg.add_text("no data", tag="_log_train_time")
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("Infer time: ")
+ dpg.add_text("no data", tag="_log_infer_time")
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("SPP: ")
+ dpg.add_text("1", tag="_log_spp")
+
+ # train button
+ if not self.opt.test:
+ with dpg.collapsing_header(label="Train", default_open=True):
+
+ # train / stop
+ with dpg.group(horizontal=True):
+ dpg.add_text("Train: ")
+
+ def callback_train(sender, app_data):
+ if self.training:
+ self.training = False
+ dpg.configure_item("_button_train", label="start")
+ else:
+ self.training = True
+ dpg.configure_item("_button_train", label="stop")
+
+ dpg.add_button(label="start", tag="_button_train", callback=callback_train)
+ dpg.bind_item_theme("_button_train", theme_button)
+
+ def callback_reset(sender, app_data):
+ @torch.no_grad()
+ def weight_reset(m: nn.Module):
+ reset_parameters = getattr(m, "reset_parameters", None)
+ if callable(reset_parameters):
+ m.reset_parameters()
+ self.trainer.model.apply(fn=weight_reset)
+ self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
+ self.need_update = True
+
+ dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
+ dpg.bind_item_theme("_button_reset", theme_button)
+
+ # save ckpt
+ with dpg.group(horizontal=True):
+ dpg.add_text("Checkpoint: ")
+
+ def callback_save(sender, app_data):
+ self.trainer.save_checkpoint(full=True, best=False)
+ dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
+ self.trainer.epoch += 1 # use epoch to indicate different calls.
+
+ dpg.add_button(label="save", tag="_button_save", callback=callback_save)
+ dpg.bind_item_theme("_button_save", theme_button)
+
+ dpg.add_text("", tag="_log_ckpt")
+
+ # save mesh
+ with dpg.group(horizontal=True):
+ dpg.add_text("Marching Cubes: ")
+
+ def callback_mesh(sender, app_data):
+ self.trainer.save_mesh(resolution=256, threshold=10)
+ dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
+ self.trainer.epoch += 1 # use epoch to indicate different calls.
+
+ dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
+ dpg.bind_item_theme("_button_mesh", theme_button)
+
+ dpg.add_text("", tag="_log_mesh")
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("", tag="_log_train_log")
+
+
+ # rendering options
+ with dpg.collapsing_header(label="Options", default_open=True):
+
+ # playing
+ with dpg.group(horizontal=True):
+ dpg.add_text("Play: ")
+
+ def callback_play(sender, app_data):
+
+ if self.playing:
+ self.playing = False
+ dpg.configure_item("_button_play", label="start")
+ else:
+ self.playing = True
+ dpg.configure_item("_button_play", label="stop")
+ if self.opt.asr:
+ self.asr.warm_up()
+ self.need_update = True
+
+ dpg.add_button(label="start", tag="_button_play", callback=callback_play)
+ dpg.bind_item_theme("_button_play", theme_button)
+
+ # set asr
+ if self.opt.asr:
+
+ # clear queue button
+ def callback_clear_queue(sender, app_data):
+
+ self.asr.clear_queue()
+ self.need_update = True
+
+ dpg.add_button(label="clear", tag="_button_clear_queue", callback=callback_clear_queue)
+ dpg.bind_item_theme("_button_clear_queue", theme_button)
+
+ # dynamic rendering resolution
+ with dpg.group(horizontal=True):
+
+ def callback_set_dynamic_resolution(sender, app_data):
+ if self.dynamic_resolution:
+ self.dynamic_resolution = False
+ self.downscale = 1
+ else:
+ self.dynamic_resolution = True
+ self.need_update = True
+
+ # Disable dynamic resolution for face.
+ # dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
+ dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
+
+ # mode combo
+ def callback_change_mode(sender, app_data):
+ self.mode = app_data
+ self.need_update = True
+
+ dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
+
+
+ # bg_color picker
+ def callback_change_bg(sender, app_data):
+ self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
+ self.need_update = True
+
+ dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
+
+ # audio index slider
+ if not self.opt.asr:
+ def callback_set_audio_index(sender, app_data):
+ self.audio_idx = app_data
+ self.need_update = True
+
+ dpg.add_slider_int(label="Audio", min_value=0, max_value=self.audio_features.shape[0] - 1, format="%d", default_value=self.audio_idx, callback=callback_set_audio_index)
+
+ # ind code index slider
+ if self.opt.ind_dim > 0:
+ def callback_set_individual_code(sender, app_data):
+ self.ind_index = app_data
+ self.need_update = True
+
+ dpg.add_slider_int(label="Individual", min_value=0, max_value=self.ind_num - 1, format="%d", default_value=self.ind_index, callback=callback_set_individual_code)
+
+ # eye area slider
+ if self.opt.exp_eye:
+ def callback_set_eye(sender, app_data):
+ self.eye_area = app_data
+ self.need_update = True
+
+ dpg.add_slider_float(label="eye area", min_value=0, max_value=0.5, format="%.2f percent", default_value=self.eye_area, callback=callback_set_eye)
+
+ # fov slider
+ def callback_set_fovy(sender, app_data):
+ self.cam.fovy = app_data
+ self.need_update = True
+
+ dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
+
+ # dt_gamma slider
+ def callback_set_dt_gamma(sender, app_data):
+ self.opt.dt_gamma = app_data
+ self.need_update = True
+
+ dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
+
+ # max_steps slider
+ def callback_set_max_steps(sender, app_data):
+ self.opt.max_steps = app_data
+ self.need_update = True
+
+ dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
+
+ # aabb slider
+ def callback_set_aabb(sender, app_data, user_data):
+ # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
+ self.trainer.model.aabb_infer[user_data] = app_data
+
+ # also change train aabb ? [better not...]
+ #self.trainer.model.aabb_train[user_data] = app_data
+
+ self.need_update = True
+
+ dpg.add_separator()
+ dpg.add_text("Axis-aligned bounding box:")
+
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
+
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
+
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
+
+
+ # debug info
+ if self.debug:
+ with dpg.collapsing_header(label="Debug"):
+ # pose
+ dpg.add_separator()
+ dpg.add_text("Camera Pose:")
+ dpg.add_text(str(self.cam.pose), tag="_log_pose")
+
+
+ ### register camera handler
+
+ def callback_camera_drag_rotate(sender, app_data):
+
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ dx = app_data[1]
+ dy = app_data[2]
+
+ self.cam.orbit(dx, dy)
+ self.need_update = True
+
+ if self.debug:
+ dpg.set_value("_log_pose", str(self.cam.pose))
+
+
+ def callback_camera_wheel_scale(sender, app_data):
+
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ delta = app_data
+
+ self.cam.scale(delta)
+ self.need_update = True
+
+ if self.debug:
+ dpg.set_value("_log_pose", str(self.cam.pose))
+
+
+ def callback_camera_drag_pan(sender, app_data):
+
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ dx = app_data[1]
+ dy = app_data[2]
+
+ self.cam.pan(dx, dy)
+ self.need_update = True
+
+ if self.debug:
+ dpg.set_value("_log_pose", str(self.cam.pose))
+
+
+ with dpg.handler_registry():
+ dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
+ dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan)
+
+
+ dpg.create_viewport(title='SyncTalk', width=1080, height=720, resizable=True)
+
+ ### global theme
+ with dpg.theme() as theme_no_padding:
+ with dpg.theme_component(dpg.mvAll):
+ # set all padding to 0 to avoid scroll bar
+ dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
+
+ dpg.bind_item_theme("_primary_window", theme_no_padding)
+
+ dpg.setup_dearpygui()
+
+ #dpg.show_metrics()
+
+ dpg.show_viewport()
+
+
+ def render(self):
+
+ while dpg.is_dearpygui_running():
+ # update texture every frame
+ if self.training:
+ self.train_step()
+ # audio stream thread...
+ if self.opt.asr and self.playing:
+ # run 2 ASR steps (audio is at 50FPS, video is at 25FPS)
+ for _ in range(2):
+ self.asr.run_step()
+ self.test_step()
+ dpg.render_dearpygui_frame()
\ No newline at end of file
diff --git a/sync/SyncTalk/nerf_triplane/network.py b/sync/SyncTalk/nerf_triplane/network.py
new file mode 100644
index 00000000..4fe7fe57
--- /dev/null
+++ b/sync/SyncTalk/nerf_triplane/network.py
@@ -0,0 +1,436 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .encoding import get_encoder
+from .renderer import NeRFRenderer
+
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, leakyReLU=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ if leakyReLU:
+ self.act = nn.LeakyReLU(0.02)
+ else:
+ self.act = nn.ReLU()
+ self.residual = residual
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+ return self.act(out)
+
+
+# Audio feature extractor
+class AudioAttNet(nn.Module):
+ def __init__(self, dim_aud=64, seq_len=8):
+ super(AudioAttNet, self).__init__()
+ self.seq_len = seq_len
+ self.dim_aud = dim_aud
+ self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len
+ nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True)
+ )
+ self.attentionNet = nn.Sequential(
+ nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True),
+ nn.Softmax(dim=1)
+ )
+
+ def forward(self, x):
+ # x: [1, seq_len, dim_aud]
+ y = x.permute(0, 2, 1) # [1, dim_aud, seq_len]
+ y = self.attentionConvNet(y)
+ y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1)
+ return torch.sum(y * x, dim=1) # [1, dim_aud]
+
+
+class AudioEncoder(nn.Module):
+ def __init__(self):
+ super(AudioEncoder, self).__init__()
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0), )
+
+ def forward(self, x):
+ out = self.audio_encoder(x)
+ out = out.squeeze(2).squeeze(2)
+
+ return out
+
+# Audio feature extractor
+class AudioNet(nn.Module):
+ def __init__(self, dim_in=29, dim_aud=64, win_size=16):
+ super(AudioNet, self).__init__()
+ self.win_size = win_size
+ self.dim_aud = dim_aud
+ self.encoder_conv = nn.Sequential( # n x 29 x 16
+ nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 8
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 4
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 2
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 1
+ nn.LeakyReLU(0.02, True),
+ )
+ self.encoder_fc1 = nn.Sequential(
+ nn.Linear(64, 64),
+ nn.LeakyReLU(0.02, True),
+ nn.Linear(64, dim_aud),
+ )
+
+ def forward(self, x):
+ half_w = int(self.win_size/2)
+ x = x[:, :, 8-half_w:8+half_w]
+ x = self.encoder_conv(x).squeeze(-1)
+ x = self.encoder_fc1(x)
+ return x
+
+
+# Audio feature extractor
+class AudioNet_ave(nn.Module):
+ def __init__(self, dim_in=29, dim_aud=64, win_size=16):
+ super(AudioNet_ave, self).__init__()
+ self.win_size = win_size
+ self.dim_aud = dim_aud
+ self.encoder_fc1 = nn.Sequential(
+ nn.Linear(512, 256),
+ nn.LeakyReLU(0.02, True),
+ nn.Linear(256, 128),
+ nn.LeakyReLU(0.02, True),
+ nn.Linear(128, dim_aud),
+ )
+ def forward(self, x):
+ # half_w = int(self.win_size/2)
+ # x = x[:, :, 8-half_w:8+half_w]
+ # x = self.encoder_conv(x).squeeze(-1)
+ x = self.encoder_fc1(x).permute(1,0,2).squeeze(0)
+ return x
+
+class MLP(nn.Module):
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers):
+ super().__init__()
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.dim_hidden = dim_hidden
+ self.num_layers = num_layers
+
+ net = []
+ for l in range(num_layers):
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False))
+
+ self.net = nn.ModuleList(net)
+
+ def forward(self, x):
+ for l in range(self.num_layers):
+ x = self.net[l](x)
+ if l != self.num_layers - 1:
+ x = F.relu(x, inplace=True)
+ # x = F.dropout(x, p=0.1, training=self.training)
+
+ return x
+
+
+class NeRFNetwork(NeRFRenderer):
+ def __init__(self,
+ opt,
+ audio_dim = 32,
+ # torso net (hard coded for now)
+ ):
+ super().__init__(opt)
+
+ # audio embedding
+ self.emb = self.opt.emb
+
+ if 'esperanto' in self.opt.asr_model:
+ self.audio_in_dim = 44
+ elif 'deepspeech' in self.opt.asr_model:
+ self.audio_in_dim = 29
+ elif 'hubert' in self.opt.asr_model:
+ self.audio_in_dim = 1024
+ else:
+ self.audio_in_dim = 32
+
+ if self.emb:
+ self.embedding = nn.Embedding(self.audio_in_dim, self.audio_in_dim)
+
+ # audio network
+ self.audio_dim = audio_dim
+ if self.opt.asr_model == 'ave':
+ self.audio_net = AudioNet_ave(self.audio_in_dim, self.audio_dim)
+ else:
+ self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim)
+
+ self.att = self.opt.att
+ if self.att > 0:
+ self.audio_att_net = AudioAttNet(self.audio_dim)
+
+ # DYNAMIC PART
+ self.num_levels = 12
+ self.level_dim = 1
+ self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
+ self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
+ self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound)
+
+ self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz
+
+ ## sigma network
+ self.num_layers = 3
+ self.hidden_dim = 64
+ self.geo_feat_dim = 64
+ if self.opt.bs_area == "upper":
+ self.eye_att_net = MLP(self.in_dim, 7, 64, 2)
+ self.eye_dim = 7 if self.exp_eye else 0
+ elif self.opt.bs_area == "single":
+ self.eye_att_net = MLP(self.in_dim, 4, 64, 2)
+ self.eye_dim = 4 if self.exp_eye else 0
+ elif self.opt.bs_area == "eye":
+ self.eye_att_net = MLP(self.in_dim, 2, 64, 2)
+ self.eye_dim = 2 if self.exp_eye else 0
+ self.sigma_net = MLP(self.in_dim + self.audio_dim + self.eye_dim, 1 + self.geo_feat_dim, self.hidden_dim, self.num_layers)
+ ## color network
+ self.num_layers_color = 2
+ self.hidden_dim_color = 64
+ self.encoder_dir, self.in_dim_dir = get_encoder('spherical_harmonics')
+ self.color_net = MLP(self.in_dim_dir + self.geo_feat_dim + self.individual_dim, 3, self.hidden_dim_color, self.num_layers_color)
+
+ self.unc_net = MLP(self.in_dim, 1, 32, 2)
+
+ self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 64, 2)
+
+ self.testing = False
+
+ if self.torso:
+ # torso deform network
+ self.register_parameter('anchor_points',
+ nn.Parameter(torch.tensor([[0.01, 0.01, 0.1, 1], [-0.1, -0.1, 0.1, 1], [0.1, -0.1, 0.1, 1]])))
+ self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('frequency', input_dim=2, multires=8)
+ # self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=1, base_resolution=16, log2_hashmap_size=16, desired_resolution=512)
+ self.anchor_encoder, self.anchor_in_dim = get_encoder('frequency', input_dim=6, multires=3)
+ self.torso_deform_net = MLP(self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 2, 32, 3)
+
+ # torso color network
+ self.torso_encoder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048)
+ self.torso_net = MLP(self.torso_in_dim + self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 4, 32, 3)
+
+
+ def forward_torso(self, x, poses, c=None):
+ # x: [N, 2] in [-1, 1]
+ # head poses: [1, 4, 4]
+ # c: [1, ind_dim], individual code
+
+ # test: shrink x
+ x = x * self.opt.torso_shrink
+
+ # deformation-based
+ wrapped_anchor = self.anchor_points[None, ...] @ poses.permute(0, 2, 1).inverse()
+ wrapped_anchor = (wrapped_anchor[:, :, :2] / wrapped_anchor[:, :, 3, None] / wrapped_anchor[:, :, 2, None]).view(1, -1)
+ # print(wrapped_anchor)
+ # enc_pose = self.pose_encoder(poses)
+ enc_anchor = self.anchor_encoder(wrapped_anchor)
+ enc_x = self.torso_deform_encoder(x)
+
+ if c is not None:
+ h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1)
+ else:
+ h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1)], dim=-1)
+
+ dx = self.torso_deform_net(h)
+
+ x = (x + dx).clamp(-1, 1)
+
+ x = self.torso_encoder(x, bound=1)
+
+ # h = torch.cat([x, h, enc_a.repeat(x.shape[0], 1)], dim=-1)
+ h = torch.cat([x, h], dim=-1)
+
+ h = self.torso_net(h)
+
+ alpha = torch.sigmoid(h[..., :1])*(1 + 2*0.001) - 0.001
+ color = torch.sigmoid(h[..., 1:])*(1 + 2*0.001) - 0.001
+
+ return alpha, color, dx
+
+
+ @staticmethod
+ @torch.jit.script
+ def split_xyz(x):
+ xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1)
+ return xy, yz, xz
+
+
+ def encode_x(self, xyz, bound):
+ # x: [N, 3], in [-bound, bound]
+ N, M = xyz.shape
+ xy, yz, xz = self.split_xyz(xyz)
+ feat_xy = self.encoder_xy(xy, bound=bound)
+ feat_yz = self.encoder_yz(yz, bound=bound)
+ feat_xz = self.encoder_xz(xz, bound=bound)
+
+ return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1)
+
+
+ def encode_audio(self, a):
+ # a: [1, 29, 16] or [8, 29, 16], audio features from deepspeech
+ # if emb, a should be: [1, 16] or [8, 16]
+
+ # fix audio traininig
+ if a is None: return None
+
+ if self.emb:
+ a = self.embedding(a).transpose(-1, -2).contiguous() # [1/8, 29, 16]
+
+ enc_a = self.audio_net(a) # [8,32]
+
+ if self.att > 0:
+ enc_a = self.audio_att_net(enc_a.unsqueeze(0)) # [1, 32]
+
+ return enc_a
+
+
+ def predict_uncertainty(self, unc_inp):
+ if self.testing or not self.opt.unc_loss:
+ unc = torch.zeros_like(unc_inp)
+ else:
+ unc = self.unc_net(unc_inp.detach())
+
+ return unc
+
+
+ def forward(self, x, d, enc_a, c, e=None):
+ # x: [N, 3], in [-bound, bound]
+ # d: [N, 3], nomalized in [-1, 1]
+ # enc_a: [1, aud_dim]
+ # c: [1, ind_dim], individual code
+ # e: [1, 1], eye feature
+ enc_x = self.encode_x(x, bound=self.bound)
+
+ sigma_result = self.density(x, enc_a, e, enc_x)
+ sigma = sigma_result['sigma']
+ geo_feat = sigma_result['geo_feat']
+ aud_ch_att = sigma_result['ambient_aud']
+ eye_att = sigma_result['ambient_eye']
+
+ # color
+ enc_d = self.encoder_dir(d)
+
+ if c is not None:
+ h = torch.cat([enc_d, geo_feat, c.repeat(x.shape[0], 1)], dim=-1)
+ else:
+ h = torch.cat([enc_d, geo_feat], dim=-1)
+
+ h_color = self.color_net(h)
+ color = torch.sigmoid(h_color)*(1 + 2*0.001) - 0.001
+
+ uncertainty = self.predict_uncertainty(enc_x)
+ uncertainty = torch.log(1 + torch.exp(uncertainty))
+
+ return sigma, color, aud_ch_att, eye_att, uncertainty[..., None]
+
+
+ def density(self, x, enc_a, e=None, enc_x=None):
+ # x: [N, 3], in [-bound, bound]
+ if enc_x is None:
+ enc_x = self.encode_x(x, bound=self.bound)
+
+ enc_a = enc_a.repeat(enc_x.shape[0], 1)
+ aud_ch_att = self.aud_ch_att_net(enc_x)
+ enc_w = enc_a * aud_ch_att
+
+ if e is not None:
+ # e = self.encoder_eye(e)
+ # eye_att = torch.sigmoid(self.eye_att_net(enc_x))
+ e = e.repeat(enc_x.shape[0], 1)
+ eye_att = self.eye_att_net(enc_x)
+ e = e * eye_att
+ # e = e.repeat(enc_x.shape[0], 1)
+ h = torch.cat([enc_x, enc_w, e], dim=-1)
+ else:
+ h = torch.cat([enc_x, enc_w], dim=-1)
+
+ h = self.sigma_net(h)
+
+ sigma = torch.exp(h[..., 0])
+ geo_feat = h[..., 1:]
+
+ return {
+ 'sigma': sigma,
+ 'geo_feat': geo_feat,
+ 'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True),
+ 'ambient_eye' : eye_att.norm(dim=-1, keepdim=True),
+ }
+
+
+ # optimizer utils
+ def get_params(self, lr, lr_net, wd=0):
+
+ # ONLY train torso
+ if self.torso:
+ params = [
+ {'params': self.torso_encoder.parameters(), 'lr': lr},
+ {'params': self.torso_deform_encoder.parameters(), 'lr': lr, 'weight_decay': wd},
+ {'params': self.torso_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
+ {'params': self.torso_deform_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
+ {'params': self.anchor_points, 'lr': lr_net, 'weight_decay': wd}
+ ]
+
+ if self.individual_dim_torso > 0:
+ params.append({'params': self.individual_codes_torso, 'lr': lr_net, 'weight_decay': wd})
+
+ return params
+
+ params = [
+ {'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
+
+ {'params': self.encoder_xy.parameters(), 'lr': lr},
+ {'params': self.encoder_yz.parameters(), 'lr': lr},
+ {'params': self.encoder_xz.parameters(), 'lr': lr},
+ # {'params': self.encoder_xyz.parameters(), 'lr': lr},
+
+ {'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
+ {'params': self.color_net.parameters(), 'lr': lr_net, 'weight_decay': wd},
+ ]
+ if self.att > 0:
+ params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001})
+ if self.emb:
+ params.append({'params': self.embedding.parameters(), 'lr': lr})
+ if self.individual_dim > 0:
+ params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd})
+ if self.train_camera:
+ params.append({'params': self.camera_dT, 'lr': 1e-5, 'weight_decay': 0})
+ params.append({'params': self.camera_dR, 'lr': 1e-5, 'weight_decay': 0})
+
+ params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
+ params.append({'params': self.unc_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
+ params.append({'params': self.eye_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd})
+
+ return params
diff --git a/sync/SyncTalk/nerf_triplane/provider.py b/sync/SyncTalk/nerf_triplane/provider.py
new file mode 100644
index 00000000..6f4a67f9
--- /dev/null
+++ b/sync/SyncTalk/nerf_triplane/provider.py
@@ -0,0 +1,647 @@
+import os
+import cv2
+import glob
+import json
+import tqdm
+import numpy as np
+from scipy.spatial.transform import Rotation
+from .network import AudioEncoder
+import trimesh
+
+import torch
+from torch.utils.data import DataLoader
+
+from .utils import get_audio_features, get_rays, get_bg_coords, AudDataset
+
+# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
+def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]):
+ new_pose = np.array([
+ [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]],
+ [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]],
+ [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]],
+ [0, 0, 0, 1],
+ ], dtype=np.float32)
+ return new_pose
+
+
+def smooth_camera_path(poses, kernel_size=5):
+ # smooth the camera trajectory...
+ # poses: [N, 4, 4], numpy array
+
+ N = poses.shape[0]
+ K = kernel_size // 2
+
+ trans = poses[:, :3, 3].copy() # [N, 3]
+ rots = poses[:, :3, :3].copy() # [N, 3, 3]
+
+ for i in range(N):
+ start = max(0, i - K)
+ end = min(N, i + K + 1)
+ poses[i, :3, 3] = trans[start:end].mean(0)
+ poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
+
+ return poses
+
+def polygon_area(x, y):
+ x_ = x - x.mean()
+ y_ = y - y.mean()
+ correction = x_[-1] * y_[0] - y_[-1]* x_[0]
+ main_area = np.dot(x_[:-1], y_[1:]) - np.dot(y_[:-1], x_[1:])
+ return 0.5 * np.abs(main_area + correction)
+
+
+def visualize_poses(poses, size=0.1):
+ # poses: [B, 4, 4]
+
+ print(f'[INFO] visualize poses: {poses.shape}')
+
+ axes = trimesh.creation.axis(axis_length=4)
+ box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
+ box.colors = np.array([[128, 128, 128]] * len(box.entities))
+ objects = [axes, box]
+
+ for pose in poses:
+ # a camera is visualized with 8 line segments.
+ pos = pose[:3, 3]
+ a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
+ b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
+ c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
+ d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
+
+ dir = (a + b + c + d) / 4 - pos
+ dir = dir / (np.linalg.norm(dir) + 1e-8)
+ o = pos + dir * 3
+
+ segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]])
+ segs = trimesh.load_path(segs)
+ objects.append(segs)
+
+ trimesh.Scene(objects).show()
+
+
+class NeRFDataset:
+ def __init__(self, opt, device, type='train', downscale=1):
+ super().__init__()
+
+ self.opt = opt
+ self.device = device
+ self.type = type # train, val, test
+ self.downscale = downscale
+ self.root_path = opt.path
+ self.preload = opt.preload # 0 = disk, 1 = cpu, 2 = gpu
+ self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box.
+ self.offset = opt.offset # camera offset
+ self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses.
+ self.fp16 = opt.fp16
+
+ self.start_index = opt.data_range[0]
+ self.end_index = opt.data_range[1]
+
+ self.training = self.type in ['train', 'all', 'trainval']
+ self.num_rays = self.opt.num_rays if self.training else -1
+
+ # load nerf-compatible format data.
+
+ # load all splits (train/valid/test)
+ if type == 'all':
+ transform_paths = glob.glob(os.path.join(self.root_path, '*.json'))
+ transform = None
+ for transform_path in transform_paths:
+ with open(transform_path, 'r') as f:
+ tmp_transform = json.load(f)
+ if transform is None:
+ transform = tmp_transform
+ else:
+ transform['frames'].extend(tmp_transform['frames'])
+ # load train and val split
+ elif type == 'trainval':
+ with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f:
+ transform = json.load(f)
+ with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f:
+ transform_val = json.load(f)
+ transform['frames'].extend(transform_val['frames'])
+ # only load one specified split
+ else:
+ # no test, use val as test
+ _split = 'val' if type == 'test' else type
+ with open(os.path.join(self.root_path, f'transforms_{_split}.json'), 'r') as f:
+ transform = json.load(f)
+
+ # load image size
+ if 'h' in transform and 'w' in transform:
+ self.H = int(transform['h']) // downscale
+ self.W = int(transform['w']) // downscale
+ else:
+ self.H = int(transform['cy']) * 2 // downscale
+ self.W = int(transform['cx']) * 2 // downscale
+
+ # read images
+ frames = transform["frames"]
+
+ # use a slice of the dataset
+ if self.end_index == -1: # abuse...
+ self.end_index = len(frames)
+
+ frames = frames[self.start_index:self.end_index]
+
+ # use a subset of dataset.
+ if type == 'train':
+ if self.opt.part:
+ frames = frames[::10] # 1/10 frames
+ elif self.opt.part2:
+ frames = frames[:375] # first 15s
+ elif type == 'val':
+ frames = frames[:100] # first 100 frames for val
+
+ print(f'[INFO] load {len(frames)} {type} frames.')
+
+ # only load pre-calculated aud features when not live-streaming
+ if not self.opt.asr:
+
+ # empty means the default self-driven extracted features.
+ if self.opt.aud == '':
+ if 'esperanto' in self.opt.asr_model:
+ aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy'))
+ elif 'deepspeech' in self.opt.asr_model:
+ aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy'))
+ # elif 'hubert_cn' in self.opt.asr_model:
+ # aud_features = np.load(os.path.join(self.root_path, 'aud_hu_cn.npy'))
+ elif 'hubert' in self.opt.asr_model:
+ aud_features = np.load(os.path.join(self.root_path, 'aud_hu.npy'))
+ elif self.opt.asr_model == 'ave':
+ aud_features = np.load(os.path.join(self.root_path, 'aud_ave.npy'))
+ else:
+ aud_features = np.load(os.path.join(self.root_path, 'aud.npy'))
+ # cross-driven extracted features.
+ else:
+ if self.opt.asr_model == 'ave':
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ model = AudioEncoder().to(device).eval()
+ ckpt = torch.load('./nerf_triplane/checkpoints/audio_visual_encoder.pth')
+ model.load_state_dict({f'audio_encoder.{k}': v for k, v in ckpt.items()})
+ dataset = AudDataset(self.opt.aud)
+ data_loader = DataLoader(dataset, batch_size=64, shuffle=False)
+ outputs = []
+ for mel in data_loader:
+ mel = mel.to(device)
+ with torch.no_grad():
+ out = model(mel)
+ outputs.append(out)
+ outputs = torch.cat(outputs, dim=0).cpu()
+ first_frame, last_frame = outputs[:1], outputs[-1:]
+ aud_features = torch.cat([first_frame.repeat(2, 1), outputs, last_frame.repeat(2, 1)], dim=0).numpy()
+ else:
+ try:
+ aud_features = np.load(self.opt.aud)
+ except:
+ print(f'[ERROR] If do not use Audio Visual Encoder, replace it with the npy file path')
+
+ if self.opt.asr_model == 'ave':
+ aud_features = torch.from_numpy(aud_features).unsqueeze(0)
+
+ # support both [N, 16] labels and [N, 16, K] logits
+ if len(aud_features.shape) == 3:
+ aud_features = aud_features.float().permute(1, 0, 2) # [N, 16, 29] --> [N, 29, 16]
+
+ if self.opt.emb:
+ print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode')
+ aud_features = aud_features.argmax(1) # [N, 16]
+
+ else:
+ assert self.opt.emb, "aud only provide labels, must use --emb"
+ aud_features = aud_features.long()
+
+ print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}')
+ else:
+ aud_features = torch.from_numpy(aud_features)
+
+ # support both [N, 16] labels and [N, 16, K] logits
+ if len(aud_features.shape) == 3:
+ aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16]
+
+ if self.opt.emb:
+ print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode')
+ aud_features = aud_features.argmax(1) # [N, 16]
+
+ else:
+ assert self.opt.emb, "aud only provide labels, must use --emb"
+ aud_features = aud_features.long()
+
+ print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}')
+
+
+ bs = np.load(os.path.join(self.root_path, 'bs.npy'))
+ if self.opt.bs_area == "upper":
+ bs = np.hstack((bs[:, 0:5], bs[:, 8:10]))
+ elif self.opt.bs_area == "single":
+ bs = np.hstack((bs[:, 0].reshape(-1, 1),bs[:, 2].reshape(-1, 1),bs[:, 3].reshape(-1, 1), bs[:, 8].reshape(-1, 1)))
+ elif self.opt.bs_area == "eye":
+ bs = bs[:,8:10]
+
+
+ self.torso_img = []
+ self.images = []
+ self.gt_images = []
+ self.face_mask_imgs = []
+
+ self.poses = []
+ self.exps = []
+
+ self.auds = []
+ self.face_rect = []
+ self.lhalf_rect = []
+ self.upface_rect = []
+ self.lowface_rect = []
+ self.lips_rect = []
+ self.eye_area = []
+ self.eye_rect = []
+
+ for f in tqdm.tqdm(frames, desc=f'Loading {type} data'):
+
+ f_path = os.path.join(self.root_path, 'gt_imgs', str(f['img_id']) + '.jpg')
+
+ if not os.path.exists(f_path):
+ print('[WARN]', f_path, 'NOT FOUND!')
+ continue
+
+ pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4]
+ pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)
+ self.poses.append(pose)
+
+ if self.preload > 0:
+ image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4]
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = image.astype(np.float32) / 255 # [H, W, 3/4]
+
+ self.images.append(image)
+ else:
+ self.images.append(f_path)
+
+ if self.opt.portrait:
+ gt_path = os.path.join(self.root_path, 'ori_imgs', str(f['img_id']) + '.jpg')
+ # gt_path = os.path.join(self.root_path, 'torso_imgs', str(f['img_id']) + '_no_face.png')
+ if not os.path.exists(f_path):
+ print('[WARN]', f_path, 'NOT FOUND!')
+ continue
+ if self.preload > 0:
+ gt_image = cv2.imread(gt_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4]
+ gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)
+ gt_image = gt_image.astype(np.float32) / 255 # [H, W, 3/4]
+
+ self.gt_images.append(gt_image)
+ else:
+ self.gt_images.append(gt_path)
+
+ face_mask_path = os.path.join(self.root_path, 'parsing', str(f['img_id']) + '_face.png')
+ if not os.path.exists(face_mask_path):
+ print('[WARN]', face_mask_path, 'NOT FOUND!')
+ continue
+ if self.preload > 0:
+ face_mask_img = (255 - cv2.imread(face_mask_path)[:, :, 1]) / 255.0
+ self.face_mask_imgs.append(face_mask_img)
+ else:
+ self.face_mask_imgs.append(face_mask_path)
+
+ # load frame-wise bg
+
+ torso_img_path = os.path.join(self.root_path, 'torso_imgs', str(f['img_id']) + '.png')
+
+ if self.preload > 0:
+ torso_img = cv2.imread(torso_img_path, cv2.IMREAD_UNCHANGED) # [H, W, 4]
+ torso_img = cv2.cvtColor(torso_img, cv2.COLOR_BGRA2RGBA)
+ torso_img = torso_img.astype(np.float32) / 255 # [H, W, 3/4]
+
+ self.torso_img.append(torso_img)
+ else:
+ self.torso_img.append(torso_img_path)
+
+ # find the corresponding audio to the image frame
+ if not self.opt.asr and self.opt.aud == '':
+ aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame...
+ self.auds.append(aud)
+
+ # load lms and extract face
+ lms = np.loadtxt(os.path.join(self.root_path, 'ori_imgs', str(f['img_id']) + '.lms')) # [68, 2]
+
+ lh_xmin, lh_xmax = int(lms[31:36, 1].min()), int(lms[:, 1].max()) # actually lower half area
+ upface_xmin, upface_xmax = int(lms[:, 1].min()),int(lms[30,1])
+ lowface_xmin, lowface_xmax = int(lms[30,1]), int(lms[:, 1].max())
+ xmin, xmax = int(lms[:, 1].min()), int(lms[:, 1].max())
+ ymin, ymax = int(lms[:, 0].min()), int(lms[:, 0].max())
+ self.face_rect.append([xmin, xmax, ymin, ymax])
+ self.lhalf_rect.append([lh_xmin, lh_xmax, ymin, ymax])
+ self.upface_rect.append([upface_xmin, upface_xmax, ymin, ymax])
+ self.lowface_rect.append([lowface_xmin, lowface_xmax, ymin, ymax])
+
+
+ if self.opt.exp_eye:
+ area = bs[f['img_id']]
+ self.eye_area.append(area)
+
+ xmin, xmax = int(lms[36:48, 1].min()), int(lms[36:48, 1].max())
+ ymin, ymax = int(lms[36:48, 0].min()), int(lms[36:48, 0].max())
+ self.eye_rect.append([xmin, xmax, ymin, ymax])
+
+ if self.opt.finetune_lips:
+ lips = slice(48, 60)
+ xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
+ ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
+
+ # padding to H == W
+ cx = (xmin + xmax) // 2
+ cy = (ymin + ymax) // 2
+
+ l = max(xmax - xmin, ymax - ymin) // 2
+ xmin = max(0, cx - l)
+ xmax = min(self.H, cx + l)
+ ymin = max(0, cy - l)
+ ymax = min(self.W, cy + l)
+
+ self.lips_rect.append([xmin, xmax, ymin, ymax])
+
+ # load pre-extracted background image (should be the same size as training image...)
+
+ if self.opt.bg_img == 'white': # special
+ bg_img = np.ones((self.H, self.W, 3), dtype=np.float32)
+ elif self.opt.bg_img == 'black': # special
+ bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32)
+ else: # load from file
+ # default bg
+ if self.opt.bg_img == '':
+ self.opt.bg_img = os.path.join(self.root_path, 'bc.jpg')
+ bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3]
+ if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W:
+ bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA)
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
+ bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4]
+
+ self.bg_img = bg_img
+
+ self.poses = np.stack(self.poses, axis=0)
+
+ # smooth camera path...
+ if self.opt.smooth_path:
+ self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window)
+
+ self.poses = torch.from_numpy(self.poses) # [N, 4, 4]
+
+ if self.preload > 0:
+ self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C]
+ self.torso_img = torch.from_numpy(np.stack(self.torso_img, axis=0)) # [N, H, W, C]
+ if self.opt.portrait:
+ self.gt_images = torch.from_numpy(np.stack(self.gt_images, axis=0)) # [N, H, W, C]
+ self.face_mask_imgs = torch.from_numpy(np.stack(self.face_mask_imgs, axis=0)) # [N, H, W, C]
+
+ else:
+ self.images = np.array(self.images)
+ self.torso_img = np.array(self.torso_img)
+ if self.opt.portrait:
+ self.gt_images = np.array(self.gt_images)
+ self.face_mask_imgs = np.array(self.face_mask_imgs)
+
+
+ if self.opt.asr:
+ # live streaming, no pre-calculated auds
+ self.auds = None
+ else:
+ # auds corresponding to images
+ if self.opt.aud == '':
+ self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16]
+ # auds is novel, may have a different length with images
+ else:
+ self.auds = aud_features
+
+ self.bg_img = torch.from_numpy(self.bg_img)
+
+ if self.opt.exp_eye:
+ self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N]
+ print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}')
+
+ if self.opt.smooth_eye:
+
+ # naive 5 window average
+ ori_eye = self.eye_area.copy()
+ for i in range(ori_eye.shape[0]):
+ start = max(0, i - 1)
+ end = min(ori_eye.shape[0], i + 2)
+ self.eye_area[i] = ori_eye[start:end].mean()
+ if self.opt.bs_area == "upper":
+ self.eye_area = torch.from_numpy(self.eye_area).view(-1, 7) # [N, 7]
+ elif self.opt.bs_area == "single":
+ self.eye_area = torch.from_numpy(self.eye_area).view(-1, 4) # [N, 7]
+ else:
+ self.eye_area = torch.from_numpy(self.eye_area).view(-1, 2)
+
+
+ # calculate mean radius of all camera poses
+ self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()
+ #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}')
+
+
+ # [debug] uncomment to view all training poses.
+ # visualize_poses(self.poses.numpy())
+
+ # [debug] uncomment to view examples of randomly generated poses.
+ # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy())
+
+ if self.preload > 1:
+ self.poses = self.poses.to(self.device)
+
+ if self.auds is not None:
+ self.auds = self.auds.to(self.device)
+
+ self.bg_img = self.bg_img.to(torch.half).to(self.device)
+
+ self.torso_img = self.torso_img.to(torch.half).to(self.device)
+ self.images = self.images.to(torch.half).to(self.device)
+ if self.opt.portrait:
+ self.gt_images = self.gt_images.to(torch.half).to(self.device)
+ self.face_mask_imgs = self.face_mask_imgs.to(torch.half).to(self.device)
+
+ if self.opt.exp_eye:
+ self.eye_area = self.eye_area.to(self.device)
+
+ # load intrinsics
+ if 'focal_len' in transform:
+ fl_x = fl_y = transform['focal_len']
+ elif 'fl_x' in transform or 'fl_y' in transform:
+ fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale
+ fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale
+ elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
+ # blender, assert in radians. already downscaled since we use H/W
+ fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
+ fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
+ if fl_x is None: fl_x = fl_y
+ if fl_y is None: fl_y = fl_x
+ else:
+ raise RuntimeError('Failed to load focal length, please check the transforms.json!')
+
+ cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2)
+ cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2)
+
+ self.intrinsics = np.array([fl_x, fl_y, cx, cy])
+
+ # directly build the coordinate meshgrid in [-1, 1]^2
+ self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1]
+
+
+ def mirror_index(self, index):
+ size = self.poses.shape[0]
+ turn = index // size
+ res = index % size
+ if turn % 2 == 0:
+ return res
+ else:
+ return size - res - 1
+
+
+ def collate(self, index):
+
+ B = len(index) # a list of length 1
+ # assert B == 1
+
+ results = {}
+
+ # audio use the original index
+ if self.auds is not None:
+ auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device)
+ results['auds'] = auds
+
+ # head pose and bg image may mirror (replay --> <-- --> <--).
+ index[0] = self.mirror_index(index[0])
+
+ poses = self.poses[index].to(self.device) # [B, 4, 4]
+
+ if self.training and self.opt.finetune_lips:
+ rect = self.lips_rect[index[0]]
+ results['rect'] = rect
+ rays = get_rays(poses, self.intrinsics, self.H, self.W, -1, rect=rect)
+ else:
+ rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size)
+ results['up_rect'] = self.upface_rect[index[0]]
+ results['low_rect'] = self.lowface_rect[index[0]]
+ results['index'] = index # for ind. code
+ results['H'] = self.H
+ results['W'] = self.W
+ results['rays_o'] = rays['rays_o']
+ results['rays_d'] = rays['rays_d']
+
+ # get a mask for rays inside rect_face
+ if self.training:
+ xmin, xmax, ymin, ymax = self.face_rect[index[0]]
+ face_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
+ results['face_mask'] = face_mask
+
+ xmin, xmax, ymin, ymax = self.lhalf_rect[index[0]]
+ lhalf_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
+ results['lhalf_mask'] = lhalf_mask
+
+ xmin, xmax, ymin, ymax = self.upface_rect[index[0]]
+ upface_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
+ results['upface_mask'] = upface_mask
+
+ xmin, xmax, ymin, ymax = self.lowface_rect[index[0]]
+ lowface_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
+ results['lowface_mask'] = lowface_mask
+
+
+ if self.opt.exp_eye:
+ results['eye'] = self.eye_area[index].to(self.device) # [1]
+ if self.training:
+ #results['eye'] += (np.random.rand()-0.5) / 10
+ xmin, xmax, ymin, ymax = self.eye_rect[index[0]]
+ eye_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N]
+ results['eye_mask'] = eye_mask
+
+ else:
+ results['eye'] = None
+
+ # load bg
+ bg_torso_img = self.torso_img[index]
+ if self.preload == 0: # on the fly loading
+ bg_torso_img = cv2.imread(bg_torso_img[0], cv2.IMREAD_UNCHANGED) # [H, W, 4]
+ bg_torso_img = cv2.cvtColor(bg_torso_img, cv2.COLOR_BGRA2RGBA)
+ bg_torso_img = bg_torso_img.astype(np.float32) / 255 # [H, W, 3/4]
+ bg_torso_img = torch.from_numpy(bg_torso_img).unsqueeze(0)
+ bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:])
+ bg_torso_img = bg_torso_img.view(B, -1, 3).to(self.device)
+
+ if not self.opt.torso:
+ bg_img = bg_torso_img
+ else:
+ bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device)
+
+ if self.training:
+ bg_img = torch.gather(bg_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3]
+
+ results['bg_color'] = bg_img
+
+ if self.opt.torso and self.training:
+ bg_torso_img = torch.gather(bg_torso_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3]
+ results['bg_torso_color'] = bg_torso_img
+
+ if self.opt.portrait:
+ bg_gt_images = self.gt_images[index]
+ if self.preload == 0:
+ bg_gt_images = cv2.imread(bg_gt_images[0], cv2.IMREAD_UNCHANGED)
+ bg_gt_images = cv2.cvtColor(bg_gt_images, cv2.COLOR_BGR2RGB)
+ bg_gt_images = bg_gt_images.astype(np.float32) / 255
+ bg_gt_images = torch.from_numpy(bg_gt_images).unsqueeze(0)
+ bg_gt_images = bg_gt_images.to(self.device)
+ results['bg_gt_images'] = bg_gt_images
+
+ bg_face_mask = self.face_mask_imgs[index]
+ if self.preload == 0:
+ # bg_face_mask = np.all(cv2.imread(bg_face_mask[0]) == [255, 0, 0], axis=-1).astype(np.uint8)
+ bg_face_mask = (255 - cv2.imread(bg_face_mask[0])[:, :, 1]) / 255.0
+ bg_face_mask = torch.from_numpy(bg_face_mask).unsqueeze(0)
+ bg_face_mask = bg_face_mask.to(self.device)
+ results['bg_face_mask'] = bg_face_mask
+
+
+ images = self.images[index] # [B, H, W, 3/4]
+ if self.preload == 0:
+ images = cv2.imread(images[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
+ images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
+ images = images.astype(np.float32) / 255 # [H, W, 3]
+ images = torch.from_numpy(images).unsqueeze(0)
+ images = images.to(self.device)
+
+ if self.training:
+ C = images.shape[-1]
+ images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
+ results['images'] = images
+
+ if self.training:
+ bg_coords = torch.gather(self.bg_coords, 1, torch.stack(2 * [rays['inds']], -1)) # [1, N, 2]
+ else:
+ bg_coords = self.bg_coords # [1, N, 2]
+
+ results['bg_coords'] = bg_coords
+
+ # results['poses'] = convert_poses(poses) # [B, 6]
+ # results['poses_matrix'] = poses # [B, 4, 4]
+ results['poses'] = poses # [B, 4, 4]
+
+ return results
+
+ def dataloader(self):
+
+ if self.training:
+ # training len(poses) == len(auds)
+ size = self.poses.shape[0]
+ else:
+ # test with novel auds, then use its length
+ if self.auds is not None:
+ size = self.auds.shape[0]
+ # live stream test, use 2 * len(poses), so it naturally mirrors.
+ else:
+ size = 2 * self.poses.shape[0]
+
+ loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
+ loader._data = self # an ugly fix... we need poses in trainer.
+
+ # do evaluate if has gt images and use self-driven setting
+ loader.has_gt = (self.opt.aud == '')
+
+ return loader
\ No newline at end of file
diff --git a/sync/SyncTalk/nerf_triplane/renderer.py b/sync/SyncTalk/nerf_triplane/renderer.py
new file mode 100644
index 00000000..a736791a
--- /dev/null
+++ b/sync/SyncTalk/nerf_triplane/renderer.py
@@ -0,0 +1,698 @@
+import math
+import trimesh
+import numpy as np
+import random
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import raymarching
+from .utils import custom_meshgrid, get_audio_features, euler_angles_to_matrix, convert_poses
+
+def sample_pdf(bins, weights, n_samples, det=False):
+ # This implementation is from NeRF
+ # bins: [B, T], old_z_vals
+ # weights: [B, T - 1], bin weights.
+ # return: [B, n_samples], new_z_vals
+
+ # Get pdf
+ weights = weights + 1e-5 # prevent nans
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
+ cdf = torch.cumsum(pdf, -1)
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
+ # Take uniform samples
+ if det:
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
+ else:
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
+
+ # Invert CDF
+ u = u.contiguous()
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
+ inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
+
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
+
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
+ t = (u - cdf_g[..., 0]) / denom
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
+
+ return samples
+
+
+def plot_pointcloud(pc, color=None):
+ # pc: [N, 3]
+ # color: [N, 3/4]
+ print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
+ pc = trimesh.PointCloud(pc, color)
+ # axis
+ axes = trimesh.creation.axis(axis_length=4)
+ # sphere
+ sphere = trimesh.creation.icosphere(radius=1)
+ trimesh.Scene([pc, axes, sphere]).show()
+
+
+class NeRFRenderer(nn.Module):
+ def __init__(self, opt):
+
+ super().__init__()
+
+ self.opt = opt
+ self.bound = opt.bound
+ self.cascade = 1 + math.ceil(math.log2(opt.bound))
+ self.grid_size = 128
+ self.density_scale = 1
+
+ self.min_near = opt.min_near
+ self.density_thresh = opt.density_thresh
+ self.density_thresh_torso = opt.density_thresh_torso
+
+ self.exp_eye = opt.exp_eye
+ self.test_train = opt.test_train
+ self.smooth_lips = opt.smooth_lips
+
+ self.torso = opt.torso
+ self.cuda_ray = opt.cuda_ray
+
+ # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
+ # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
+ aabb_train = torch.FloatTensor([-opt.bound, -opt.bound/2, -opt.bound, opt.bound, opt.bound/2, opt.bound])
+ aabb_infer = aabb_train.clone()
+ self.register_buffer('aabb_train', aabb_train)
+ self.register_buffer('aabb_infer', aabb_infer)
+
+ # individual codes
+ self.individual_num = opt.ind_num
+
+ self.individual_dim = opt.ind_dim
+ if self.individual_dim > 0:
+ self.individual_codes = nn.Parameter(torch.randn(self.individual_num, self.individual_dim) * 0.1)
+
+ if self.torso:
+ self.individual_dim_torso = opt.ind_dim_torso
+ if self.individual_dim_torso > 0:
+ self.individual_codes_torso = nn.Parameter(torch.randn(self.individual_num, self.individual_dim_torso) * 0.1)
+
+ # optimize camera pose
+ self.train_camera = self.opt.train_camera
+ if self.train_camera:
+ self.camera_dR = nn.Parameter(torch.zeros(self.individual_num, 3)) # euler angle
+ self.camera_dT = nn.Parameter(torch.zeros(self.individual_num, 3)) # xyz offset
+
+ # extra state for cuda raymarching
+
+ # 3D head density grid
+ density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
+ density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
+ self.register_buffer('density_grid', density_grid)
+ self.register_buffer('density_bitfield', density_bitfield)
+ self.mean_density = 0
+ self.iter_density = 0
+
+ # 2D torso density grid
+ if self.torso:
+ density_grid_torso = torch.zeros([self.grid_size ** 2]) # [H * H]
+ self.register_buffer('density_grid_torso', density_grid_torso)
+ self.mean_density_torso = 0
+
+ # step counter
+ step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
+ self.register_buffer('step_counter', step_counter)
+ self.mean_count = 0
+ self.local_step = 0
+
+ # decay for enc_a
+ if self.smooth_lips:
+ self.enc_a = None
+
+ def forward(self, x, d):
+ raise NotImplementedError()
+
+ # separated density and color query (can accelerate non-cuda-ray mode.)
+ def density(self, x):
+ raise NotImplementedError()
+
+ def color(self, x, d, mask=None, **kwargs):
+ raise NotImplementedError()
+
+ def reset_extra_state(self):
+ if not self.cuda_ray:
+ return
+ # density grid
+ self.density_grid.zero_()
+ self.mean_density = 0
+ self.iter_density = 0
+ # step counter
+ self.step_counter.zero_()
+ self.mean_count = 0
+ self.local_step = 0
+
+
+ def run_cuda(self, rays_o, rays_d, auds, bg_coords, poses, eye=None, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
+ # auds: [B, 16]
+ # index: [B]
+ # return: image: [B, N, 3], depth: [B, N]
+
+ prefix = rays_o.shape[:-1]
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+ bg_coords = bg_coords.contiguous().view(-1, 2)
+
+ # only add camera offset at training!
+ if self.train_camera and (self.training or self.test_train):
+ dT = self.camera_dT[index] # [1, 3]
+ dR = euler_angles_to_matrix(self.camera_dR[index] / 180 * np.pi + 1e-8).squeeze(0) # [1, 3] --> [3, 3]
+
+ rays_o = rays_o + dT
+ rays_d = rays_d @ dR
+
+ N = rays_o.shape[0] # N = B * N, in fact
+ device = rays_o.device
+
+ results = {}
+
+ # pre-calculate near far
+ nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
+ nears = nears.detach()
+ fars = fars.detach()
+
+ # encode audio
+ enc_a = self.encode_audio(auds) # [1, 32]
+
+ if enc_a is not None and self.smooth_lips:
+ if self.enc_a is not None:
+ _lambda = 0.35
+ enc_a = _lambda * self.enc_a + (1 - _lambda) * enc_a
+ self.enc_a = enc_a
+
+
+ if self.individual_dim > 0:
+ if self.training:
+ ind_code = self.individual_codes[index]
+ # use a fixed ind code for the unknown test data.
+ else:
+ ind_code = self.individual_codes[0]
+ else:
+ ind_code = None
+
+ if self.training:
+ # setup counter
+ counter = self.step_counter[self.local_step % 16]
+ counter.zero_() # set to 0
+ self.local_step += 1
+
+ xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
+ sigmas, rgbs, amb_aud, amb_eye, uncertainty = self(xyzs, dirs, enc_a, ind_code, eye)
+ sigmas = self.density_scale * sigmas
+
+ #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
+
+ # weights_sum, ambient_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_uncertainty(sigmas, rgbs, ambient.abs().sum(-1), uncertainty, deltas, rays)
+ weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_triplane(sigmas, rgbs, amb_aud.abs().sum(-1), amb_eye.abs().sum(-1), uncertainty, deltas, rays)
+
+ results['weights_sum'] = weights_sum
+ results['ambient_aud'] = amb_aud_sum
+ results['ambient_eye'] = amb_eye_sum
+ results['uncertainty'] = uncertainty_sum
+
+ results['rays'] = xyzs, dirs, enc_a, ind_code, eye
+
+ else:
+
+ dtype = torch.float32
+
+ weights_sum = torch.zeros(N, dtype=dtype, device=device)
+ depth = torch.zeros(N, dtype=dtype, device=device)
+ image = torch.zeros(N, 3, dtype=dtype, device=device)
+ amb_aud_sum = torch.zeros(N, dtype=dtype, device=device)
+ amb_eye_sum = torch.zeros(N, dtype=dtype, device=device)
+ uncertainty_sum = torch.zeros(N, dtype=dtype, device=device)
+
+ n_alive = N
+ rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
+ rays_t = nears.clone() # [N]
+
+ step = 0
+
+ while step < max_steps:
+
+ # count alive rays
+ n_alive = rays_alive.shape[0]
+
+ # exit loop
+ if n_alive <= 0:
+ break
+
+ # decide compact_steps
+ n_step = max(min(N // n_alive, 8), 1)
+
+ xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
+
+ sigmas, rgbs, ambients_aud, ambients_eye, uncertainties = self(xyzs, dirs, enc_a, ind_code, eye)
+ sigmas = self.density_scale * sigmas
+
+ # raymarching.composite_rays_uncertainty(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh)
+ raymarching.composite_rays_triplane(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients_aud, ambients_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh)
+
+ rays_alive = rays_alive[rays_alive >= 0]
+
+ # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
+
+ step += n_step
+
+ torso_results = self.run_torso(rays_o, bg_coords, poses, index, bg_color)
+ bg_color = torso_results['bg_color']
+ image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
+ image = image.view(*prefix, 3)
+ image = image.clamp(0, 1)
+
+ depth = torch.clamp(depth - nears, min=0) / (fars - nears)
+ depth = depth.view(*prefix)
+
+ amb_aud_sum = amb_aud_sum.view(*prefix)
+ amb_eye_sum = amb_eye_sum.view(*prefix)
+
+ results['depth'] = depth
+ results['image'] = image # head_image if train, else com_image
+ results['ambient_aud'] = amb_aud_sum
+ results['ambient_eye'] = amb_eye_sum
+ results['uncertainty'] = uncertainty_sum
+
+ return results
+
+
+ def run_torso(self, rays_o, bg_coords, poses, index=0, bg_color=None, **kwargs):
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
+ # auds: [B, 16]
+ # index: [B]
+ # return: image: [B, N, 3], depth: [B, N]
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ bg_coords = bg_coords.contiguous().view(-1, 2)
+
+ N = rays_o.shape[0] # N = B * N, in fact
+ device = rays_o.device
+
+ results = {}
+
+ # background
+ if bg_color is None:
+ bg_color = 1
+
+ # first mix torso with background
+ if self.torso:
+ # torso ind code
+ if self.individual_dim_torso > 0:
+ if self.training:
+ ind_code_torso = self.individual_codes_torso[index]
+ # use a fixed ind code for the unknown test data.
+ else:
+ ind_code_torso = self.individual_codes_torso[0]
+ else:
+ ind_code_torso = None
+
+ # 2D density grid for acceleration...
+ density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
+ occupancy = F.grid_sample(self.density_grid_torso.view(1, 1, self.grid_size, self.grid_size), bg_coords.view(1, -1, 1, 2), align_corners=True).view(-1)
+ mask = occupancy > density_thresh_torso
+
+ # masked query of torso
+ torso_alpha = torch.zeros([N, 1], device=device)
+ torso_color = torch.zeros([N, 3], device=device)
+
+ if mask.any():
+ torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, ind_code_torso)
+
+ torso_alpha[mask] = torso_alpha_mask.float()
+ torso_color[mask] = torso_color_mask.float()
+
+ results['deform'] = deform
+
+ # first mix torso with background
+
+ bg_color = torso_color * torso_alpha + bg_color * (1 - torso_alpha)
+
+ results['torso_alpha'] = torso_alpha
+ results['torso_color'] = bg_color
+
+ # print(torso_alpha.shape, torso_alpha.max().item(), torso_alpha.min().item())
+
+ results['bg_color'] = bg_color
+
+ return results
+
+
+ @torch.no_grad()
+ def mark_untrained_grid(self, poses, intrinsic, S=64):
+ # poses: [B, 4, 4]
+ # intrinsic: [3, 3]
+
+ if not self.cuda_ray:
+ return
+
+ if isinstance(poses, np.ndarray):
+ poses = torch.from_numpy(poses)
+
+ B = poses.shape[0]
+
+ fx, fy, cx, cy = intrinsic
+
+ X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+
+ count = torch.zeros_like(self.density_grid)
+ poses = poses.to(count.device)
+
+ # 5-level loop, forgive me...
+
+ for xs in X:
+ for ys in Y:
+ for zs in Z:
+
+ # construct points
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
+ coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
+ indices = raymarching.morton3D(coords).long() # [N]
+ world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1]
+
+ # cascading
+ for cas in range(self.cascade):
+ bound = min(2 ** cas, self.bound)
+ half_grid_size = bound / self.grid_size
+ # scale to current cascade's resolution
+ cas_world_xyzs = world_xyzs * (bound - half_grid_size)
+
+ # split batch to avoid OOM
+ head = 0
+ while head < B:
+ tail = min(head + S, B)
+
+ # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
+ cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1)
+ cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
+
+ # query if point is covered by any camera
+ mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
+ mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
+ mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
+ mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]
+
+ # update count
+ count[cas, indices] += mask
+ head += S
+
+ # mark untrained grid as -1
+ self.density_grid[count == 0] = -1
+
+ #print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}')
+
+ @torch.no_grad()
+ def update_extra_state(self, decay=0.95, S=128):
+ # call before each epoch to update extra states.
+
+ if not self.cuda_ray:
+ return
+
+ # use random auds (different expressions should have similar density grid...)
+ rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
+ auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
+
+ # encode audio
+ enc_a = self.encode_audio(auds)
+
+ ### update density grid
+ if not self.torso: # forbid updating head if is training torso...
+
+ tmp_grid = torch.zeros_like(self.density_grid)
+
+ # use a random eye area based on training dataset's statistics...
+ if self.exp_eye:
+ eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
+ else:
+ eye = None
+
+ # full update
+ X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+
+ for xs in X:
+ for ys in Y:
+ for zs in Z:
+
+ # construct points
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
+ coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
+ indices = raymarching.morton3D(coords).long() # [N]
+ xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
+
+ # cascading
+ for cas in range(self.cascade):
+ bound = min(2 ** cas, self.bound)
+ half_grid_size = bound / self.grid_size
+ # scale to current cascade's resolution
+ cas_xyzs = xyzs * (bound - half_grid_size)
+ # add noise in [-hgs, hgs]
+ cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
+ # query density
+ sigmas = self.density(cas_xyzs, enc_a, eye)['sigma'].reshape(-1).detach().to(tmp_grid.dtype)
+ sigmas *= self.density_scale
+ # assign
+ tmp_grid[cas, indices] = sigmas
+
+ # dilate the density_grid (less aggressive culling)
+ tmp_grid = raymarching.morton3D_dilation(tmp_grid)
+
+ # ema update
+ valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
+ self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
+ self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 non-training regions are viewed as 0 density.
+ self.iter_density += 1
+
+ # convert to bitfield
+ density_thresh = min(self.mean_density, self.density_thresh)
+ self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
+
+ ### update torso density grid
+ if self.torso:
+ tmp_grid_torso = torch.zeros_like(self.density_grid_torso)
+
+ # random pose, random ind_code
+ rand_idx = random.randint(0, self.poses.shape[0] - 1)
+ # pose = convert_poses(self.poses[[rand_idx]]).to(self.density_bitfield.device)
+ pose = self.poses[[rand_idx]].to(self.density_bitfield.device)
+
+ if self.opt.ind_dim_torso > 0:
+ ind_code = self.individual_codes_torso[[rand_idx]]
+ else:
+ ind_code = None
+
+ X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+
+ half_grid_size = 1 / self.grid_size
+
+ for xs in X:
+ for ys in Y:
+ xx, yy = custom_meshgrid(xs, ys)
+ coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], dim=-1) # [N, 2], in [0, 128)
+ indices = (coords[:, 1] * self.grid_size + coords[:, 0]).long() # NOTE: xy transposed!
+ xys = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 2] in [-1, 1]
+ xys = xys * (1 - half_grid_size)
+ # add noise in [-hgs, hgs]
+ xys += (torch.rand_like(xys) * 2 - 1) * half_grid_size
+ # query density
+ alphas, _, _ = self.forward_torso(xys, pose, ind_code) # [N, 1]
+
+ # assign
+ tmp_grid_torso[indices] = alphas.squeeze(1).float()
+
+ # dilate
+ tmp_grid_torso = tmp_grid_torso.view(1, 1, self.grid_size, self.grid_size)
+ # tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=3, stride=1, padding=1)
+ tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=5, stride=1, padding=2)
+ tmp_grid_torso = tmp_grid_torso.view(-1)
+
+ self.density_grid_torso = torch.maximum(self.density_grid_torso * decay, tmp_grid_torso)
+ self.mean_density_torso = torch.mean(self.density_grid_torso).item()
+
+ # density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso)
+ # print(f'[density grid torso] min={self.density_grid_torso.min().item():.4f}, max={self.density_grid_torso.max().item():.4f}, mean={self.mean_density_torso:.4f}, occ_rate={(self.density_grid_torso > density_thresh_torso).sum() / (128**2):.3f}')
+
+ ### update step counter
+ total_step = min(16, self.local_step)
+ if total_step > 0:
+ self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
+ self.local_step = 0
+
+ #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
+
+
+ @torch.no_grad()
+ def get_audio_grid(self, S=128):
+ # call before each epoch to update extra states.
+
+ if not self.cuda_ray:
+ return
+
+ # use random auds (different expressions should have similar density grid...)
+ rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
+ auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
+
+ # encode audio
+ enc_a = self.encode_audio(auds)
+ tmp_grid = torch.zeros_like(self.density_grid)
+
+ # use a random eye area based on training dataset's statistics...
+ if self.exp_eye:
+ eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
+ else:
+ eye = None
+
+ # full update
+ X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+
+ for xs in X:
+ for ys in Y:
+ for zs in Z:
+
+ # construct points
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
+ coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
+ indices = raymarching.morton3D(coords).long() # [N]
+ xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
+
+ # cascading
+ for cas in range(self.cascade):
+ bound = min(2 ** cas, self.bound)
+ half_grid_size = bound / self.grid_size
+ # scale to current cascade's resolution
+ cas_xyzs = xyzs * (bound - half_grid_size)
+ # add noise in [-hgs, hgs]
+ cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
+ # query density
+ aud_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_aud'].reshape(-1).detach().to(tmp_grid.dtype)
+ # assign
+ tmp_grid[cas, indices] = aud_norms
+
+ # dilate the density_grid (less aggressive culling)
+ tmp_grid = raymarching.morton3D_dilation(tmp_grid)
+ return tmp_grid
+ # # ema update
+ # valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
+ # self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
+
+
+ @torch.no_grad()
+ def get_eye_grid(self, S=128):
+ # call before each epoch to update extra states.
+
+ if not self.cuda_ray:
+ return
+
+ # use random auds (different expressions should have similar density grid...)
+ rand_idx = random.randint(0, self.aud_features.shape[0] - 1)
+ auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device)
+
+ # encode audio
+ enc_a = self.encode_audio(auds)
+ tmp_grid = torch.zeros_like(self.density_grid)
+
+ # use a random eye area based on training dataset's statistics...
+ if self.exp_eye:
+ eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1]
+ else:
+ eye = None
+
+ # full update
+ X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+ Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
+
+ for xs in X:
+ for ys in Y:
+ for zs in Z:
+
+ # construct points
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
+ coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
+ indices = raymarching.morton3D(coords).long() # [N]
+ xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
+
+ # cascading
+ for cas in range(self.cascade):
+ bound = min(2 ** cas, self.bound)
+ half_grid_size = bound / self.grid_size
+ # scale to current cascade's resolution
+ cas_xyzs = xyzs * (bound - half_grid_size)
+ # add noise in [-hgs, hgs]
+ cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
+ # query density
+ eye_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_eye'].reshape(-1).detach().to(tmp_grid.dtype)
+ # assign
+ tmp_grid[cas, indices] = eye_norms
+
+ # dilate the density_grid (less aggressive culling)
+ tmp_grid = raymarching.morton3D_dilation(tmp_grid)
+ return tmp_grid
+ # # ema update
+ # valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
+ # self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
+
+
+
+ def render(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs):
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
+ # auds: [B, 29, 16]
+ # eye: [B, 1]
+ # bg_coords: [1, N, 2]
+ # return: pred_rgb: [B, N, 3]
+
+ _run = self.run_cuda
+
+ B, N = rays_o.shape[:2]
+ device = rays_o.device
+
+ # never stage when cuda_ray
+ if staged and not self.cuda_ray:
+ # not used
+ raise NotImplementedError
+
+ else:
+ results = _run(rays_o, rays_d, auds, bg_coords, poses, **kwargs)
+
+ return results
+
+
+ def render_torso(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs):
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
+ # auds: [B, 29, 16]
+ # eye: [B, 1]
+ # bg_coords: [1, N, 2]
+ # return: pred_rgb: [B, N, 3]
+
+ _run = self.run_torso
+
+ B, N = rays_o.shape[:2]
+ device = rays_o.device
+
+ # never stage when cuda_ray
+ if staged and not self.cuda_ray:
+ # not used
+ raise NotImplementedError
+
+ else:
+ results = _run(rays_o, bg_coords, poses, **kwargs)
+
+ return results
\ No newline at end of file
diff --git a/sync/SyncTalk/nerf_triplane/utils.py b/sync/SyncTalk/nerf_triplane/utils.py
new file mode 100644
index 00000000..eee07f9b
--- /dev/null
+++ b/sync/SyncTalk/nerf_triplane/utils.py
@@ -0,0 +1,1649 @@
+import os
+import glob
+import tqdm
+import random
+import tensorboardX
+import librosa
+import librosa.filters
+from scipy import signal
+from os.path import basename
+import numpy as np
+import time
+import cv2
+import matplotlib.pyplot as plt
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+
+import trimesh
+import mcubes
+from rich.console import Console
+from torch_ema import ExponentialMovingAverage
+
+from packaging import version as pver
+import imageio
+import lpips
+
+def custom_meshgrid(*args):
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
+ return torch.meshgrid(*args)
+ else:
+ return torch.meshgrid(*args, indexing='ij')
+
+def blend_with_mask_cuda(src, dst, mask):
+ src = src.permute(2, 0, 1)
+ dst = dst.permute(2, 0, 1)
+ mask = mask.unsqueeze(0)
+
+ # Blending
+ blended = src * mask + dst * (1 - mask)
+
+ # Convert back to numpy and return
+ return blended.permute(1, 2, 0).detach().cpu().numpy()
+
+
+def get_audio_features(features, att_mode, index):
+ if att_mode == 0:
+ return features[[index]]
+ elif att_mode == 1:
+ left = index - 8
+ pad_left = 0
+ if left < 0:
+ pad_left = -left
+ left = 0
+ auds = features[left:index]
+ if pad_left > 0:
+ # pad may be longer than auds, so do not use zeros_like
+ auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0)
+ return auds
+ elif att_mode == 2:
+ left = index - 4
+ right = index + 4
+ pad_left = 0
+ pad_right = 0
+ if left < 0:
+ pad_left = -left
+ left = 0
+ if right > features.shape[0]:
+ pad_right = right - features.shape[0]
+ right = features.shape[0]
+ auds = features[left:right]
+ if pad_left > 0:
+ auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0)
+ if pad_right > 0:
+ auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16]
+ return auds
+ else:
+ raise NotImplementedError(f'wrong att_mode: {att_mode}')
+
+
+@torch.jit.script
+def linear_to_srgb(x):
+ return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
+
+
+@torch.jit.script
+def srgb_to_linear(x):
+ return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
+
+# copied from pytorch3d
+def _angle_from_tan(
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
+) -> torch.Tensor:
+ """
+ Extract the first or third Euler angle from the two members of
+ the matrix which are positive constant times its sine and cosine.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
+ convention.
+ data: Rotation matrices as tensor of shape (..., 3, 3).
+ horizontal: Whether we are looking for the angle for the third axis,
+ which means the relevant entries are in the same row of the
+ rotation matrix. If not, they are in the same column.
+ tait_bryan: Whether the first and third axes in the convention differ.
+
+ Returns:
+ Euler Angles in radians for each matrix in data as a tensor
+ of shape (...).
+ """
+
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
+ if horizontal:
+ i2, i1 = i1, i2
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
+ if horizontal == even:
+ return torch.atan2(data[..., i1], data[..., i2])
+ if tait_bryan:
+ return torch.atan2(-data[..., i2], data[..., i1])
+ return torch.atan2(data[..., i2], -data[..., i1])
+
+
+def _index_from_letter(letter: str) -> int:
+ if letter == "X":
+ return 0
+ if letter == "Y":
+ return 1
+ if letter == "Z":
+ return 2
+ raise ValueError("letter must be either X, Y or Z.")
+
+
+def matrix_to_euler_angles(matrix: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to Euler angles in radians.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+ convention: Convention string of three uppercase letters.
+
+ Returns:
+ Euler angles in radians as tensor of shape (..., 3).
+ """
+ # if len(convention) != 3:
+ # raise ValueError("Convention must have 3 letters.")
+ # if convention[1] in (convention[0], convention[2]):
+ # raise ValueError(f"Invalid convention {convention}.")
+ # for letter in convention:
+ # if letter not in ("X", "Y", "Z"):
+ # raise ValueError(f"Invalid letter {letter} in convention string.")
+ # if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ # raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+ i0 = _index_from_letter(convention[0])
+ i2 = _index_from_letter(convention[2])
+ tait_bryan = i0 != i2
+ if tait_bryan:
+ central_angle = torch.asin(
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
+ )
+ else:
+ central_angle = torch.acos(matrix[..., i0, i0])
+
+ o = (
+ _angle_from_tan(
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
+ ),
+ central_angle,
+ _angle_from_tan(
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
+ ),
+ )
+ return torch.stack(o, -1)
+
+@torch.cuda.amp.autocast(enabled=False)
+def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
+ """
+ Return the rotation matrices for one of the rotations about an axis
+ of which Euler angles describe, for each value of the angle given.
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ elif axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+ else:
+ raise ValueError("letter must be either X, Y or Z.")
+
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+@torch.cuda.amp.autocast(enabled=False)
+def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str='XYZ') -> torch.Tensor:
+ """
+ Convert rotations given as Euler angles in radians to rotation matrices.
+ Args:
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
+ convention: Convention string of three uppercase letters from
+ {"X", "Y", and "Z"}.
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+
+ # print(euler_angles, euler_angles.dtype)
+
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
+ raise ValueError("Invalid input euler angles.")
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ matrices = [
+ _axis_angle_rotation(c, e)
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
+ ]
+
+ return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
+
+
+@torch.cuda.amp.autocast(enabled=False)
+def convert_poses(poses):
+ # poses: [B, 4, 4]
+ # return [B, 3], 4 rot, 3 trans
+ out = torch.empty(poses.shape[0], 6, dtype=torch.float32, device=poses.device)
+ out[:, :3] = matrix_to_euler_angles(poses[:, :3, :3])
+ out[:, 3:] = poses[:, :3, 3]
+ return out
+
+@torch.cuda.amp.autocast(enabled=False)
+def get_bg_coords(H, W, device):
+ X = torch.arange(H, device=device) / (H - 1) * 2 - 1 # in [-1, 1]
+ Y = torch.arange(W, device=device) / (W - 1) * 2 - 1 # in [-1, 1]
+ xs, ys = custom_meshgrid(X, Y)
+ bg_coords = torch.cat([xs.reshape(-1, 1), ys.reshape(-1, 1)], dim=-1).unsqueeze(0) # [1, H*W, 2], in [-1, 1]
+ return bg_coords
+
+
+@torch.cuda.amp.autocast(enabled=False)
+def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, rect=None):
+ ''' get rays
+ Args:
+ poses: [B, 4, 4], cam2world
+ intrinsics: [4]
+ H, W, N: int
+ Returns:
+ rays_o, rays_d: [B, N, 3]
+ inds: [B, N]
+ '''
+
+ device = poses.device
+ B = poses.shape[0]
+ fx, fy, cx, cy = intrinsics
+
+ if rect is not None:
+ xmin, xmax, ymin, ymax = rect
+ N = (xmax - xmin) * (ymax - ymin)
+
+ i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float
+ i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
+ j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
+
+ results = {}
+
+ if N > 0:
+ N = min(N, H*W)
+
+ if patch_size > 1:
+
+ # random sample left-top cores.
+ # NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas.
+ num_patch = N // (patch_size ** 2)
+ inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)
+ inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)
+ inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]
+ # all_inds = torch.randperm((H - patch_size + 1) * (W - patch_size + 1), device=device)[:num_patch]
+ # all_inds, _ = torch.sort(all_inds)
+ #
+ # inds_x = all_inds // (W - patch_size)
+ # inds_y = all_inds % (W - patch_size)
+ # inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]
+
+ # create meshgrid for each patch
+ pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device))
+ offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2]
+
+ inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2]
+ inds = inds.view(-1, 2) # [N, 2]
+ inds = inds[:, 0] * W + inds[:, 1] # [N], flatten
+
+ inds = inds.expand([B, N])
+
+ # only get rays in the specified rect
+ elif rect is not None:
+ # assert B == 1
+ mask = torch.zeros(H, W, dtype=torch.bool, device=device)
+ xmin, xmax, ymin, ymax = rect
+ mask[xmin:xmax, ymin:ymax] = 1
+ inds = torch.where(mask.view(-1))[0] # [nzn]
+ inds = inds.unsqueeze(0) # [1, N]
+
+ else:
+ inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
+ inds = inds.expand([B, N])
+
+ # inds = torch.randperm(H * W, device=device)[:N]
+ # inds, _ = torch.sort(inds)
+ # inds = inds.expand([B, N])
+
+ i = torch.gather(i, -1, inds)
+ j = torch.gather(j, -1, inds)
+
+
+ else:
+ inds = torch.arange(H*W, device=device).expand([B, H*W])
+
+ results['i'] = i
+ results['j'] = j
+ results['inds'] = inds
+
+ zs = torch.ones_like(i)
+ xs = (i - cx) / fx * zs
+ ys = (j - cy) / fy * zs
+ directions = torch.stack((xs, ys, zs), dim=-1)
+ directions = directions / torch.norm(directions, dim=-1, keepdim=True)
+
+ rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
+
+ rays_o = poses[..., :3, 3] # [B, 3]
+ rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
+
+ results['rays_o'] = rays_o
+ results['rays_d'] = rays_d
+
+ return results
+
+
+def seed_everything(seed):
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ #torch.backends.cudnn.deterministic = True
+ #torch.backends.cudnn.benchmark = True
+
+
+def torch_vis_2d(x, renormalize=False):
+ # x: [3, H, W] or [1, H, W] or [H, W]
+ import matplotlib.pyplot as plt
+ import numpy as np
+ import torch
+
+ if isinstance(x, torch.Tensor):
+ if len(x.shape) == 3:
+ x = x.permute(1,2,0).squeeze()
+ x = x.detach().cpu().numpy()
+
+ print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
+
+ x = x.astype(np.float32)
+
+ # renormalize
+ if renormalize:
+ x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
+
+ plt.imshow(x)
+ plt.show()
+
+
+def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
+
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)
+
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
+ with torch.no_grad():
+ for xi, xs in enumerate(X):
+ for yi, ys in enumerate(Y):
+ for zi, zs in enumerate(Z):
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
+ val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
+ u[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val
+ return u
+
+
+def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
+ #print('threshold: {}'.format(threshold))
+ u = extract_fields(bound_min, bound_max, resolution, query_func)
+
+ #print(u.shape, u.max(), u.min(), np.percentile(u, 50))
+
+ vertices, triangles = mcubes.marching_cubes(u, threshold)
+
+ b_max_np = bound_max.detach().cpu().numpy()
+ b_min_np = bound_min.detach().cpu().numpy()
+
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
+ return vertices, triangles
+
+def ssim_1d_loss(pred, true, C1=1e-4, C2=9e-4):
+ """
+ Compute 1D SSIM loss between two signals.
+ Args:
+ pred: predicted signal, [1, 512*512, 3]
+ true: ground truth signal, [1, 512*512, 3]
+ Returns:
+ ssim_val: ssim index of two input signals
+ """
+ if pred.size() != true.size():
+ raise ValueError(f'Expected input size ({pred.size()}) to match target size ({true.size()}).')
+
+ mu1 = pred.mean(dim=1, keepdim=True)
+ mu2 = true.mean(dim=1, keepdim=True)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = (pred * pred).mean(dim=1, keepdim=True) - mu1_sq
+ sigma2_sq = (true * true).mean(dim=1, keepdim=True) - mu2_sq
+ sigma12 = (pred * true).mean(dim=1, keepdim=True) - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+ ssim_val = ssim_map.mean()
+
+ return ssim_val
+
+class PSNRMeter:
+ def __init__(self):
+ self.V = 0
+ self.N = 0
+
+ def clear(self):
+ self.V = 0
+ self.N = 0
+
+ def prepare_inputs(self, *inputs):
+ outputs = []
+ for i, inp in enumerate(inputs):
+ if torch.is_tensor(inp):
+ inp = inp.detach().cpu().numpy()
+ outputs.append(inp)
+
+ return outputs
+
+ def update(self, preds, truths):
+ preds, truths = self.prepare_inputs(preds, truths) # [B, N, 3] or [B, H, W, 3], range in [0, 1]
+
+ # simplified since max_pixel_value is 1 here.
+ psnr = -10 * np.log10(np.mean((preds - truths) ** 2))
+
+ self.V += psnr
+ self.N += 1
+
+ def measure(self):
+ return self.V / self.N
+
+ def write(self, writer, global_step, prefix=""):
+ writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step)
+
+ def report(self):
+ return f'PSNR = {self.measure():.6f}'
+
+class LPIPSMeter:
+ def __init__(self, net='alex', device=None):
+ self.V = 0
+ self.N = 0
+ self.net = net
+
+ self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.fn = lpips.LPIPS(net=net).eval().to(self.device)
+
+ def clear(self):
+ self.V = 0
+ self.N = 0
+
+ def prepare_inputs(self, *inputs):
+ outputs = []
+ for i, inp in enumerate(inputs):
+ inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]
+ inp = inp.to(self.device)
+ outputs.append(inp)
+ return outputs
+
+ def update(self, preds, truths):
+ preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1]
+ v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1]
+ self.V += v
+ self.N += 1
+
+ def measure(self):
+ return self.V / self.N
+
+ def write(self, writer, global_step, prefix=""):
+ writer.add_scalar(os.path.join(prefix, f"LPIPS ({self.net})"), self.measure(), global_step)
+
+ def report(self):
+ return f'LPIPS ({self.net}) = {self.measure():.6f}'
+
+
+class LMDMeter:
+ def __init__(self, backend='dlib', region='mouth'):
+ self.backend = backend
+ self.region = region # mouth or face
+
+ if self.backend == 'dlib':
+ import dlib
+
+ # load checkpoint manually
+ self.predictor_path = './shape_predictor_68_face_landmarks.dat'
+ if not os.path.exists(self.predictor_path):
+ raise FileNotFoundError('Please download dlib checkpoint from http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')
+
+ self.detector = dlib.get_frontal_face_detector()
+ self.predictor = dlib.shape_predictor(self.predictor_path)
+
+ else:
+
+ import face_alignment
+ try:
+ self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
+ except:
+ self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
+
+ self.V = 0
+ self.N = 0
+
+ def get_landmarks(self, img):
+
+ if self.backend == 'dlib':
+ dets = self.detector(img, 1)
+ for det in dets:
+ shape = self.predictor(img, det)
+ # ref: https://github.com/PyImageSearch/imutils/blob/c12f15391fcc945d0d644b85194b8c044a392e0a/imutils/face_utils/helpers.py
+ lms = np.zeros((68, 2), dtype=np.int32)
+ for i in range(0, 68):
+ lms[i, 0] = shape.part(i).x
+ lms[i, 1] = shape.part(i).y
+ break
+
+ else:
+ lms = self.predictor.get_landmarks(img)[-1]
+
+ # self.vis_landmarks(img, lms)
+ lms = lms.astype(np.float32)
+
+ return lms
+
+ def vis_landmarks(self, img, lms):
+ plt.imshow(img)
+ plt.plot(lms[48:68, 0], lms[48:68, 1], marker='o', markersize=1, linestyle='-', lw=2)
+ plt.show()
+
+ def clear(self):
+ self.V = 0
+ self.N = 0
+
+ def prepare_inputs(self, *inputs):
+ outputs = []
+ for i, inp in enumerate(inputs):
+ inp = inp.detach().cpu().numpy()
+ inp = (inp * 255).astype(np.uint8)
+ outputs.append(inp)
+ return outputs
+
+ def update(self, preds, truths):
+ # assert B == 1
+ preds, truths = self.prepare_inputs(preds[0], truths[0]) # [H, W, 3] numpy array
+
+ # get lms
+ lms_pred = self.get_landmarks(preds)
+ lms_truth = self.get_landmarks(truths)
+
+ if self.region == 'mouth':
+ lms_pred = lms_pred[48:68]
+ lms_truth = lms_truth[48:68]
+
+ # avarage
+ lms_pred = lms_pred - lms_pred.mean(0)
+ lms_truth = lms_truth - lms_truth.mean(0)
+
+ # distance
+ dist = np.sqrt(((lms_pred - lms_truth) ** 2).sum(1)).mean(0)
+
+ self.V += dist
+ self.N += 1
+
+ def measure(self):
+ return self.V / self.N
+
+ def write(self, writer, global_step, prefix=""):
+ writer.add_scalar(os.path.join(prefix, f"LMD ({self.backend})"), self.measure(), global_step)
+
+ def report(self):
+ return f'LMD ({self.backend}) = {self.measure():.6f}'
+
+
+class Trainer(object):
+ def __init__(self,
+ name, # name of this experiment
+ opt, # extra conf
+ model, # network
+ criterion=None, # loss function, if None, assume inline implementation in train_step
+ optimizer=None, # optimizer
+ ema_decay=None, # if use EMA, set the decay
+ ema_update_interval=1000, # update ema per $ training steps.
+ lr_scheduler=None, # scheduler
+ metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
+ local_rank=0, # which GPU am I
+ world_size=1, # total num of GPUs
+ device=None, # device to use, usually setting to None is OK. (auto choose device)
+ mute=False, # whether to mute all print
+ fp16=False, # amp optimize level
+ eval_interval=1, # eval once every $ epoch
+ max_keep_ckpt=50, # max num of saved ckpts in disk
+ workspace='workspace', # workspace to save logs & ckpts
+ best_mode='min', # the smaller/larger result, the better
+ use_loss_as_metric=True, # use loss as the first metric
+ report_metric_at_train=False, # also report metrics at training
+ use_checkpoint="latest", # which ckpt to use at init time
+ use_tensorboardX=True, # whether to use tensorboard for logging
+ scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
+ ):
+
+ self.name = name
+ self.opt = opt
+ self.mute = mute
+ self.metrics = metrics
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.workspace = workspace
+ self.ema_decay = ema_decay
+ self.ema_update_interval = ema_update_interval
+ self.fp16 = fp16
+ self.best_mode = best_mode
+ self.use_loss_as_metric = use_loss_as_metric
+ self.report_metric_at_train = report_metric_at_train
+ self.max_keep_ckpt = max_keep_ckpt
+ self.eval_interval = eval_interval
+ self.use_checkpoint = use_checkpoint
+ self.use_tensorboardX = use_tensorboardX
+ self.flip_finetune_lips = self.opt.finetune_lips
+ self.flip_init_lips = self.opt.init_lips
+ self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
+ self.scheduler_update_every_step = scheduler_update_every_step
+ self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
+ self.console = Console()
+
+ model.to(self.device)
+ if self.world_size > 1:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
+ self.model = model
+
+ if isinstance(criterion, nn.Module):
+ criterion.to(self.device)
+ self.criterion = criterion
+ self.criterionL1 = nn.L1Loss().to(self.device)
+ if optimizer is None:
+ self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
+ else:
+ self.optimizer = optimizer(self.model)
+
+ if lr_scheduler is None:
+ self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
+ else:
+ self.lr_scheduler = lr_scheduler(self.optimizer)
+
+ if ema_decay is not None:
+ self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
+ else:
+ self.ema = None
+
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
+
+ # optionally use LPIPS loss for patch-based training
+ if self.opt.patch_size > 1 or self.opt.finetune_lips or True:
+ import lpips
+ # self.criterion_lpips_vgg = lpips.LPIPS(net='vgg').to(self.device)
+ self.criterion_lpips_alex = lpips.LPIPS(net='alex').to(self.device)
+
+ # variable init
+ self.epoch = 0
+ self.global_step = 0
+ self.local_step = 0
+ self.stats = {
+ "loss": [],
+ "valid_loss": [],
+ "results": [], # metrics[0], or valid_loss
+ "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
+ "best_result": None,
+ }
+
+ # auto fix
+ if len(metrics) == 0 or self.use_loss_as_metric:
+ self.best_mode = 'min'
+
+ # workspace prepare
+ self.log_ptr = None
+ if self.workspace is not None:
+ os.makedirs(self.workspace, exist_ok=True)
+ self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
+ self.log_ptr = open(self.log_path, "a+")
+
+ self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
+ self.best_path = f"{self.ckpt_path}/{self.name}.pth"
+ os.makedirs(self.ckpt_path, exist_ok=True)
+
+ self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
+ self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
+
+ if self.workspace is not None:
+ if self.use_checkpoint == "scratch":
+ self.log("[INFO] Training from scratch ...")
+ elif self.use_checkpoint == "latest":
+ self.log("[INFO] Loading latest checkpoint ...")
+ self.load_checkpoint()
+ elif self.use_checkpoint == "latest_model":
+ self.log("[INFO] Loading latest checkpoint (model only)...")
+ self.load_checkpoint(model_only=True)
+ elif self.use_checkpoint == "best":
+ if os.path.exists(self.best_path):
+ self.log("[INFO] Loading best checkpoint ...")
+ self.load_checkpoint(self.best_path)
+ else:
+ self.log(f"[INFO] {self.best_path} not found, loading latest ...")
+ self.load_checkpoint()
+ else: # path to ckpt
+ self.log(f"[INFO] Loading {self.use_checkpoint} ...")
+ self.load_checkpoint(self.use_checkpoint)
+
+ def __del__(self):
+ if self.log_ptr:
+ self.log_ptr.close()
+
+
+ def log(self, *args, **kwargs):
+ if self.local_rank == 0:
+ if not self.mute:
+ #print(*args)
+ self.console.print(*args, **kwargs)
+ if self.log_ptr:
+ print(*args, file=self.log_ptr)
+ self.log_ptr.flush() # write immediately to file
+
+ ### ------------------------------
+
+ def train_step(self, data):
+
+ rays_o = data['rays_o'] # [B, N, 3]
+ rays_d = data['rays_d'] # [B, N, 3]
+ bg_coords = data['bg_coords'] # [1, N, 2]
+ poses = data['poses'] # [B, 6]
+ face_mask = data['face_mask'] # [B, N]
+ upface_mask = data['upface_mask'] # [B, N]
+ lowface_mask = data['lowface_mask'] # [B, N]
+ eye_mask = data['eye_mask'] # [B, N]
+ lhalf_mask = data['lhalf_mask']
+ eye = data['eye'] # [B, 1]
+ auds = data['auds'] # [B, 29, 16]
+ index = data['index'] # [B]
+ loss_perception =0
+
+ if not self.opt.torso:
+ rgb = data['images'] # [B, N, 3]
+ else:
+ rgb = data['bg_torso_color']
+
+ B, N, C = rgb.shape
+
+ if self.opt.color_space == 'linear':
+ rgb[..., :3] = srgb_to_linear(rgb[..., :3])
+
+ bg_color = data['bg_color']
+
+ if not self.opt.torso:
+ outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if (self.opt.patch_size <= 1 and not self.opt.train_camera) else True, **vars(self.opt))
+ else:
+ outputs = self.model.render_torso(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if (self.opt.patch_size <= 1 and not self.opt.train_camera) else True, **vars(self.opt))
+
+ if not self.opt.torso:
+ pred_rgb = outputs['image']
+ else:
+ pred_rgb = outputs['torso_color']
+
+
+ # loss factor
+ step_factor = min(self.global_step / self.opt.iters, 1.0)
+ # MSE loss
+ loss = self.criterion(pred_rgb, rgb).mean(-1) # [B, N, 3] --> [B, N]
+
+ if self.opt.torso:
+ loss = loss.mean()
+ loss += ((1 - self.model.anchor_points[:, 3])**2).mean()
+ return pred_rgb, rgb, loss
+
+
+ if self.opt.unc_loss and not self.flip_finetune_lips:
+ alpha = 0.2
+ uncertainty = outputs['uncertainty'] # [N], abs sum
+ beta = uncertainty + 1
+
+ unc_weight = F.softmax(uncertainty, dim=-1) * N
+ loss *= alpha + (1-alpha)*((1 - step_factor) + step_factor * unc_weight.detach()).clamp(0, 10)
+
+ beta = uncertainty + 1
+ norm_rgb = torch.norm((pred_rgb - rgb), dim=-1).detach()
+ loss_u = norm_rgb / (2*beta**2) + (torch.log(beta)**2) / 2
+ loss_u *= face_mask.view(-1)
+
+ loss += 0.01 * step_factor * loss_u
+
+
+ loss_static_uncertainty = (uncertainty * (~face_mask.view(-1)))
+ loss += 0.01 * step_factor * loss_static_uncertainty
+
+ # patch-based rendering
+ if self.opt.patch_size > 1 and not self.opt.finetune_lips:
+ rgb = rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous() * 2 - 1
+ pred_rgb = pred_rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous() * 2 - 1
+
+
+ loss_lpips = self.criterion_lpips_alex(pred_rgb, rgb)
+
+ loss = loss + 0.1 * loss_lpips
+
+ # lips finetune
+ if self.opt.finetune_lips:
+ xmin, xmax, ymin, ymax = data['rect']
+ rgb = rgb.view(-1, xmax - xmin, ymax - ymin, 3).permute(0, 3, 1, 2).contiguous() * 2 - 1
+ pred_rgb = pred_rgb.view(-1, xmax - xmin, ymax - ymin, 3).permute(0, 3, 1, 2).contiguous() * 2 - 1
+
+ padding_h = max(0, (32 - rgb.shape[-2] + 1) // 2)
+ padding_w = max(0, (32 - rgb.shape[-1] + 1) // 2)
+
+ if padding_w or padding_h:
+ rgb = torch.nn.functional.pad(rgb, (padding_w, padding_w, padding_h, padding_h))
+ pred_rgb = torch.nn.functional.pad(pred_rgb, (padding_w, padding_w, padding_h, padding_h))
+
+ loss = loss + 0.01 * self.criterion_lpips_alex(pred_rgb, rgb)
+ # flip every step... if finetune lips
+ if self.flip_finetune_lips:
+ self.opt.finetune_lips = not self.opt.finetune_lips
+
+
+ loss = loss.mean()
+
+ if self.opt.patch_size > 1 and not self.opt.finetune_lips:
+ if self.opt.pyramid_loss:
+ loss = loss + 0.1 * loss_perception
+ # print('loss', loss.item())
+
+ # weights_sum loss
+ # entropy to encourage weights_sum to be 0 or 1.
+ if self.opt.torso:
+ alphas = outputs['torso_alpha'].clamp(1e-5, 1 - 1e-5)
+ # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
+ loss_ws = - alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)
+ loss = loss + 1e-4 * loss_ws.mean()
+
+ else:
+ alphas = outputs['weights_sum'].clamp(1e-5, 1 - 1e-5)
+ loss_ws = - alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)
+ loss = loss + 1e-4 * loss_ws.mean()
+
+ # aud att loss (regions out of face should be static)
+ if self.opt.amb_aud_loss and not self.opt.torso:
+ ambient_aud = outputs['ambient_aud']
+ loss_amb_aud = (ambient_aud * (~lowface_mask.view(-1))).mean()
+ # gradually increase it
+ lambda_amb = step_factor * self.opt.lambda_amb
+ loss += lambda_amb * loss_amb_aud
+
+ # eye att loss
+ if self.opt.amb_eye_loss and not self.opt.torso:
+ ambient_eye = outputs['ambient_eye']
+ loss_cross = ((ambient_eye)*(lowface_mask.view(-1))).mean()
+ lambda_amb = step_factor * self.opt.lambda_amb
+ loss += lambda_amb * loss_cross
+
+ # regularize
+ if self.global_step % 16 == 0 and not self.flip_finetune_lips:
+ xyzs, dirs, enc_a, ind_code, eye = outputs['rays']
+ xyz_delta = (torch.rand(size=xyzs.shape, dtype=xyzs.dtype, device=xyzs.device) * 2 - 1) * 1e-3
+ with torch.no_grad():
+ sigmas_raw, rgbs_raw, ambient_aud_raw, ambient_eye_raw, unc_raw = self.model(xyzs, dirs, enc_a.detach(), ind_code.detach(), eye)
+ sigmas_reg, rgbs_reg, ambient_aud_reg, ambient_eye_reg, unc_reg = self.model(xyzs+xyz_delta, dirs, enc_a.detach(), ind_code.detach(), eye)
+
+ lambda_reg = step_factor * 1e-5
+ reg_loss = 0
+ if self.opt.unc_loss:
+ reg_loss += self.criterion(unc_raw, unc_reg).mean()
+ if self.opt.amb_aud_loss:
+ reg_loss += self.criterion(ambient_aud_raw, ambient_aud_reg).mean()
+ if self.opt.amb_eye_loss:
+ reg_loss += self.criterion(ambient_eye_raw, ambient_eye_reg).mean()
+
+ loss += reg_loss * lambda_reg
+
+ return pred_rgb, rgb, loss
+
+
+ def eval_step(self, data):
+
+ rays_o = data['rays_o'] # [B, N, 3]
+ rays_d = data['rays_d'] # [B, N, 3]
+ bg_coords = data['bg_coords'] # [1, N, 2]
+ poses = data['poses'] # [B, 7]
+
+ images = data['images'] # [B, H, W, 3/4]
+ if self.opt.portrait:
+ images = data['bg_gt_images']
+ auds = data['auds']
+ index = data['index'] # [B]
+ eye = data['eye'] # [B, 1]
+
+ B, H, W, C = images.shape
+
+ if self.opt.color_space == 'linear':
+ images[..., :3] = srgb_to_linear(images[..., :3])
+
+ # eval with fixed background color
+ # bg_color = 1
+ bg_color = data['bg_color']
+
+ outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt))
+
+ pred_rgb = outputs['image'].reshape(B, H, W, 3)
+ pred_depth = outputs['depth'].reshape(B, H, W)
+ pred_ambient_aud = outputs['ambient_aud'].reshape(B, H, W)
+ pred_ambient_eye = outputs['ambient_eye'].reshape(B, H, W)
+ pred_uncertainty = outputs['uncertainty'].reshape(B, H, W)
+
+ loss_raw = self.criterion(pred_rgb, images)
+ loss = loss_raw.mean()
+
+ return pred_rgb, pred_depth, pred_ambient_aud, pred_ambient_eye, pred_uncertainty, images, loss, loss_raw
+
+ # moved out bg_color and perturb for more flexible control...
+ def test_step(self, data, bg_color=None, perturb=False):
+
+ rays_o = data['rays_o'] # [B, N, 3]
+ rays_d = data['rays_d'] # [B, N, 3]
+ bg_coords = data['bg_coords'] # [1, N, 2]
+ poses = data['poses'] # [B, 7]
+
+ auds = data['auds'] # [B, 29, 16]
+ index = data['index']
+ H, W = data['H'], data['W']
+
+ # allow using a fixed eye area (avoid eye blink) at test
+ if self.opt.exp_eye and self.opt.fix_eye >= 0:
+ eye = torch.FloatTensor([self.opt.fix_eye]).view(1, 1).to(self.device)
+ else:
+ eye = data['eye'] # [B, 1]
+
+ if bg_color is not None:
+ bg_color = bg_color.to(self.device)
+ else:
+ bg_color = data['bg_color']
+
+ self.model.testing = True
+ outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt))
+ self.model.testing = False
+
+ pred_rgb = outputs['image'].reshape(-1, H, W, 3)
+ pred_depth = outputs['depth'].reshape(-1, H, W)
+
+ return pred_rgb, pred_depth
+
+
+ def save_mesh(self, save_path=None, resolution=256, threshold=10):
+
+ if save_path is None:
+ save_path = os.path.join(self.workspace, 'meshes', f'{self.name}_{self.epoch}.ply')
+
+ self.log(f"==> Saving mesh to {save_path}")
+
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ def query_func(pts):
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ sigma = self.model.density(pts.to(self.device))['sigma']
+ return sigma
+
+ vertices, triangles = extract_geometry(self.model.aabb_infer[:3], self.model.aabb_infer[3:], resolution=resolution, threshold=threshold, query_func=query_func)
+
+ mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
+ mesh.export(save_path)
+
+ self.log(f"==> Finished saving mesh.")
+
+ ### ------------------------------
+
+ def train(self, train_loader, valid_loader, max_epochs):
+ if self.use_tensorboardX and self.local_rank == 0:
+ self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))
+
+ # mark untrained region (i.e., not covered by any camera from the training dataset)
+ if self.model.cuda_ray:
+ self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics)
+
+ for epoch in range(self.epoch + 1, max_epochs + 1):
+ self.epoch = epoch
+
+ self.train_one_epoch(train_loader)
+
+ if self.workspace is not None and self.local_rank == 0:
+ self.save_checkpoint(full=True, best=False)
+
+ if self.epoch % self.eval_interval == 0:
+ self.evaluate_one_epoch(valid_loader)
+ self.save_checkpoint(full=False, best=True)
+
+ if self.use_tensorboardX and self.local_rank == 0:
+ self.writer.close()
+
+ def evaluate(self, loader, name=None):
+ self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
+ self.evaluate_one_epoch(loader, name)
+ self.use_tensorboardX = use_tensorboardX
+
+ # Function to blend two images with a mask
+
+ def test(self, loader, save_path=None, name=None, write_image=False):
+
+ if save_path is None:
+ save_path = os.path.join(self.workspace, 'results')
+
+ if name is None:
+ name = f'{self.name}_ep{self.epoch:04d}'
+
+ os.makedirs(save_path, exist_ok=True)
+
+ self.log(f"==> Start Test, save results to {save_path}")
+
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
+ self.model.eval()
+
+ all_preds = []
+ all_preds_depth = []
+
+ with torch.no_grad():
+
+ for i, data in enumerate(loader):
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ preds, preds_depth = self.test_step(data)
+
+ path = os.path.join(save_path, f'{name}_{i:04d}_rgb.png')
+ path_depth = os.path.join(save_path, f'{name}_{i:04d}_depth.png')
+
+ #self.log(f"[INFO] saving test image to {path}")
+
+ if self.opt.color_space == 'linear':
+ preds = linear_to_srgb(preds)
+ if self.opt.portrait:
+ pred = blend_with_mask_cuda(preds[0], data["bg_gt_images"].squeeze(0), data["bg_face_mask"].squeeze(0))
+ pred = (pred * 255).astype(np.uint8)
+ else:
+ pred = preds[0].detach().cpu().numpy()
+ pred = (pred * 255).astype(np.uint8)
+
+ pred_depth = preds_depth[0].detach().cpu().numpy()
+ pred_depth = (pred_depth * 255).astype(np.uint8)
+
+ if write_image:
+ imageio.imwrite(path, pred)
+ imageio.imwrite(path_depth, pred_depth)
+
+ all_preds.append(pred)
+ all_preds_depth.append(pred_depth)
+
+ pbar.update(loader.batch_size)
+
+ # write video
+ all_preds = np.stack(all_preds, axis=0)
+ all_preds_depth = np.stack(all_preds_depth, axis=0)
+ imageio.mimwrite(os.path.join(save_path, f'{name}.mp4'), all_preds, fps=25, quality=8, macro_block_size=1)
+ imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1)
+ if self.opt.aud != '':
+ os.system(f'ffmpeg -i {os.path.join(save_path, f"{name}.mp4")} -i {self.opt.aud} -strict -2 {os.path.join(save_path, f"{name}_audio.mp4")} -y')
+
+ self.log(f"==> Finished Test.")
+
+ # [GUI] just train for 16 steps, without any other overhead that may slow down rendering.
+ def train_gui(self, train_loader, step=16):
+
+ self.model.train()
+
+ total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
+
+ loader = iter(train_loader)
+
+ # mark untrained grid
+ if self.global_step == 0:
+ self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics)
+
+ for _ in range(step):
+
+ # mimic an infinite loop dataloader (in case the total dataset is smaller than step)
+ try:
+ data = next(loader)
+ except StopIteration:
+ loader = iter(train_loader)
+ data = next(loader)
+
+ # update grid every 16 steps
+ if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ self.model.update_extra_state()
+
+ self.global_step += 1
+
+ self.optimizer.zero_grad()
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ preds, truths, loss = self.train_step(data)
+
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+
+ if self.scheduler_update_every_step:
+ self.lr_scheduler.step()
+
+ total_loss += loss.detach()
+
+ if self.ema is not None and self.global_step % self.ema_update_interval == 0:
+ self.ema.update()
+
+ average_loss = total_loss.item() / step
+
+ if not self.scheduler_update_every_step:
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step(average_loss)
+ else:
+ self.lr_scheduler.step()
+
+ outputs = {
+ 'loss': average_loss,
+ 'lr': self.optimizer.param_groups[0]['lr'],
+ }
+
+ return outputs
+
+ # [GUI] test on a single image
+ def test_gui(self, pose, intrinsics, W, H, auds, eye=None, index=0, bg_color=None, spp=1, downscale=1):
+
+ # render resolution (may need downscale to for better frame rate)
+ rH = int(H * downscale)
+ rW = int(W * downscale)
+ intrinsics = intrinsics * downscale
+
+ if auds is not None:
+ auds = auds.to(self.device)
+
+ pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
+ rays = get_rays(pose, intrinsics, rH, rW, -1)
+
+ bg_coords = get_bg_coords(rH, rW, self.device)
+
+ if eye is not None:
+ eye = torch.FloatTensor([eye]).view(1, 1).to(self.device)
+
+ data = {
+ 'rays_o': rays['rays_o'],
+ 'rays_d': rays['rays_d'],
+ 'H': rH,
+ 'W': rW,
+ 'auds': auds,
+ 'index': [index], # support choosing index for individual codes
+ 'eye': eye,
+ 'poses': pose,
+ 'bg_coords': bg_coords,
+ }
+
+ self.model.eval()
+
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ # here spp is used as perturb random seed!
+ # face: do not perturb for the first spp, else lead to scatters.
+ preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=False if spp == 1 else spp)
+
+ if self.ema is not None:
+ self.ema.restore()
+
+ # interpolation to the original resolution
+ if downscale != 1:
+ # TODO: have to permute twice with torch...
+ preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='bilinear').permute(0, 2, 3, 1).contiguous()
+ preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
+
+ if self.opt.color_space == 'linear':
+ preds = linear_to_srgb(preds)
+
+ pred = preds[0].detach().cpu().numpy()
+ pred_depth = preds_depth[0].detach().cpu().numpy()
+
+ outputs = {
+ 'image': pred,
+ 'depth': pred_depth,
+ }
+
+ return outputs
+
+ # [GUI] test with provided data
+ def test_gui_with_data(self, data, W, H):
+
+ self.model.eval()
+
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ # here spp is used as perturb random seed!
+ # face: do not perturb for the first spp, else lead to scatters.
+ preds, preds_depth = self.test_step(data, perturb=False)
+
+ if self.ema is not None:
+ self.ema.restore()
+
+ if self.opt.color_space == 'linear':
+ preds = linear_to_srgb(preds)
+
+ # the H/W in data may be differnt to GUI, so we still need to resize...
+ preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='bilinear').permute(0, 2, 3, 1).contiguous()
+ preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
+
+ pred = preds[0].detach().cpu().numpy()
+ pred_depth = preds_depth[0].detach().cpu().numpy()
+
+ outputs = {
+ 'image': pred,
+ 'depth': pred_depth,
+ }
+
+ return outputs
+
+ def train_one_epoch(self, loader):
+ self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
+
+ total_loss = 0
+ if self.local_rank == 0 and self.report_metric_at_train:
+ for metric in self.metrics:
+ metric.clear()
+
+ self.model.train()
+
+ # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
+ # ref: https://pytorch.org/docs/stable/data.html
+ if self.world_size > 1:
+ loader.sampler.set_epoch(self.epoch)
+
+ if self.local_rank == 0:
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, mininterval=1, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
+
+ self.local_step = 0
+
+ for data in loader:
+ # update grid every 16 steps
+ if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ self.model.update_extra_state()
+
+ self.local_step += 1
+ self.global_step += 1
+
+ self.optimizer.zero_grad()
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ preds, truths, loss = self.train_step(data)
+
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+
+ if self.scheduler_update_every_step:
+ self.lr_scheduler.step()
+
+ loss_val = loss.item()
+ total_loss += loss_val
+
+ if self.ema is not None and self.global_step % self.ema_update_interval == 0:
+ self.ema.update()
+
+ if self.local_rank == 0:
+ if self.report_metric_at_train:
+ for metric in self.metrics:
+ metric.update(preds, truths)
+
+ if self.use_tensorboardX:
+ self.writer.add_scalar("train/loss", loss_val, self.global_step)
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
+
+ if self.scheduler_update_every_step:
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}")
+ else:
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
+ pbar.update(loader.batch_size)
+
+ average_loss = total_loss / self.local_step
+ self.stats["loss"].append(average_loss)
+
+ if self.local_rank == 0:
+ pbar.close()
+ if self.report_metric_at_train:
+ for metric in self.metrics:
+ self.log(metric.report(), style="red")
+ if self.use_tensorboardX:
+ metric.write(self.writer, self.epoch, prefix="train")
+ metric.clear()
+
+ if not self.scheduler_update_every_step:
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step(average_loss)
+ else:
+ self.lr_scheduler.step()
+ self.log(f"loss={average_loss:.4f}")
+ self.log(f"==> Finished Epoch {self.epoch}.")
+
+
+ def evaluate_one_epoch(self, loader, name=None):
+ self.log(f"++> Evaluate at epoch {self.epoch} ...")
+
+ if name is None:
+ name = f'{self.name}_ep{self.epoch:04d}'
+
+ total_loss = 0
+ if self.local_rank == 0:
+ for metric in self.metrics:
+ metric.clear()
+
+ self.model.eval()
+
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ if self.local_rank == 0:
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
+
+ with torch.no_grad():
+ self.local_step = 0
+
+ for data in loader:
+ self.local_step += 1
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ preds, preds_depth, pred_ambient_aud, pred_ambient_eye, pred_uncertainty, truths, loss, loss_raw = self.eval_step(data)
+ loss_val = loss.item()
+ total_loss += loss_val
+
+ # only rank = 0 will perform evaluation.
+ if self.local_rank == 0:
+
+ # save image
+ save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
+ save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png')
+ save_path_ambient_aud = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_aud.png')
+ save_path_ambient_eye = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_eye.png')
+ save_path_uncertainty = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_uncertainty.png')
+
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ if self.opt.color_space == 'linear':
+ preds = linear_to_srgb(preds)
+
+ if self.opt.portrait:
+ pred = blend_with_mask_cuda(preds[0], data["bg_gt_images"].squeeze(0),data["bg_face_mask"].squeeze(0))
+ preds = torch.from_numpy(pred).unsqueeze(0).to(self.device).float()
+ else:
+ pred = preds[0].detach().cpu().numpy()
+ pred_depth = preds_depth[0].detach().cpu().numpy()
+
+ for metric in self.metrics:
+ metric.update(preds, truths)
+ # loss_raw = loss_raw[0].mean(-1).detach().cpu().numpy()
+ # loss_raw = (loss_raw - np.min(loss_raw)) / (np.max(loss_raw) - np.min(loss_raw))
+ pred_ambient_aud = pred_ambient_aud[0].detach().cpu().numpy()
+ pred_ambient_aud /= np.max(pred_ambient_aud)
+ pred_ambient_eye = pred_ambient_eye[0].detach().cpu().numpy()
+ pred_ambient_eye /= np.max(pred_ambient_eye)
+ # pred_ambient = pred_ambient / 16
+ # print(pred_ambient.shape)
+ pred_uncertainty = pred_uncertainty[0].detach().cpu().numpy()
+ # pred_uncertainty = (pred_uncertainty - np.min(pred_uncertainty)) / (np.max(pred_uncertainty) - np.min(pred_uncertainty))
+ pred_uncertainty /= np.max(pred_uncertainty)
+
+ cv2.imwrite(save_path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
+
+ if not self.opt.torso:
+ cv2.imwrite(save_path_depth, (pred_depth * 255).astype(np.uint8))
+ # cv2.imwrite(save_path_error, (loss_raw * 255).astype(np.uint8))
+ cv2.imwrite(save_path_ambient_aud, (pred_ambient_aud * 255).astype(np.uint8))
+ cv2.imwrite(save_path_ambient_eye, (pred_ambient_eye * 255).astype(np.uint8))
+ cv2.imwrite(save_path_uncertainty, (pred_uncertainty * 255).astype(np.uint8))
+ #cv2.imwrite(save_path_gt, cv2.cvtColor((linear_to_srgb(truths[0].detach().cpu().numpy()) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
+
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
+ pbar.update(loader.batch_size)
+
+
+ average_loss = total_loss / self.local_step
+ self.stats["valid_loss"].append(average_loss)
+
+ if self.local_rank == 0:
+ pbar.close()
+ if not self.use_loss_as_metric and len(self.metrics) > 0:
+ result = self.metrics[0].measure()
+ self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
+ else:
+ self.stats["results"].append(average_loss) # if no metric, choose best by min loss
+
+ for metric in self.metrics:
+ self.log(metric.report(), style="blue")
+ if self.use_tensorboardX:
+ metric.write(self.writer, self.epoch, prefix="evaluate")
+ metric.clear()
+
+ if self.ema is not None:
+ self.ema.restore()
+
+ self.log(f"++> Evaluate epoch {self.epoch} Finished.")
+
+ def save_checkpoint(self, name=None, full=False, best=False, remove_old=True):
+
+ if name is None:
+ name = f'{self.name}_ep{self.epoch:04d}'
+
+ state = {
+ 'epoch': self.epoch,
+ 'global_step': self.global_step,
+ 'stats': self.stats,
+ }
+
+
+ state['mean_count'] = self.model.mean_count
+ state['mean_density'] = self.model.mean_density
+ state['mean_density_torso'] = self.model.mean_density_torso
+
+ if full:
+ state['optimizer'] = self.optimizer.state_dict()
+ state['lr_scheduler'] = self.lr_scheduler.state_dict()
+ state['scaler'] = self.scaler.state_dict()
+ if self.ema is not None:
+ state['ema'] = self.ema.state_dict()
+
+ if not best:
+
+ state['model'] = self.model.state_dict()
+
+ file_path = f"{self.ckpt_path}/{name}.pth"
+
+ if remove_old:
+ self.stats["checkpoints"].append(file_path)
+
+ if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
+ old_ckpt = self.stats["checkpoints"].pop(0)
+ if os.path.exists(old_ckpt):
+ os.remove(old_ckpt)
+
+
+ torch.save(state, file_path)
+
+ else:
+ if len(self.stats["results"]) > 0:
+ # always save new as best... (since metric cannot really reflect performance...)
+ if True:
+
+ # save ema results
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ state['model'] = self.model.state_dict()
+
+ # we don't consider continued training from the best ckpt, so we discard the unneeded density_grid to save some storage (especially important for dnerf)
+ if 'density_grid' in state['model']:
+ del state['model']['density_grid']
+
+ if self.ema is not None:
+ self.ema.restore()
+
+ torch.save(state, self.best_path)
+ else:
+ self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")
+
+ def load_checkpoint(self, checkpoint=None, model_only=False):
+ if checkpoint is None:
+ checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}_ep*.pth'))
+ if checkpoint_list:
+ checkpoint = checkpoint_list[-1]
+ self.log(f"[INFO] Latest checkpoint is {checkpoint}")
+ else:
+ self.log("[WARN] No checkpoint found, model randomly initialized.")
+ return
+
+ checkpoint_dict = torch.load(checkpoint, map_location=self.device)
+
+ if 'model' not in checkpoint_dict:
+ self.model.load_state_dict(checkpoint_dict)
+ self.log("[INFO] loaded bare model.")
+ return
+
+ missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
+ self.log("[INFO] loaded model.")
+ if len(missing_keys) > 0:
+ self.log(f"[WARN] missing keys: {missing_keys}")
+ if len(unexpected_keys) > 0:
+ self.log(f"[WARN] unexpected keys: {unexpected_keys}")
+
+ if self.ema is not None and 'ema' in checkpoint_dict:
+ self.ema.load_state_dict(checkpoint_dict['ema'])
+
+
+ if 'mean_count' in checkpoint_dict:
+ self.model.mean_count = checkpoint_dict['mean_count']
+ if 'mean_density' in checkpoint_dict:
+ self.model.mean_density = checkpoint_dict['mean_density']
+ if 'mean_density_torso' in checkpoint_dict:
+ self.model.mean_density_torso = checkpoint_dict['mean_density_torso']
+
+ if model_only:
+ return
+
+ self.stats = checkpoint_dict['stats']
+ self.epoch = checkpoint_dict['epoch']
+ self.global_step = checkpoint_dict['global_step']
+ self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
+
+ if self.optimizer and 'optimizer' in checkpoint_dict:
+ try:
+ self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
+ self.log("[INFO] loaded optimizer.")
+ except:
+ self.log("[WARN] Failed to load optimizer.")
+
+ if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
+ try:
+ self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
+ self.log("[INFO] loaded scheduler.")
+ except:
+ self.log("[WARN] Failed to load scheduler.")
+
+ if self.scaler and 'scaler' in checkpoint_dict:
+ try:
+ self.scaler.load_state_dict(checkpoint_dict['scaler'])
+ self.log("[INFO] loaded scaler.")
+ except:
+ self.log("[WARN] Failed to load scaler.")
+
+
+def load_wav(path, sr):
+ return librosa.core.load(path, sr=sr)[0]
+
+
+def preemphasis(wav, k):
+ return signal.lfilter([1, -k], [1], wav)
+
+
+def melspectrogram(wav):
+ D = _stft(preemphasis(wav, 0.97))
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - 20
+
+ return _normalize(S)
+
+
+def _stft(y):
+ return librosa.stft(y=y, n_fft=800, hop_length=200, win_length=800)
+
+
+def _linear_to_mel(spectogram):
+ global _mel_basis
+ _mel_basis = _build_mel_basis()
+ return np.dot(_mel_basis, spectogram)
+
+
+def _build_mel_basis():
+ return librosa.filters.mel(sr=16000, n_fft=800, n_mels=80, fmin=55, fmax=7600)
+
+
+def _amp_to_db(x):
+ min_level = np.exp(-5 * np.log(10))
+ return 20 * np.log10(np.maximum(min_level, x))
+
+
+def _normalize(S):
+ return np.clip((2 * 4.) * ((S - -100) / (--100)) - 4., -4., 4.)
+
+
+class AudDataset(object):
+ def __init__(self, wavpath):
+ wav = load_wav(wavpath, 16000)
+
+ self.orig_mel = melspectrogram(wav).T
+ self.data_len = int((self.orig_mel.shape[0] - 16) / 80. * float(25))
+
+ def get_frame_id(self, frame):
+ return int(basename(frame).split('.')[0])
+
+ def crop_audio_window(self, spec, start_frame):
+ if type(start_frame) == int:
+ start_frame_num = start_frame
+ else:
+ start_frame_num = self.get_frame_id(start_frame)
+ start_idx = int(80. * (start_frame_num / float(25)))
+
+ end_idx = start_idx + 16
+
+ return spec[start_idx: end_idx, :]
+
+ def __len__(self):
+ return self.data_len
+
+ def __getitem__(self, idx):
+
+ mel = self.crop_audio_window(self.orig_mel.copy(), idx)
+ if (mel.shape[0] != 16):
+ raise Exception('mel.shape[0] != 16')
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
+
+ return mel
diff --git a/sync/SyncTalk/raymarching/__init__.py b/sync/SyncTalk/raymarching/__init__.py
new file mode 100644
index 00000000..26d3cc6d
--- /dev/null
+++ b/sync/SyncTalk/raymarching/__init__.py
@@ -0,0 +1 @@
+from .raymarching import *
\ No newline at end of file
diff --git a/sync/SyncTalk/raymarching/__pycache__/__init__.cpython-38.pyc b/sync/SyncTalk/raymarching/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 00000000..02222d07
Binary files /dev/null and b/sync/SyncTalk/raymarching/__pycache__/__init__.cpython-38.pyc differ
diff --git a/sync/SyncTalk/raymarching/__pycache__/__init__.cpython-39.pyc b/sync/SyncTalk/raymarching/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 00000000..bf5672f4
Binary files /dev/null and b/sync/SyncTalk/raymarching/__pycache__/__init__.cpython-39.pyc differ
diff --git a/sync/SyncTalk/raymarching/__pycache__/raymarching.cpython-38.pyc b/sync/SyncTalk/raymarching/__pycache__/raymarching.cpython-38.pyc
new file mode 100644
index 00000000..442955de
Binary files /dev/null and b/sync/SyncTalk/raymarching/__pycache__/raymarching.cpython-38.pyc differ
diff --git a/sync/SyncTalk/raymarching/__pycache__/raymarching.cpython-39.pyc b/sync/SyncTalk/raymarching/__pycache__/raymarching.cpython-39.pyc
new file mode 100644
index 00000000..128a391f
Binary files /dev/null and b/sync/SyncTalk/raymarching/__pycache__/raymarching.cpython-39.pyc differ
diff --git a/sync/SyncTalk/raymarching/backend.py b/sync/SyncTalk/raymarching/backend.py
new file mode 100644
index 00000000..d8f65d6f
--- /dev/null
+++ b/sync/SyncTalk/raymarching/backend.py
@@ -0,0 +1,40 @@
+import os
+from torch.utils.cpp_extension import load
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+_backend = load(name='_raymarching_face',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'raymarching.cu',
+ 'bindings.cpp',
+ ]],
+ )
+
+__all__ = ['_backend']
\ No newline at end of file
diff --git a/sync/SyncTalk/raymarching/raymarching.py b/sync/SyncTalk/raymarching/raymarching.py
new file mode 100644
index 00000000..8acc894c
--- /dev/null
+++ b/sync/SyncTalk/raymarching/raymarching.py
@@ -0,0 +1,671 @@
+import numpy as np
+import time
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+try:
+ import _raymarching_face as _backend
+except ImportError:
+ from .backend import _backend
+
+# ----------------------------------------
+# utils
+# ----------------------------------------
+
+class _near_far_from_aabb(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
+ ''' near_far_from_aabb, CUDA implementation
+ Calculate rays' intersection time (near and far) with aabb
+ Args:
+ rays_o: float, [N, 3]
+ rays_d: float, [N, 3]
+ aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
+ min_near: float, scalar
+ Returns:
+ nears: float, [N]
+ fars: float, [N]
+ '''
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ N = rays_o.shape[0] # num rays
+
+ nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
+ fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
+
+ _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
+
+ return nears, fars
+
+near_far_from_aabb = _near_far_from_aabb.apply
+
+
+class _sph_from_ray(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, rays_o, rays_d, radius):
+ ''' sph_from_ray, CUDA implementation
+ get spherical coordinate on the background sphere from rays.
+ Assume rays_o are inside the Sphere(radius).
+ Args:
+ rays_o: [N, 3]
+ rays_d: [N, 3]
+ radius: scalar, float
+ Return:
+ coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
+ '''
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ N = rays_o.shape[0] # num rays
+
+ coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
+
+ _backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
+
+ return coords
+
+sph_from_ray = _sph_from_ray.apply
+
+
+class _morton3D(Function):
+ @staticmethod
+ def forward(ctx, coords):
+ ''' morton3D, CUDA implementation
+ Args:
+ coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
+ TODO: check if the coord range is valid! (current 128 is safe)
+ Returns:
+ indices: [N], int32, in [0, 128^3)
+
+ '''
+ if not coords.is_cuda: coords = coords.cuda()
+
+ N = coords.shape[0]
+
+ indices = torch.empty(N, dtype=torch.int32, device=coords.device)
+
+ _backend.morton3D(coords.int(), N, indices)
+
+ return indices
+
+morton3D = _morton3D.apply
+
+class _morton3D_invert(Function):
+ @staticmethod
+ def forward(ctx, indices):
+ ''' morton3D_invert, CUDA implementation
+ Args:
+ indices: [N], int32, in [0, 128^3)
+ Returns:
+ coords: [N, 3], int32, in [0, 128)
+
+ '''
+ if not indices.is_cuda: indices = indices.cuda()
+
+ N = indices.shape[0]
+
+ coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
+
+ _backend.morton3D_invert(indices.int(), N, coords)
+
+ return coords
+
+morton3D_invert = _morton3D_invert.apply
+
+
+class _packbits(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, grid, thresh, bitfield=None):
+ ''' packbits, CUDA implementation
+ Pack up the density grid into a bit field to accelerate ray marching.
+ Args:
+ grid: float, [C, H * H * H], assume H % 2 == 0
+ thresh: float, threshold
+ Returns:
+ bitfield: uint8, [C, H * H * H / 8]
+ '''
+ if not grid.is_cuda: grid = grid.cuda()
+ grid = grid.contiguous()
+
+ C = grid.shape[0]
+ H3 = grid.shape[1]
+ N = C * H3 // 8
+
+ if bitfield is None:
+ bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
+
+ _backend.packbits(grid, N, thresh, bitfield)
+
+ return bitfield
+
+packbits = _packbits.apply
+
+
+class _morton3D_dilation(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, grid):
+ ''' max pooling with morton coord, CUDA implementation
+ or maybe call it dilation... we don't support adjust kernel size.
+ Args:
+ grid: float, [C, H * H * H], assume H % 2 == 0
+ Returns:
+ grid_dilate: float, [C, H * H * H], assume H % 2 == 0bitfield: uint8, [C, H * H * H / 8]
+ '''
+ if not grid.is_cuda: grid = grid.cuda()
+ grid = grid.contiguous()
+
+ C = grid.shape[0]
+ H3 = grid.shape[1]
+ H = int(np.cbrt(H3))
+ grid_dilation = torch.empty_like(grid)
+
+ _backend.morton3D_dilation(grid, C, H, grid_dilation)
+
+ return grid_dilation
+
+morton3D_dilation = _morton3D_dilation.apply
+
+# ----------------------------------------
+# train functions
+# ----------------------------------------
+
+class _march_rays_train(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
+ ''' march rays to generate points (forward only)
+ Args:
+ rays_o/d: float, [N, 3]
+ bound: float, scalar
+ density_bitfield: uint8: [CHHH // 8]
+ C: int
+ H: int
+ nears/fars: float, [N]
+ step_counter: int32, (2), used to count the actual number of generated points.
+ mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
+ perturb: bool
+ align: int, pad output so its size is dividable by align, set to -1 to disable.
+ force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
+ dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
+ max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
+ Returns:
+ xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
+ dirs: float, [M, 3], all generated points' view dirs.
+ deltas: float, [M, 2], first is delta_t, second is rays_t
+ rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 1] + rays[i, 2]] --> points belonging to rays[i, 0]
+ '''
+
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
+ if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+ density_bitfield = density_bitfield.contiguous()
+
+ N = rays_o.shape[0] # num rays
+ M = N * max_steps # init max points number in total
+
+ # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
+ # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
+ if not force_all_rays and mean_count > 0:
+ if align > 0:
+ mean_count += align - mean_count % align
+ M = mean_count
+
+ xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
+ dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
+ deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
+ rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
+
+ if step_counter is None:
+ step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
+
+ if perturb:
+ noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
+ else:
+ noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
+
+ _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
+
+ #print(step_counter, M)
+
+ # only used at the first (few) epochs.
+ if force_all_rays or mean_count <= 0:
+ m = step_counter[0].item() # D2H copy
+ if align > 0:
+ m += align - m % align
+ xyzs = xyzs[:m]
+ dirs = dirs[:m]
+ deltas = deltas[:m]
+
+ torch.cuda.empty_cache()
+
+ ctx.save_for_backward(rays, deltas)
+
+ return xyzs, dirs, deltas, rays
+
+ # to support optimizing camera poses.
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_xyzs, grad_dirs, grad_deltas, grad_rays):
+ # grad_xyzs/dirs: [M, 3]
+
+ rays, deltas = ctx.saved_tensors
+
+ N = rays.shape[0]
+ M = grad_xyzs.shape[0]
+
+ grad_rays_o = torch.zeros(N, 3, device=rays.device)
+ grad_rays_d = torch.zeros(N, 3, device=rays.device)
+
+ _backend.march_rays_train_backward(grad_xyzs, grad_dirs, rays, deltas, N, M, grad_rays_o, grad_rays_d)
+
+ return grad_rays_o, grad_rays_d, None, None, None, None, None, None, None, None, None, None, None, None, None
+
+march_rays_train = _march_rays_train.apply
+
+
+class _composite_rays_train(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
+ ''' composite rays' rgbs, according to the ray marching formula.
+ Args:
+ rgbs: float, [M, 3]
+ sigmas: float, [M,]
+ ambient: float, [M,] (after summing up the last dimension)
+ deltas: float, [M, 2]
+ rays: int32, [N, 3]
+ Returns:
+ weights_sum: float, [N,], the alpha channel
+ depth: float, [N, ], the Depth
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
+ '''
+
+ sigmas = sigmas.contiguous()
+ rgbs = rgbs.contiguous()
+ ambient = ambient.contiguous()
+
+ M = sigmas.shape[0]
+ N = rays.shape[0]
+
+ weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
+
+ _backend.composite_rays_train_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image)
+
+ ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image)
+ ctx.dims = [M, N, T_thresh]
+
+ return weights_sum, ambient_sum, depth, image
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):
+
+ # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
+
+ grad_weights_sum = grad_weights_sum.contiguous()
+ grad_ambient_sum = grad_ambient_sum.contiguous()
+ grad_image = grad_image.contiguous()
+
+ sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors
+ M, N, T_thresh = ctx.dims
+
+ grad_sigmas = torch.zeros_like(sigmas)
+ grad_rgbs = torch.zeros_like(rgbs)
+ grad_ambient = torch.zeros_like(ambient)
+
+ _backend.composite_rays_train_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient)
+
+ return grad_sigmas, grad_rgbs, grad_ambient, None, None, None
+
+
+composite_rays_train = _composite_rays_train.apply
+
+# ----------------------------------------
+# infer functions
+# ----------------------------------------
+
+class _march_rays(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
+ ''' march rays to generate points (forward only, for inference)
+ Args:
+ n_alive: int, number of alive rays
+ n_step: int, how many steps we march
+ rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
+ rays_t: float, [N], the alive rays' time, we only use the first n_alive.
+ rays_o/d: float, [N, 3]
+ bound: float, scalar
+ density_bitfield: uint8: [CHHH // 8]
+ C: int
+ H: int
+ nears/fars: float, [N]
+ align: int, pad output so its size is dividable by align, set to -1 to disable.
+ perturb: bool/int, int > 0 is used as the random seed.
+ dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
+ max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
+ Returns:
+ xyzs: float, [n_alive * n_step, 3], all generated points' coords
+ dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
+ deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
+ '''
+
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ M = n_alive * n_step
+
+ if align > 0:
+ M += align - (M % align)
+
+ xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
+ dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
+ deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
+
+ if perturb:
+ # torch.manual_seed(perturb) # test_gui uses spp index as seed
+ noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
+ else:
+ noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
+
+ _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
+
+ return xyzs, dirs, deltas
+
+march_rays = _march_rays.apply
+
+
+class _composite_rays(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
+ ''' composite rays' rgbs, according to the ray marching formula. (for inference)
+ Args:
+ n_alive: int, number of alive rays
+ n_step: int, how many steps we march
+ rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
+ rays_t: float, [N], the alive rays' time
+ sigmas: float, [n_alive * n_step,]
+ rgbs: float, [n_alive * n_step, 3]
+ deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
+ In-place Outputs:
+ weights_sum: float, [N,], the alpha channel
+ depth: float, [N,], the depth value
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
+ '''
+ _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
+ return tuple()
+
+
+composite_rays = _composite_rays.apply
+
+
+class _composite_rays_ambient(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2):
+ _backend.composite_rays_ambient(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum)
+ return tuple()
+
+
+composite_rays_ambient = _composite_rays_ambient.apply
+
+
+
+
+
+# custom
+
+class _composite_rays_train_sigma(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4):
+ ''' composite rays' rgbs, according to the ray marching formula.
+ Args:
+ rgbs: float, [M, 3]
+ sigmas: float, [M,]
+ ambient: float, [M,] (after summing up the last dimension)
+ deltas: float, [M, 2]
+ rays: int32, [N, 3]
+ Returns:
+ weights_sum: float, [N,], the alpha channel
+ depth: float, [N, ], the Depth
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
+ '''
+
+ sigmas = sigmas.contiguous()
+ rgbs = rgbs.contiguous()
+ ambient = ambient.contiguous()
+
+ M = sigmas.shape[0]
+ N = rays.shape[0]
+
+ weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
+
+ _backend.composite_rays_train_sigma_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image)
+
+ ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image)
+ ctx.dims = [M, N, T_thresh]
+
+ return weights_sum, ambient_sum, depth, image
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image):
+
+ # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
+
+ grad_weights_sum = grad_weights_sum.contiguous()
+ grad_ambient_sum = grad_ambient_sum.contiguous()
+ grad_image = grad_image.contiguous()
+
+ sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors
+ M, N, T_thresh = ctx.dims
+
+ grad_sigmas = torch.zeros_like(sigmas)
+ grad_rgbs = torch.zeros_like(rgbs)
+ grad_ambient = torch.zeros_like(ambient)
+
+ _backend.composite_rays_train_sigma_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient)
+
+ return grad_sigmas, grad_rgbs, grad_ambient, None, None, None
+
+
+composite_rays_train_sigma = _composite_rays_train_sigma.apply
+
+
+class _composite_rays_ambient_sigma(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2):
+ _backend.composite_rays_ambient_sigma(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum)
+ return tuple()
+
+
+composite_rays_ambient_sigma = _composite_rays_ambient_sigma.apply
+
+
+
+# uncertainty
+class _composite_rays_train_uncertainty(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, sigmas, rgbs, ambient, uncertainty, deltas, rays, T_thresh=1e-4):
+ ''' composite rays' rgbs, according to the ray marching formula.
+ Args:
+ rgbs: float, [M, 3]
+ sigmas: float, [M,]
+ ambient: float, [M,] (after summing up the last dimension)
+ deltas: float, [M, 2]
+ rays: int32, [N, 3]
+ Returns:
+ weights_sum: float, [N,], the alpha channel
+ depth: float, [N, ], the Depth
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
+ '''
+
+ sigmas = sigmas.contiguous()
+ rgbs = rgbs.contiguous()
+ ambient = ambient.contiguous()
+ uncertainty = uncertainty.contiguous()
+
+ M = sigmas.shape[0]
+ N = rays.shape[0]
+
+ weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ uncertainty_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
+
+ _backend.composite_rays_train_uncertainty_forward(sigmas, rgbs, ambient, uncertainty, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, uncertainty_sum, depth, image)
+
+ ctx.save_for_backward(sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, depth, image)
+ ctx.dims = [M, N, T_thresh]
+
+ return weights_sum, ambient_sum, uncertainty_sum, depth, image
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_depth, grad_image):
+
+ # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
+
+ grad_weights_sum = grad_weights_sum.contiguous()
+ grad_ambient_sum = grad_ambient_sum.contiguous()
+ grad_uncertainty_sum = grad_uncertainty_sum.contiguous()
+ grad_image = grad_image.contiguous()
+
+ sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, depth, image = ctx.saved_tensors
+ M, N, T_thresh = ctx.dims
+
+ grad_sigmas = torch.zeros_like(sigmas)
+ grad_rgbs = torch.zeros_like(rgbs)
+ grad_ambient = torch.zeros_like(ambient)
+ grad_uncertainty = torch.zeros_like(uncertainty)
+
+ _backend.composite_rays_train_uncertainty_backward(grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_image, sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient, grad_uncertainty)
+
+ return grad_sigmas, grad_rgbs, grad_ambient, grad_uncertainty, None, None, None
+
+
+composite_rays_train_uncertainty = _composite_rays_train_uncertainty.apply
+
+
+class _composite_rays_uncertainty(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh=1e-2):
+ _backend.composite_rays_uncertainty(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum)
+ return tuple()
+
+
+composite_rays_uncertainty = _composite_rays_uncertainty.apply
+
+
+
+# triplane(eye)
+class _composite_rays_train_triplane(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, T_thresh=1e-4):
+ ''' composite rays' rgbs, according to the ray marching formula.
+ Args:
+ rgbs: float, [M, 3]
+ sigmas: float, [M,]
+ ambient: float, [M,] (after summing up the last dimension)
+ deltas: float, [M, 2]
+ rays: int32, [N, 3]
+ Returns:
+ weights_sum: float, [N,], the alpha channel
+ depth: float, [N, ], the Depth
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
+ '''
+
+ sigmas = sigmas.contiguous()
+ rgbs = rgbs.contiguous()
+ amb_aud = amb_aud.contiguous()
+ amb_eye = amb_eye.contiguous()
+ uncertainty = uncertainty.contiguous()
+
+ M = sigmas.shape[0]
+ N = rays.shape[0]
+
+ weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ amb_aud_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ amb_eye_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ uncertainty_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
+
+ _backend.composite_rays_train_triplane_forward(sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, M, N, T_thresh, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image)
+
+ ctx.save_for_backward(sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image)
+ ctx.dims = [M, N, T_thresh]
+
+ return weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_depth, grad_image):
+
+ # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
+
+ grad_weights_sum = grad_weights_sum.contiguous()
+ grad_amb_aud_sum = grad_amb_aud_sum.contiguous()
+ grad_amb_eye_sum = grad_amb_eye_sum.contiguous()
+ grad_uncertainty_sum = grad_uncertainty_sum.contiguous()
+ grad_image = grad_image.contiguous()
+
+ sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = ctx.saved_tensors
+ M, N, T_thresh = ctx.dims
+
+ grad_sigmas = torch.zeros_like(sigmas)
+ grad_rgbs = torch.zeros_like(rgbs)
+ grad_amb_aud = torch.zeros_like(amb_aud)
+ grad_amb_eye = torch.zeros_like(amb_eye)
+ grad_uncertainty = torch.zeros_like(uncertainty)
+
+ _backend.composite_rays_train_triplane_backward(grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_image, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_amb_aud, grad_amb_eye, grad_uncertainty)
+
+ return grad_sigmas, grad_rgbs, grad_amb_aud, grad_amb_eye, grad_uncertainty, None, None, None
+
+
+composite_rays_train_triplane = _composite_rays_train_triplane.apply
+
+
+class _composite_rays_triplane(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh=1e-2):
+ _backend.composite_rays_triplane(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum)
+ return tuple()
+
+
+composite_rays_triplane = _composite_rays_triplane.apply
\ No newline at end of file
diff --git a/sync/SyncTalk/raymarching/setup.py b/sync/SyncTalk/raymarching/setup.py
new file mode 100644
index 00000000..6a7e62f7
--- /dev/null
+++ b/sync/SyncTalk/raymarching/setup.py
@@ -0,0 +1,63 @@
+import os
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ # '-lineinfo', # to debug illegal memory access
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+'''
+Usage:
+
+python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
+
+python setup.py install # build extensions and install (copy) to PATH.
+pip install . # ditto but better (e.g., dependency & metadata handling)
+
+python setup.py develop # build extensions and install (symbolic) to PATH.
+pip install -e . # ditto but better (e.g., dependency & metadata handling)
+
+'''
+setup(
+ name='raymarching_face', # package name, import this to use python API
+ ext_modules=[
+ CUDAExtension(
+ name='_raymarching_face', # extension name, import this to use CUDA API
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'raymarching.cu',
+ 'bindings.cpp',
+ ]],
+ extra_compile_args={
+ 'cxx': c_flags,
+ 'nvcc': nvcc_flags,
+ }
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
\ No newline at end of file
diff --git a/sync/SyncTalk/raymarching/src/bindings.cpp b/sync/SyncTalk/raymarching/src/bindings.cpp
new file mode 100644
index 00000000..a9622b24
--- /dev/null
+++ b/sync/SyncTalk/raymarching/src/bindings.cpp
@@ -0,0 +1,39 @@
+#include
+
+#include "raymarching.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ // utils
+ m.def("packbits", &packbits, "packbits (CUDA)");
+ m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
+ m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
+ m.def("morton3D", &morton3D, "morton3D (CUDA)");
+ m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
+ m.def("morton3D_dilation", &morton3D_dilation, "morton3D_dilation (CUDA)");
+ // train
+ m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
+ m.def("march_rays_train_backward", &march_rays_train_backward, "march_rays_train_backward (CUDA)");
+ m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
+ m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
+ // infer
+ m.def("march_rays", &march_rays, "march rays (CUDA)");
+ m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
+ m.def("composite_rays_ambient", &composite_rays_ambient, "composite rays with ambient (CUDA)");
+
+ // train
+ m.def("composite_rays_train_sigma_forward", &composite_rays_train_sigma_forward, "composite_rays_train_forward (CUDA)");
+ m.def("composite_rays_train_sigma_backward", &composite_rays_train_sigma_backward, "composite_rays_train_backward (CUDA)");
+ // infer
+ m.def("composite_rays_ambient_sigma", &composite_rays_ambient_sigma, "composite rays with ambient (CUDA)");
+
+ // uncertainty train
+ m.def("composite_rays_train_uncertainty_forward", &composite_rays_train_uncertainty_forward, "composite_rays_train_forward (CUDA)");
+ m.def("composite_rays_train_uncertainty_backward", &composite_rays_train_uncertainty_backward, "composite_rays_train_backward (CUDA)");
+ m.def("composite_rays_uncertainty", &composite_rays_uncertainty, "composite rays with ambient (CUDA)");
+
+ // triplane
+ m.def("composite_rays_train_triplane_forward", &composite_rays_train_triplane_forward, "composite_rays_train_forward (CUDA)");
+ m.def("composite_rays_train_triplane_backward", &composite_rays_train_triplane_backward, "composite_rays_train_backward (CUDA)");
+ m.def("composite_rays_triplane", &composite_rays_triplane, "composite rays with ambient (CUDA)");
+
+}
\ No newline at end of file
diff --git a/sync/SyncTalk/raymarching/src/raymarching.cu b/sync/SyncTalk/raymarching/src/raymarching.cu
new file mode 100644
index 00000000..95462877
--- /dev/null
+++ b/sync/SyncTalk/raymarching/src/raymarching.cu
@@ -0,0 +1,2258 @@
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+
+inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
+inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
+inline constexpr __device__ float PI() { return 3.141592653589793f; }
+inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
+
+
+template
+inline __host__ __device__ T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+inline __host__ __device__ float signf(const float x) {
+ return copysignf(1.0, x);
+}
+
+inline __host__ __device__ float clamp(const float x, const float min, const float max) {
+ return fminf(max, fmaxf(min, x));
+}
+
+inline __host__ __device__ void swapf(float& a, float& b) {
+ float c = a; a = b; b = c;
+}
+
+inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
+ const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
+ int exponent;
+ frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
+ return fminf(max_cascade - 1, fmaxf(0, exponent));
+}
+
+inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
+ const float mx = dt * H * 0.5;
+ int exponent;
+ frexpf(mx, &exponent);
+ return fminf(max_cascade - 1, fmaxf(0, exponent));
+}
+
+inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
+{
+ v = (v * 0x00010001u) & 0xFF0000FFu;
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
+ v = (v * 0x00000011u) & 0xC30C30C3u;
+ v = (v * 0x00000005u) & 0x49249249u;
+ return v;
+}
+
+inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
+{
+ uint32_t xx = __expand_bits(x);
+ uint32_t yy = __expand_bits(y);
+ uint32_t zz = __expand_bits(z);
+ return xx | (yy << 1) | (zz << 2);
+}
+
+inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
+{
+ x = x & 0x49249249;
+ x = (x | (x >> 2)) & 0xc30c30c3;
+ x = (x | (x >> 4)) & 0x0f00f00f;
+ x = (x | (x >> 8)) & 0xff0000ff;
+ x = (x | (x >> 16)) & 0x0000ffff;
+ return x;
+}
+
+
+////////////////////////////////////////////////////
+///////////// utils /////////////
+////////////////////////////////////////////////////
+
+// rays_o/d: [N, 3]
+// nears/fars: [N]
+// scalar_t should always be float in use.
+template
+__global__ void kernel_near_far_from_aabb(
+ const scalar_t * __restrict__ rays_o,
+ const scalar_t * __restrict__ rays_d,
+ const scalar_t * __restrict__ aabb,
+ const uint32_t N,
+ const float min_near,
+ scalar_t * nears, scalar_t * fars
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ rays_o += n * 3;
+ rays_d += n * 3;
+
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
+
+ // get near far (assume cube scene)
+ float near = (aabb[0] - ox) * rdx;
+ float far = (aabb[3] - ox) * rdx;
+ if (near > far) swapf(near, far);
+
+ float near_y = (aabb[1] - oy) * rdy;
+ float far_y = (aabb[4] - oy) * rdy;
+ if (near_y > far_y) swapf(near_y, far_y);
+
+ if (near > far_y || near_y > far) {
+ nears[n] = fars[n] = std::numeric_limits::max();
+ return;
+ }
+
+ if (near_y > near) near = near_y;
+ if (far_y < far) far = far_y;
+
+ float near_z = (aabb[2] - oz) * rdz;
+ float far_z = (aabb[5] - oz) * rdz;
+ if (near_z > far_z) swapf(near_z, far_z);
+
+ if (near > far_z || near_z > far) {
+ nears[n] = fars[n] = std::numeric_limits::max();
+ return;
+ }
+
+ if (near_z > near) near = near_z;
+ if (far_z < far) far = far_z;
+
+ if (near < min_near) near = min_near;
+
+ nears[n] = near;
+ fars[n] = far;
+}
+
+
+void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ rays_o.scalar_type(), "near_far_from_aabb", ([&] {
+ kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr());
+ }));
+}
+
+
+// rays_o/d: [N, 3]
+// radius: float
+// coords: [N, 2]
+template
+__global__ void kernel_sph_from_ray(
+ const scalar_t * __restrict__ rays_o,
+ const scalar_t * __restrict__ rays_d,
+ const float radius,
+ const uint32_t N,
+ scalar_t * coords
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ rays_o += n * 3;
+ rays_d += n * 3;
+ coords += n * 2;
+
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
+
+ // solve t from || o + td || = radius
+ const float A = dx * dx + dy * dy + dz * dz;
+ const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
+ const float C = ox * ox + oy * oy + oz * oz - radius * radius;
+
+ const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
+
+ // solve theta, phi (assume y is the up axis)
+ const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
+ const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
+ const float phi = atan2(z, x); // [-PI, PI)
+
+ // normalize to [-1, 1]
+ coords[0] = 2 * theta * RPI() - 1;
+ coords[1] = phi * RPI();
+}
+
+
+void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ rays_o.scalar_type(), "sph_from_ray", ([&] {
+ kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr());
+ }));
+}
+
+
+// coords: int32, [N, 3]
+// indices: int32, [N]
+__global__ void kernel_morton3D(
+ const int * __restrict__ coords,
+ const uint32_t N,
+ int * indices
+) {
+ // parallel
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ coords += n * 3;
+ indices[n] = __morton3D(coords[0], coords[1], coords[2]);
+}
+
+
+void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
+ static constexpr uint32_t N_THREAD = 128;
+ kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr());
+}
+
+
+// indices: int32, [N]
+// coords: int32, [N, 3]
+__global__ void kernel_morton3D_invert(
+ const int * __restrict__ indices,
+ const uint32_t N,
+ int * coords
+) {
+ // parallel
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ coords += n * 3;
+
+ const int ind = indices[n];
+
+ coords[0] = __morton3D_invert(ind >> 0);
+ coords[1] = __morton3D_invert(ind >> 1);
+ coords[2] = __morton3D_invert(ind >> 2);
+}
+
+
+void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
+ static constexpr uint32_t N_THREAD = 128;
+ kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr());
+}
+
+
+// grid: float, [C, H, H, H]
+// N: int, C * H * H * H / 8
+// density_thresh: float
+// bitfield: uint8, [N]
+template
+__global__ void kernel_packbits(
+ const scalar_t * __restrict__ grid,
+ const uint32_t N,
+ const float density_thresh,
+ uint8_t * bitfield
+) {
+ // parallel per byte
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ grid += n * 8;
+
+ uint8_t bits = 0;
+
+ #pragma unroll
+ for (uint8_t i = 0; i < 8; i++) {
+ bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
+ }
+
+ bitfield[n] = bits;
+}
+
+
+void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grid.scalar_type(), "packbits", ([&] {
+ kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr());
+ }));
+}
+
+
+// grid: float, [C, H, H, H]
+__global__ void kernel_morton3D_dilation(
+ const float * __restrict__ grid,
+ const uint32_t C,
+ const uint32_t H,
+ float * __restrict__ grid_dilation
+) {
+ // parallel per byte
+ const uint32_t H3 = H * H * H;
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= C * H3) return;
+
+ // locate
+ const uint32_t c = n / H3;
+ const uint32_t ind = n - c * H3;
+
+ const uint32_t x = __morton3D_invert(ind >> 0);
+ const uint32_t y = __morton3D_invert(ind >> 1);
+ const uint32_t z = __morton3D_invert(ind >> 2);
+
+ // manual max pool
+ float res = grid[n];
+
+ if (x + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x + 1, y, z)]);
+ if (x > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x - 1, y, z)]);
+ if (y + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y + 1, z)]);
+ if (y > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y - 1, z)]);
+ if (z + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z + 1)]);
+ if (z > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z - 1)]);
+
+ // write
+ grid_dilation[n] = res;
+}
+
+void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation) {
+ static constexpr uint32_t N_THREAD = 128;
+
+ kernel_morton3D_dilation<<>>(grid.data_ptr(), C, H, grid_dilation.data_ptr());
+}
+
+////////////////////////////////////////////////////
+///////////// training /////////////
+////////////////////////////////////////////////////
+
+// rays_o/d: [N, 3]
+// grid: [CHHH / 8]
+// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
+// dirs: [M, 3]
+// rays: [N, 3], idx, offset, num_steps
+template
+__global__ void kernel_march_rays_train(
+ const scalar_t * __restrict__ rays_o,
+ const scalar_t * __restrict__ rays_d,
+ const uint8_t * __restrict__ grid,
+ const float bound,
+ const float dt_gamma, const uint32_t max_steps,
+ const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
+ const scalar_t* __restrict__ nears,
+ const scalar_t* __restrict__ fars,
+ scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
+ int * rays,
+ int * counter,
+ const scalar_t* __restrict__ noises
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ rays_o += n * 3;
+ rays_d += n * 3;
+
+ // ray marching
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
+ const float rH = 1 / (float)H;
+ const float H3 = H * H * H;
+
+ const float near = nears[n];
+ const float far = fars[n];
+ const float noise = noises[n];
+
+ const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
+ const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps);
+
+ float t0 = near;
+
+ // perturb
+ t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
+
+ // first pass: estimation of num_steps
+ float t = t0;
+ uint32_t num_steps = 0;
+
+ //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
+
+ while (t < far && num_steps < max_steps) {
+ // current point
+ const float x = clamp(ox + t * dx, -bound, bound);
+ const float y = clamp(oy + t * dy, -bound, bound);
+ const float z = clamp(oz + t * dz, -bound, bound);
+
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
+
+ // get mip level
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
+
+ const float mip_bound = fminf(scalbnf(1.0f, level), bound);
+ const float mip_rbound = 1 / mip_bound;
+
+ // convert to nearest grid position
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
+ const bool occ = grid[index / 8] & (1 << (index % 8));
+
+ // if occpuied, advance a small step, and write to output
+ //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
+
+ if (occ) {
+ num_steps++;
+ t += dt;
+ // else, skip a large step (basically skip a voxel grid)
+ } else {
+ // calc distance to next voxel
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
+
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
+ // step until next voxel
+ do {
+ t += clamp(t * dt_gamma, dt_min, dt_max);
+ } while (t < tt);
+ }
+ }
+
+ //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
+
+ // second pass: really locate and write points & dirs
+ uint32_t point_index = atomicAdd(counter, num_steps);
+ uint32_t ray_index = atomicAdd(counter + 1, 1);
+
+ //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
+
+ // write rays
+ rays[ray_index * 3] = n;
+ rays[ray_index * 3 + 1] = point_index;
+ rays[ray_index * 3 + 2] = num_steps;
+
+ if (num_steps == 0) return;
+ if (point_index + num_steps > M) return;
+
+ xyzs += point_index * 3;
+ dirs += point_index * 3;
+ deltas += point_index * 2;
+
+ t = t0;
+ uint32_t step = 0;
+
+ while (t < far && step < num_steps) {
+ // current point
+ const float x = clamp(ox + t * dx, -bound, bound);
+ const float y = clamp(oy + t * dy, -bound, bound);
+ const float z = clamp(oz + t * dz, -bound, bound);
+
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
+
+ // get mip level
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
+
+ const float mip_bound = fminf(scalbnf(1.0f, level), bound);
+ const float mip_rbound = 1 / mip_bound;
+
+ // convert to nearest grid position
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+
+ // query grid
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
+ const bool occ = grid[index / 8] & (1 << (index % 8));
+
+ // if occpuied, advance a small step, and write to output
+ if (occ) {
+ // write step
+ xyzs[0] = x;
+ xyzs[1] = y;
+ xyzs[2] = z;
+ dirs[0] = dx;
+ dirs[1] = dy;
+ dirs[2] = dz;
+ t += dt;
+ deltas[0] = dt;
+ deltas[1] = t; // used to calc depth
+ xyzs += 3;
+ dirs += 3;
+ deltas += 2;
+ step++;
+ // else, skip a large step (basically skip a voxel grid)
+ } else {
+ // calc distance to next voxel
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
+ // step until next voxel
+ do {
+ t += clamp(t * dt_gamma, dt_min, dt_max);
+ } while (t < tt);
+ }
+ }
+}
+
+void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ rays_o.scalar_type(), "march_rays_train", ([&] {
+ kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), noises.data_ptr());
+ }));
+}
+
+
+// grad_xyzs/dirs: [M, 3]
+// rays: [N, 3]
+// deltas: [M, 2]
+// grad_rays_o/d: [N, 3]
+template
+__global__ void kernel_march_rays_train_backward(
+ const scalar_t * __restrict__ grad_xyzs,
+ const scalar_t * __restrict__ grad_dirs,
+ const int * __restrict__ rays,
+ const scalar_t * __restrict__ deltas,
+ const uint32_t N, const uint32_t M,
+ scalar_t * grad_rays_o,
+ scalar_t * grad_rays_d
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ grad_rays_o += n * 3;
+ grad_rays_d += n * 3;
+
+ uint32_t index = rays[n * 3];
+ uint32_t offset = rays[n * 3 + 1];
+ uint32_t num_steps = rays[n * 3 + 2];
+
+ // empty ray, or ray that exceed max step count.
+ if (num_steps == 0 || offset + num_steps > M) return;
+
+ grad_xyzs += offset * 3;
+ grad_dirs += offset * 3;
+ deltas += offset * 2;
+
+ // accumulate
+ uint32_t step = 0;
+ while (step < num_steps) {
+
+ grad_rays_o[0] += grad_xyzs[0];
+ grad_rays_o[1] += grad_xyzs[1];
+ grad_rays_o[2] += grad_xyzs[2];
+
+ grad_rays_d[0] += grad_xyzs[0] * deltas[1] + grad_dirs[0];
+ grad_rays_d[1] += grad_xyzs[1] * deltas[1] + grad_dirs[1];
+ grad_rays_d[2] += grad_xyzs[2] * deltas[1] + grad_dirs[2];
+
+ // locate
+ grad_xyzs += 3;
+ grad_dirs += 3;
+ deltas += 2;
+
+ step++;
+ }
+}
+
+void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad_xyzs.scalar_type(), "march_rays_train_backward", ([&] {
+ kernel_march_rays_train_backward<<>>(grad_xyzs.data_ptr(), grad_dirs.data_ptr(), rays.data_ptr(), deltas.data_ptr(), N, M, grad_rays_o.data_ptr(), grad_rays_d.data_ptr());
+ }));
+}
+
+
+// sigmas: [M]
+// rgbs: [M, 3]
+// deltas: [M, 2]
+// rays: [N, 3], idx, offset, num_steps
+// weights_sum: [N], final pixel alpha
+// depth: [N,]
+// image: [N, 3]
+template
+__global__ void kernel_composite_rays_train_forward(
+ const scalar_t * __restrict__ sigmas,
+ const scalar_t * __restrict__ rgbs,
+ const scalar_t * __restrict__ ambient,
+ const scalar_t * __restrict__ deltas,
+ const int * __restrict__ rays,
+ const uint32_t M, const uint32_t N, const float T_thresh,
+ scalar_t * weights_sum,
+ scalar_t * ambient_sum,
+ scalar_t * depth,
+ scalar_t * image
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ uint32_t index = rays[n * 3];
+ uint32_t offset = rays[n * 3 + 1];
+ uint32_t num_steps = rays[n * 3 + 2];
+
+ // empty ray, or ray that exceed max step count.
+ if (num_steps == 0 || offset + num_steps > M) {
+ weights_sum[index] = 0;
+ ambient_sum[index] = 0;
+ depth[index] = 0;
+ image[index * 3] = 0;
+ image[index * 3 + 1] = 0;
+ image[index * 3 + 2] = 0;
+ return;
+ }
+
+ sigmas += offset;
+ rgbs += offset * 3;
+ ambient += offset;
+ deltas += offset * 2;
+
+ // accumulate
+ uint32_t step = 0;
+
+ scalar_t T = 1.0f;
+ scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0;
+
+ while (step < num_steps) {
+
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
+ const scalar_t weight = alpha * T;
+
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+
+ d += weight * deltas[1];
+
+ ws += weight;
+
+ amb += ambient[0];
+
+ T *= 1.0f - alpha;
+
+ // minimal remained transmittence
+ if (T < T_thresh) break;
+
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
+
+ // locate
+ sigmas++;
+ rgbs += 3;
+ ambient++;
+ deltas += 2;
+
+ step++;
+ }
+
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
+
+ // write
+ weights_sum[index] = ws; // weights_sum
+ ambient_sum[index] = amb;
+ depth[index] = d;
+ image[index * 3] = r;
+ image[index * 3 + 1] = g;
+ image[index * 3 + 2] = b;
+}
+
+
+void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
+ kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), ambient_sum.data_ptr(), depth.data_ptr(), image.data_ptr());
+ }));
+}
+
+
+// grad_weights_sum: [N,]
+// grad: [N, 3]
+// sigmas: [M]
+// rgbs: [M, 3]
+// deltas: [M, 2]
+// rays: [N, 3], idx, offset, num_steps
+// weights_sum: [N,], weights_sum here
+// image: [N, 3]
+// grad_sigmas: [M]
+// grad_rgbs: [M, 3]
+template
+__global__ void kernel_composite_rays_train_backward(
+ const scalar_t * __restrict__ grad_weights_sum,
+ const scalar_t * __restrict__ grad_ambient_sum,
+ const scalar_t * __restrict__ grad_image,
+ const scalar_t * __restrict__ sigmas,
+ const scalar_t * __restrict__ rgbs,
+ const scalar_t * __restrict__ ambient,
+ const scalar_t * __restrict__ deltas,
+ const int * __restrict__ rays,
+ const scalar_t * __restrict__ weights_sum,
+ const scalar_t * __restrict__ ambient_sum,
+ const scalar_t * __restrict__ image,
+ const uint32_t M, const uint32_t N, const float T_thresh,
+ scalar_t * grad_sigmas,
+ scalar_t * grad_rgbs,
+ scalar_t * grad_ambient
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ uint32_t index = rays[n * 3];
+ uint32_t offset = rays[n * 3 + 1];
+ uint32_t num_steps = rays[n * 3 + 2];
+
+ if (num_steps == 0 || offset + num_steps > M) return;
+
+ grad_weights_sum += index;
+ grad_ambient_sum += index;
+ grad_image += index * 3;
+ weights_sum += index;
+ ambient_sum += index;
+ image += index * 3;
+
+ sigmas += offset;
+ rgbs += offset * 3;
+ ambient += offset;
+ deltas += offset * 2;
+
+ grad_sigmas += offset;
+ grad_rgbs += offset * 3;
+ grad_ambient += offset;
+
+ // accumulate
+ uint32_t step = 0;
+
+ scalar_t T = 1.0f;
+ const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
+ scalar_t r = 0, g = 0, b = 0, ws = 0;
+
+ while (step < num_steps) {
+
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
+ const scalar_t weight = alpha * T;
+
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+ // amb += weight * ambient[0];
+ ws += weight;
+
+ T *= 1.0f - alpha;
+
+ // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
+ // write grad_rgbs
+ grad_rgbs[0] = grad_image[0] * weight;
+ grad_rgbs[1] = grad_image[1] * weight;
+ grad_rgbs[2] = grad_image[2] * weight;
+
+ // write grad_ambient
+ grad_ambient[0] = grad_ambient_sum[0];
+
+ // write grad_sigmas
+ grad_sigmas[0] = deltas[0] * (
+ grad_image[0] * (T * rgbs[0] - (r_final - r)) +
+ grad_image[1] * (T * rgbs[1] - (g_final - g)) +
+ grad_image[2] * (T * rgbs[2] - (b_final - b)) +
+ // grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) +
+ grad_weights_sum[0] * (1 - ws_final)
+ );
+
+ //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
+ // minimal remained transmittence
+ if (T < T_thresh) break;
+
+ // locate
+ sigmas++;
+ rgbs += 3;
+ // ambient++;
+ deltas += 2;
+ grad_sigmas++;
+ grad_rgbs += 3;
+ grad_ambient++;
+
+ step++;
+ }
+}
+
+
+void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
+ kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_ambient_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), ambient_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr(), grad_ambient.data_ptr());
+ }));
+}
+
+
+////////////////////////////////////////////////////
+///////////// infernce /////////////
+////////////////////////////////////////////////////
+
+template
+__global__ void kernel_march_rays(
+ const uint32_t n_alive,
+ const uint32_t n_step,
+ const int* __restrict__ rays_alive,
+ const scalar_t* __restrict__ rays_t,
+ const scalar_t* __restrict__ rays_o,
+ const scalar_t* __restrict__ rays_d,
+ const float bound,
+ const float dt_gamma, const uint32_t max_steps,
+ const uint32_t C, const uint32_t H,
+ const uint8_t * __restrict__ grid,
+ const scalar_t* __restrict__ nears,
+ const scalar_t* __restrict__ fars,
+ scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
+ const scalar_t* __restrict__ noises
+) {
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= n_alive) return;
+
+ const int index = rays_alive[n]; // ray id
+ const float noise = noises[n];
+
+ // locate
+ rays_o += index * 3;
+ rays_d += index * 3;
+ xyzs += n * n_step * 3;
+ dirs += n * n_step * 3;
+ deltas += n * n_step * 2;
+
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
+ const float rH = 1 / (float)H;
+ const float H3 = H * H * H;
+
+ float t = rays_t[index]; // current ray's t
+ const float near = nears[index], far = fars[index];
+
+ const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
+ const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps);
+
+ // march for n_step steps, record points
+ uint32_t step = 0;
+
+ // introduce some randomness
+ t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
+
+ while (t < far && step < n_step) {
+ // current point
+ const float x = clamp(ox + t * dx, -bound, bound);
+ const float y = clamp(oy + t * dy, -bound, bound);
+ const float z = clamp(oz + t * dz, -bound, bound);
+
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
+
+ // get mip level
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
+
+ const float mip_bound = fminf(scalbnf(1, level), bound);
+ const float mip_rbound = 1 / mip_bound;
+
+ // convert to nearest grid position
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
+ const bool occ = grid[index / 8] & (1 << (index % 8));
+
+ // if occpuied, advance a small step, and write to output
+ if (occ) {
+ // write step
+ xyzs[0] = x;
+ xyzs[1] = y;
+ xyzs[2] = z;
+ dirs[0] = dx;
+ dirs[1] = dy;
+ dirs[2] = dz;
+ // calc dt
+ t += dt;
+ deltas[0] = dt;
+ deltas[1] = t; // used to calc depth
+ // step
+ xyzs += 3;
+ dirs += 3;
+ deltas += 2;
+ step++;
+
+ // else, skip a large step (basically skip a voxel grid)
+ } else {
+ // calc distance to next voxel
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
+ // step until next voxel
+ do {
+ t += clamp(t * dt_gamma, dt_min, dt_max);
+ } while (t < tt);
+ }
+ }
+}
+
+
+void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ rays_o.scalar_type(), "march_rays", ([&] {
+ kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), noises.data_ptr());
+ }));
+}
+
+
+template
+__global__ void kernel_composite_rays(
+ const uint32_t n_alive,
+ const uint32_t n_step,
+ const float T_thresh,
+ int* rays_alive,
+ scalar_t* rays_t,
+ const scalar_t* __restrict__ sigmas,
+ const scalar_t* __restrict__ rgbs,
+ const scalar_t* __restrict__ deltas,
+ scalar_t* weights_sum, scalar_t* depth, scalar_t* image
+) {
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= n_alive) return;
+
+ const int index = rays_alive[n]; // ray id
+
+ // locate
+ sigmas += n * n_step;
+ rgbs += n * n_step * 3;
+ deltas += n * n_step * 2;
+
+ rays_t += index;
+ weights_sum += index;
+ depth += index;
+ image += index * 3;
+
+ scalar_t t = rays_t[0]; // current ray's t
+
+ scalar_t weight_sum = weights_sum[0];
+ scalar_t d = depth[0];
+ scalar_t r = image[0];
+ scalar_t g = image[1];
+ scalar_t b = image[2];
+
+ // accumulate
+ uint32_t step = 0;
+ while (step < n_step) {
+
+ // ray is terminated if delta == 0
+ if (deltas[0] == 0) break;
+
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
+
+ /*
+ T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
+ w_i = alpha_i * T_i
+ -->
+ T_i = 1 - \sum_{j=0}^{i-1} w_j
+ */
+ const scalar_t T = 1 - weight_sum;
+ const scalar_t weight = alpha * T;
+ weight_sum += weight;
+
+ t = deltas[1];
+ d += weight * t;
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
+
+ // ray is terminated if T is too small
+ // use a larger bound to further accelerate inference
+ if (T < T_thresh) break;
+
+ // locate
+ sigmas++;
+ rgbs += 3;
+ deltas += 2;
+ step++;
+ }
+
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
+
+ // rays_alive = -1 means ray is terminated early.
+ if (step < n_step) {
+ rays_alive[n] = -1;
+ } else {
+ rays_t[0] = t;
+ }
+
+ weights_sum[0] = weight_sum; // this is the thing I needed!
+ depth[0] = d;
+ image[0] = r;
+ image[1] = g;
+ image[2] = b;
+}
+
+
+void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
+ static constexpr uint32_t N_THREAD = 128;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ image.scalar_type(), "composite_rays", ([&] {
+ kernel_composite_rays<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr());
+ }));
+}
+
+
+
+template
+__global__ void kernel_composite_rays_ambient(
+ const uint32_t n_alive,
+ const uint32_t n_step,
+ const float T_thresh,
+ int* rays_alive,
+ scalar_t* rays_t,
+ const scalar_t* __restrict__ sigmas,
+ const scalar_t* __restrict__ rgbs,
+ const scalar_t* __restrict__ deltas,
+ const scalar_t* __restrict__ ambients,
+ scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* ambient_sum
+) {
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= n_alive) return;
+
+ const int index = rays_alive[n]; // ray id
+
+ // locate
+ sigmas += n * n_step;
+ rgbs += n * n_step * 3;
+ deltas += n * n_step * 2;
+ ambients += n * n_step;
+
+ rays_t += index;
+ weights_sum += index;
+ depth += index;
+ image += index * 3;
+ ambient_sum += index;
+
+ scalar_t t = rays_t[0]; // current ray's t
+
+ scalar_t weight_sum = weights_sum[0];
+ scalar_t d = depth[0];
+ scalar_t r = image[0];
+ scalar_t g = image[1];
+ scalar_t b = image[2];
+ scalar_t a = ambient_sum[0];
+
+ // accumulate
+ uint32_t step = 0;
+ while (step < n_step) {
+
+ // ray is terminated if delta == 0
+ if (deltas[0] == 0) break;
+
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
+
+ /*
+ T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
+ w_i = alpha_i * T_i
+ -->
+ T_i = 1 - \sum_{j=0}^{i-1} w_j
+ */
+ const scalar_t T = 1 - weight_sum;
+ const scalar_t weight = alpha * T;
+ weight_sum += weight;
+
+ t = deltas[1];
+ d += weight * t;
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+ a += ambients[0];
+
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
+
+ // ray is terminated if T is too small
+ // use a larger bound to further accelerate inference
+ if (T < T_thresh) break;
+
+ // locate
+ sigmas++;
+ rgbs += 3;
+ deltas += 2;
+ step++;
+ ambients++;
+ }
+
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
+
+ // rays_alive = -1 means ray is terminated early.
+ if (step < n_step) {
+ rays_alive[n] = -1;
+ } else {
+ rays_t[0] = t;
+ }
+
+ weights_sum[0] = weight_sum; // this is the thing I needed!
+ depth[0] = d;
+ image[0] = r;
+ image[1] = g;
+ image[2] = b;
+ ambient_sum[0] = a;
+}
+
+
+void composite_rays_ambient(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum) {
+ static constexpr uint32_t N_THREAD = 128;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ image.scalar_type(), "composite_rays_ambient", ([&] {
+ kernel_composite_rays_ambient<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), ambients.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr(), ambient_sum.data_ptr());
+ }));
+}
+
+
+
+
+
+
+// -------------------------------- sigma ambient -----------------------------
+
+// sigmas: [M]
+// rgbs: [M, 3]
+// deltas: [M, 2]
+// rays: [N, 3], idx, offset, num_steps
+// weights_sum: [N], final pixel alpha
+// depth: [N,]
+// image: [N, 3]
+template
+__global__ void kernel_composite_rays_train_sigma_forward(
+ const scalar_t * __restrict__ sigmas,
+ const scalar_t * __restrict__ rgbs,
+ const scalar_t * __restrict__ ambient,
+ const scalar_t * __restrict__ deltas,
+ const int * __restrict__ rays,
+ const uint32_t M, const uint32_t N, const float T_thresh,
+ scalar_t * weights_sum,
+ scalar_t * ambient_sum,
+ scalar_t * depth,
+ scalar_t * image
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ uint32_t index = rays[n * 3];
+ uint32_t offset = rays[n * 3 + 1];
+ uint32_t num_steps = rays[n * 3 + 2];
+
+ // empty ray, or ray that exceed max step count.
+ if (num_steps == 0 || offset + num_steps > M) {
+ weights_sum[index] = 0;
+ ambient_sum[index] = 0;
+ depth[index] = 0;
+ image[index * 3] = 0;
+ image[index * 3 + 1] = 0;
+ image[index * 3 + 2] = 0;
+ return;
+ }
+
+ sigmas += offset;
+ rgbs += offset * 3;
+ ambient += offset;
+ deltas += offset * 2;
+
+ // accumulate
+ uint32_t step = 0;
+
+ scalar_t T = 1.0f;
+ scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0;
+
+ while (step < num_steps) {
+
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
+ const scalar_t weight = alpha * T;
+
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+
+ d += weight * deltas[1];
+
+ ws += weight;
+
+ amb += weight * ambient[0];
+
+ T *= 1.0f - alpha;
+
+ // minimal remained transmittence
+ if (T < T_thresh) break;
+
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
+
+ // locate
+ sigmas++;
+ rgbs += 3;
+ ambient++;
+ deltas += 2;
+
+ step++;
+ }
+
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
+
+ // write
+ weights_sum[index] = ws; // weights_sum
+ ambient_sum[index] = amb;
+ depth[index] = d;
+ image[index * 3] = r;
+ image[index * 3 + 1] = g;
+ image[index * 3 + 2] = b;
+}
+
+
+void composite_rays_train_sigma_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ sigmas.scalar_type(), "composite_rays_train_sigma_forward", ([&] {
+ kernel_composite_rays_train_sigma_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), ambient_sum.data_ptr(), depth.data_ptr(), image.data_ptr());
+ }));
+}
+
+
+// grad_weights_sum: [N,]
+// grad: [N, 3]
+// sigmas: [M]
+// rgbs: [M, 3]
+// deltas: [M, 2]
+// rays: [N, 3], idx, offset, num_steps
+// weights_sum: [N,], weights_sum here
+// image: [N, 3]
+// grad_sigmas: [M]
+// grad_rgbs: [M, 3]
+template
+__global__ void kernel_composite_rays_train_sigma_backward(
+ const scalar_t * __restrict__ grad_weights_sum,
+ const scalar_t * __restrict__ grad_ambient_sum,
+ const scalar_t * __restrict__ grad_image,
+ const scalar_t * __restrict__ sigmas,
+ const scalar_t * __restrict__ rgbs,
+ const scalar_t * __restrict__ ambient,
+ const scalar_t * __restrict__ deltas,
+ const int * __restrict__ rays,
+ const scalar_t * __restrict__ weights_sum,
+ const scalar_t * __restrict__ ambient_sum,
+ const scalar_t * __restrict__ image,
+ const uint32_t M, const uint32_t N, const float T_thresh,
+ scalar_t * grad_sigmas,
+ scalar_t * grad_rgbs,
+ scalar_t * grad_ambient
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ uint32_t index = rays[n * 3];
+ uint32_t offset = rays[n * 3 + 1];
+ uint32_t num_steps = rays[n * 3 + 2];
+
+ if (num_steps == 0 || offset + num_steps > M) return;
+
+ grad_weights_sum += index;
+ grad_ambient_sum += index;
+ grad_image += index * 3;
+ weights_sum += index;
+ ambient_sum += index;
+ image += index * 3;
+
+ sigmas += offset;
+ rgbs += offset * 3;
+ ambient += offset;
+ deltas += offset * 2;
+
+ grad_sigmas += offset;
+ grad_rgbs += offset * 3;
+ grad_ambient += offset;
+
+ // accumulate
+ uint32_t step = 0;
+
+ scalar_t T = 1.0f;
+ const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], amb_final = ambient_sum[0];
+ scalar_t r = 0, g = 0, b = 0, ws = 0, amb = 0;
+
+ while (step < num_steps) {
+
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
+ const scalar_t weight = alpha * T;
+
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+ amb += weight * ambient[0];
+ ws += weight;
+
+ T *= 1.0f - alpha;
+
+ // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
+ // write grad_rgbs
+ grad_rgbs[0] = grad_image[0] * weight;
+ grad_rgbs[1] = grad_image[1] * weight;
+ grad_rgbs[2] = grad_image[2] * weight;
+
+ // write grad_ambient
+ grad_ambient[0] = grad_ambient_sum[0] * weight;
+
+ // write grad_sigmas
+ grad_sigmas[0] = deltas[0] * (
+ grad_image[0] * (T * rgbs[0] - (r_final - r)) +
+ grad_image[1] * (T * rgbs[1] - (g_final - g)) +
+ grad_image[2] * (T * rgbs[2] - (b_final - b)) +
+ grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) +
+ grad_weights_sum[0] * (1 - ws_final)
+ );
+
+ //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
+ // minimal remained transmittence
+ if (T < T_thresh) break;
+
+ // locate
+ sigmas++;
+ rgbs += 3;
+ ambient++;
+ deltas += 2;
+ grad_sigmas++;
+ grad_rgbs += 3;
+ grad_ambient++;
+
+ step++;
+ }
+}
+
+
+void composite_rays_train_sigma_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad_image.scalar_type(), "composite_rays_train_sigma_backward", ([&] {
+ kernel_composite_rays_train_sigma_backward<<>>(grad_weights_sum.data_ptr(), grad_ambient_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), ambient_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr(), grad_ambient.data_ptr());
+ }));
+}
+
+
+////////////////////////////////////////////////////
+///////////// infernce /////////////
+////////////////////////////////////////////////////
+
+
+template
+__global__ void kernel_composite_rays_ambient_sigma(
+ const uint32_t n_alive,
+ const uint32_t n_step,
+ const float T_thresh,
+ int* rays_alive,
+ scalar_t* rays_t,
+ const scalar_t* __restrict__ sigmas,
+ const scalar_t* __restrict__ rgbs,
+ const scalar_t* __restrict__ deltas,
+ const scalar_t* __restrict__ ambients,
+ scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* ambient_sum
+) {
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= n_alive) return;
+
+ const int index = rays_alive[n]; // ray id
+
+ // locate
+ sigmas += n * n_step;
+ rgbs += n * n_step * 3;
+ deltas += n * n_step * 2;
+ ambients += n * n_step;
+
+ rays_t += index;
+ weights_sum += index;
+ depth += index;
+ image += index * 3;
+ ambient_sum += index;
+
+ scalar_t t = rays_t[0]; // current ray's t
+
+ scalar_t weight_sum = weights_sum[0];
+ scalar_t d = depth[0];
+ scalar_t r = image[0];
+ scalar_t g = image[1];
+ scalar_t b = image[2];
+ scalar_t a = ambient_sum[0];
+
+ // accumulate
+ uint32_t step = 0;
+ while (step < n_step) {
+
+ // ray is terminated if delta == 0
+ if (deltas[0] == 0) break;
+
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
+
+ /*
+ T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
+ w_i = alpha_i * T_i
+ -->
+ T_i = 1 - \sum_{j=0}^{i-1} w_j
+ */
+ const scalar_t T = 1 - weight_sum;
+ const scalar_t weight = alpha * T;
+ weight_sum += weight;
+
+ t = deltas[1];
+ d += weight * t;
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+ a += weight * ambients[0];
+
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
+
+ // ray is terminated if T is too small
+ // use a larger bound to further accelerate inference
+ if (T < T_thresh) break;
+
+ // locate
+ sigmas++;
+ rgbs += 3;
+ deltas += 2;
+ step++;
+ ambients++;
+ }
+
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
+
+ // rays_alive = -1 means ray is terminated early.
+ if (step < n_step) {
+ rays_alive[n] = -1;
+ } else {
+ rays_t[0] = t;
+ }
+
+ weights_sum[0] = weight_sum; // this is the thing I needed!
+ depth[0] = d;
+ image[0] = r;
+ image[1] = g;
+ image[2] = b;
+ ambient_sum[0] = a;
+}
+
+
+void composite_rays_ambient_sigma(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum) {
+ static constexpr uint32_t N_THREAD = 128;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ image.scalar_type(), "composite_rays_ambient_sigma", ([&] {
+ kernel_composite_rays_ambient_sigma<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr