ONNX 类型

可选类型

可选类型表示对元素(可以是张量、序列、映射或稀疏张量)或空值的引用。可选类型出现在模型输入、输出以及中间值中。

用例

可选类型使用户能够在 ONNX 中表示更多动态类型场景。类似于 Python 类型提示中的 Optional[X] 类型提示,它等效于 Union[None, X],ONNX 中的可选类型可以引用单个元素,也可以引用空值。

PyTorch 中的示例

可选类型仅出现在由 jit 脚本编译器生成的 TorchScript 图中。对模型进行脚本编写会捕获动态类型,其中可选值可以分配 None 或值。

  • 示例 1

      class Model(torch.nn.Module):
          def forward(self, x, y:Optional[Tensor]=None):
              if y is not None:
                  return x + y
              return x
    

    相应的 TorchScript 图

      Graph(
          %self : __torch__.Model,
          %x.1 : Tensor,
          %y.1 : Tensor?
      ):
          %11 : int = prim::Constant[value=1]()
          %4 : None = prim::Constant()
          %5 : bool = aten::__isnot__(%y.1, %4)
          %6 : Tensor = prim::If(%5)
              block0():
                  %y.4 : Tensor = prim::unchecked_cast(%y.1)
                  %12 : Tensor = aten::add(%x.1, %y.4, %11)
              -> (%12)
              block1():
              -> (%x.1)
          return (%6)
    

    ONNX 图

      Graph(
          %x.1 : Float(2, 3),
          %y.1 : Float(2, 3)
      ):
          %2 : Bool(1) = onnx::OptionalHasElement(%y.1)
          %5 : Float(2, 3) = onnx::If(%2)
              block0():
                  %3 : Float(2, 3) = onnx::OptionalGetElement(%y.1)
                  %4 : Float(2, 3) = onnx::Add(%x.1, %3)
              -> (%4)
              block1():
                  %x.2 : Float(2, 3) = onnx::Identity(%x.1)
              -> (%x.2)
          return (%5)
    
  • 示例 2

      class Model(torch.nn.Module):
          def forward(
                  self,
                  src_tokens,
                  return_all_hiddens=torch.tensor([False]),
          ):
              encoder_states: Optional[Tensor] = None
              if return_all_hiddens:
                  encoder_states = src_tokens
    
              return src_tokens, encoder_states
    

    相应的 TorchScript 图

      Graph(
          %src_tokens.1 : Float(3, 2, 4,),
          %return_all_hiddens.1 : Bool(1)
      ):
          %3 : None = prim::Constant()
          %encoder_states : Tensor? = prim::If(%return_all_hiddens.1)
              block0():
              -> (%src_tokens.1)
              block1():
              -> (%3)
          return (%src_tokens.1, %encoder_states)
    

    ONNX 图

      Graph(
          %src_tokens.1 : Float(3, 2, 4),
          %return_all_hiddens.1 : Bool(1)
      ):
          %2 : Float(3, 2, 4) = onnx::Optional[type=tensor(float)]()
          %3 : Float(3, 2, 4) = onnx::If(%return_all_hiddens.1)
              block0():
              -> (%src_tokens.1)
              block1():
              -> (%2)
          return (%3)