-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
113 lines (86 loc) · 2.66 KB
/
inference.py
File metadata and controls
113 lines (86 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import time
import torch
import torch.nn as nn
import argparse
from gptq import *
from modelutils import *
from quant import *
import os
from utils.construct_tff import construct_real_tff
from utils.quant_utils import compute_tffs
from utils.eval_utils import llama_eval
from datetime import datetime
from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
from tqdm import tqdm
import math
from datautils import get_loaders
def load_llama_from_config(model, default_type=torch.half):
config = LlamaConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
if default_type is not None:
torch.set_default_dtype(default_type)
modeling_utils._init_weights = False
if default_type is not None:
torch.set_default_dtype(default_type)
model = LlamaForCausalLM(config)
return model
def load_quant(model, checkpoint, l_den, tff_redundancy, wbits, eval=True):
model = load_llama_from_config(model, default_type=torch.half)
torch.set_default_dtype(torch.float)
if eval:
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant3(model, layers, l_den, tff_redundancy, wbits)
del layers
print('Loading model ...')
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str,
help='path to the directory where model config file is stored'
)
parser.add_argument(
'saved_model', type=str,
help='path to the directory where saved model is stored'
)
parser.add_argument(
'--l_den', type=int, default=16,
help='denominator to be used for L_tff'
)
parser.add_argument(
'--tff_redundancy', type=float, default=1.0,
help='Redundancy in TFF representations'
)
parser.add_argument(
'--wbits', type=int, default=2,
help='#bits to use for quantization; use 16 for evaluating base model.'
)
parser.add_argument(
'--seed',
type=int, default=0, help='Seed for sampling the calibration data.'
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
ckpt_path = os.path.join(args.saved_model, 'packed_model.ckpt')
print('loading model ...')
model = load_quant(args.model, ckpt_path, args.l_den, args.tff_redundancy, args.wbits, eval=True)
datasets = ['wikitext2']
for dataset in datasets:
dataloader, testloader = get_loaders(
dataset, nsamples=2, seed=args.seed, model=args.model, seqlen=model.seqlen
)
print(dataset)
ppl, results = llama_eval(model, testloader, DEV, nsamples = 2, verbose=True)