import types
import math
from torch._six import inf
from functools import partial, wraps
import warnings
from bisect import bisect_right
from .optimizer import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
last_epoch = 0
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.last_epoch = last_epoch
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(func, opt):
@wraps(func)
def wrapper(*args, **kwargs):
opt._step_count += 1
return func(*args, **kwargs)
wrapper._with_counter = True
return wrapper
self.optimizer.step = with_counter(self.optimizer.step, self.optimizer)
self.optimizer._step_count = 0
self._step_count = 0
self.step(last_epoch)
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule."
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
self._step_count += 1
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
[docs]class LambdaLR(_LRScheduler):
"""Sets the learning rate of each parameter group to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # Assuming optimizer has two groups.
>>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95 ** epoch
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
self.optimizer = optimizer
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
if len(lr_lambda) != len(optimizer.param_groups):
raise ValueError("Expected {} lr_lambdas, but got {}".format(
len(optimizer.param_groups), len(lr_lambda)))
self.lr_lambdas = list(lr_lambda)
self.last_epoch = last_epoch
super(LambdaLR, self).__init__(optimizer, last_epoch)
[docs] def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
"""
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType):
state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
return state_dict
[docs] def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
lr_lambdas = state_dict.pop('lr_lambdas')
self.__dict__.update(state_dict)
for idx, fn in enumerate(lr_lambdas):
if fn is not None:
self.lr_lambdas[idx].__dict__.update(fn)
def get_lr(self):
return [base_lr * lmbda(self.last_epoch)
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
[docs]class StepLR(_LRScheduler):
"""Sets the learning rate of each parameter group to the initial lr
decayed by gamma every step_size epochs. When last_epoch=-1, sets
initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
step_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 60
>>> # lr = 0.0005 if 60 <= epoch < 90
>>> # ...
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
self.step_size = step_size
self.gamma = gamma
super(StepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs]
[docs]class MultiStepLR(_LRScheduler):
"""Set the learning rate of each parameter group to the initial lr decayed
by gamma once the number of epoch reaches one of the milestones. When
last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
if not list(milestones) == sorted(milestones):
raise ValueError('Milestones should be a list of'
' increasing integers. Got {}', milestones)
self.milestones = milestones
self.gamma = gamma
super(MultiStepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs]
[docs]class ExponentialLR(_LRScheduler):
"""Set the learning rate of each parameter group to the initial lr decayed
by gamma every epoch. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay.
last_epoch (int): The index of last epoch. Default: -1.
"""
def __init__(self, optimizer, gamma, last_epoch=-1):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [base_lr * self.gamma ** self.last_epoch
for base_lr in self.base_lrs]
[docs]class CosineAnnealingLR(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
\cos(\frac{T_{cur}}{T_{max}}\pi))
When last_epoch=-1, sets initial lr as lr.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
for base_lr in self.base_lrs]
[docs]class ReduceLROnPlateau(object):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This scheduler reads a metrics
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced.
Args:
optimizer (Optimizer): Wrapped optimizer.
mode (str): One of `min`, `max`. In `min` mode, lr will
be reduced when the quantity monitored has stopped
decreasing; in `max` mode it will be reduced when the
quantity monitored has stopped increasing. Default: 'min'.
factor (float): Factor by which the learning rate will be
reduced. new_lr = lr * factor. Default: 0.1.
patience (int): Number of epochs with no improvement after
which learning rate will be reduced. For example, if
`patience = 2`, then we will ignore the first 2 epochs
with no improvement, and will only decrease the LR after the
3rd epoch if the loss still hasn't improved then.
Default: 10.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
dynamic_threshold = best * ( 1 + threshold ) in 'max'
mode or best * ( 1 - threshold ) in `min` mode.
In `abs` mode, dynamic_threshold = best + threshold in
`max` mode or best - threshold in `min` mode. Default: 'rel'.
cooldown (int): Number of epochs to wait before resuming
normal operation after lr has been reduced. Default: 0.
min_lr (float or list): A scalar or a list of scalars. A
lower bound on the learning rate of all param groups
or each group respectively. Default: 0.
eps (float): Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
>>> for epoch in range(10):
>>> train(...)
>>> val_loss = validate(...)
>>> # Note that step should be called after validate()
>>> scheduler.step(val_loss)
"""
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
verbose=False, threshold=1e-4, threshold_mode='rel',
cooldown=0, min_lr=0, eps=1e-8):
if factor >= 1.0:
raise ValueError('Factor should be < 1.0.')
self.factor = factor
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
if isinstance(min_lr, list) or isinstance(min_lr, tuple):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError("expected {} min_lrs, got {}".format(
len(optimizer.param_groups), len(min_lr)))
self.min_lrs = list(min_lr)
else:
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
self.cooldown_counter = 0
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
self.best = None
self.num_bad_epochs = None
self.mode_worse = None # the worse value for the chosen mode
self.is_better = None
self.eps = eps
self.last_epoch = -1
self._init_is_better(mode=mode, threshold=threshold,
threshold_mode=threshold_mode)
self._reset()
def _reset(self):
"""Resets num_bad_epochs counter and cooldown counter."""
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_epochs = 0
def step(self, metrics, epoch=None):
# convert `metrics` to float, in case it's a zero-dim Tensor
current = float(metrics)
if epoch is None:
epoch = self.last_epoch = self.last_epoch + 1
self.last_epoch = epoch
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
def _reduce_lr(self, epoch):
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group['lr'])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
if old_lr - new_lr > self.eps:
param_group['lr'] = new_lr
if self.verbose:
print('Epoch {:5d}: reducing learning rate'
' of group {} to {:.4e}.'.format(epoch, i, new_lr))
@property
def in_cooldown(self):
return self.cooldown_counter > 0
def _cmp(self, mode, threshold_mode, threshold, a, best):
if mode == 'min' and threshold_mode == 'rel':
rel_epsilon = 1. - threshold
return a < best * rel_epsilon
elif mode == 'min' and threshold_mode == 'abs':
return a < best - threshold
elif mode == 'max' and threshold_mode == 'rel':
rel_epsilon = threshold + 1.
return a > best * rel_epsilon
else: # mode == 'max' and epsilon_mode == 'abs':
return a > best + threshold
def _init_is_better(self, mode, threshold, threshold_mode):
if mode not in {'min', 'max'}:
raise ValueError('mode ' + mode + ' is unknown!')
if threshold_mode not in {'rel', 'abs'}:
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
if mode == 'min':
self.mode_worse = inf
else: # mode == 'max':
self.mode_worse = -inf
self.is_better = partial(self._cmp, mode, threshold_mode, threshold)
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
[docs]class CyclicLR(_LRScheduler):
"""Sets the learning rate of each parameter group according to
cyclical learning rate policy (CLR). The policy cycles the learning
rate between two boundaries with a constant frequency, as detailed in
the paper `Cyclical Learning Rates for Training Neural Networks`_.
The distance between the two boundaries can be scaled on a per-iteration
or per-cycle basis.
Cyclical learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This class has three built-in policies, as put forth in the paper:
"triangular":
A basic triangular cycle w/ no amplitude scaling.
"triangular2":
A basic triangular cycle that scales initial amplitude by half each cycle.
"exp_range":
A cycle that scales initial amplitude by gamma**(cycle iterations) at each
cycle iteration.
This implementation was adapted from the github repo: `bckenstler/CLR`_
Args:
optimizer (Optimizer): Wrapped optimizer.
base_lr (float or list): Initial learning rate which is the
lower boundary in the cycle for each parameter group.
max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_lr - base_lr).
The lr at any cycle is the sum of base_lr
and some scaling of the amplitude; therefore
max_lr may not actually be reached depending on
scaling function.
step_size_up (int): Number of training iterations in the
increasing half of a cycle. Default: 2000
step_size_down (int): Number of training iterations in the
decreasing half of a cycle. If step_size_down is None,
it is set to step_size_up. Default: None
mode (str): One of {triangular, triangular2, exp_range}.
Values correspond to policies detailed above.
If scale_fn is not None, this argument is ignored.
Default: 'triangular'
gamma (float): Constant in 'exp_range' scaling function:
gamma**(cycle iterations)
Default: 1.0
scale_fn (function): Custom scaling policy defined by a single
argument lambda function, where
0 <= scale_fn(x) <= 1 for all x >= 0.
If specified, then 'mode' is ignored.
Default: None
scale_mode (str): {'cycle', 'iterations'}.
Defines whether scale_fn is evaluated on
cycle number or cycle iterations (training
iterations since start of cycle).
Default: 'cycle'
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'base_momentum' and 'max_momentum'.
Default: True
base_momentum (float or list): Lower momentum boundaries in the cycle
for each parameter group. Note that momentum is cycled inversely
to learning rate; at the peak of a cycle, momentum is
'base_momentum' and learning rate is 'max_lr'.
Default: 0.8
max_momentum (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_momentum - base_momentum).
The momentum at any cycle is the difference of max_momentum
and some scaling of the amplitude; therefore
base_momentum may not actually be reached depending on
scaling function. Note that momentum is cycled inversely
to learning rate; at the start of a cycle, momentum is 'max_momentum'
and learning rate is 'base_lr'
Default: 0.9
last_epoch (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
"""
def __init__(self,
optimizer,
base_lr,
max_lr,
step_size_up=2000,
step_size_down=None,
mode='triangular',
gamma=1.,
scale_fn=None,
scale_mode='cycle',
cycle_momentum=True,
base_momentum=0.8,
max_momentum=0.9,
last_epoch=-1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
base_lrs = self._format_param('base_lr', optimizer, base_lr)
if last_epoch == -1:
for lr, group in zip(base_lrs, optimizer.param_groups):
group['lr'] = lr
self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
step_size_up = float(step_size_up)
step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
self.total_size = step_size_up + step_size_down
self.step_ratio = step_size_up / self.total_size
if mode not in ['triangular', 'triangular2', 'exp_range'] \
and scale_fn is None:
raise ValueError('mode is invalid and scale_fn is None')
self.mode = mode
self.gamma = gamma
if scale_fn is None:
if self.mode == 'triangular':
self.scale_fn = self._triangular_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'triangular2':
self.scale_fn = self._triangular2_scale_fn
self.scale_mode = 'cycle'
elif self.mode == 'exp_range':
self.scale_fn = self._exp_range_scale_fn
self.scale_mode = 'iterations'
else:
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.cycle_momentum = cycle_momentum
if cycle_momentum:
if 'momentum' not in optimizer.defaults:
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
if last_epoch == -1:
for momentum, group in zip(base_momentums, optimizer.param_groups):
group['momentum'] = momentum
self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
super(CyclicLR, self).__init__(optimizer, last_epoch)
def _format_param(self, name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError("expected {} values for {}, got {}".format(
len(optimizer.param_groups), name, len(param)))
return param
else:
return [param] * len(optimizer.param_groups)
def _triangular_scale_fn(self, x):
return 1.
def _triangular2_scale_fn(self, x):
return 1 / (2. ** (x - 1))
def _exp_range_scale_fn(self, x):
return self.gamma**(x)
[docs] def get_lr(self):
"""Calculates the learning rate at batch index. This function treats
`self.last_epoch` as the last batch index.
If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
cycle = math.floor(1 + self.last_epoch / self.total_size)
x = 1. + self.last_epoch / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)
lrs = []
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
base_height = (max_lr - base_lr) * scale_factor
if self.scale_mode == 'cycle':
lr = base_lr + base_height * self.scale_fn(cycle)
else:
lr = base_lr + base_height * self.scale_fn(self.last_epoch)
lrs.append(lr)
if self.cycle_momentum:
momentums = []
for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
base_height = (max_momentum - base_momentum) * scale_factor
if self.scale_mode == 'cycle':
momentum = max_momentum - base_height * self.scale_fn(cycle)
else:
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
momentums.append(momentum)
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['momentum'] = momentum
return lrs
class CosineAnnealingWarmRestarts(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
is the number of epochs since the last restart and :math:`T_{i}` is the number
of epochs between two warm restarts in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
\cos(\frac{T_{cur}}{T_{i}}\pi))
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_0 (int): Number of iterations for the first restart.
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
eta_min (float, optional): Minimum learning rate. Default: 0.
last_epoch (int, optional): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
self.T_0 = T_0
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
self.T_cur = self.last_epoch
def get_lr(self):
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
"""Step could be called after every batch update
Example:
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> iters = len(dataloader)
>>> for epoch in range(20):
>>> for i, sample in enumerate(dataloader):
>>> inputs, labels = sample['inputs'], sample['labels']
>>> scheduler.step(epoch + i / iters)
>>> optimizer.zero_grad()
>>> outputs = net(inputs)
>>> loss = criterion(outputs, labels)
>>> loss.backward()
>>> optimizer.step()
This function can be called in an interleaved way.
Example:
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> for epoch in range(20):
>>> scheduler.step()
>>> scheduler.step(26)
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
"""
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.T_cur = self.T_cur - self.T_i
self.T_i = self.T_i * self.T_mult
else:
if epoch < 0:
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr