Source code for actorcritic.kfac_utils

"""Contains utilities that concern K-FAC."""

import kfac
import tensorflow as tf


[docs]class ColdStartPeriodicInvUpdateKfacOpt(kfac.KfacOptimizer): """A modified :obj:`~kfac.KfacOptimizer` that runs the inverse operation periodically and uses a standard SGD optimizer for a few updates in the beginning, called `cold updates` and `cold optimizer`. This can be used to slowly initialize the parameters in the beginning before using the heavy K-FAC optimizer. The covariances get updated every step (after the `cold updates`). See Also: * :obj:`kfac.PeriodicInvCovUpdateKfacOpt` * The idea is taken from the `original ACKTR implementation <https://github.com/openai/baselines/blob/master/baselines/acktr/kfac.py>`_. """
[docs] def __init__(self, num_cold_updates, cold_optimizer, invert_every, **kwargs): """ Args: num_cold_updates (:obj:`int`): The number of `cold updates` in the beginning before using the actual K-FAC optimizer. cold_optimizer (:obj:`tf.train.Optimizer`): An optimizer that is used for the `cold updates`. invert_every (:obj:`int`): The inverse operation gets called every `invert_every` steps (after the `cold updates` have finished). """ self._num_cold_updates = num_cold_updates self._cold_optimizer = cold_optimizer self._invert_every = invert_every self._counter = None super().__init__(**kwargs)
[docs] def apply_gradients(self, grads_and_vars, global_step=None, name=None): cov_update_thunks, inv_update_thunks = self.make_vars_and_create_op_thunks() with tf.control_dependencies([global_step]): do_cold_or_cov_updates = tf.cond(tf.less(global_step, self._num_cold_updates), lambda: self._cold_optimizer.apply_gradients(grads_and_vars, global_step), lambda: tf.group([thunk() for thunk in cov_update_thunks])) with tf.control_dependencies([do_cold_or_cov_updates]): do_inv_updates = tf.cond(tf.logical_and(tf.greater(global_step, self._num_cold_updates), tf.equal(tf.mod(global_step - self._num_cold_updates, self._invert_every), 0)), lambda: tf.group([thunk() for thunk in inv_update_thunks]), tf.no_op) with tf.control_dependencies([do_inv_updates]): return super().apply_gradients(grads_and_vars=grads_and_vars, global_step=global_step, name=name)