-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbasic_multi_model.py
More file actions
22 lines (19 loc) · 870 Bytes
/
basic_multi_model.py
File metadata and controls
22 lines (19 loc) · 870 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import multiprocessing
from src.networks.simple_network import SimpleActorNetwork, SimpleCriticNetwork, SimpleDDQNNetwork
from src.user_interface.ui import GameUI
from src.games.space_invaders_large import SpaceInvadersLarge
from src.algorithms.PPO import PPO
from src.algorithms.DDQN import DDQN
from src.algorithms.MultiModelPPO import MultiModelPPO
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
############ PPO Train
print('Starting Training')
ppo = MultiModelPPO()
print('Training PPO')
print('######################')
ppo.train(SpaceInvadersLarge, SimpleActorNetwork, SimpleCriticNetwork,
save_location = f'{os.environ["PBS_O_WORKDIR"]}/models/ppo_multi_simple',
stats_location= f'{os.environ["PBS_O_WORKDIR"]}/models/ppo_multi_simple_stats')