-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
84 lines (57 loc) · 2.33 KB
/
agent.py
File metadata and controls
84 lines (57 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import json
import os
class ReinforceAgent:
def __init__(self,num_policies, lr=0.2, baseline_decay = 0.9, policies_path='policy_probabilities.json'):
self.num_policies = num_policies
self.policies_path = policies_path
with open(self.policies_path, 'r') as f:
self.probs = np.array(json.load(f))
self.lr = lr
self.baseline = 0.0
self.baseline_decay = baseline_decay
self.first = True
def select_actions(self, all_policies):
actions = []
selected = []
for i, policy in enumerate(all_policies):
if np.random.rand() < self.probs[i]:
selected.append(policy)
actions.append(1)
else:
actions.append(0)
if not selected:
idx = np.random.randint(len(all_policies))
selected = [all_policies[idx]]
actions[idx] = 1
return selected, actions
def update(self, selected_policies, reward, selected_actions):
if self.first == True:
self.last_actions = np.array(selected_actions)
self.first = False
return
self.baseline = (self.baseline_decay * self.baseline +(1-self.baseline_decay)*reward)
advantage = reward - self.baseline
grad = self.last_actions - self.probs
self.probs += self.lr * advantage * grad
self.probs = np.clip(self.probs, 0.01, 0.99)
self.last_actions = np.array(selected_actions)
def save_policies(self, current_test):
with open("policies_probs_" + str(current_test) + ".json", "w") as f:
json.dump(self.probs.tolist(), f)
def save_history(self, reward, number_of_traces, path="historico.json"):
historico = []
if os.path.exists(path):
with open(path, "r") as f:
try:
content = f.read().strip()
if content:
historico = json.loads(content)
except json.JSONDecodeError:
historico = []
historico.append({
"reward": reward,
"number_of_traces": number_of_traces,
})
with open(path, "w") as f:
json.dump(historico, f, indent=2)