AffineGrid¶
AffineGrid - 20¶
版本¶
域:
main
自版本:
20
函数:
True
支持级别:
SupportType.COMMON
形状推断:
True
此版本的运算符自版本 20起可用。
摘要¶
给定一批仿射矩阵 theta(https://pytorch.ac.cn/docs/stable/generated/torch.nn.functional.affine_grid.html),生成一个 2D 或 3D 流场(采样网格)。仿射矩阵theta
应用于以其齐次表达式表示的位置张量。以下是一个 3D 示例
[r00, r01, r02, t0] [x] [x']
[r10, r11, r12, t1] * [y] = [y']
[r20, r21, r22, t2] [z] [z']
[0, 0, 0, 1 ] [1] [1 ]
其中(x, y, z)
是原始空间中的位置,(x', y', z')
是输出空间中的位置。最后一行始终为[0, 0, 0, 1]
,并且未存储在仿射矩阵中。因此,对于 2D,我们有theta
的形状为(N, 2, 3)
,对于 3D,形状为(N, 3, 4)
。
输入size
用于定义在原始 2D 或 3D 空间中均匀分布的位置网格,其维度范围从-1
到1
。输出grid
包含输出空间中的位置。
当align_corners=1
时,将-1
和1
视为角像素的中心(在插图中标记为v
)。
v v v v
|-------------------|------------------|
-1 0 1
当align_corners=0
时,将-1
和1
视为角像素的外边缘。
v v v v
|------------------|-------------------|
-1 0 1
函数体¶
此运算符的函数定义。
<
domain: "",
opset_import: ["" : 20]
>
AffineGrid <align_corners>(theta, size) => (grid)
{
one = Constant <value_int: int = 1> ()
two = Constant <value_int: int = 2> ()
zero = Constant <value_int: int = 0> ()
four = Constant <value_int: int = 4> ()
one_1d = Constant <value_ints: ints = [1]> ()
zero_1d = Constant <value_ints: ints = [0]> ()
minus_one = Constant <value_int: int = -1> ()
minus_one_f = CastLike (minus_one, theta)
zero_f = CastLike (zero, theta)
one_f = CastLike (one, theta)
two_f = CastLike (two, theta)
constant_align_corners = Constant <value_int: int = @align_corners> ()
constant_align_corners_equal_zero = Equal (constant_align_corners, zero)
size_ndim = Size (size)
condition_is_2d = Equal (size_ndim, four)
N, C, D, H, W = If (condition_is_2d) <then_branch: graph = g1 () => ( N_then, C_then, D_then, H_then, W_then) {
N_then, C_then, H_then, W_then = Split <num_outputs: int = 4> (size)
D_then = Identity (one_1d)
}, else_branch: graph = g2 () => ( N_else, C_else, D_else, H_else, W_else) {
N_else, C_else, D_else, H_else, W_else = Split <num_outputs: int = 5> (size)
}>
size_NCDHW = Concat <axis: int = 0> (N, C, D, H, W)
theta_3d = If (condition_is_2d) <then_branch: graph = g3 () => ( theta_then) {
gather_idx_6 = Constant <value_ints: ints = [0, 1, 2, 0, 1, 2]> ()
shape_23 = Constant <value_ints: ints = [2, 3]> ()
gather_idx_23 = Reshape (gather_idx_6, shape_23)
shape_N23 = Concat <axis: int = 0> (N, shape_23)
gather_idx_N23 = Expand (gather_idx_23, shape_N23)
thetaN23 = GatherElements <axis: int = 2> (theta, gather_idx_N23)
r1, r2 = Split <axis: int = 1, num_outputs: int = 2> (thetaN23)
r1_ = Squeeze (r1)
r2_ = Squeeze (r2)
r11, r12, t1 = Split <axis: int = 1, num_outputs: int = 3> (r1_)
r21, r22, t2 = Split <axis: int = 1, num_outputs: int = 3> (r2_)
r11_shape = Shape (r21)
float_zero_1d_ = ConstantOfShape (r11_shape)
float_zero_1d = CastLike (float_zero_1d_, theta)
float_one_1d = Add (float_zero_1d, one_f)
R1 = Concat <axis: int = 1> (r11, r12, float_zero_1d, t1)
R2 = Concat <axis: int = 1> (r21, r22, float_zero_1d, t2)
R3 = Concat <axis: int = 1> (float_zero_1d, float_zero_1d, float_one_1d, float_zero_1d)
R1_ = Unsqueeze (R1, one_1d)
R2_ = Unsqueeze (R2, one_1d)
R3_ = Unsqueeze (R3, one_1d)
theta_then = Concat <axis: int = 1> (R1_, R2_, R3_)
}, else_branch: graph = g4 () => ( theta_else) {
theta_else = Identity (theta)
}>
two_1d = Constant <value_ints: ints = [2]> ()
three_1d = Constant <value_ints: ints = [3]> ()
five_1d = Constant <value_ints: ints = [5]> ()
constant_D_H_W_shape = Slice (size_NCDHW, two_1d, five_1d)
zeros_D_H_W_ = ConstantOfShape (constant_D_H_W_shape)
zeros_D_H_W = CastLike (zeros_D_H_W_, theta)
ones_D_H_W = Add (zeros_D_H_W, one_f)
D_float = CastLike (D, zero_f)
H_float = CastLike (H, zero_f)
W_float = CastLike (W, zero_f)
start_d, step_d, start_h, step_h, start_w, step_w = If (constant_align_corners_equal_zero) <then_branch: graph = h1 () => ( start_d_then, step_d_then, start_h_then, step_h_then, start_w_then, step_w_then) {
step_d_then = Div (two_f, D_float)
step_h_then = Div (two_f, H_float)
step_w_then = Div (two_f, W_float)
step_d_half = Div (step_d_then, two_f)
start_d_then = Add (minus_one_f, step_d_half)
step_h_half = Div (step_h_then, two_f)
start_h_then = Add (minus_one_f, step_h_half)
step_w_half = Div (step_w_then, two_f)
start_w_then = Add (minus_one_f, step_w_half)
}, else_branch: graph = h2 () => ( start_d_else, step_d_else, start_h_else, step_h_else, start_w_else, step_w_else) {
D_float_nimus_one = Sub (D_float, one_f)
H_float_nimus_one = Sub (H_float, one_f)
W_float_nimus_one = Sub (W_float, one_f)
D_equals_one = Equal (D, one)
step_d_else = If (D_equals_one) <then_branch: graph = g5 () => ( step_d_else_then) {
step_d_else_then = Identity (zero_f)
}, else_branch: graph = g6 () => ( step_d_else_else) {
step_d_else_else = Div (two_f, D_float_nimus_one)
}>
step_h_else = Div (two_f, H_float_nimus_one)
step_w_else = Div (two_f, W_float_nimus_one)
start_d_else = Identity (minus_one_f)
start_h_else = Identity (minus_one_f)
start_w_else = Identity (minus_one_f)
}>
grid_w_steps_int = Range (zero, W, one)
grid_w_steps_float = CastLike (grid_w_steps_int, step_w)
grid_w_steps = Mul (grid_w_steps_float, step_w)
grid_w_0 = Add (start_w, grid_w_steps)
grid_h_steps_int = Range (zero, H, one)
grid_h_steps_float = CastLike (grid_h_steps_int, step_h)
grid_h_steps = Mul (grid_h_steps_float, step_h)
grid_h_0 = Add (start_h, grid_h_steps)
grid_d_steps_int = Range (zero, D, one)
grid_d_steps_float = CastLike (grid_d_steps_int, step_d)
grid_d_steps = Mul (grid_d_steps_float, step_d)
grid_d_0 = Add (start_d, grid_d_steps)
zeros_H_W_D = Transpose <perm: ints = [1, 2, 0]> (zeros_D_H_W)
grid_d_1 = Add (zeros_H_W_D, grid_d_0)
grid_d = Transpose <perm: ints = [2, 0, 1]> (grid_d_1)
zeros_D_W_H = Transpose <perm: ints = [0, 2, 1]> (zeros_D_H_W)
grid_h_1 = Add (zeros_D_W_H, grid_h_0)
grid_h = Transpose <perm: ints = [0, 2, 1]> (grid_h_1)
grid_w = Add (grid_w_0, zeros_D_H_W)
grid_w_usqzed = Unsqueeze (grid_w, minus_one)
grid_h_usqzed = Unsqueeze (grid_h, minus_one)
grid_d_usqzed = Unsqueeze (grid_d, minus_one)
ones_D_H_W_usqzed = Unsqueeze (ones_D_H_W, minus_one)
original_grid = Concat <axis: int = -1> (grid_w_usqzed, grid_h_usqzed, grid_d_usqzed, ones_D_H_W_usqzed)
constant_shape_DHW_4 = Constant <value_ints: ints = [-1, 4]> ()
original_grid_DHW_4 = Reshape (original_grid, constant_shape_DHW_4)
original_grid_4_DHW_ = Transpose (original_grid_DHW_4)
original_grid_4_DHW = CastLike (original_grid_4_DHW_, theta_3d)
grid_N_3_DHW = MatMul (theta_3d, original_grid_4_DHW)
grid_N_DHW_3 = Transpose <perm: ints = [0, 2, 1]> (grid_N_3_DHW)
N_D_H_W_3 = Concat <axis: int = -1> (N, D, H, W, three_1d)
grid_3d_else_ = Reshape (grid_N_DHW_3, N_D_H_W_3)
grid_3d = CastLike (grid_3d_else_, theta_3d)
grid = If (condition_is_2d) <then_branch: graph = g1 () => ( grid_then) {
grid_squeezed = Squeeze (grid_3d, one_1d)
grid_then = Slice (grid_squeezed, zero_1d, two_1d, three_1d)
}, else_branch: graph = g2 () => ( grid_else) {
grid_else = Identity (grid_3d)
}>
}
属性¶
align_corners - INT(默认为
'0'
)如果 align_corners=1,则将 -1 和 1 视为角像素的中心。如果 align_corners=0,则将 -1 和 1 视为角像素的外边缘。
输入¶
theta (异构) - T1
形状为 (N, 2, 3)(对于 2D)或 (N, 3, 4)(对于 3D)的仿射矩阵输入批次
size (异构) - T2
目标输出图像大小 (N, C, H, W)(对于 2D)或 (N, C, D, H, W)(对于 3D)
输出¶
grid (异构) - T1
形状为 (N, H, W, 2) 的 2D 样本坐标输出张量或形状为 (N, D, H, W, 3) 的 3D 样本坐标输出张量。
类型约束¶
T1 in (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)将网格类型约束为浮点张量。
T2 in (
tensor(int64)
)将 size 的类型约束为 int64 张量。