Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ac559f3
Adding basic support for batch optimisation to the framework
javdrher Jul 30, 2017
52e37a8
Merge branch 'master' into batch_support
javdrher Jul 31, 2017
b1a77f0
Merge branch 'master' into batch_support
javdrher Aug 23, 2017
82642dd
Add missing axis parameters & introduced a batch function in domain
javdrher Aug 23, 2017
76b14e8
Put axis parameter in the wrong function
javdrher Aug 23, 2017
0af4ded
Merge branch 'master' into batch_support
javdrher Aug 23, 2017
648f0ed
Merge branch 'master' into batch_support
javdrher Sep 6, 2017
418e716
Implementation of ParallelBatchAcquisition
javdrher Sep 7, 2017
a0b6d0e
Making inverse acquisition available as a private method
javdrher Sep 7, 2017
226b9eb
Minor interface fix
javdrher Sep 7, 2017
c29c35b
Following the re-design, ParallelBatch is now the parent class.
javdrher Sep 7, 2017
8a786fe
Merge branch 'master' into batch_support
javdrher Sep 15, 2017
6bd2561
Major rework of acquisition framework. Introduced an interface class,…
javdrher Sep 16, 2017
ec477e3
Reworked some interfaces, started testing generating operators and re…
javdrher Sep 17, 2017
7c7649e
Cleaned up file and added some documentation
javdrher Sep 17, 2017
c9206ce
Restructuring exposed a bug
javdrher Sep 17, 2017
83c300d
Merge branch 'master' into batch_support
javdrher Sep 17, 2017
62ce738
Exluding abstract methods from the coverage report
javdrher Sep 18, 2017
e14f7f2
Further testing of operators
javdrher Sep 18, 2017
9a054c6
Merge branch 'master' into batch_support
javdrher Nov 9, 2017
ffbe499
Fixed abstract classes in python 3
javdrher Nov 9, 2017
4e74915
Inital qEI implementation (doesn't work yet)
gpfins Nov 9, 2017
7ac4a3d
Merge remote-tracking branch 'origin/batch_support' into batch_support
gpfins Nov 9, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion gpflowopt/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

# Framework components and interfaces
from .acquisition import Acquisition, AcquisitionAggregation, AcquisitionProduct, AcquisitionSum, MCMCAcquistion
from .acquisition import (Acquisition, ParallelBatchAcquisition, AcquisitionAggregation, AcquisitionProduct, \
AcquisitionSum, MCMCAcquistion, IAcquisition, ParToSeqAcquisitionWrapper, setup_required)

# Single objective
from .ei import ExpectedImprovement
Expand All @@ -24,5 +25,8 @@
# Multiobjective
from .hvpoi import HVProbabilityOfImprovement

# Batch
from .qei import qExpectedImprovement

# Black-box constraint
from .pof import ProbabilityOfFeasibility
627 changes: 444 additions & 183 deletions gpflowopt/acquisition/acquisition.py

Large diffs are not rendered by default.

85 changes: 85 additions & 0 deletions gpflowopt/acquisition/qei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2017 Joachim van der Herten
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .acquisition import ParallelBatchAcquisition

from gpflow.model import Model
from gpflow.param import DataHolder
from gpflow import settings

import numpy as np
import tensorflow as tf
import tensorflow.contrib.distributions as ds

stability = settings.numerics.jitter_level
float_type = settings.dtypes.float_type


class qExpectedImprovement(ParallelBatchAcquisition):
def __init__(self, model, batch_size=4):
"""
:param model: GPflow model (single output) representing our belief of the objective
"""
super(qExpectedImprovement, self).__init__(model, batch_size=batch_size)
self.fmin = DataHolder(np.zeros(1))
self._setup()

def _setup(self):
super(qExpectedImprovement, self)._setup()
# Obtain the lowest posterior mean for the previous - feasible - evaluations
feasible_samples = self.data[0][self.highest_parent.feasible_data_index(), :]
samples_mean, _ = self.models[0].predict_f(feasible_samples)
self.fmin.set_data(np.min(samples_mean, axis=0))

def build_acquisition(self, *args):
# Obtain predictive distributions for candidates
N, D = tf.shape(args[0])[0], tf.shape(args[0])[1]
q = self.batch_size

Xcand = tf.transpose(tf.stack(args, axis=0), perm=[1, 0, 2])
m, sig = tf.map_fn(lambda x: self.models[0].build_predict(x, full_cov=True), Xcand,
dtype=(float_type, float_type)) # N x q x 1, N x q x q

eye = tf.tile(tf.expand_dims(tf.eye(q, dtype=float_type), 0), [q, 1, 1])
A = eye
A = A - tf.transpose(eye, perm=[1, 2, 0])
A = A - tf.transpose(eye, perm=[2, 0, 1]) # q x q x q (k x q x q)

mk = tf.tensordot(A, m, [[2], [1]]) # N x q(k) x q Mean of Zk (k x q)
sigk = tf.tensordot(A, sig, [[2], [1]]) # N x q x q x q
sigk = tf.reduce_sum(tf.expand_dims(sigk, 3) * tf.expand_dims(tf.expand_dims(A, 0), 2), axis=-1)# N x q(k) x q x q

a = tf.tile(tf.expand_dims(tf.eye(self.batch_size, self.batch_size, dtype=float_type), 0), [self.batch_size, 1, 1])
A1 = tf.gather_nd(a, [[[i, j + (j >= i)] for j in range(self.batch_size)] for i in range(self.batch_size)])
# q(i) x (q-1) x q

bk = -self.fmin * tf.eye(q, dtype=float_type) # q(k) x q

Sigk_t = tf.expand_dims(tf.transpose(tf.tensordot(sigk, A1, [[-2], [2]]), [0, 1, 3, 4, 2]), -2) # N x q(k) x q(i) x (q-1) x 1 x q
Sigk = tf.reduce_sum(tf.expand_dims(tf.expand_dims(tf.expand_dims(A1, 0), 0), -3)* Sigk_t, axis=-1)
#Sigk = tf.einsum('ijklm,lnk->ijlmn', Sigk_t, A1) # N x q(k) x q(i) x q-1 x q-1
c = tf.tensordot(tf.expand_dims(bk, 0) - mk, A1, [[2], [2]])
c = tf.tensordot(tf.expand_dims(bk, 0) - mk, A1, [[2], [2]]) # N x q(k) x q(i) x q-1
F = tf.expand_dims(tf.expand_dims(tf.expand_dims(bk, 0) - mk, -1) / tf.matrix_diag_part(sigk),
-1) # N x q(k) x q(i) x 1
F *= tf.transpose(tf.squeeze(tf.matrix_diag_part(tf.transpose(Sigk_t, [0, 1, 3, 4, 5, 2])), 4), [0, 1, 3, 2])
c -= F

MVN = ds.MultivariateNormalFullCovariance(loc=mk, covariance_matrix=sigk)
MVN2 = ds.MultivariateNormalFullCovariance(loc=tf.zeros(tf.shape(c), dtype=float_type), covariance_matrix=Sigk)
UVN = ds.MultivariateNormalDiag(loc=mk, scale_diag=tf.sqrt(tf.matrix_diag_part(sigk)))
t1 = tf.reduce_sum((self.fmin - m) * MVN.cdf(bk), axis=1)
sigkk = tf.transpose(tf.matrix_diag_part(tf.transpose(sigk, perm=[0, 3, 1, 2])), perm=[0, 2, 1])
t2 = tf.reduce_sum(sigkk * UVN.pdf(-tf.expand_dims(bk, 0)) * MVN2.cdf(c), axis=[1, 2])
return tf.add(t1, t2, name=self.__class__.__name__)
17 changes: 9 additions & 8 deletions gpflowopt/bo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tensorflow as tf
from gpflow.gpr import GPR

from .acquisition import Acquisition, MCMCAcquistion
from .acquisition import IAcquisition, MCMCAcquistion
from .design import Design, EmptyDesign
from .objective import ObjectiveWrapper
from .optim import Optimizer, SciPyOptimizer
Expand Down Expand Up @@ -89,20 +89,24 @@ def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=Tr
:class:`~.Acquisition` this allows several scenarios: do the optimization manually from the callback
(optimize_restarts equals 0), or choose the starting point + some random restarts (optimize_restarts > 0).
"""
assert isinstance(acquisition, Acquisition)
assert isinstance(acquisition, IAcquisition)
assert hyper_draws is None or hyper_draws > 0
assert optimizer is None or isinstance(optimizer, Optimizer)
assert initial is None or isinstance(initial, Design)
super(BayesianOptimizer, self).__init__(domain, exclude_gradient=True)

# Configure MCMC
self._scaling = scaling
if self._scaling:
acquisition.enable_scaling(domain)

self.acquisition = acquisition if hyper_draws is None else MCMCAcquistion(acquisition, hyper_draws)

# Setup optimizer
self.optimizer = optimizer or SciPyOptimizer(domain)
self.optimizer.domain = domain

# Setup initial evaluations
initial = initial or EmptyDesign(domain)
self.set_initial(initial.generate())

Expand Down Expand Up @@ -225,17 +229,14 @@ def _optimize(self, fx, n_iter):
# Remove initial design for additional calls to optimize to proceed optimization
self.set_initial(EmptyDesign(self.domain).generate())

def inverse_acquisition(x):
return tuple(map(lambda r: -r, self.acquisition.evaluate_with_gradients(np.atleast_2d(x))))

# Optimization loop
for i in range(n_iter):
# If a callback is specified, and acquisition has the setup flag enabled (indicating an upcoming
# compilation), run the callback.
# setup), run the callback.
if self._model_callback and self.acquisition._needs_setup:
self._model_callback([m.wrapped for m in self.acquisition.models])
result = self.optimizer.optimize(inverse_acquisition)
self._update_model_data(result.x, fx(result.x))
Xnew = self.acquisition.get_suggestion(self.optimizer)
self._update_model_data(Xnew, fx(Xnew))

return self._create_bo_result(True, "OK")

Expand Down
4 changes: 4 additions & 0 deletions gpflowopt/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import copy
from itertools import chain
from gpflow.param import Parentable

Expand Down Expand Up @@ -53,6 +54,9 @@ def size(self):
"""
return sum(map(lambda param: param.size, self._parameters))

def batch(self, size):
return np.sum([copy.deepcopy(self) for i in range(size)])

def __setattr__(self, key, value):
super(Domain, self).__setattr__(key, value)
if key is not '_parent':
Expand Down
Loading