Skip to content

Commit

Permalink
gemm: Add --section flag and burst size alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Oct 24, 2023
1 parent 082dd6a commit dd1cfce
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
6 changes: 4 additions & 2 deletions sw/blas/gemm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ MK_DIR := $(dir $(realpath $(lastword $(MAKEFILE_LIST))))
DATA_DIR := $(realpath $(MK_DIR)/data)
SRC_DIR := $(realpath $(MK_DIR)/src)

DATA_CFG ?= $(DATA_DIR)/params.hjson
SECTION ?=

APP ?= gemm
SRCS ?= $(realpath $(SRC_DIR)/main.c)
INCDIRS ?= $(DATA_DIR) $(SRC_DIR)

DATA_CFG ?= $(DATA_DIR)/params.hjson
DATAGEN_PY = $(DATA_DIR)/datagen.py
DATA_H = $(DATA_DIR)/data.h

$(DATA_H): $(DATAGEN_PY) $(DATA_CFG)
$< -c $(DATA_CFG) > $@
$< -c $(DATA_CFG) --section="$(SECTION)" > $@

.PHONY: clean-data clean

Expand Down
18 changes: 15 additions & 3 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
'fp8alt': {'exp': 4, 'mant': 3}
}

# AXI splits bursts crossing 4KB address boundaries. To minimize
# the occurrence of these splits the data should be aligned to 4KB
BURST_ALIGNMENT = 4096


def golden_model(alpha, a, b, beta, c):
return alpha * np.matmul(a, b) + beta * c
Expand Down Expand Up @@ -96,9 +100,12 @@ def emit_header(**kwargs):
data_str += [format_scalar_definition('uint32_t', 'BETA', kwargs['beta'])]
data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)]
data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten())]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten())]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten())]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
if kwargs['prec'] == 8:
result_def = format_vector_definition(C_TYPES['64'], 'result', result.flatten())
else:
Expand All @@ -120,11 +127,16 @@ def main():
required=True,
help='Select param config file kernel'
)
parser.add_argument(
'--section',
type=str,
help='Section to store matrices in')
args = parser.parse_args()

# Load param config file
with args.cfg.open() as f:
param = hjson.loads(f.read())
param['section'] = args.section

# Emit header file
print(emit_header(**param))
Expand Down

0 comments on commit dd1cfce

Please sign in to comment.