-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpredict_unified.py
More file actions
84 lines (68 loc) · 2.78 KB
/
predict_unified.py
File metadata and controls
84 lines (68 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import os
from os.path import join
import pandas as pd
from training import predict
from feature_utils import conbine_features, combine_ssim, combine_msssim
from train_model import get_features as get_vmaf_features
from feature_utils import get_ssim_features, get_msssim_features
def remove_extensions(file_name):
"""
Remove file extensions and _4k suffix from file names
"""
if file_name.endswith('.yuv'):
return file_name[:-4]
elif file_name.endswith('.mp4'):
return file_name[:-4]
elif file_name.endswith('_4k'):
return file_name[:-3]
else:
return file_name
def main():
argparser = argparse.ArgumentParser(
description='Unified prediction tool for video quality assessment')
argparser.add_argument('feature_path', type=str,
help='Path to the folder containing the features')
argparser.add_argument('output_name', type=str,
help='Output name', default='predict.csv')
argparser.add_argument(
'--model', type=str,
help='Which model to use. Options: VMAF, SSIM, MSSSIM',
default='VMAF')
args = argparser.parse_args()
feats_pth = args.feature_path
output_name = args.output_name
model_type = args.model.upper()
# Get features based on model type
if model_type == 'VMAF':
# Get VMAF features
feature, nonlinear_features = get_vmaf_features(feats_pth)
features = feature.merge(nonlinear_features, on='video')
svr_path = 'models/svr/model_svr_livehdr.pkl'
scaler_path = 'models/scaler/model_scaler_livehdr.pkl'
print("VMAF model used")
elif model_type == 'SSIM':
# Get SSIM features
ssim_features, nonlinear_features = get_ssim_features(feats_pth)
features = ssim_features.merge(nonlinear_features, on='video')
svr_path = 'models/svr/ssim_svr.pkl'
scaler_path = 'models/scaler/ssim_scaler.pkl'
print("SSIM model used")
elif model_type == 'MSSSIM':
# Get MS-SSIM features
msssim_features, nonlinear_features = get_msssim_features(feats_pth)
features = msssim_features.merge(nonlinear_features, on='video')
svr_path = 'models/svr/msssim_svr.pkl'
scaler_path = 'models/scaler/msssim_scaler.pkl'
print("MS-SSIM model used")
else:
raise ValueError(f"Unknown model type: {model_type}. Supported types: VMAF, SSIM, MSSSIM")
# Convert column names to strings to avoid type issues
features.columns = features.columns.astype(str)
# Predict quality scores
res = predict(features, svr_path, scaler_path)
# Save results
res.to_csv(output_name)
print(f"Results saved to {output_name}")
if __name__ == "__main__":
main()