Broadcasting semantics(广播语义)
广播语义
本文档描述了XLA中的广播语义如何工作。
什么是广播?
广播是使具有不同形状的阵列具有用于算术运算的兼容形状的过程。术语是从Numpy (广播)借来的。
对于不同等级的多维阵列之间的操作或具有不同但可兼容的形状的多维阵列之间的操作可能需要广播。考虑加入X+v
其中X
是矩阵(秩2的阵列)和v
是一个矢量(秩1的阵列)。为了执行元素相加,X
LA需要通过复制特定次数来将该矢量“广播” v
到与矩阵相同的等级。矢量的长度必须与矩阵的至少一个维度相匹配。Xv
例如:
|1 2 3| + |7 8 9|
|4 5 6|
矩阵的维数是(2,3),矢量是(3)。该矢量通过在行上复制来获得:
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
在Numpy中,这被称为广播。
原则
XLA是具有XLA语言的低级基础结构,尽可能严格和明确,避免了隐含的和“神奇”的特性,这些特性可能会使得一些计算更容易定义,代价是将更多的假设置于用户代码中长期难以改变。如果需要,可以在客户端级别的包装中添加隐含的和神奇的功能。
在广播方面,需要在不同级别的阵列之间进行明确的广播规范。这与Numpy不同,Numpy在可能的情况下推断了规格。
将较低等级的阵列广播到较高等级的阵列上
标量
可以始终通过数组进行广播,而不需要明确规定广播尺寸。标量
和数组之间的元素明确的二元运算意味着对数组中的每个元素应用标量
运算。例如,将一个标量
添加到矩阵意味着生成一个矩阵,其中的每个元素是标量
与相应输入矩阵元素的和。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
大多数广播需求可以通过在二进制操作中使用维度元组来捕获。当操作的输入具有不同的等级时,该广播元组指定较高等级
阵列中的哪个维度与较低等级
阵列相匹配。
考虑前面的例子,不是向(2,3)矩阵中添加标量,而是将维度向量(3)添加到维度矩阵(2,3)中。没有指定广播,这个操作是无效的。
为了正确地请求矩阵矢量加法,指定广播维度为(1),这意味着矢量的维度与矩阵的维度1相匹配。在2D中,如果维度0被视为行,维度1被视为列,则这意味着该向量的每个元素都会变成一个匹配矩阵中行数的列:
|7 8 9| ==> |7 8 9|
|7 8 9|
作为一个更复杂的例子,考虑在3x3矩阵(维度(3,3))中添加一个3维向量(dimension(3))。这个例子有两种广播可以发生的方式:
(1)可以使用广播维度1。每个向量元素成为一列,并且该向量针对矩阵中的每一行进行复制。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2)可以使用广播维度0。每个向量元素都成为一行,并且该向量针对矩阵中的每列进行复制。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
注意:
在向3元素矢量添加2x3矩阵时,广播维度为0是无效的。
广播尺寸可以是描述较小等级形状如何广播为较大等级形状的元组。例如,给定2x3x4长方体和3x4矩阵,广播元组(1,2)意味着将矩阵与长方体的尺寸1和2匹配。
ComputationBuilder
如果broadcast_dimensions
给出参数,则在二进制操作中使用这种类型的广播。例如,请参阅ComputationBuilder
:: Add。在XLA源代码中,这种类型的广播有时被称为“InDim”广播。
正式定义
广播属性允许通过指定要匹配较高级别数组的哪些维度来将较低级别的数组与较高级别的数组进行匹配。例如,对于一个尺寸为MxNxPxQ的数组,可以将一个尺寸为T的矢量匹配如下:
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
在每种情况下,T必须等于较高等级阵列的匹配维度。矢量的值然后从匹配的维度广播到所有其他维度。
为了将TxV矩阵匹配到MxNxPxQ阵列上,使用了一对广播维度:
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
广播元组中的维度顺序必须是预计较低等级阵列的维度与较高等级阵列的维度相匹配的顺序。元组中的第一个元素说明较高级别数组中的哪个维度必须与较低级别数组中的维度0匹配。维1的第二个元素,依此类推。广播尺寸的顺序必须严格增加。例如,在前面的例子中,将V与N和T匹配到P是非法的; 将V与P和N匹配也是非法的。
用简并维度广播相似等级的阵列
相关的广播问题是广播具有相同等级但不同尺寸大小的两个阵列。与Numpy的规则类似,只有在阵列兼容
时才可能。两个阵列在所有维度兼容
时都是兼容
的。两个维度是兼容
的,如果:
- 他们是平等的,或者
- 其中一个是1(“退化”维度)
当遇到两个兼容的数组时,结果形状在每个维度索引处的两个输入中具有最大值。
例子:
1.(2,1)和(2,3)广播到(2,3)。
2. (1,2,5)和(7,2,5)广播到(7,2,5)
3.(7,2,5)和(7,1,5)广播到(7,2,5)
4. (7,2,5)和(7,2,6)不兼容,不能广播。
出现了一种特殊情况,并且也支持这种情况,其中每个输入数组在不同索引处具有退化维度。在这种情况下,结果是“外部操作”:(2,1)和(1,3)广播到(2,3)。有关更多示例,请参阅有关广播的Numpy文档。
广播组成
将较低等级的阵列广播到较高等级的阵列并
使用退化维度进行广播都可以在相同的二进制操作中执行。例如,可以使用(0)的广播尺寸值将大小为4的矢量和大小为1x2的矩阵相加在一起:
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
首先,使用广播维度将矢量广播到2(矩阵)。广播维度中的单值(0)表示矢量的维度零与矩阵的维度零相匹配。这产生大小为4xM的矩阵,其中值M被选择为匹配1x2阵列中的相应尺寸大小。因此,生成一个4x2矩阵:
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
然后,“退化维度广播”广播1x2矩阵的维度零以匹配右侧的相应维度大小:
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
一个更复杂的例子是使用(1,2)的广播尺寸将大小为1×2的矩阵添加到大小为4×3×1的数组。首先,使用广播尺寸将1×2矩阵广播到等级3以产生中间M×1×2阵列,其中尺寸尺寸M由较大操作数(4×3×1阵列)的尺寸确定,产生4×1×2中间阵列。因为当广播维度是(1,2)时,维度1和维度2被映射到原始1x2矩阵的维度,所以M处于维度0(最左边的维度)。可以使用退化维度的广播将此中间数组添加到4x3x1矩阵中,以产生4x3x2数组结果。