Shortcuts

torch.dynamic-value

constrain_as_size_example

Note

Tags: torch.escape-hatch, torch.dynamic-value

Support Level: SUPPORTED

Original source code:

import torch
from torch._export.constraints import constrain_as_size



def constrain_as_size_example(x):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at constrain_as_value and constrain_as_size APIs
    constrain_as_size is used for values that NEED to be used for constructing
    tensor.
    """
    a = x.item()
    constrain_as_size(a, min=0, max=5)
    return torch.ones((a, 5))

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: i64[]):
                _local_scalar_dense: Sym(i4) = torch.ops.aten._local_scalar_dense.default(arg0_1);  arg0_1 = None

                ge: Sym(i4 >= 0) = _local_scalar_dense >= 0
            scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor = None
            le: Sym(i4 <= 5) = _local_scalar_dense <= 5
            scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor_1 = None

                sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = 0, max = 5)

                full: f32[i4, 5] = torch.ops.aten.full.default([_local_scalar_dense, 5], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None

                sym_size: Sym(i4) = torch.ops.aten.sym_size.int(full, 0)
            ge_1: Sym(i4 >= 0) = sym_size >= 0
            scalar_tensor_2: f32[] = torch.ops.aten.scalar_tensor.default(ge_1);  ge_1 = None
            _assert_async_2 = torch.ops.aten._assert_async.msg(scalar_tensor_2, 'full.shape[0] is outside of inline constraint [0, 5].');  scalar_tensor_2 = None
            le_1: Sym(i4 <= 5) = sym_size <= 5;  sym_size = None
            scalar_tensor_3: f32[] = torch.ops.aten.scalar_tensor.default(le_1);  le_1 = None
            _assert_async_3 = torch.ops.aten._assert_async.msg(scalar_tensor_3, 'full.shape[0] is outside of inline constraint [0, 5].');  scalar_tensor_3 = None
            return (full,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['full'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {i0: RangeConstraint(min_val=2, max_val=5), i1: RangeConstraint(min_val=2, max_val=5), i2: RangeConstraint(min_val=2, max_val=5), i3: RangeConstraint(min_val=2, max_val=5), i4: RangeConstraint(min_val=2, max_val=5)}
Equality constraints: []

constrain_as_value_example

Note

Tags: torch.escape-hatch, torch.dynamic-value

Support Level: SUPPORTED

Original source code:

import torch
from torch._export.constraints import constrain_as_value



def constrain_as_value_example(x, y):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at constrain_as_value and constrain_as_size APIs.
    constrain_as_value is used for values that don't need to be used for constructing
    tensor.
    """
    a = x.item()
    constrain_as_value(a, min=0, max=5)

    if a < 6:
        return y.sin()
    return y.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: i64[], arg1_1: f32[5, 5]):
                _local_scalar_dense: Sym(i4) = torch.ops.aten._local_scalar_dense.default(arg0_1);  arg0_1 = None

                ge: Sym(i4 >= 0) = _local_scalar_dense >= 0
            scalar_tensor: f32[] = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor = None
            le: Sym(i4 <= 5) = _local_scalar_dense <= 5
            scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor_1 = None

                sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 0, max = 5);  _local_scalar_dense = None

                sin: f32[5, 5] = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
            return (sin,)

Graph signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['sin'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Range constraints: {i0: RangeConstraint(min_val=0, max_val=5), i1: RangeConstraint(min_val=0, max_val=5), i2: RangeConstraint(min_val=0, max_val=5), i3: RangeConstraint(min_val=0, max_val=5), i4: RangeConstraint(min_val=0, max_val=5)}
Equality constraints: []

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources