Distributed RPC Framework¶
The distributed RPC framework provides mechanisms for multi-machine model training through a set of primitives to allow for remote communication, and a higher-level API to automatically differentiate models split across several machines.
Warning
APIs in the RPC package are stable. There are multiple ongoing work items to improve performance and error handling, which will ship in future releases.
Basics¶
The distributed RPC framework makes it easy to run functions remotely, supports referencing remote objects without copying the real data around, and provides autograd and optimizer APIs to transparently run backward and update parameters across RPC boundaries. These features can be categorized into four sets of APIs.
Remote Procedure Call (RPC) supports running a function on the specified destination worker with the given arguments and getting the return value back or creating a reference to the return value. There are three main RPC APIs:
rpc_sync()
(synchronous),rpc_async()
(asynchronous), andremote()
(asynchronous and returns a reference to the remote return value). Use the synchronous API if the user code cannot proceed without the return value. Otherwise, use the asynchronous API to get a future, and wait on the future when the return value is needed on the caller. Theremote()
API is useful when the requirement is to create something remotely but never need to fetch it to the caller. Imagine the case that a driver process is setting up a parameter server and a trainer. The driver can create an embedding table on the parameter server and then share the reference to the embedding table with the trainer, but itself will never use the embedding table locally. In this case,rpc_sync()
andrpc_async()
are no longer appropriate, as they always imply that the return value will be returned to the caller immediately or in the future.Remote Reference (RRef) serves as a distributed shared pointer to a local or remote object. It can be shared with other workers and reference counting will be handled transparently. Each RRef only has one owner and the object only lives on that owner. Non-owner workers holding RRefs can get copies of the object from the owner by explicitly requesting it. This is useful when a worker needs to access some data object, but itself is neither the creator (the caller of
remote()
) or the owner of the object. The distributed optimizer, as we will discuss below, is one example of such use cases.Distributed Autograd stitches together local autograd engines on all the workers involved in the forward pass, and automatically reach out to them during the backward pass to compute gradients. This is especially helpful if the forward pass needs to span multiple machines when conducting, e.g., distributed model parallel training, parameter-server training, etc. With this feature, user code no longer needs to worry about how to send gradients across RPC boundaries and in which order should the local autograd engines be launched, which can become quite complicated where there are nested and inter-dependent RPC calls in the forward pass.
Distributed Optimizer’s constructor takes a
Optimizer()
(e.g.,SGD()
,Adagrad()
, etc.) and a list of parameter RRefs, creates anOptimizer()
instance on each distinct RRef owner, and updates parameters accordingly when runningstep()
. When you have distributed forward and backward passes, parameters and gradients will be scattered across multiple workers, and hence it requires an optimizer on each of the involved workers. Distributed Optimizer wraps all those local optimizers into one, and provides a concise constructor andstep()
API.
RPC¶
Before using RPC and distributed autograd primitives, initialization must take
place. To initialize the RPC framework we need to use
init_rpc()
which would initialize the RPC
framework, RRef framework and distributed autograd. By default, this will also
initialize the ProcessGroup
(init_process_group()
)
backend for RPC communication. The ProcessGroup
backend internally uses gloo
for communication.
-
torch.distributed.rpc.
init_rpc
(name, backend=BackendType.PROCESS_GROUP, rank=-1, world_size=None, rpc_backend_options=None)[source]¶ Initializes RPC primitives such as the local RPC agent and distributed autograd.
Initializes the local RPC agent which immediately makes the current process ready to send and receive RPCs. This method also properly initializes a default process group backend that uses Gloo for communication.
- Parameters
backend (Enum) – type of RPC backend implementation. Currently, process group backend is the only available backend implementation. (default:
RpcBackend.PROCESS_GROUP
).name (str) – a globally unique name of this node. (e.g.,
Trainer3
,ParameterServer2
,Master
,Worker1
) Name can only contain number, alphabet, underscore, and/or dash, and must be shorter than 128 characters.rank (int) – a globally unique id/rank of this node.
world_size (int) – The number of workers in the group.
rpc_backend_options (RpcBackendOptions) – The options passed to RpcAgent constructor. It contains RpcAgent specific initialization configurations. By default, it contains
rpc_timeout = timedelta(seconds=60)
,init_method = "env://"
,num_send_recv_threads = 4
for process group agent. If using the defaultrpc_backend_options
, RPC would initialize the underlying process group backend usinginit_method = "env://"
, meaning that environment variablesMASTER_ADDRESS
andMASTER_PORT
needs to be set properly. SeeProcessGroupRpcBackendOptions
for examples.
The following APIs allow users to remotely execute functions as well as create
references (RRefs) to remote data objects. In these APIs, when passing a
Tensor
as an argument or a return value, the destination worker will try to
create a Tensor
with the same meta (i.e., shape, stride, etc.). We
intentionally disallow transmitting CUDA tensors because it might crash if the
device lists on source and destination workers do not match. In such cases,
applications can always explicitly move the input tensors to CPU on the caller
and move it to the desired devices on the callee if necessary.
Warning
TorchScript support in RPC is experimental and subject to change.
-
torch.distributed.rpc.
rpc_sync
(to, func, args=None, kwargs=None)[source]¶ Make a blocking RPC call to run function
func
on workerto
. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe.- Parameters
to (str or WorkerInfo) – id or name of the destination worker.
func (callable) – a callable function, such as Python callables, builtin operators (e.g.
add()
) and annotated TorchScript functions.args (tuple) – the argument tuple for the
func
invocation.kwargs (dict) – is a dictionary of keyword arguments for the
func
invocation.
- Returns
Returns the result of running
func
withargs
andkwargs
.
Warning
Using GPU tensors as arguments or return values of
func
is not supported since we don’t support sending GPU tensors over the wire. You need to explicitly copy GPU tensors to CPU before using them as arguments or return values offunc
.- Example::
Make sure that
MASTER_ADDRESS
andMASTER_PORT
are set properly on both workers. Refer toinit_process_group()
API for more details. For example,>>> export MASTER_ADDRESS=localhost >>> export MASTER_port=5678
Then run the following code in two different processes:
>>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) >>> rpc.shutdown()
>>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
Below is an example of running a TorchScript function using RPC.
>>> # On both workers: >>> @torch.jit.script >>> def my_script_add(t1, t2): >>> return torch.add(t1, t2)
>>> # On worker 0: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) >>> rpc.shutdown()
>>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
-
torch.distributed.rpc.
rpc_async
(to, func, args=None, kwargs=None)[source]¶ Make a non-blocking RPC call to run function
func
on workerto
. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe. This method will immediately return a Future that can be awaited on.- Parameters
to (str or WorkerInfo) – id or name of the destination worker.
func (callable) – a callable function, such as Python callables, builtin operators (e.g.
add()
) and annotated TorchScript functions.args (tuple) – the argument tuple for the
func
invocation.kwargs (dict) – is a dictionary of keyword arguments for the
func
invocation.
- Returns
Returns a Future object that can be waited on. When completed, the return value of
func
onargs
andkwargs
can be retrieved from the Future object.
Warning
Using GPU tensors as arguments or return values of
func
is not supported since we don’t support sending GPU tensors over the wire. You need to explicitly copy GPU tensors to CPU before using them as arguments or return values offunc
.Warning
The
rpc_async
API does not copy storages of argument tensors until sending them over the wire, which could be done by a different thread depending on the RPC backend type. The caller should make sure that the contents of those tensors stay intact until the returned Future completes.- Example::
Make sure that
MASTER_ADDRESS
andMASTER_PORT
are set properly on both workers. Refer toinit_process_group()
API for more details. For example,>>> export MASTER_ADDRESS=localhost >>> export MASTER_port=5678
Then run the following code in two different processes:
>>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) >>> result = fut1.wait() + fut2.wait() >>> rpc.shutdown()
>>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
Below is an example of running a TorchScript function using RPC.
>>> # On both workers: >>> @torch.jit.script >>> def my_script_add(t1, t2): >>> return torch.add(t1, t2)
>>> # On worker 0: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) >>> ret = fut.wait() >>> rpc.shutdown()
>>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
-
torch.distributed.rpc.
remote
(to, func, args=None, kwargs=None)[source]¶ Make a remote call to run
func
on workerto
and return anRRef
to the result value immediately. Workerto
will be the owner of the returnedRRef
, and the worker callingremote
is a user. The owner manages the global reference count of itsRRef
, and the ownerRRef
is only destructed when globally there are no living references to it.- Parameters
to (str or WorkerInfo) – id or name of the destination worker.
func (callable) – a callable function, such as Python callables, builtin operators (e.g.
add()
) and annotated TorchScript functions.args (tuple) – the argument tuple for the
func
invocation.kwargs (dict) – is a dictionary of keyword arguments for the
func
invocation.
- Returns
A user
RRef
instance to the result value. Use the blocking APItorch.distributed.rpc.RRef.to_here()
to retrieve the result value locally.
Warning
Using GPU tensors as arguments or return values of
func
is not supported since we don’t support sending GPU tensors over the wire. You need to explicitly copy GPU tensors to CPU before using them as arguments or return values offunc
.Warning
The
remote
API does not copy storages of argument tensors until sending them over the wire, which could be done by a different thread depending on the RPC backend type. The caller should make sure that the contents of those tensors stay intact until the returned RRef is confirmed by the owner, which can be checked using thetorch.distributed.rpc.RRef.confirmed_by_owner()
API.- Example::
Make sure that
MASTER_ADDRESS
andMASTER_PORT
are set properly on both workers. Refer toinit_process_group()
API for more details. For example,>>> export MASTER_ADDRESS=localhost >>> export MASTER_port=5678
Then run the following code in two different processes:
>>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> x = rref1.to_here() + rref2.to_here() >>> rpc.shutdown()
>>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
Below is an example of running a TorchScript function using RPC.
>>> # On both workers: >>> @torch.jit.script >>> def my_script_add(t1, t2): >>> return torch.add(t1, t2)
>>> # On worker 0: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) >>> rref.to_here() >>> rpc.shutdown()
>>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown()
-
torch.distributed.rpc.
get_worker_info
(worker_name=None)[source]¶ Get
WorkerInfo
of a given worker name. Use thisWorkerInfo
to avoid passing an expensive string on every invocation.- Parameters
worker_name (str) – the string name of a worker. If
None
, return the the id of the current worker. (defaultNone
)- Returns
WorkerInfo
instance for the givenworker_name
orWorkerInfo
of the current worker ifworker_name
isNone
.
-
torch.distributed.rpc.
shutdown
(graceful=True)[source]¶ Perform a shutdown of the RPC agent, and then destroy the RPC agent. This stops the local agent from accepting outstanding requests, and shuts down the RPC framework by terminating all RPC threads. If
graceful=True
, this will block until all local and remote RPC processes reach this method and wait for all outstanding work to complete. Otherwise, ifgraceful=False
, this is a local shutdown, and it does not wait for other RPC processes to reach this method.- Parameters
graceful (bool) – Whether to do a graceful shutdown or not. If True, this will 1) wait until there is no pending system messages for
UserRRefs
and delete them; 2) block until all local and remote RPC processes have reached this method and wait for all outstanding work to complete.
- Example::
Make sure that
MASTER_ADDRESS
andMASTER_PORT
are set properly on both workers. Refer toinit_process_group()
API for more details. For example,>>> export MASTER_ADDRESS=localhost >>> export MASTER_port=5678
Then run the following code in two different processes:
>>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> # do some work >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) >>> # ready to shutdown >>> rpc.shutdown()
>>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> # wait for worker 0 to finish work, and then shutdown. >>> rpc.shutdown()
-
class
torch.distributed.rpc.
WorkerInfo
¶ A structure that encapsulates information of a worker in the system. Contains the name and ID of the worker. This class is not meant to be constructed directly, rather, an instance can be retrieved through
get_worker_info()
and the result can be passed in to functions such asrpc_sync()
,rpc_async
,remote()
to avoid copying a string on every invocation.-
property
id
¶ Globally unique id to identify the worker.
-
property
name
¶ The name of the worker.
-
property
-
class
torch.distributed.rpc.
ProcessGroupRpcBackendOptions
¶ The backend options class for
ProcessGroupAgent
, which is derived fromRpcBackendOptions
.- Parameters
num_send_recv_threads (int, optional) – The number of threads in the thread-pool used by
ProcessGroupAgent
(default: 4).rpc_timeout (datetime.timedelta, optional) – The timeout for RPC requests (default:
timedelta(seconds=60)
).init_method (str, optional) – The URL to initialize
ProcessGroupGloo
(default:env://
).
- Example::
>>> import datetime, os >>> from torch.distributed import rpc >>> os.environ['MASTER_ADDR'] = 'localhost' >>> os.environ['MASTER_PORT'] = '29500' >>> >>> rpc.init_rpc( >>> "worker1", >>> rank=0, >>> world_size=2, >>> rpc_backend_options=rpc.ProcessGroupRpcBackendOptions( >>> num_send_recv_threads=16, >>> datetime.timedelta(seconds=20) >>> ) >>> ) >>> >>> # omitting init_rpc invocation on worker2
-
property
init_method
¶ URL specifying how to initialize the process group. Default is
env://
-
property
num_send_recv_threads
¶ The number of threads in the thread-pool used by ProcessGroupAgent.
-
property
rpc_timeout
¶ A
datetime.timedelta
indicating the timeout to use for all RPCs. If an RPC does not complete in this timeframe, it will complete with an exception indicating that it has timed out.
RRef¶
An RRef
(Remote REFerence) is a reference to a value of some type T
(e.g. Tensor
) on a remote worker. This handle keeps the referenced remote
value alive on the owner, but there is no implication that the value will be
transferred to the local worker in the future. RRefs can be used in
multi-machine training by holding references to nn.Modules that exist on
other workers, and calling the appropriate functions to retrieve or modify their
parameters during training. See Remote Reference Protocol for more
details.
-
class
torch.distributed.rpc.
RRef
¶ A class encapsulating a reference to a value of some type on a remote worker. This handle will keep the referenced remote value alive on the worker. A
UserRRef
will be deleted when 1) no references to it in both the application code and in the local RRef context, or 2) the application has called a graceful shutdown. Invoking methods on a deleted RRef leads to undefined behaviors. RRef implementation only offers best-effort error detection, and applications should not useUserRRefs
afterrpc.shutdown()
.Warning
RRefs can only be serialized and deserialized by the RPC module. Serializing and deserializing RRefs without RPC (e.g., Python pickle, torch
save()
/load()
, JITsave()
/load()
, etc.) will lead to errors.- Example::
Following examples skip RPC initialization and shutdown code for simplicity. Refer to RPC docs for those details.
Create an RRef using rpc.remote
>>> import torch >>> import torch.distributed.rpc as rpc >>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> # get a copy of value from the RRef >>> x = rref.to_here()
Create an RRef from a local object
>>> import torch >>> from torch.distributed.rpc import RRef >>> x = torch.zeros(2, 2) >>> rref = RRef(x)
Share an RRef with other workers
>>> # On both worker0 and worker1: >>> def f(rref): >>> return rref.to_here() + 1
>>> # On worker0: >>> import torch >>> import torch.distributed.rpc as rpc >>> from torch.distributed.rpc import RRef >>> rref = RRef(torch.zeros(2, 2)) >>> # the following RPC shares the rref with worker1, reference >>> # count is automatically updated. >>> rpc.rpc_sync("worker1", f, args(rref,))
-
confirmed_by_owner
(self: torch.distributed.rpc.RRef) → bool¶ Returns whether this
RRef
has been confirmed by the owner.OwnerRRef
always returns true, whileUserRRef
only returns true when the owner knowns about thisUserRRef
.
-
is_owner
(self: torch.distributed.rpc.RRef) → bool¶ Returns whether or not the current node is the owner of this
RRef
.
-
local_value
(self: torch.distributed.rpc.RRef) → object¶ If the current node is the owner, returns a reference to the local value. Otherwise, throws an exception.
-
owner
(self: torch.distributed.rpc.RRef) → torch.distributed.rpc.WorkerInfo¶ Returns worker information of the node that owns this
RRef
.
-
owner_name
(self: torch.distributed.rpc.RRef) → str¶ Returns worker name of the node that owns this
RRef
.
-
to_here
(self: torch.distributed.rpc.RRef) → object¶ Blocking call that copies the value of the RRef from the owner to the local node and returns it. If the current node is the owner, returns a reference to the local value.
Distributed Autograd Framework¶
This module provides an RPC-based distributed autograd framework that can be used for applications such as model parallel training. In short, applications may send and receive gradient recording tensors over RPC. In the forward pass, we record when gradient recording tensors are sent over RPC and during the backward pass we use this information to perform a distributed backward pass using RPC. For more details see Distributed Autograd Design.
-
class
torch.distributed.autograd.
context
[source]¶ Context object to wrap forward and backward passes when using distributed autograd. The
context_id
generated in thewith
statement is required to uniquely identify a distributed backward pass on all workers. Each worker stores metadata associated with thiscontext_id
, which is required to correctly execute a distributed autograd pass.- Example::
>>> import torch.distributed.autograd as dist_autograd >>> with dist_autograd.context() as context_id: >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >>> dist_autograd.backward(context_id, [loss])
-
torch.distributed.autograd.
backward
(context_id: int, roots: List[Tensor], retain_graph = False) → None¶ Kicks off the distributed backward pass using the provided roots. This currently implements the FAST mode algorithm which assumes all RPC messages sent in the same distributed autograd context across workers would be part of the autograd graph during the backward pass.
We use the provided roots to discover the autograd graph and compute appropriate dependencies. This method blocks until the entire autograd computation is done.
We accumulate the gradients in the appropriate
torch.distributed.autograd.context
on each of the nodes. The autograd context to be used is looked up given thecontext_id
that is passed in whentorch.distributed.autograd.backward()
is called. If there is no valid autograd context corresponding to the given ID, we throw an error. You can retrieve the accumulated gradients using theget_gradients()
API.- Parameters
context_id (int) – The autograd context id for which we should retrieve the gradients.
roots (list) – Tensors which represent the roots of the autograd computation. All the tensors should be scalars.
retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Usually, you need to set this to True to run backward multiple times.
- Example::
>>> import torch.distributed.autograd as dist_autograd >>> with dist_autograd.context() as context_id: >>> pred = model.forward() >>> loss = loss_func(pred, loss) >>> dist_autograd.backward(context_id, loss)
-
torch.distributed.autograd.
get_gradients
(context_id: int) → Dict[Tensor, Tensor]¶ Retrieves a map from Tensor to the appropriate gradient for that Tensor accumulated in the provided context corresponding to the given
context_id
as part of the distributed autograd backward pass.- Parameters
context_id (int) – The autograd context id for which we should retrieve the gradients.
- Returns
A map where the key is the Tensor and the value is the associated gradient for that Tensor.
- Example::
>>> import torch.distributed.autograd as dist_autograd >>> with dist_autograd.context() as context_id: >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = t1 + t2 >>> dist_autograd.backward(context_id, [loss.sum()]) >>> grads = dist_autograd.get_gradients(context_id) >>> print(grads[t1]) >>> print(grads[t2])
Distributed Optimizer¶
torch.distributed.optim
exposes DistributedOptimizer, which takes a list
of remote parameters (RRef
) and runs the
optimizer locally on the workers where the parameters live. The distributed
optimizer can use any of the local optimizer Algorithms to
apply the gradients on each worker.
-
class
torch.distributed.optim.
DistributedOptimizer
(optimizer_class, params_rref, *args, **kwargs)[source]¶ DistributedOptimizer takes remote references to parameters scattered across workers and applies the given optimizer locally for each parameter.
This class uses
get_gradients()
in order to retrieve the gradients for specific parameters.Concurrent calls to
step()
, either from the same or different clients, will be serialized on each worker – as each worker’s optimizer can only work on one set of gradients at a time. However, there is no guarantee that the full forward-backward-optimizer sequence will execute for one client at a time. This means that the gradients being applied may not correspond to the latest forward pass executed on a given worker. Also, there is no guaranteed ordering across workers.- Parameters
optimizer_class (optim.Optimizer) – the class of optimizer to instantiate on each worker.
params_rref (list[RRef]) – list of RRefs to local or remote parameters to optimize.
args – arguments to pass to the optimizer constructor on each worker.
kwargs – arguments to pass to the optimizer constructor on each worker.
- Example::
>>> import torch.distributed.autograd as dist_autograd >>> import torch.distributed.rpc as rpc >>> from torch import optim >>> from torch.distributed.optim import DistributedOptimizer >>> >>> with dist_autograd.context() as context_id: >>> # Forward pass. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> loss = rref1.to_here() + rref2.to_here() >>> >>> # Backward pass. >>> dist_autograd.backward(context_id, [loss.sum()]) >>> >>> # Optimizer. >>> dist_optim = DistributedOptimizer( >>> optim.SGD, >>> [rref1, rref2], >>> lr=0.05, >>> ) >>> dist_optim.step(context_id)
-
step
(context_id)[source]¶ Performs a single optimization step.
This will call
torch.optim.Optimizer.step()
on each worker containing parameters to be optimized, and will block until all workers return. The providedcontext_id
will be used to retrieve the correspondingcontext
that contains the gradients that should be applied to the parameters.- Parameters
context_id – the autograd context id for which we should run the optimizer step.