Distributed Optimizers¶
-
class
torch.distributed.optim.
ZeroRedundancyOptimizer
(params, optimizer_class, group=None, parameters_as_bucket_view=False, **defaults)[source]¶ This class wraps an arbitrary
optim.Optimizer
and shards its states across ranks in the group as described by ZeRO. The local optimizer instance in each rank is only responsible for updating approximately1 / world_size
parameters and hence only needs to keep1 / world_size
optimizer states. After parameters are updated locally, each rank will broadcast its parameters to all other peers to keep all model replicas in the same state.ZeroRedundancyOptimizer
can be used in conjunction withtorch.nn.parallel.DistributedDataParallel
to reduce per-rank peak memory consumption.ZeroRedundancyOptimizer
uses a sorted-greedy algorithm to pack a number of parameters at each rank. Each parameter belongs to a single rank and is not divided among ranks. The partition is arbitrary and might not match the the parameter registration or usage order.- Parameters
params (
Iterable
) – anIterable
oftorch.Tensor
s giving all parameters, which will be sharded across ranks.- Keyword Arguments
optimizer_class (
torch.nn.Optimizer
) – the class of the local optimizer.group (
ProcessGroup
, optional) –torch.distributed
ProcessGroup
(default:dist.group.WORLD
initialized bytorch.distributed.init_process_group()
).parameters_as_bucket_view (bool) – when enabled, parameters are packed into larger buckets to speed up communication, and
param.data
fields point to bucket views at different offsets; when disabled, each individual parameter is communicated separately, but eachparams.data
stays intact.**defaults – any trailing arguments, which are forwarded to the local optimizer.
Example:
>>> import torch.nn as nn >>> from torch.distributed.optim import ZeroRedundancyOptimizer >>> from torch.nn.parallel import DistributedDataParallel as DDP >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) >>> ddp = DDP(model, device_ids=[rank]) >>> opt = ZeroRedundancyOptimizer( >>> ddp.parameters(), >>> optimizer_class=torch.optim.Adam, >>> lr=0.01 >>> ) >>> ddp(inputs).sum().backward() >>> opt.step()
-
add_param_group
(param_group)[source]¶ Add a parameter group to the
Optimizer
sparam_groups
.This can be useful when fine tuning a pre-trained network, as frozen layers can be made trainable and added to the
Optimizer
as training progresses.- Parameters
param_group (dict) – specifies the parameters to be optimized and group-specific optimization options.
-
consolidate_state_dict
(to=0)[source]¶ Consolidate a list of
state_dict
s (one per rank) on the target rank.- Parameters
to (int) – the rank that receives the optimizer states (default: 0).
-
load_state_dict
(state_dict)[source]¶ Load the state pertaining to the given rank from the input
state_dict
, updating the local optimizer as needed.- Parameters
state_dict (dict) – optimizer state; should be an object returned from a call to
state_dict()
.