Custom Operations System¶
The minitensor library provides a powerful extensibility framework that lets users define their own tensor operations with full automatic differentiation support. This system enables developers to extend the library’s functionality without modifying the core codebase.
Overview¶
The custom operations system consists of several key components:
CustomOp Trait: Defines the interface for custom operations
CustomOpRegistry: Manages registration and execution of custom operations
CustomOpBuilder: Provides a convenient builder pattern for creating operations
Automatic Differentiation Integration: Seamless integration with the gradient computation system
Python Bindings: Full Python API for registering and executing custom operations
Core Components¶
CustomOp Trait¶
The CustomOp trait defines the interface that all custom operations must implement:
pub trait CustomOp: Send + Sync {
fn name(&self) -> &str;
fn validate_inputs(&self, inputs: &[&Tensor]) -> Result<()>;
fn forward(&self, inputs: &[&Tensor]) -> Result<Tensor>;
fn create_gradient_function(&self, inputs: &[&Tensor], output: &Tensor) -> Option<Arc<dyn GradientFunction>>;
fn num_inputs(&self) -> usize;
fn output_shape(&self, input_shapes: &[&Shape]) -> Result<Shape>;
fn output_dtype(&self, input_dtypes: &[DataType]) -> Result<DataType>;
fn output_device(&self, input_devices: &[&Device]) -> Result<Device>;
}
CustomOpBuilder¶
The builder pattern provides a convenient way to create custom operations:
let op = CustomOpBuilder::new("my_operation", 2)
.forward(|inputs| {
// Forward pass implementation
Ok(result_tensor)
})
.backward(|grad_output, input_ids, input_shapes, input_dtypes, input_devices| {
// Backward pass implementation
Ok(gradients_map)
})
.validate(|inputs| {
// Input validation
Ok(())
})
.build()?;
Example Custom Operations¶
The library includes several example custom operations to demonstrate the system.
They can be registered from Python via mt.register_example_custom_ops():
1. Swish Activation Function¶
pub fn create_swish_op() -> Result<Arc<dyn CustomOp>> {
CustomOpBuilder::new("swish", 1)
.forward(|inputs| {
let x = inputs[0];
let sigmoid_x = activation::sigmoid(x)?;
arithmetic::mul(x, &sigmoid_x)
})
.backward(|grad_output, input_ids, input_shapes, input_dtypes, input_devices| {
// Swish gradient implementation
let mut gradients = HashMap::new();
// ... gradient computation logic
Ok(gradients)
})
.build()
}
2. GELU Activation Function¶
pub fn create_gelu_op() -> Result<Arc<dyn CustomOp>> {
CustomOpBuilder::new("gelu", 1)
.forward(|inputs| {
let x = inputs[0];
// GELU implementation using existing operations
let tanh_x = activation::tanh(x)?;
let one = Tensor::ones(x.shape().clone(), x.dtype(), x.device(), false);
let one_plus_tanh = arithmetic::add(&one, &tanh_x)?;
arithmetic::mul(x, &one_plus_tanh)
})
.build()
}
3. Element-wise Power Operation¶
pub fn create_power_op() -> Result<Arc<dyn CustomOp>> {
CustomOpBuilder::new("power", 2)
.forward(|inputs| {
let base = inputs[0];
let exponent = inputs[1];
// Power operation implementation
arithmetic::mul(base, exponent) // Simplified
})
.validate(|inputs| {
if inputs[0].shape() != inputs[1].shape() {
return Err(MinitensorError::shape_mismatch(
inputs[0].shape().dims().to_vec(),
inputs[1].shape().dims().to_vec()
));
}
Ok(())
})
.build()
}
Python API¶
The custom operations system is fully accessible from Python:
Registration¶
import minitensor as mt
# Register example custom operations
mt.register_example_custom_ops()
# List registered operations
ops = mt.list_custom_ops_py()
print("Available operations:", ops)
# Check if an operation is registered
is_registered = mt.is_custom_op_registered_py("swish")
print("Swish registered:", is_registered)
Execution¶
import minitensor as mt
# Create input tensor
x = mt.Tensor([[1.0, 2.0, -1.0]], requires_grad=True)
# Execute custom operation (result is a core tensor)
res_core = mt.execute_custom_op_py("swish", [x._tensor])
result = mt.Tensor.__new__(mt.Tensor)
result._tensor = res_core
print("Swish result:", result)
Advanced Usage¶
# Power operation with two inputs
base = mt.Tensor([[2.0, 3.0], [4.0, 5.0]], requires_grad=True)
exponent = mt.Tensor([[2.0, 2.0], [3.0, 2.0]], requires_grad=True)
power_core = mt.execute_custom_op_py("power", [base._tensor, exponent._tensor])
power = mt.Tensor.__new__(mt.Tensor)
power._tensor = power_core
print("Power result:", power)
# Layer normalization with three inputs
input_tensor = mt.Tensor([[1.0, 2.0, 3.0, 4.0]], requires_grad=True)
weight = mt.Tensor([1.0, 1.0, 1.0, 1.0], requires_grad=True)
bias = mt.Tensor([0.0, 0.0, 0.0, 0.0], requires_grad=True)
norm_core = mt.execute_custom_op_py(
"layer_norm", [input_tensor._tensor, weight._tensor, bias._tensor]
)
normalized = mt.Tensor.__new__(mt.Tensor)
normalized._tensor = norm_core
print("Layer norm result:", normalized)
Features¶
1. Automatic Differentiation Integration¶
Custom operations seamlessly integrate with the automatic differentiation system:
Operations can define custom gradient functions
Gradients flow through custom operations just like built-in operations
Support for complex gradient computations with multiple inputs/outputs
2. Type Safety and Validation¶
The system provides comprehensive validation:
Input tensor count validation
Shape compatibility checking
Data type validation
Device compatibility verification
Custom validation logic support
3. Performance Optimization¶
Custom operations are designed for performance:
Zero-copy tensor passing where possible
Efficient memory management
Integration with SIMD and GPU backends
Lazy evaluation support
4. Error Handling¶
Robust error handling throughout the system:
Clear error messages with actionable suggestions
Proper error propagation from Rust to Python
Validation errors with detailed context
Runtime error detection and reporting
Best Practices¶
1. Operation Design¶
Keep operations focused and composable
Use descriptive names that avoid conflicts
Implement proper input validation
Consider numerical stability
2. Gradient Implementation¶
Ensure gradient correctness through testing
Handle edge cases (zeros, infinities, NaNs)
Consider gradient clipping for stability
Test gradient computation with finite differences
3. Performance Considerations¶
Minimize memory allocations in hot paths
Reuse tensors where possible
Consider vectorization opportunities
Profile custom operations for bottlenecks
4. Testing¶
Test with various input shapes and types
Verify gradient correctness
Test error conditions and edge cases
Include performance benchmarks
Future Extensions¶
The custom operations system is designed to be extensible:
Plugin System: Load operations from external libraries
JIT Compilation: Compile custom operations for better performance
Graph Optimization: Fuse custom operations with built-in operations
Distributed Operations: Support for multi-device custom operations
Serialization: Save and load custom operations with models