Custom operators in PyTorch enable extending the framework with specialized operations — implementing functionality not available in standard libraries, optimizing performance-critical code with CUDA kernels, or integrating external libraries for domain-specific needs.
What Are Custom Operators?
- Definition: User-defined operations extending PyTorch.
- Use Cases: Missing ops, CUDA optimization, library integration.
- Levels: Python functions, C++ extensions, CUDA kernels.
- Integration: Works with autograd, torch.compile, export.
Why Custom Operators
- Performance: Fused operations, CUDA optimization.
- Functionality: Operations not in standard PyTorch.
- Integration: Connect external C++/CUDA libraries.
- Research: Implement novel operations.
Custom Op Levels
Complexity Spectrum:
```
Level | Performance | Complexity | Use Case
----------------|-------------|------------|------------------
Python function | Low | Easy | Prototyping
torch.autograd | Medium | Easy | Custom backward
C++ extension | High | Medium | CPU optimization
CUDA extension | Highest | Hard | GPU optimization
Triton kernel | High | Medium | GPU, Python-like
Python Custom Function
With Custom Backward:
`python
import torch
from torch.autograd import Function
class MyReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
# Usage
my_relu = MyReLU.apply
output = my_relu(input_tensor)
`
C++ Extension
Setup (setup.py):
`python
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name="my_ops",
ext_modules=[
CppExtension(
"my_ops",
["my_ops.cpp"],
),
],
cmdclass={"build_ext": BuildExtension},
)
`
C++ Implementation (my_ops.cpp):
`cpp
#include <torch/extension.h>
torch::Tensor my_add(torch::Tensor a, torch::Tensor b) {
TORCH_CHECK(a.sizes() == b.sizes(), "Size mismatch");
return a + b; // Simple example
}
torch::Tensor fused_gelu(torch::Tensor x) {
// Fused GELU: x 0.5 (1 + tanh(sqrt(2/pi) (x + 0.044715 x^3)))
auto x3 = x x x;
auto inner = 0.79788456 (x + 0.044715 x3);
return x 0.5 (1.0 + torch::tanh(inner));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("my_add", &my_add, "Element-wise addition");
m.def("fused_gelu", &fused_gelu, "Fused GELU activation");
}
`
Usage:
`python
import torch
import my_ops
x = torch.randn(1000, 1000)
y = my_ops.fused_gelu(x)
`
CUDA Extension
CUDA Kernel (my_ops_cuda.cu):
`cuda
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
template <typename scalar_t>
__global__ void fused_gelu_kernel(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
size_t size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
scalar_t x = input[idx];
scalar_t x3 = x x x;
scalar_t inner = 0.79788456f (x + 0.044715f x3);
output[idx] = x 0.5f (1.0f + tanhf(inner));
}
}
torch::Tensor fused_gelu_cuda(torch::Tensor input) {
auto output = torch::empty_like(input);
const int threads = 256;
const int blocks = (input.numel() + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(input.type(), "fused_gelu", ([&] {
fused_gelu_kernel<scalar_t><<<blocks, threads>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
input.numel()
);
}));
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_gelu", &fused_gelu_cuda, "Fused GELU (CUDA)");
}
`
Triton Alternative
Easier GPU Kernels:
`python
import triton
import triton.language as tl
import torch
@triton.jit
def gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK + tl.arange(0, BLOCK)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# GELU computation
x3 = x x x
inner = 0.79788456 (x + 0.044715 x3)
output = x 0.5 (1.0 + tl.libdevice.tanh(inner))
tl.store(output_ptr + offsets, output, mask=mask)
def fused_gelu_triton(x):
output = torch.empty_like(x)
n = x.numel()
gelu_kernel[(n // 1024 + 1,)](x, output, n, BLOCK=1024)
return output
`
Registering for torch.compile
`python
import torch
from torch.library import Library
# Define custom library
my_lib = Library("myops", "DEF")
# Register schema
my_lib.define("fused_gelu(Tensor x) -> Tensor")
# Register implementation
@torch.library.impl(my_lib, "fused_gelu", "CUDA")
def fused_gelu_impl(x):
return fused_gelu_cuda(x)
# Now works with torch.compile
@torch.compile
def model(x):
return torch.ops.myops.fused_gelu(x)
``
Custom operators are essential for pushing PyTorch performance boundaries — when standard operations aren't sufficient, custom ops enable the optimizations and integrations that production ML systems require.