torch.utils.tensorboard¶
Warning
This code is EXPERIMENTAL and might change in the future. It also
currently does not support all model types for add_graph
, which
we are actively working on.
Before going further, more details on TensorBoard can be found at https://www.tensorflow.org/tensorboard/
Once you’ve installed TensorBoard, these utilities let you log PyTorch models and metrics into a directory for visualization within the TensorBoard UI. Scalars, images, histograms, graphs, and embedding visualizations are all supported for PyTorch models and tensors as well as Caffe2 nets and blobs.
The SummaryWriter class is your main entry to log data for consumption and visualization by TensorBoard. For example:
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
# Writer will output to ./runs/ directory by default
writer = SummaryWriter()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = torchvision.models.resnet50(False)
# Have ResNet model take in grayscale rather than RGB
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
images, labels = next(iter(trainloader))
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
writer.close()
This can then be visualized with TensorBoard, which should be installable and runnable with:
pip install tb-nightly # Until 1.14 moves to the release channel
tensorboard --logdir=runs
-
class
torch.utils.tensorboard.writer.
SummaryWriter
(log_dir=None, comment='', **kwargs)[source]¶ Writes entries directly to event files in the log_dir to be consumed by TensorBoard.
The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it. The class updates the file contents asynchronously. This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training.
-
add_histogram
(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None)[source]¶ Add histogram to summary.
- Parameters
tag (string) – Data identifier
values (torch.Tensor, numpy.array, or string/blobname) – Values to build histogram
global_step (int) – Global step value to record
bins (string) – one of {‘tensorflow’,’auto’, ‘fd’, …}, this determines how the bins are made. You can find other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event
-
add_image
(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')[source]¶ Add image data to summary.
Note that this requires the
pillow
package.- Parameters
tag (string) – Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname) – Image data
global_step (int) – Global step value to record
walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event
- Shape:
img_tensor: Default is \((3, H, W)\). You can use
torchvision.utils.make_grid()
to convert a batch of tensor into 3xHxW format or calladd_images
and let us do the job. Tensor with \((1, H, W)\), \((H, W)\), \((H, W, 3)\) is also suitible as long as correspondingdataformats
argument is passed. e.g. CHW, HWC, HW.
-
add_figure
(tag, figure, global_step=None, close=True, walltime=None)[source]¶ Render matplotlib figure into an image and add it to summary.
Note that this requires the
matplotlib
package.- Parameters
-
add_video
(tag, vid_tensor, global_step=None, fps=4, walltime=None)[source]¶ Add video data to summary.
Note that this requires the
moviepy
package.- Parameters
tag (string) – Data identifier
vid_tensor (torch.Tensor) – Video data
global_step (int) – Global step value to record
walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event
- Shape:
vid_tensor: \((N, T, C, H, W)\). The values should lie in [0, 255] for type uint8 or [0, 1] for type float.
-
add_audio
(tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None)[source]¶ Add audio data to summary.
- Parameters
tag (string) – Data identifier
snd_tensor (torch.Tensor) – Sound data
global_step (int) – Global step value to record
sample_rate (int) – sample rate in Hz
walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event
- Shape:
snd_tensor: \((1, L)\). The values should lie between [-1, 1].
-
add_text
(tag, text_string, global_step=None, walltime=None)[source]¶ Add text data to summary.
- Parameters
Examples:
writer.add_text('lstm', 'This is an lstm', 0) writer.add_text('rnn', 'This is an rnn', 10)
-
add_graph
(model, input_to_model=None, verbose=False, **kwargs)[source]¶ Add graph data to summary.
- Parameters
model (torch.nn.Module) – model to draw.
input_to_model (torch.Tensor or list of torch.Tensor) – a variable or a tuple of variables to be fed.
verbose (bool) – Whether to print graph structure in console.
omit_useless_nodes (bool) – Default to
true
, which eliminates unused nodes.operator_export_type (string) – One of:
"ONNX"
,"RAW"
. This determines the optimization level of the graph. If error happens during exporting the graph, use"RAW"
may help.
-
add_embedding
(mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None)[source]¶ Add embedding projector data to summary.
- Parameters
mat (torch.Tensor or numpy.array) – A matrix which each row is the feature vector of the data point
metadata (list) – A list of labels, each element will be convert to string
label_img (torch.Tensor) – Images correspond to each data point
global_step (int) – Global step value to record
tag (string) – Name for the embedding
- Shape:
mat: \((N, D)\), where N is number of data and D is feature dimension
label_img: \((N, C, H, W)\)
Examples:
import keyword import torch meta = [] while len(meta)<100: meta = meta+keyword.kwlist # get some strings meta = meta[:100] for i, v in enumerate(meta): meta[i] = v+str(i) label_img = torch.rand(100, 3, 10, 32) for i in range(100): label_img[i]*=i/100.0 writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img) writer.add_embedding(torch.randn(100, 5), label_img=label_img) writer.add_embedding(torch.randn(100, 5), metadata=meta)
-
add_pr_curve
(tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None)[source]¶ Adds precision recall curve.
- Parameters
tag (string) – Data identifier
labels (torch.Tensor, numpy.array, or string/blobname) – Ground truth data. Binary label for each element.
predictions (torch.Tensor, numpy.array, or string/blobname) –
probability that an element be classified as true. Value should in [0, 1] (The) –
global_step (int) – Global step value to record
num_thresholds (int) – Number of thresholds used to draw the curve.
walltime (float) – Optional override default walltime (time.time()) seconds after epoch of event
-
add_custom_scalars
(layout)[source]¶ Create special chart by collecting charts tags in ‘scalars’. Note that this function can only be called once for each SummaryWriter() object. Because it only provides metadata to tensorboard, the function can be called before or after the training loop.
- Parameters
layout (dict) – {categoryName: charts}, where charts is also a dictionary {chartName: ListOfProperties}. The first element in ListOfProperties is the chart’s type (one of Multiline or Margin) and the second element should be a list containing the tags you have used in add_scalar function, which will be collected into the new chart.
Examples:
layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]}, 'USA':{ 'dow':['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], 'nasdaq':['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}} writer.add_custom_scalars(layout)
-