forked from dmu1981/mpt_tracking
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
151 lines (117 loc) · 4.3 KB
/
main.py
File metadata and controls
151 lines (117 loc) · 4.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import pickle
import dummy
from matplotlib import pyplot as plt
from replay import replay
from evaluate import evaluate
from config import filters
from scorestatistics import print_score_statistics
import numpy as np
import argparse
# No need to touch anything below this line
# ---------------------------------------------------
def run_mode(mode):
file = mode + ".pk"
# Make sure index is an integer
args.index = int(args.index)
# Load the timeseries data from a pickle file
with open(file, "rb") as f:
timeseries = pickle.loads(f.read())
# args.all is a boolean, if its True --> process all time series,
# initialize the RMSE scores and compute summirized RMSE & RMSE per run
# if args.all is False, check the index, process the given time series
# and calculate their RMSE.
if args.all:
scores = {}
for teams in filters.keys():
scores[teams] = {}
scores[teams]["rmse"] = 0.0
scores[teams]["rmse_per_run"] = []
for series in timeseries:
results = replay(filters, mode, series)
intermediate = evaluate(series["targets"], results)
for teams in filters.keys():
scores[teams]["rmse_per_run"].append(intermediate[teams]["rmse"])
scores[teams]["rmse"] += intermediate[teams]["rmse"]
# Turn scores_per_run into numpy array
for teams in filters.keys():
scores[teams]["rmse_per_run"] = np.array(scores[teams]["rmse_per_run"])
for teams in filters.keys():
scores[teams]["rmse"] /= len(timeseries)
ordered = print_score_statistics(scores, args, multiruns=True)
else:
if args.index >= len(timeseries) or args.index < 0:
print("Invalid index")
exit()
results = replay(filters, mode, timeseries[args.index])
scores = evaluate(timeseries[args.index]["targets"], results)
for teams in filters.keys():
rmse = scores[teams]["rmse"]
scores[teams]["rmse_per_run"] = [rmse]
ordered = print_score_statistics(scores, args, multiruns=False)
if args.debug and not args.all:
import visualize
visualize.draw(timeseries, args, filters, results, scores)
return scores, ordered
list_of_modes = {
"constantposition": 0,
"constantvelocity": 0,
"constantvelocity2": 0,
"constantturn": 0,
"randomnoise": 0,
"angular": 0,
"all": 0,
}
# Sanity check filters
for team in filters.keys():
res = filters[team]
if "color" not in res:
print(f"Team {team}: You must specify a color in config.py")
exit()
if type(res["color"]) != list or len(res["color"]) != 3:
print(f"Team {team}: Color must be a 3-element list like [1.0, 0.0, 0.0]")
exit()
for mode in list_of_modes.keys():
if mode not in res and mode != "all":
print(
f"Team {team}: You did not specify a filter for mode {mode}... replacing with Dummy Filter"
)
filters[team][mode] = dummy.DummyFilter(2)
# Create an argument parser
parser = argparse.ArgumentParser(description="MPT Replay Tool")
parser.add_argument("--mode", action="store")
parser.add_argument("--index", action="store", default=0)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--all", action="store_true")
# Parse command line
args = parser.parse_args()
# Is there a file given?
if args.mode is None:
print("Must specify mode")
exit()
if args.mode not in list_of_modes.keys():
print("Unknown mode: ", args.mode)
exit()
if args.mode == "all":
s = set()
for team in filters.keys():
for k in filters[team].keys():
if k != "color":
s.add(k)
modes = list(s)
scores = {}
for teams in filters.keys():
scores[teams] = {}
scores[teams]["rmse"] = 0.0
scores[teams]["rmse_per_run"] = [0]
for mode in modes:
print(mode)
score, ordered = run_mode(mode)
# Do a ranked statistic for the overall score
for index, (team, score) in enumerate(ordered):
scores[team]["rmse"] += index
print("")
print("overall")
print_score_statistics(scores, args, multiruns=False)
# print(scores)
else:
scores, ordered = run_mode(args.mode)