Skip to content

Commit 4cb73ca

Browse files
zou3519svekars
andauthored
Improve custom ops tutorials (#3020)
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent fc016bd commit 4cb73ca

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

advanced_source/cpp_custom_ops.rst

+2
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ To add ``torch.compile`` support for an operator, we must add a FakeTensor kerne
174174
known as a "meta kernel" or "abstract impl"). FakeTensors are Tensors that have
175175
metadata (such as shape, dtype, device) but no data: the FakeTensor kernel for an
176176
operator specifies how to compute the metadata of output tensors given the metadata of input tensors.
177+
The FakeTensor kernel should return dummy Tensors of your choice with
178+
the correct Tensor metadata (shape/strides/``dtype``/device).
177179

178180
We recommend that this be done from Python via the `torch.library.register_fake` API,
179181
though it is possible to do this from C++ as well (see

advanced_source/python_custom_ops.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def display(img):
6666
######################################################################
6767
# ``crop`` is not handled effectively out-of-the-box by
6868
# ``torch.compile``: ``torch.compile`` induces a
69-
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
69+
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
7070
# on functions it is unable to handle and graph breaks are bad for performance.
7171
# The following code demonstrates this by raising an error
7272
# (``torch.compile`` with ``fullgraph=True`` raises an error if a
@@ -85,9 +85,9 @@ def f(img):
8585
#
8686
# 1. wrap the function into a PyTorch custom operator.
8787
# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.
88-
# Given the metadata (e.g. shapes)
89-
# of the input Tensors, this function says how to compute the metadata
90-
# of the output Tensor(s).
88+
# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),
89+
# this function should return dummy Tensors of your choice with the correct
90+
# Tensor metadata (shape/strides/``dtype``/device).
9191

9292

9393
from typing import Sequence
@@ -130,6 +130,11 @@ def f(img):
130130
# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and
131131
# has led to) silent incorrectness when composed with ``torch.compile``.
132132
#
133+
# If you don't need training support, there is no need to use
134+
# ``torch.library.register_autograd``.
135+
# If you end up training with a ``custom_op`` that doesn't have an autograd
136+
# registration, we'll raise an error message.
137+
#
133138
# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the
134139
# derivation as an exercise to the reader). Let's first wrap ``paste`` into a
135140
# custom operator:
@@ -203,7 +208,7 @@ def setup_context(ctx, inputs, output):
203208
######################################################################
204209
# Mutable Python Custom operators
205210
# -------------------------------
206-
# You can also wrap a Python function that mutates its inputs into a custom
211+
# You can also wrap a Python function that mutates its inputs into a custom
207212
# operator.
208213
# Functions that mutate inputs are common because that is how many low-level
209214
# kernels are written; for example, a kernel that computes ``sin`` may take in

0 commit comments

Comments
 (0)