Automatic Mixed Precision package - torch.cuda.amp¶
torch.cuda.amp
provides convenience methods for running networks with mixed precision,
where some operations use the torch.float32
(float
) datatype and other operations
use torch.float16
(half
). Some operations, like linear layers and convolutions,
are much faster in float16
. Other operations, like reductions, often require the dynamic
range of float32
. Networks running in mixed precision try to match each operation to its appropriate datatype.
Warning
torch.cuda.amp.GradScaler
is not a complete implementation of automatic mixed precision.
GradScaler
is only useful if you manually run regions of your model in float16
.
If you aren’t sure how to choose op precision manually, the master branch and nightly pip/conda
builds include a context manager that chooses op precision automatically wherever it’s enabled.
See the master documentation for details.
Gradient Scaling¶
When training a network with mixed precision, if the forward pass for a particular op has
torch.float16
inputs, the backward pass for that op will produce torch.float16
gradients.
Gradient values with small magnitudes may not be representable in torch.float16
.
These values will flush to zero (“underflow”), so the update for the corresponding parameters will be lost.
To prevent underflow, “gradient scaling” multiplies the network’s loss(es) by a scale factor and invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are then scaled by the same factor. In other words, gradient values have a larger magnitude, so they don’t flush to zero.
The parameters’ gradients (.grad
attributes) should be unscaled before the optimizer uses them
to update the parameters, so the scale factor does not interfere with the learning rate.
-
class
torch.cuda.amp.
GradScaler
(init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)[source]¶ An instance
scaler
ofGradScaler
helps perform the steps of gradient scaling conveniently.scaler.scale(loss)
multiplies a given loss byscaler
’s current scale factor.scaler.step(optimizer)
safely unscales gradients and callsoptimizer.step()
.scaler.update()
updatesscaler
’s scale factor.
Typical use:
# Creates a GradScaler once at the beginning of training. scaler = GradScaler() for epoch in epochs: for input, target in data: optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) # Scales the loss, and calls backward() on the scaled loss to create scaled gradients. scaler.scale(loss).backward() # scaler.step() first unscales the gradients of the optimizer's assigned params. # If these gradients do not contain infs or NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. scaler.step(optimizer) # Updates the scale for next iteration. scaler.update()
See the Gradient Scaling Examples for usage in more complex cases like gradient clipping, gradient penalty, and multiple losses/optimizers.
scaler
dynamically estimates the scale factor each iteration. To minimize gradient underflow, a large scale factor should be used. However,torch.float16
values can “overflow” (become inf or NaN) if the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used without incurring inf or NaN gradient values.scaler
approximates the optimal scale factor over time by checking the gradients for infs and NaNs during everyscaler.step(optimizer)
(or optional separatescaler.unscale_(optimizer)
, seeunscale_()
).If infs/NaNs are found,
scaler.step(optimizer)
skips the underlyingoptimizer.step()
(so the params themselves remain uncorrupted) andupdate()
multiplies the scale bybackoff_factor
.If no infs/NaNs are found,
scaler.step(optimizer)
runs the underlyingoptimizer.step()
as usual. Ifgrowth_interval
unskipped iterations occur consecutively,update()
multiplies the scale bygrowth_factor
.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its value calibrates.
scaler.step
will skip the underlyingoptimizer.step()
for these iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).- Parameters
init_scale (float, optional, default=2.**16) – Initial scale factor.
growth_factor (float, optional, default=2.0) – Factor by which the scale is multiplied during
update()
if no inf/NaN gradients occur forgrowth_factor
consecutive iterations.backoff_factor (float, optional, default=0.5) – Factor by which the scale is multiplied during
update()
if inf/NaN gradients occur in an iteration.growth_interval (int, optional, default=2000) – Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by
growth_factor
.enabled (bool, optional, default=True) – If
False
, disables gradient scaling.step()
simply invokes the underlyingoptimizer.step()
, and other methods become no-ops.
-
get_scale
()[source]¶ Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
Warning
get_scale()
incurs a CPU-GPU sync.
-
load_state_dict
(state_dict)[source]¶ Loads the scaler state. If this instance is disabled,
load_state_dict()
is a no-op.- Parameters
state_dict (dict) – scaler state. Should be an object returned from a call to
state_dict()
.
-
scale
(outputs)[source]¶ Multiplies (‘scales’) a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of
GradScaler
is not enabled, outputs are returned unmodified.- Parameters
outputs (Tensor or iterable of Tensors) – Outputs to scale.
-
set_backoff_factor
(new_factor)[source]¶ - Parameters
new_scale (float) – Value to use as the new scale backoff factor.
-
set_growth_factor
(new_factor)[source]¶ - Parameters
new_scale (float) – Value to use as the new scale growth factor.
-
set_growth_interval
(new_interval)[source]¶ - Parameters
new_interval (int) – Value to use as the new growth interval.
-
state_dict
()[source]¶ Returns the state of the scaler as a
dict
. It contains five entries:"scale"
- a Python float containing the current scale"growth_factor"
- a Python float containing the current growth factor"backoff_factor"
- a Python float containing the current backoff factor"growth_interval"
- a Python int containing the current growth interval"_growth_tracker"
- a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
Note
If you wish to checkpoint the scaler’s state after a particular iteration,
state_dict()
should be called afterupdate()
.
-
step
(optimizer, *args, **kwargs)[source]¶ step()
carries out the following two operations:Internally invokes
unscale_(optimizer)
(unlessunscale_()
was explicitly called foroptimizer
earlier in the iteration). As part of theunscale_()
, gradients are checked for infs/NaNs.If no inf/NaN gradients are found, invokes
optimizer.step()
using the unscaled gradients. Otherwise,optimizer.step()
is skipped to avoid corrupting the params.
*args
and**kwargs
are forwarded tooptimizer.step()
.Returns the return value of
optimizer.step(*args, **kwargs)
.- Parameters
optimizer (torch.optim.Optimizer) – Optimizer that applies the gradients.
args – Any arguments.
kwargs – Any keyword arguments.
Warning
Closure use is not currently supported.
-
unscale_
(optimizer)[source]¶ Divides (“unscales”) the optimizer’s gradient tensors by the scale factor.
unscale_()
is optional, serving cases where you need to modify or inspect gradients between the backward pass(es) andstep()
. Ifunscale_()
is not called explicitly, gradients will be unscaled automatically duringstep()
.Simple example, using
unscale_()
to enable clipping of unscaled gradients:... scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) scaler.step(optimizer) scaler.update()
- Parameters
optimizer (torch.optim.Optimizer) – Optimizer that owns the gradients to be unscaled.
Note
unscale_()
does not incur a CPU-GPU sync.Warning
unscale_()
should only be called once per optimizer perstep()
call, and only after all gradients for that optimizer’s assigned parameters have been accumulated. Callingunscale_()
twice for a given optimizer between eachstep()
triggers a RuntimeError.
-
update
(new_scale=None)[source]¶ Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by
backoff_factor
to reduce it. Ifgrowth_interval
unskipped iterations occurred consecutively, the scale is multiplied bygrowth_factor
to increase it.Passing
new_scale
sets the scale directly.- Parameters
new_scale (float or
torch.cuda.FloatTensor
, optional, default=None) – New scale factor.
Warning
update()
should only be called at the end of the iteration, afterscaler.step(optimizer)
has been invoked for all optimizers used this iteration.