Custom operators

Keywords: custom operator, extension, pytorch, cuda, c++, kernel, triton

Custom operators in PyTorch enable extending the framework with specialized operations — implementing functionality not available in standard libraries, optimizing performance-critical code with CUDA kernels, or integrating external libraries for domain-specific needs.

What Are Custom Operators?

- Definition: User-defined operations extending PyTorch.
- Use Cases: Missing ops, CUDA optimization, library integration.
- Levels: Python functions, C++ extensions, CUDA kernels.
- Integration: Works with autograd, torch.compile, export.

Why Custom Operators

- Performance: Fused operations, CUDA optimization.
- Functionality: Operations not in standard PyTorch.
- Integration: Connect external C++/CUDA libraries.
- Research: Implement novel operations.

Custom Op Levels

Complexity Spectrum:
``
Level | Performance | Complexity | Use Case
----------------|-------------|------------|------------------
Python function | Low | Easy | Prototyping
torch.autograd | Medium | Easy | Custom backward
C++ extension | High | Medium | CPU optimization
CUDA extension | Highest | Hard | GPU optimization
Triton kernel | High | Medium | GPU, Python-like
`

Python Custom Function

With Custom Backward:
`python
import torch
from torch.autograd import Function

class MyReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)

@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input

# Usage
my_relu = MyReLU.apply
output = my_relu(input_tensor)
`

C++ Extension

Setup (setup.py):
`python
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(
name="my_ops",
ext_modules=[
CppExtension(
"my_ops",
["my_ops.cpp"],
),
],
cmdclass={"build_ext": BuildExtension},
)
`

C++ Implementation (my_ops.cpp):
`cpp
#include <torch/extension.h>

torch::Tensor my_add(torch::Tensor a, torch::Tensor b) {
TORCH_CHECK(a.sizes() == b.sizes(), "Size mismatch");
return a + b; // Simple example
}

torch::Tensor fused_gelu(torch::Tensor x) {
// Fused GELU: x 0.5 (1 + tanh(sqrt(2/pi) (x + 0.044715 x^3)))
auto x3 = x x x;
auto inner = 0.79788456 (x + 0.044715 x3);
return x 0.5 (1.0 + torch::tanh(inner));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("my_add", &my_add, "Element-wise addition");
m.def("fused_gelu", &fused_gelu, "Fused GELU activation");
}
`

Usage:
`python
import torch
import my_ops

x = torch.randn(1000, 1000)
y = my_ops.fused_gelu(x)
`

CUDA Extension

CUDA Kernel (my_ops_cuda.cu):
`cuda
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

template <typename scalar_t>
__global__ void fused_gelu_kernel(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
size_t size
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
scalar_t x = input[idx];
scalar_t x3 = x x x;
scalar_t inner = 0.79788456f (x + 0.044715f x3);
output[idx] = x 0.5f (1.0f + tanhf(inner));
}
}

torch::Tensor fused_gelu_cuda(torch::Tensor input) {
auto output = torch::empty_like(input);

const int threads = 256;
const int blocks = (input.numel() + threads - 1) / threads;

AT_DISPATCH_FLOATING_TYPES(input.type(), "fused_gelu", ([&] {
fused_gelu_kernel<scalar_t><<<blocks, threads>>>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
input.numel()
);
}));

return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_gelu", &fused_gelu_cuda, "Fused GELU (CUDA)");
}
`

Triton Alternative

Easier GPU Kernels:
`python
import triton
import triton.language as tl
import torch

@triton.jit
def gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK + tl.arange(0, BLOCK)
mask = offsets < n_elements

x = tl.load(x_ptr + offsets, mask=mask)

# GELU computation
x3 = x x x
inner = 0.79788456 (x + 0.044715 x3)
output = x 0.5 (1.0 + tl.libdevice.tanh(inner))

tl.store(output_ptr + offsets, output, mask=mask)

def fused_gelu_triton(x):
output = torch.empty_like(x)
n = x.numel()
gelu_kernel[(n // 1024 + 1,)](x, output, n, BLOCK=1024)
return output
`

Registering for torch.compile

`python
import torch
from torch.library import Library

# Define custom library
my_lib = Library("myops", "DEF")

# Register schema
my_lib.define("fused_gelu(Tensor x) -> Tensor")

# Register implementation
@torch.library.impl(my_lib, "fused_gelu", "CUDA")
def fused_gelu_impl(x):
return fused_gelu_cuda(x)

# Now works with torch.compile
@torch.compile
def model(x):
return torch.ops.myops.fused_gelu(x)
``

Custom operators are essential for pushing PyTorch performance boundaries — when standard operations aren't sufficient, custom ops enable the optimizations and integrations that production ML systems require.

Want to learn more?

Search 13,225+ semiconductor and AI topics or chat with our AI assistant.

Search Topics Chat with CFSGPT