Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions adastop/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@ def adastop(ctx):
@click.option("-N", "--size-group", default=5,type=int, show_default=True, help="Number of groups.")
@click.option("-B", "--n-permutations", default=10000,type=int, show_default=True, help="Number of random permutations.")
@click.option("--alpha", default=0.05,type=float, show_default=True, help="Type I error.")
@click.option("--beta", default=0.0,type=float, show_default=True, help="early accept parameter.")
@click.option("--seed", default=None, type=int, show_default=True, help="Random seed.")
@click.option("--compare-to-first", is_flag=True, show_default=True, default=False, help="Compare all algorithms to the first algorithm.")
@click.argument('input_file',required = True, type=str)
@click.pass_context
def compare(ctx, input_file, n_groups, size_group, n_permutations, alpha, beta, seed, compare_to_first):
def compare(ctx, input_file, n_groups, size_group, n_permutations, alpha, seed, compare_to_first):
"""
Perform one step of adaptive stopping algorithm using csv file intput_file.
The csv file must be of size `size_group`.
Expand Down Expand Up @@ -69,7 +68,7 @@ def compare(ctx, input_file, n_groups, size_group, n_permutations, alpha, beta,
else:
comparator = MultipleAgentsComparator(n_fits_per_group, n_groups,
n_permutations, comparisons,
alpha, beta, seed)
alpha, seed)
names = df.columns

Z = [df[agent].values for agent in names]
Expand Down
31 changes: 2 additions & 29 deletions adastop/compare_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ class MultipleAgentsComparator:
alpha: float, default=0.01
level of the test

beta: float, default=0
power spent in early accept.

seed: int or None, default = None

Attributes
Expand Down Expand Up @@ -84,19 +81,16 @@ def __init__(
B=10000,
comparisons = None,
alpha=0.01,
beta=0,
seed=None,
):
self.n = n
self.K = K
self.B = B
self.alpha = alpha
self.beta = beta
self.comparisons = comparisons
self.boundary = []
self.k = 0
self.level_spent = 0
self.power_spent = 0
self.seed = seed
self.rng = np.random.RandomState(seed)
self.rejected_decision = []
Expand Down Expand Up @@ -145,7 +139,7 @@ def compute_mean_diffs(self, k, Z):
zval = []
for comp in self.mean_diffs:
zval.append(self.mean_diffs[str(comp)][i])
if np.max(np.abs(zval)) <= boundary[-1][1]:
if np.max(np.abs(zval)) <= boundary[-1]:
for comp in self.mean_diffs:
mean_diffs[str(comp)].append(self.mean_diffs[str(comp)][i])

Expand Down Expand Up @@ -235,7 +229,6 @@ def partial_compare(self, eval_values, verbose=True):
k = self.k

clevel = self.alpha*(k + 1) / self.K
dlevel = self.beta*(k + 1) / self.K

mean_diffs = self.compute_mean_diffs(k, Z)

Expand Down Expand Up @@ -270,20 +263,6 @@ def partial_compare(self, eval_values, verbose=True):
bk_sup = np.inf
level_to_add = 0

cumulative_probas = np.arange(len(values)) / self.normalization # corresponds to P(T < t)
admissible_values_inf = values[
self.power_spent + cumulative_probas <= dlevel
]

if len(admissible_values_inf) > 0:
bk_inf = admissible_values_inf[-1] # the maximum admissible value
power_to_add = cumulative_probas[
self.power_spent + cumulative_probas <= dlevel
][-1]
else:
bk_inf = -np.inf
power_to_add = 0

# Test statistic, step-down
Tmax = 0
Tmin = np.inf
Expand Down Expand Up @@ -323,19 +302,13 @@ def partial_compare(self, eval_values, verbose=True):
self.decisions[str(self.current_comparisons[id_reject])] = "smaller"
if verbose:
print("reject")
elif Tmin < bk_inf:
id_accept = np.arange(len(current_decisions))[current_decisions == "continue"][imin]
current_decisions[id_accept] = "accept"
self.decisions[str(self.current_comparisons[id_accept])] = "equal"
else:
break



self.boundary.append((bk_inf, bk_sup))

self.boundary.append(bk_sup)
self.level_spent += level_to_add # level effectively used at this point
self.power_spent += power_to_add

if k == self.K - 1:
for c in self.comparisons:
Expand Down
27 changes: 4 additions & 23 deletions tests/test_compare_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
def test_partial_compare():
rng = np.random.RandomState(seed)
idxs = []
comparator = MultipleAgentsComparator(n=3, K=3, B=B, alpha=alpha, seed=42, beta = 0.01)
comparator = MultipleAgentsComparator(n=3, K=3, B=B, alpha=alpha, seed=42)
evals = {"Agent "+str(k): rng.normal(size=3) for k in range(3)}
comparator.partial_compare(evals)


def test_partial_compare_not_enough_points():
comparator = MultipleAgentsComparator(n=3, K=3, B=5000, alpha=-1e-5, seed=42, beta = 0.01)
comparator = MultipleAgentsComparator(n=3, K=3, B=5000, alpha=-1e-5, seed=42)
evals = {"Agent 1":np.array([0,0,0]),"Agent 2":np.array([0,0,0]),"Agent 3":np.array([0,0,0])}
comparator.partial_compare(evals)

Expand All @@ -29,7 +29,7 @@ def test_type1(K,n):
idxs = []
n_agents = 3
for M in range(n_runs):
comparator = MultipleAgentsComparator(n=n, K=K, B=B, alpha=alpha, seed=M, beta = 0.01)
comparator = MultipleAgentsComparator(n=n, K=K, B=B, alpha=alpha, seed=M)
evals = {}
while not comparator.is_finished:
if len(evals) >0:
Expand All @@ -42,34 +42,15 @@ def test_type1(K,n):
print(comparator.get_results())
assert np.mean(idxs) < 2*alpha + 1/4/(np.sqrt(n_runs)), "type 1 error seems to be too large."

@pytest.mark.parametrize("K,n", [(5,3), (3, 5), (1, 15)])
def test_type1_large_beta(K,n):
rng = np.random.RandomState(seed)

idxs = []
n_agents = 3
for M in range(n_runs):
comparator = MultipleAgentsComparator(n=n, K=K, B=B, alpha=alpha, seed=M, beta = 0.1)
evals = {}
while not comparator.is_finished:
if len(evals) >0:
for k in range(n_agents):
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] , rng.normal(size=n)])
else:
evals = {"Agent "+str(k): rng.normal(size=n) for k in range(n_agents)}
comparator.partial_compare(evals)
idxs.append(not("equal" in comparator.decisions.values()))
print(comparator.get_results())
assert np.mean(idxs) < 2*alpha + 1/4/(np.sqrt(n_runs)), "type 1 error seems to be too large."

@pytest.mark.parametrize("K,n", [(3, 5), (1, 15)])
def test_type2(K,n):
rng = np.random.RandomState(seed)

idxs = []
n_agents = 2
for M in range(n_runs):
comparator = MultipleAgentsComparator(n=n, K=K, B=B, alpha=alpha, seed=M, beta = 0.01)
comparator = MultipleAgentsComparator(n=n, K=K, B=B, alpha=alpha, seed=M)
evals = {}
while not comparator.is_finished:
if len(evals) >0:
Expand Down
48 changes: 30 additions & 18 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@
n_runs = 10
K = 5
n = 4
seed = 42

def test_plot():
n_agents = 3
comparator = MultipleAgentsComparator(n=n, K=K, B=B, alpha=alpha, seed=42, beta = 0.01)
comparator = MultipleAgentsComparator(n=n, K=K, B=B, alpha=alpha, seed=42)
evals = {}
rng = np.random.RandomState(seed)
while not comparator.is_finished:
if len(evals) >0:
for k in range(n_agents):
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] ,np.random.normal(size=n)])
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] ,rng.normal(size=n)])
else:
evals = {"Agent "+str(k): np.random.normal(size=n) for k in range(n_agents)}
evals = {"Agent "+str(k): rng.normal(size=n) for k in range(n_agents)}
comparator.partial_compare(evals)
comparator.plot_results()

Expand All @@ -28,16 +30,18 @@ def test_plot():

def test_plot_sota():
n_agents = 3
rng = np.random.RandomState(seed)

comparisons = np.array([(0,i) for i in [1,2]])
comparator = MultipleAgentsComparator(n=n, K=K, B=B, alpha=alpha,
comparisons=comparisons, seed=42, beta = 0.01)
comparisons=comparisons, seed=42)
evals = {}
while not comparator.is_finished:
if len(evals) >0:
for k in range(n_agents):
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] ,np.random.normal(size=n)])
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] ,rng.normal(size=n)])
else:
evals = {"Agent "+str(k): np.random.normal(size=n) for k in range(n_agents)}
evals = {"Agent "+str(k): rng.normal(size=n) for k in range(n_agents)}
comparator.partial_compare(evals)
comparator.plot_results_sota()
# plt.savefig('fig2.pdf')
Expand All @@ -46,31 +50,35 @@ def test_plot_sota():

def test_plot_noteq():
n_agents = 3
comparator = MultipleAgentsComparator(n=10, K=K, B=B, alpha=alpha, seed=42, beta = 0.01)
comparator = MultipleAgentsComparator(n=10, K=K, B=B, alpha=alpha, seed=42)
rng = np.random.RandomState(seed)

evals = {}
while not comparator.is_finished:
if len(evals) >0:
for k in range(n_agents):
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] , k+np.random.normal(size=10)])
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] , k+ rng.normal(size=10)])
else:
evals = {"Agent "+str(k): np.random.normal(size=10)+k for k in range(n_agents)}
evals = {"Agent "+str(k): rng.normal(size=10)+k for k in range(n_agents)}
comparator.partial_compare(evals)
# plt.savefig('fig2.pdf')
fig, axes= plt.subplots(1,2)
comparator.plot_results(axes=axes)

def test_plot_sota_noteq():
n_agents = 3
rng = np.random.RandomState(seed)

comparisons = np.array([(0,i) for i in [1,2]])
comparator = MultipleAgentsComparator(n=10, K=K, B=B, alpha=alpha,
comparisons=comparisons, seed=42, beta = 0.01)
comparisons=comparisons, seed=42)
evals = {}
while not comparator.is_finished:
if len(evals) >0:
for k in range(n_agents):
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] ,np.random.normal(size=10)+k])
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] , rng.normal(size=10)+k])
else:
evals = {"Agent "+str(k): np.random.normal(size=10) for k in range(n_agents)}
evals = {"Agent "+str(k): rng.normal(size=10) for k in range(n_agents)}
comparator.partial_compare(evals)
comparator.plot_results_sota()
# plt.savefig('fig2.pdf')
Expand All @@ -81,14 +89,16 @@ def test_plot_sota_noteq():

def test_plot_noteq2():
n_agents = 3
comparator = MultipleAgentsComparator(n=10, K=K, B=B, alpha=alpha, seed=42, beta = 0.01)
comparator = MultipleAgentsComparator(n=10, K=K, B=B, alpha=alpha, seed=42)
rng = np.random.RandomState(seed)

evals = {}
while not comparator.is_finished:
if len(evals) >0:
for k in range(n_agents):
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] , np.abs(2*K-k)+np.random.normal(size=10)])
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] , np.abs(2*K-k)+ rng.normal(size=10)])
else:
evals = {"Agent "+str(k): np.random.normal(size=10)+np.abs(2*K-k) for k in range(n_agents)}
evals = {"Agent "+str(k): rng.normal(size=10)+np.abs(2*K-k) for k in range(n_agents)}
comparator.partial_compare(evals)
# plt.savefig('fig2.pdf')
fig, axes= plt.subplots(1,2)
Expand All @@ -98,14 +108,16 @@ def test_plot_sota_noteq2():
n_agents = 3
comparisons = np.array([(0,i) for i in [1,2]])
comparator = MultipleAgentsComparator(n=10, K=K, B=B, alpha=alpha,
comparisons=comparisons, seed=42, beta = 0.01)
comparisons=comparisons, seed=42)
rng = np.random.RandomState(seed)

evals = {}
while not comparator.is_finished:
if len(evals) >0:
for k in range(n_agents):
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] ,np.random.normal(size=10)+np.abs(2*K-k)])
evals["Agent "+str(k)] = np.hstack([evals["Agent "+str(k)] , rng.normal(size=10)+np.abs(2*K-k)])
else:
evals = {"Agent "+str(k): np.random.normal(size=10)+np.abs(2*K-k) for k in range(n_agents)}
evals = {"Agent "+str(k): rng.normal(size=10)+np.abs(2*K-k) for k in range(n_agents)}
comparator.partial_compare(evals)
comparator.plot_results_sota()
# plt.savefig('fig2.pdf')
Expand Down