forked from eclipse-t2i/eclipse-inference
-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathutils.py
More file actions
80 lines (64 loc) · 2.77 KB
/
utils.py
File metadata and controls
80 lines (64 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch
import math
from PIL import Image
from typing import List
from torch import FloatTensor, LongTensor, Tensor, Size, lerp, zeros_like
from torch.linalg import norm
def get_image_grid(images: List[Image.Image], rows=6, cols=6) -> Image:
width, height = images[0].size
grid_image = Image.new("RGB", (cols * width, rows * height))
for i, img in enumerate(images):
x = i % cols
y = i // cols
grid_image.paste(img, (x * width, y * height))
return grid_image
def linear_interpolation(v1, v2, t=0.0):
v = v1 * (1 - t) + v2 * t
return v
def slerp(v0: FloatTensor, v1: FloatTensor, t: float, DOT_THRESHOLD=0.9995):
'''
Spherical linear interpolation
Args:
v0: Starting vector
v1: Final vector
t: Float value between 0.0 and 1.0
DOT_THRESHOLD: Threshold for considering the two vectors as
colinear. Not recommended to alter this.
Returns:
Interpolation vector between v0 and v1
'''
assert v0.shape == v1.shape, "shapes of v0 and v1 must match"
# Normalize the vectors to get the directions and angles
v0_norm: FloatTensor = norm(v0, dim=-1)
v1_norm: FloatTensor = norm(v1, dim=-1)
v0_normed: FloatTensor = v0 / v0_norm.unsqueeze(-1)
v1_normed: FloatTensor = v1 / v1_norm.unsqueeze(-1)
# Dot product with the normalized vectors
dot: FloatTensor = (v0_normed * v1_normed).sum(-1)
dot_mag: FloatTensor = dot.abs()
# if dp is NaN, it's because the v0 or v1 row was filled with 0s
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
gotta_lerp: LongTensor = dot_mag.isnan() | (dot_mag > DOT_THRESHOLD)
can_slerp: LongTensor = ~gotta_lerp
t_batch_dim_count: int = max(0, t.dim()-v0.dim()) if isinstance(t, Tensor) else 0
t_batch_dims: Size = t.shape[:t_batch_dim_count] if isinstance(t, Tensor) else Size([])
out: FloatTensor = zeros_like(v0.expand(*t_batch_dims, *[-1]*v0.dim()))
# if no elements are lerpable, our vectors become 0-dimensional, preventing broadcasting
if gotta_lerp.any():
lerped: FloatTensor = lerp(v0, v1, t)
out: FloatTensor = lerped.where(gotta_lerp.unsqueeze(-1), out)
# if no elements are slerpable, our vectors become 0-dimensional, preventing broadcasting
if can_slerp.any():
# Calculate initial angle between v0 and v1
theta_0: FloatTensor = dot.arccos().unsqueeze(-1)
sin_theta_0: FloatTensor = theta_0.sin()
# Angle at timestep t
theta_t: FloatTensor = theta_0 * t
sin_theta_t: FloatTensor = theta_t.sin()
# Finish the slerp algorithm
s0: FloatTensor = (theta_0 - theta_t).sin() / sin_theta_0
s1: FloatTensor = sin_theta_t / sin_theta_0
slerped: FloatTensor = s0 * v0 + s1 * v1
out: FloatTensor = slerped.where(can_slerp.unsqueeze(-1), out)
return out