Distributed RPC Framework¶
The distributed RPC framework provides mechanisms for multi-machine model training through a set of primitives to allow for remote communication, and a higher-level API to automatically differentiate models split across several machines.
Warning
The RPC API is experimental and subject to change.
RPC and RRef Framework¶
Before using RPC and distributed autograd primitives, initialization must take
place. To initialize the RPC framework we need to use
init_rpc()
which would initialize the RPC
framework, RRef framework and distributed autograd. By default, this will also
initialize the ProcessGroup (init_process_group()
)
backend for RPC communication. The ProcessGroup backend internally uses gloo
for communication.
-
torch.distributed.rpc.
init_rpc
(name, backend=BackendType.PROCESS_GROUP, rank=-1, world_size=None, rpc_backend_options=None)[source]¶ Initializes RPC primitives such as the local RPC agent and distributed autograd.
Initializes the local RPC agent which immediately makes the current process ready to send and receive RPCs. This method also properly initializes a default process group backend that uses gloo for collective communication.
- Parameters
backend (Enum) – type of RPC backend implementation. Currently, process group backend is the only available backend implementation. (default:
RpcBackend.PROCESS_GROUP
).name (str) – a globally unique name of this node. (e.g.,
Trainer3
,ParameterServer2
,Master
,Worker1
) Name can only contain number, alphabet, underscore, and/or dash, and must be shorter than 128 characters.rank (python:int) – a globally unique id/rank of this node.
world_size (python:int) – The number of workers in the group.
rpc_backend_options (RpcBackendOptions) – The options passed to RpcAgent consturctor.
RRef¶
An RRef (Remote REFerence) is a reference to a value of some type T (e.g. Tensor) on a remote worker. This handle keeps the referenced remote value alive on the owner, but there is no implication that the value will be transferred to the local worker in the future. RRefs can be used in multi-machine training by holding references to nn.Modules that exist on other workers, and calling the appropriate functions to retrieve or modify their parameters during training. See Remote Reference Protocol for more details.
-
class
torch.distributed.rpc.
RRef
¶ A class encapsulating a reference to a value of some type on a remote worker. This handle will keep the referenced remote value alive on the worker.
-
is_owner
(self: torch.distributed.rpc.RRef) → bool¶ Returns whether or not the current node is the owner of this
RRef
.
-
local_value
(self: torch.distributed.rpc.RRef) → object¶ If the current node is the owner, returns a reference to the local value. Otherwise, throws an exception.
-
owner
(self: torch.distributed.rpc.RRef) → torch.distributed.rpc.WorkerInfo¶ Returns worker information of the node that owns this
RRef
.
-
to_here
(self: torch.distributed.rpc.RRef) → object¶ Blocking call that copies the value of the RRef from the owner to the local node and returns it. If the current node is the owner, returns a reference to the local value.
-
RPC and RRef primitives¶
This library provides primitives allowing users to create and modify references (RRefs) to remote data as well as remotely execute functions.
-
torch.distributed.rpc.
rpc_sync
(to, func, args=None, kwargs=None)[source]¶ Make a blocking RPC call to run function
func
on workerto
. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe.- Parameters
to (str or WorkerInfo) – id or name of the destination worker.
func (callable) – any callable function. builtin functions (like
torch.add()
) can be sent over RPC more efficiently.args (tuple) – the argument tuple for the
func
invocation.kwargs (dict) – is a dictionary of keyword arguments for the
func
invocation.
- Returns
Returns the result of running
func
onargs
andkwargs
.
Example:
On worker 0: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) >>> rpc.shutdown() On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
-
torch.distributed.rpc.
rpc_async
(to, func, args=None, kwargs=None)[source]¶ Make a non-blocking RPC call to run function
func
on workerto
. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe. This method will immediately return atorch.distributed.FutureMessage
that can be awaited on.- Parameters
to (str or WorkerInfo) – id or name of the destination worker.
func (callable) – any callable function. builtin functions (like
torch.add()
) can be sent over RPC more efficiently.args (tuple) – the argument tuple for the
func
invocation.kwargs (dict) – is a dictionary of keyword arguments for the
func
invocation.
- Returns
Returns a
torch.distributed.FutureMessage
object that can be waited on. When completed, the return value offunc
onargs
andkwargs
can be retrieved from theFutureMessage
object.
Example:
On worker 0: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) >>> result = fut1.wait() + fut2.wait() >>> rpc.shutdown() On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
-
torch.distributed.rpc.
remote
(to, func, args=None, kwargs=None)[source]¶ Make a remote call to run
func
on workerto
and return anRRef
to the result value immediately. Workerto
will be the owner of the returnedRRef
, and the worker callingremote
is a user. The owner manages the global reference count of itsRRef
, and the ownerRRef
is only destructed when globally there are no living references to it.- Parameters
to (str or WorkerInfo) – id or name of the destination worker.
func (callable) – builtin functions (like
torch.add()
).args (tuple) – the argument tuple for the
func
invocation.kwargs (dict) – is a dictionary of keyword arguments for the
func
invocation.
- Returns
A user
RRef
instance to the result value. Use the blocking APItorch.distributed.rpc.RRef.to_here()
to retrieve the result value locally.
Example:
On worker 0: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> x = rref1.to_here() + rref2.to_here() >>> rpc.shutdown() On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
-
torch.distributed.rpc.
get_worker_info
(worker_name=None)[source]¶ Get
WorkerInfo
of a given worker name. Use thisWorkerInfo
to avoid passing an expensive string on every invocation.- Parameters
worker_name (str) – the string name of a worker. If
None
, return the the id of the current worker. (defaultNone
)- Returns
WorkerInfo
instance for the givenworker_name
orWorkerInfo
of the current worker ifworker_name
isNone
.
-
torch.distributed.rpc.
shutdown
(graceful=True)[source]¶ Perform a shutdown of the RPC agent, and then destroy the RPC agent. This stops the local agent from accepting outstanding requests, and shuts down the RPC framework by terminating all RPC threads. If graceful=True, then this will block until all local and remote RPC processes reach this method and wait for all outstanding work to complete. Otherwise, if graceful=False, then this is a local shutdown, and it does not wait for other RPC processes to reach this method.
- Parameters
graceful (bool) – Whether to do a graceful shutdown or not. If True, this will block until all local and remote RPC processes have reached this method and wait for all outstanding work to complete.
Example:
On worker 0: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> # do some work >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) >>> # ready to shutdown >>> rpc.shutdown() On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> # wait for worker 0 to finish work, and then shutdown. >>> rpc.shutdown()
Distributed Autograd Framework¶
This module provides an RPC-based distributed autograd framework that can be used for applications such as model parallel training. In short, applications may send and receive gradient recording tensors over RPC. In the forward pass, we record when gradient recording tensors are sent over RPC and during the backward pass we use this information to perform a distributed backward pass using RPC. For more details see Distributed Autograd Design.
-
class
torch.distributed.autograd.
context
[source]¶ Context object to wrap forward and backward passes when using distributed autograd. The
context_id
generated in thewith
statement is required to uniquely identify a distributed backward pass on all workers. Each worker stores metadata associated with thiscontext_id
, which is required to correctly execute a distributed autograd pass.Example:
>> import torch.distributed.autograd as dist_autograd >> with dist_autograd.context() as context_id: >> t1 = torch.rand((3, 3), requires_grad=True) >> t2 = torch.rand((3, 3), requires_grad=True) >> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >> dist_autograd.backward([loss])
-
torch.distributed.autograd.
backward
(roots: List[Tensor]) → None¶ Kicks off the distributed backward pass using the provided roots. This currently implements the FAST mode algorithm which assumes all RPC messages sent in the same distributed autograd context across workers would be part of the autograd graph during the backward pass.
We use the provided roots to discover the autograd graph and compute appropriate dependencies. This method blocks until the entire autograd computation is done.
We accumulate the gradients in the appropriate
torch.distributed.autograd.context
on each of the nodes. The autograd context used is the current autograd context of this node whentorch.distributed.autograd.backward()
is called. If there is no valid autograd context, we throw an error. You can retrieve the accumulated gradients using theget_gradients()
API.- Parameters
roots (list) – Tensors which represent the roots of the autograd computation. All the tensors should be scalars.
Example:
>> import torch.distributed.autograd as dist_autograd >> with dist_autograd.context() as context_id: >> pred = model.forward() >> loss = loss_func(pred, loss) >> dist_autograd.backward(loss)
-
torch.distributed.autograd.
get_gradients
(context_id: int) → Dict[Tensor, Tensor]¶ Retrieves a map from Tensor to the appropriate gradient for that Tensor accumulated in the provided
context_id
as part of the distributed autograd backward pass.- Parameters
context_id (python:int) – The autograd context id for which we should retrieve the gradients.
- Returns
A map where the key is the Tensor and the value is the associated gradient for that Tensor.
Example:
>> import torch.distributed.autograd as dist_autograd >> with dist_autograd.context() as context_id: >> t1 = torch.rand((3, 3), requires_grad=True) >> t2 = torch.rand((3, 3), requires_grad=True) >> loss = t1 + t2 >> dist_autograd.backward([loss.sum()]) >> grads = dist_autograd.get_gradients(context_id) >> print (grads[t1]) >> print (grads[t2])
Distributed Optimizer¶
torch.distributed.optim
exposes DistributedOptimizer, which takes a list
of remote parameters (RRef
) and runs the
optimizer locally on the workers where the parameters live. The distributed
optimizer can use any of the local optimizer Algorithms to
apply the gradients on each worker.
-
class
torch.distributed.optim.
DistributedOptimizer
(optimizer_class, params_rref, *args, **kwargs)[source]¶ DistributedOptimizer takes remote references to parameters scattered across workers and applies the given optimizer locally for each parameter.
This class uses
get_gradients()
in order to retrieve the gradients for specific parameters.Concurrent calls to
step()
, either from the same or different clients, will be serialized on each worker – as each worker’s optimizer can only work on one set of gradients at a time. However, there is no guarantee that the full forward-backward-optimizer sequence will execute for one client at a time. This means that the gradients being applied may not correspond to the latest forward pass executed on a given worker. Also, there is no guaranteed ordering across workers.- Parameters
optimizer_class (optim.Optimizer) – the class of optimizer to instantiate on each worker.
params_rref (list[RRef]) – list of RRefs to local or remote parameters to optimize.
args – arguments to pass to the optimizer constructor on each worker.
kwargs – arguments to pass to the optimizer constructor on each worker.
Example:
>> import torch.distributed.autograd as dist_autograd >> import torch.distributed.rpc as rpc >> from torch import optim >> from torch.distributed.optim import DistributedOptimizer >> >> with dist_autograd.context() as context_id: >> # Forward pass. >> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >> loss = rref1.to_here() + rref2.to_here() >> >> # Backward pass. >> dist_autograd.backward([loss.sum()]) >> >> # Optimizer. >> dist_optim = DistributedOptimizer( >> optim.SGD, >> [rref1, rref2], >> lr=0.05, >> ) >> dist_optim.step()
-
step
()[source]¶ Performs a single optimization step.
This will call
torch.optim.Optimizer.step()
on each worker containing parameters to be optimized, and will block until all workers return. The current distributed autogradcontext
will be used globally.