PyRTL Matrix Library¶
Hardware implementations of matrix operations, in the PyRTL hardware description language.
PyRTL Matrix Operations¶
PyRTL implementations of common linear algebra operations.
The operations in this module all use WireMatrix2D as their input and output,
so they can be composed to implement arbitrary matrix calculations. See the
pyrtl_matrix demo for an example that computes x · (y - y_zero) + a.
Warning
These implementations may not be completely general. They have only been tested in the context of dense neural networks.
- pyrtlnet.pyrtl_matrix.make_argmax(a)[source]¶
Combinationally argmax a matrix
aby column, returning the index of the row containing the largest value in each column.For example, given the matrix:
┌ ┐ │ 1 5 │ │ 3 4 │ └ ┘
make_argmaxreturns awire_matrix()containing the values[1, 0], because the1-st value (3) is the largest in the first column, and the0-th value (5) is the largest value in the second column.This implementation is fully combinational (no registers).
- Parameters:
a (
WireMatrix2D) – Input matrix.- Return type:
wire_matrix- Returns:
A
wire_matrix()containing the concatenation of the row indexes of the largest values in each column ofain unsigned binary.
- pyrtlnet.pyrtl_matrix.make_elementwise_add(name, a, b, output_bitwidth)[source]¶
Combinationally add matrices
aandbelementwise.This implementation is fully combinational (no registers).
bis allowed to be a column vector of the same amount of rows asa.- Parameters:
name (
str) – The returnedWireMatrix2Dwill be named{name}.output.a (
WireMatrix2D)b (
WireMatrix2D)
- Return type:
- Returns:
WireMatrix2Dcontaining a + b.
- pyrtlnet.pyrtl_matrix.make_elementwise_normalize(name, a, m0, n, z3, output_bitwidth)[source]¶
Convert an un-normalized layer output to a normalized output.
This function effectively multiplies the layer’s output by its scale factor
mand adds its zero pointz3.mis a floating-point number, which is represented by a 32-bit fixed-point multiplierm0and bitwise rounding right shiftn, seenormalization_constants(). So instead of doing a floating-point multiplication bym, we do a fixed-point multiplication bym0, followed by a bitwise rounding right shift byn.See Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference for more details. This implements the part of Equation 7 that’s outside the parentheses (addition of
z3and multiplication bym).Layers can have per-axis scale factors, so
m0andnwill be vectors of per-row scale factors and shift amounts. See per-axis quantization for details.For example, if
accumulator_bitwidthis 32, andoutput_bitwidthis 8, this function can multiply and shift 32-bitavalues into 8-bit output values to most effectively utilize the limited 8-bit output range.This implementation is fully combinational (no registers).
- Parameters:
name (
str) – The returnedWireMatrix2Dwill be named{name}.output.a (
WireMatrix2D) – Matrix to normalize, with bitwidthaccumulator_bitwidth.m0 (
Fxp) – Vector of per-row fixed-point multipliers.n (
ndarray) – Vector of per-row shift amounts.z3 (
ndarray) – Vector of per-row zero-point adjustments.output_bitwidth (
int) – Number of bits to output for each element. This should generally be 8.
- Return type:
- Returns:
z3 + (a * m0) >> n, where*is elementwise fixed-point multiplication, and>>is a rounding right shift. The return value has the same shape asaand bitwidthoutput_bitwidth.
- pyrtlnet.pyrtl_matrix.make_elementwise_relu(name, a)[source]¶
Combinationally ReLU matrix
a. This computesmax(a, 0)elementwise.This implementation is fully combinational (no registers).
- Parameters:
name (
str) – The returnedWireMatrix2Dwill be named{name}.output.a (
WireMatrix2D) – InputWireMatrix2Dto ReLU.
- Return type:
- Returns:
max(a, 0)computed elementwise.
- pyrtlnet.pyrtl_matrix.make_input_memblock_data(a, input_bitwidth, addrwidth)[source]¶
Convert a
ndarraytoMemBlockdata for use with the systolic array.When
make_systolic_array()uses aWireMatrix2Dwith aMemBlockas input, the systolic array will read consecutive addresses of theMemBlockeach cycle. The data at each address must contain all the values inathat will be consumed by the systolic array in the next cycle. All the values needed in a cycle are concatenated together, and stored at one address.The returned
memblock_datacan be directly used as theromdatafor aRomBlock, or enumerated and converted to adictand used with aMemBlock, viamemory_value_mapinSimulation:memblock_data = pyrtl_matrix.make_input_memblock_data(...) memblock_dict = dict(enumerate(memblock_data)) sim = pyrtl.Simulation(memory_value_map={memblock: memblock_dict})
- Parameters:
- Return type:
- Returns:
A list of integer values, ready for storage in a
MemBlock. Each integer contains all the bits fromathat the systolic array needs in one cycle.
- pyrtlnet.pyrtl_matrix.make_systolic_array(name, a, b, b_zero, input_bitwidth, accumulator_bitwidth, initial_delay_cycles=0)[source]¶
Generate an output-stationary systolic array, computing
a ⋅ (b - b_zero).- Parameters:
name (
str) – The returnedWireMatrix2Dwill be named{name}.output.a (
WireMatrix2D|ndarray) – Left input to the systolic array. The types ofaandbdo not have to match.b (
WireMatrix2D|ndarray) – Right input to the systolic array. The types ofaandbdo not have to match.b_zero (
int) – Zero point forb. Useful for quantized neural network computations. Set it to zero for standard matrix multiplication.input_bitwidth (
int) – Bitwidth of each input element.accumulator_bitwidth (
int) – Bitwidth used when summing dot products. The systolic array multiplies and accumulates many input elements, soaccumulator_bitwidthshould be larger thaninput_bitwidth.initial_delay_cycles (
int, default:0) – Number of cycles to wait before starting operation. This is a temporary hack that’s currently required for correct synthesis with Vivado. No delay cycles should be required.
- Return type:
- Returns:
A
WireMatrix2Drepresentinga ⋅ (b - b_zero).
Systolic Array Architecture¶
The systolic array’s architecture is shown in the diagram below.
l0'isl0, delayed by one cycle, andl0''isl0, delayed by two cycles:t0 t1 │ │ ▼ ▼ ┌─────────┐ l0' ┌─────────┐ l0'' l0 ─────▶│ reg_0_0 │─────┬─────────▶│ reg_0_1 │─────┬───── ... └─────────┘ │ └─────────┘ │ │ │ │ │ │ ▼ │ ▼ │ ┌────────┐ │ ┌────────┐ t0' ├─────▶│ pe_0_0 │ t1' ├─────▶│ pe_0_1 │ │ └────────┘ │ └────────┘ │ │ ▼ ▼ ┌─────────┐ l1' ┌─────────┐ l1'' l1 ─────▶│ reg_1_0 │─────┬─────────▶│ reg_1_1 │─────┬───── ... └─────────┘ │ └─────────┘ │ │ │ │ │ │ ▼ │ ▼ │ ┌────────┐ │ ┌────────┐ t0'' ├─────▶│ pe_1_0 │ t1'' ├─────▶│ pe_1_1 │ │ └────────┘ │ └────────┘ │ │ ... ...The systolic array multiplies matrices
aandb, whereahas shape(num_rows, num_inner)andbhas shape(num_inner, num_columns).The systolic array is a 2D array of
Register(reg) and processing elements (pe), arranged innum_rowsrows andnum_columnscolumns. Pairs of Register and processing element are grouped into a tile, for examplereg_0_0andpe_0_0form the tile at(0, 0). Multiple tiles can be wired together to create the full systolic array.Systolic Array Operation¶
Matrix
astreams in the left inputs(l0, l1, ... ln), over(num_inner + num_rows - 1)cycles.Matrix
bstreams in the top inputs(t0, t1, ... tn), over(num_inner + num_columns - 1)cycles.Data streams from these left and top inputs, through registers
(reg_0_0, reg_0_1, ...), to processing elements(pe_0_0, pe_0_1, ...). The processing elements store the matrix multiplication output in accumulator registers. The output does not move through the array, which makes this array “output-stationary.”The left and top inputs change over time. If the matrix
ais:┌ ┐ a = │ 1 2 3 │ │ 4 5 6 │ └ ┘then
num_rows=2andnum_inner=3because matrixahas shape(2, 3). There are two left inputs becausenum_rows=2. It will take4cycles to stream matrixa, because3 + 2 - 1 = 4. The left inputs for each cycle are:│ cycle │ 0 1 2 3 ───┼─────── l0 │ 1 2 3 0 l1 │ 0 4 5 6
Note how
l1is shifted forward one cycle, and the holes have been filled with zeroes.If the matrix
bis:┌ ┐ b = │ 7 8 9 10 │ │ 11 12 13 14 │ │ 15 16 17 18 │ └ ┘then
num_inner=3andnum_columns=4because matrixbhas shape(3, 4). There are four top inputs becausenum_columns=4. It will take6cycles to stream matrixb, because3 + 4 - 1 = 6. The top inputs for each cycle are:│ cycle │ 0 1 2 3 4 5 ───┼────────────────── t0 │ 7 11 15 0 0 0 t1 │ 0 8 12 16 0 0 t2 │ 0 0 9 13 17 0 t3 │ 0 0 0 10 14 18
Note how matrix
bhas been transposed.t0is[7 11 15]over the first three cycles, which corresponds to the leftmost column of matrixb.t1is shifted forward one cycle,t2is shifted forward two cycles,t3is shifted forward three cycles, and the holes have been filled with zeroes.Compare
t0andl0.l0corresponds to the topmost row of matrixa, andt0corresponds to the leftmost column of matrixb.t0andl0can be generated by following the same procedure, except matrixbis initially transposed, while matrixais not.When there is no more input to stream in to the left or top inputs, the corresponding input should be set to zero. The final result will be ready in
(num_rows + num_inner + num_columns)cycles, and the matrix multiplication result can be read from thepe_{row}_{col}registers.The pyrtl_matrix demo runs this example through the systolic array named
mm0, and these parallelogram-shaped inputs can be seen propagating through the array’smm0.leftandmm0.topinputs in the output fromrender_trace():▕0 ▕1 ▕2 ▕3 ▕4 ▕5 ▕6 ▕7 ▕8 ▕9 mm0.left[0] ──────▏1 ▕ 2 ▕ 3 ▕──────────────────────────────────── mm0.left[1] ────────────▏4 ▕ 5 ▕ 6 ▕────────────────────────────── mm0.top[0] ──────▏7 ▕ 11 ▕ 15 ▕──────────────────────────────────── mm0.top[1] ────────────▏8 ▕ 12 ▕ 16 ▕────────────────────────────── mm0.top[2] ──────────────────▏9 ▕ 13 ▕ 17 ▕──────────────────────── mm0.top[3] ────────────────────────▏10 ▕ 14 ▕ 18 ▕────────────────── mm0.output[0][0] ──────────────────▏7 ▕ 29 ▕ 74 mm0.output[0][1] ────────────────────────▏8 ▕ 32 ▕ 80 mm0.output[0][2] ──────────────────────────────▏9 ▕ 35 ▕ 86 mm0.output[0][3] ────────────────────────────────────▏10 ▕ 38 ▕ 92 mm0.output[1][0] ────────────────────────▏28 ▕ 83 ▕ 173 mm0.output[1][1] ──────────────────────────────▏32 ▕ 92 ▕ 188 mm0.output[1][2] ────────────────────────────────────▏36 ▕ 101 ▕ 203 mm0.output[1][3] ──────────────────────────────────────────▏40 ▕ 110 ▕ 218 mm0.state INIT ▕ BUSY ▕ DONE ▁▁▁▁▁▁ mm0.output.valid ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▏The
mm0.outputsignals show the systolic array’s output matrix. For example,mm0.output[0][0]shows the output matrix’s final top left value is74, which is1 * 7 + 2 * 11 + 3 * 15.The trace shows how the systolic array multiplies and accumulates to execute this matrix multiplication over time. For example, the trace for
mm0.output[0][0]shows:7 in cycle 3, which is 1 * 7. 29 in cycle 4, which is 1 * 7 + 2 * 11. 74 in cycle 5, which is 1 * 7 + 2 * 11 + 3 * 15.
The inputs for computing
mm0.output[0][0]can be found in themm0.left[0]andmm0.top[0]traces.The expected result of multiplying matrices
aandbis:┌ ┐ output = │ 74 80 86 92 │ │ 173 188 203 218 │ └ ┘And these values can be found on the right side of the
mm0.outputtraces.
- pyrtlnet.pyrtl_matrix.minimum_bitwidth(a)[source]¶
Return the minimum number of bits needed to represent each element in
a.
- pyrtlnet.pyrtl_matrix.num_systolic_array_cycles(a_shape, b_shape)[source]¶
Return the cycles needed to multiply
aandbwith the systolic array.When using
make_systolic_array()with aMemBlockas input, this function is useful for calculating theMemBlock’saddrwidth.
- pyrtlnet.pyrtl_matrix.saturating_truncate(value, bitwidth)[source]¶
Truncate a signed
valuetobitwidth, saturating at the largest and smallest representable values.If
valueis too large to fit inbitwidth(overflow), the outputWireVectorwill have the value2 ** (bitwidth - 1) - 1.If
valueis too small to fit inbitwidth(underflow), the outputWireVectorwill have the value-2 ** (bitwidth - 1).Otherwise, the output
WireVectorwill have valuevalue.- Parameters:
value (
WireVector) – Value to truncate.bitwidth (
int) – Bitwidth to truncatevalueto. Must be less thanvalue.bitwidth.
- Return type:
- Returns:
valuetruncated tobitwidth, saturating at the largest and smallest representable values if overflow or underflow occur.
WireMatrix2D¶
- class pyrtlnet.wire_matrix_2d.WireMatrix2D(values, shape=(), bitwidth=0, name='', ready=None, valid=None)[source]¶
WireMatrix2Drepresents a 2D matrix ofWireVector.WireMatrix2Dfunctions like a 2Dwire_matrix(), with a NumPy-styleshapetuple, andready/validsignals. It serves as the input and output type for all operations in the PyRTL Matrix Library. These matrix operations can be composed. For example, when computingx ⋅ y + a, there is an intermediateWireMatrix2Dthat serves as both the output of the multiplicationx ⋅ y, and the input to the addition_ + a.WireMatrix2Dsupports two underlying representations:self.Matrix, which is a 2Dwire_matrix().wire_matrixsupports any PyRTLWireVectortype, so you could have aself.MatrixofRegisterfor example. This representation is used when theWireMatrix2Dis constructed without aMemBlock.MemBlock, where the matrix data is stored in aMemBlockorRomBlock. This representation is currently experimental and not completely supported. This representation is used when theWireMatrix2Dis constructed with aMemBlock.
ready/validprotocol¶WireMatrix2Dserves as a shared buffer between an upstream producer that writes data into theWireMatrix2D, and a downstream consumer that reads data from theWireMatrix2D. The producer and consumer must coordinate their usage to avoid corrupting this shared resource. For example:While the producer is writing data to the
WireMatrix2D, it is not safe for the consumer to read data from theWireMatrix2D.While the consumer is reading data from the
WireMatrix2D, it is not safe for the producer to write new data into theWireMatrix2D.
WireMatrix2Dprovidesreadyandvalidsignals to help the producer and consumer coordinate.readyindicates when it is safe for the producer to write new data to theWireMatrix2D, andvalidindicates when it is safe for the consumer to read data from theWireMatrix2D.Note
readyandvalidare undrivenWireVectorsprovided as a convenience.WireMatrix2Ditself does not assign any values to these wires or inspect the values on these wires. The producer and consumer must set and check these signals appropriately.- __getitem__(row)[source]¶
Implements
WireMatrix2D’s[]operator.If this
WireMatrix2Dwas not constructed with aMemBlock, its elements can be accessed withself[row][column]. This returns aWireVectorwith bitwidthself.bitwidth. This method only implements row-level indexing, and returns awire_matrix(). Column-level indexing is implemented by the returnedwire_matrix().Warning
If this
WireMatrix2Dwas constructed with aMemBlock, this method can currently only retrieve a full row of values asmatrix[row]. Per-element access is currently not supported.- Parameters:
row (
WireVector) – Row number to retrieve from the matrix.- Return type:
- Returns:
A
WireVectorcontaining all the data in the row concatenated together. If thisWireMatrix2Dwas not constructed with aMemBlock, the returnedWireVectoris actually awire_matrix(), which can be further indexed with its__getitem__operator to retrieve data in a specific column.
- __init__(values, shape=(), bitwidth=0, name='', ready=None, valid=None)[source]¶
Construct a 2D
wire_matrix()containingvalues.- Parameters:
values (
ndarray|list[list[WireVector]] |WireVector|MemBlock) – Values for theWireMatrix2D. IfNone, creates aWireMatrix2DofInput.valuescan also be andarray, a list of lists ofWireVector, one large concatenatedWireVectorcontaining all the values for matrix, or aMemBlock.shape (
tuple[int,int], default:()) – Shape of theWireMatrix2D. Must be two dimensional. Ifvaluesis andarray, the shape will be inferred from thendarrayand thisshapeargument can be omitted.bitwidth (
int, default:0) – The bitwidth of each element.name (
str, default:'') – Names for all elements in theWireMatrix2Dwill be generated based on this prefix. For example, ifnameisfoothen the top left element will be namedfoo[0][0].ready (
bool|WireVector, default:None) – A 1-bit signal indicating if theWireMatrix2Dcan be safely written by the upstream producer.valid (
bool|WireVector, default:None) – A 1-bit signal indicating if theWireMatrix2Dcan be safely read by the downstream consumer.
- inspect(sim)[source]¶
Collect and return
Outputvalues from aSimulation.Retrieves
Outputvalues forselffrom aSimulation, and returns the retrieved values in andarray.Use
make_outputs()to create the retrievedOutputvalues.- Parameters:
sim (
Simulation) –Simulationto retrieve values from.- Return type:
- Returns:
Retrieved values as a
ndarray.
- make_provided_inputs(values)[source]¶
Create
provided_inputsforSimulation.This should only be used with a
WireMatrix2DofInput. ThisWireMatrix2Dshould have been constructed withvalues=None.
- ready: WireVector[source]¶
A 1-bit signal indicating if the
WireMatrix2Dcan be safely written by the upstream producer.See ready/valid protocol.
- transpose()[source]¶
Return a transposed version of
self, as anotherWireMatrix2D.Warning
If
self.memblockis notNone, this does not reformat theMemBlock’s data. It only changes the shape. TheMemBlockis assumed to already contain transposed data.- Return type:
- Returns:
A transposed version of
self.
- valid: WireVector[source]¶
A 1-bit signal indicating if the
WireMatrix2Dcan be safely read by the downstream consumer.See ready/valid protocol.