序列化

保存模型和任何 Proto 类

此 ONNX 图需要序列化为一个连续的内存缓冲区。方法 SerializeToString 在每个 ONNX 对象中都可用。

with open("model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

此方法具有以下签名。

class onnx.ModelProto
SerializeToString()

将消息序列化为字符串,仅适用于已初始化的消息。

每个 Proto 类都实现了方法 SerializeToString。因此,以下代码适用于页面 Protos 中描述的任何类。

with open("proto.pb", "wb") as f:
    f.write(proto.SerializeToString())

以下示例演示如何保存一个 NodeProto.

from onnx import NodeProto

node = NodeProto()
node.name = "example-type-proto"
node.op_type = "Add"
node.input.extend(["X", "Y"])
node.output.extend(["Z"])

with open("node.pb", "wb") as f:
    f.write(node.SerializeToString())

加载模型

以下函数仅自动加载类 ModelProto。下一节将展示如何恢复任何其他 proto 类。

onnx.load(f: IO[bytes] | str | PathLike, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str | None = None, load_external_data: bool = True) ModelProto

将序列化的 ModelProto 加载到内存中。

参数:
  • f – 可以是文件类对象(具有“read”函数)或包含文件名字符串/PathLike

  • format – 序列化格式。当未指定时,它从文件扩展名推断(当 f 是路径时)。如果未指定 _且_ f 不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。

  • load_external_data – 是否加载外部数据。如果数据在模型的同一目录下,则设置为 True。如果不是,用户需要使用目录调用 load_external_data_for_model() 来加载外部数据。

返回值:

加载到内存中的 ModelProto。

from onnx import load

onnx_model = load("model.onnx")

from onnx import load

with open("model.onnx", "rb") as f:
    onnx_model = load(f)

下一个函数从字节数组中执行相同的操作。

onnx.load_model_from_string(s: bytes | str, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str = 'protobuf') ModelProto[source]

加载包含序列化的 ModelProto 的二进制字符串(字节)。

参数:
  • s – 一个字符串,包含序列化的 ModelProto

  • format – 序列化格式。当未指定时,它从文件扩展名推断(当 f 是路径时)。如果未指定 _且_ f 不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。

返回值:

加载到内存中的 ModelProto。

加载 Proto

这里“Proto”指的是任何包含数据的类型,包括模型、张量、稀疏张量以及Protos页面列出的任何类。用户必须知道需要还原的数据类型,然后调用方法ParseFromStringprotobuf 不存储有关保存数据的类类型的任何信息。因此,必须在还原对象之前知道该类。

class onnx.ModelProto
ParseFromString()

将序列化消息解析到当前消息。

以下示例显示了如何还原 NodeProto

from onnx import NodeProto

tp2 = NodeProto()
with open("node.pb", "rb") as f:
    content = f.read()

tp2.ParseFromString(content)

print(tp2)
input: "X"
input: "Y"
output: "Z"
name: "example-type-proto"
op_type: "Add"

对于 TensorProto,存在一个快捷方法。

onnx.load_tensor_from_string(s: bytes, format: Literal['protobuf', 'textproto', 'onnxtxt', 'json'] | str = 'protobuf') TensorProto[source]

加载包含序列化 TensorProto 的二进制字符串(字节)。

参数:
  • s – 包含序列化 TensorProto 的字符串

  • format – 序列化格式。当未指定时,它从文件扩展名推断(当 f 是路径时)。如果未指定 _且_ f 不是路径,则使用“protobuf”。当格式为文本格式时,编码假定为“utf-8”。

返回值:

加载到内存中的 TensorProto。