openblas/openblas-0.3.26-incompatibletypes.patch

753 lines
20 KiB
Diff
Raw Permalink Normal View History

2024-06-27 11:33:30 +00:00
This is a compilation of more upstream commits related to:
https://github.com/OpenMathLib/OpenBLAS/issues/4475
From 63004fa5f76ef1058975271314bc4591e7878726 Mon Sep 17 00:00:00 2001
From: Honza Horak <hhorak@redhat.com>
Date: Fri, 9 Feb 2024 09:49:41 +0100
Subject: [PATCH 1/6] Fix incompatible pointer type in BFLOAT16 mode
Upstream commit:
commit 68d354814f9f846338e1988c4f609c8add419012
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Date: Sun Feb 4 01:14:22 2024 +0100
Fix incompatible pointer type in BFLOAT16 mode
---
interface/gemmt.c | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/interface/gemmt.c b/interface/gemmt.c
index 046432670..2fb9954ad 100644
--- a/interface/gemmt.c
+++ b/interface/gemmt.c
@@ -478,7 +478,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif
// for alignment
buffer_size = (buffer_size + 3) & ~3;
- STACK_ALLOC(buffer_size, FLOAT, buffer);
+ STACK_ALLOC(buffer_size, IFLOAT, buffer);
#ifdef SMP
@@ -567,7 +567,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif
// for alignment
buffer_size = (buffer_size + 3) & ~3;
- STACK_ALLOC(buffer_size, FLOAT, buffer);
+ STACK_ALLOC(buffer_size, IFLOAT, buffer);
#ifdef SMP
--
2.41.0
From edfd4f52f3f22344863c233411ae792fb12aa81b Mon Sep 17 00:00:00 2001
From: Honza Horak <hhorak@redhat.com>
Date: Fri, 9 Feb 2024 09:53:40 +0100
Subject: [PATCH 2/6] Separate the interface for SBGEMMT from GEMMT due to
differences in GEMV arguments
Upstream commit:
commit d4db6a9f16a5c82bbe1860f591cc731c4d83d7c8
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Date: Tue Feb 6 22:23:47 2024 +0100
Separate the interface for SBGEMMT from GEMMT due to differences in GEMV arguments
---
interface/CMakeLists.txt | 1 +
interface/Makefile | 4 +-
interface/sbgemmt.c | 447 +++++++++++++++++++++++++++++++++++++++
3 files changed, 450 insertions(+), 2 deletions(-)
create mode 100644 interface/sbgemmt.c
diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt
index 4e082928b..3110f2e90 100644
--- a/interface/CMakeLists.txt
+++ b/interface/CMakeLists.txt
@@ -119,6 +119,7 @@ endif ()
if (BUILD_BFLOAT16)
GenerateNamedObjects("bf16dot.c" "" "sbdot" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("gemm.c" "" "sbgemm" ${CBLAS_FLAG} "" "" true "BFLOAT16")
+ GenerateNamedObjects("gemmt.c" "" "sbgemmt" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("sbgemv.c" "" "sbgemv" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("tobf16.c" "SINGLE_PREC" "sbstobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")
GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")
diff --git a/interface/Makefile b/interface/Makefile
index 78335357b..d106ca568 100644
--- a/interface/Makefile
+++ b/interface/Makefile
@@ -1301,7 +1301,7 @@ xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c
ifeq ($(BUILD_BFLOAT16),1)
sbgemm.$(SUFFIX) sbgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)
-sbgemmt.$(SUFFIX) sbgemmt.$(PSUFFIX) : gemmt.c ../param.h
+sbgemmt.$(SUFFIX) sbgemmt.$(PSUFFIX) : sbgemmt.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)
endif
@@ -1932,7 +1932,7 @@ cblas_sgemmt.$(SUFFIX) cblas_sgemmt.$(PSUFFIX) : gemmt.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
ifeq ($(BUILD_BFLOAT16),1)
-cblas_sbgemmt.$(SUFFIX) cblas_sbgemmt.$(PSUFFIX) : gemmt.c ../param.h
+cblas_sbgemmt.$(SUFFIX) cblas_sbgemmt.$(PSUFFIX) : sbgemmt.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif
diff --git a/interface/sbgemmt.c b/interface/sbgemmt.c
new file mode 100644
index 000000000..759af4bfb
--- /dev/null
+++ b/interface/sbgemmt.c
@@ -0,0 +1,447 @@
+/*********************************************************************/
+/* Copyright 2024, The OpenBLAS Project. */
+/* All rights reserved. */
+/* */
+/* Redistribution and use in source and binary forms, with or */
+/* without modification, are permitted provided that the following */
+/* conditions are met: */
+/* */
+/* 1. Redistributions of source code must retain the above */
+/* copyright notice, this list of conditions and the following */
+/* disclaimer. */
+/* */
+/* 2. Redistributions in binary form must reproduce the above */
+/* copyright notice, this list of conditions and the following */
+/* disclaimer in the documentation and/or other materials */
+/* provided with the distribution. */
+/* */
+/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */
+/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */
+/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */
+/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */
+/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */
+/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */
+/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */
+/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */
+/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */
+/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */
+/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */
+/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */
+/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */
+/* POSSIBILITY OF SUCH DAMAGE. */
+/* */
+/*********************************************************************/
+
+#include <stdio.h>
+#include <stdlib.h>
+#include "common.h"
+
+#define SMP_THRESHOLD_MIN 65536.0
+#define ERROR_NAME "SBGEMMT "
+
+#ifndef GEMM_MULTITHREAD_THRESHOLD
+#define GEMM_MULTITHREAD_THRESHOLD 4
+#endif
+
+#ifndef CBLAS
+
+void NAME(char *UPLO, char *TRANSA, char *TRANSB,
+ blasint * M, blasint * K,
+ FLOAT * Alpha,
+ IFLOAT * a, blasint * ldA,
+ IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC)
+{
+
+ blasint m, k;
+ blasint lda, ldb, ldc;
+ int transa, transb, uplo;
+ blasint info;
+
+ char transA, transB, Uplo;
+ blasint nrowa, nrowb;
+ IFLOAT *buffer;
+ IFLOAT *aa, *bb;
+ FLOAT *cc;
+ FLOAT alpha, beta;
+
+ PRINT_DEBUG_NAME;
+
+ m = *M;
+ k = *K;
+
+ alpha = *Alpha;
+ beta = *Beta;
+
+ lda = *ldA;
+ ldb = *ldB;
+ ldc = *ldC;
+
+ transA = *TRANSA;
+ transB = *TRANSB;
+ Uplo = *UPLO;
+ TOUPPER(transA);
+ TOUPPER(transB);
+ TOUPPER(Uplo);
+
+ transa = -1;
+ transb = -1;
+ uplo = -1;
+
+ if (transA == 'N')
+ transa = 0;
+ if (transA == 'T')
+ transa = 1;
+
+ if (transA == 'R')
+ transa = 0;
+ if (transA == 'C')
+ transa = 1;
+
+ if (transB == 'N')
+ transb = 0;
+ if (transB == 'T')
+ transb = 1;
+
+ if (transB == 'R')
+ transb = 0;
+ if (transB == 'C')
+ transb = 1;
+
+ if (Uplo == 'U')
+ uplo = 0;
+ if (Uplo == 'L')
+ uplo = 1;
+ nrowa = m;
+ if (transa & 1) nrowa = k;
+ nrowb = k;
+ if (transb & 1) nrowb = m;
+
+ info = 0;
+
+ if (ldc < MAX(1, m))
+ info = 13;
+ if (ldb < MAX(1, nrowb))
+ info = 10;
+ if (lda < MAX(1, nrowa))
+ info = 8;
+ if (k < 0)
+ info = 5;
+ if (m < 0)
+ info = 4;
+ if (transb < 0)
+ info = 3;
+ if (transa < 0)
+ info = 2;
+ if (uplo < 0)
+ info = 1;
+
+ if (info != 0) {
+ BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME));
+ return;
+ }
+#else
+
+void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
+ enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint m,
+ blasint k,
+ FLOAT alpha,
+ IFLOAT * A, blasint LDA,
+ IFLOAT * B, blasint LDB, FLOAT beta, FLOAT * c, blasint ldc)
+{
+ IFLOAT *aa, *bb;
+ FLOAT *cc;
+
+ int transa, transb, uplo;
+ blasint info;
+ blasint lda, ldb;
+ IFLOAT *a, *b;
+ XFLOAT *buffer;
+
+ PRINT_DEBUG_CNAME;
+
+ uplo = -1;
+ transa = -1;
+ transb = -1;
+ info = 0;
+
+ if (order == CblasColMajor) {
+ if (Uplo == CblasUpper) uplo = 0;
+ if (Uplo == CblasLower) uplo = 1;
+
+ if (TransA == CblasNoTrans)
+ transa = 0;
+ if (TransA == CblasTrans)
+ transa = 1;
+
+ if (TransA == CblasConjNoTrans)
+ transa = 0;
+ if (TransA == CblasConjTrans)
+ transa = 1;
+
+ if (TransB == CblasNoTrans)
+ transb = 0;
+ if (TransB == CblasTrans)
+ transb = 1;
+
+ if (TransB == CblasConjNoTrans)
+ transb = 0;
+ if (TransB == CblasConjTrans)
+ transb = 1;
+
+ a = (void *)A;
+ b = (void *)B;
+ lda = LDA;
+ ldb = LDB;
+
+ info = -1;
+
+ blasint nrowa;
+ blasint nrowb;
+ nrowa = m;
+ if (transa & 1) nrowa = k;
+ nrowb = k;
+ if (transb & 1) nrowb = m;
+
+ if (ldc < MAX(1, m))
+ info = 13;
+ if (ldb < MAX(1, nrowb))
+ info = 10;
+ if (lda < MAX(1, nrowa))
+ info = 8;
+ if (k < 0)
+ info = 5;
+ if (m < 0)
+ info = 4;
+ if (transb < 0)
+ info = 3;
+ if (transa < 0)
+ info = 2;
+ if (uplo < 0)
+ info = 1;
+ }
+
+ if (order == CblasRowMajor) {
+
+ a = (void *)B;
+ b = (void *)A;
+
+ lda = LDB;
+ ldb = LDA;
+
+ if (Uplo == CblasUpper) uplo = 0;
+ if (Uplo == CblasLower) uplo = 1;
+
+ if (TransB == CblasNoTrans)
+ transa = 0;
+ if (TransB == CblasTrans)
+ transa = 1;
+
+ if (TransB == CblasConjNoTrans)
+ transa = 0;
+ if (TransB == CblasConjTrans)
+ transa = 1;
+
+ if (TransA == CblasNoTrans)
+ transb = 0;
+ if (TransA == CblasTrans)
+ transb = 1;
+
+ if (TransA == CblasConjNoTrans)
+ transb = 0;
+ if (TransA == CblasConjTrans)
+ transb = 1;
+
+ info = -1;
+
+ blasint ncola;
+ blasint ncolb;
+
+ ncola = m;
+ if (transa & 1) ncola = k;
+ ncolb = k;
+
+ if (transb & 1) {
+ ncolb = m;
+ }
+
+ if (ldc < MAX(1,m))
+ info = 13;
+ if (ldb < MAX(1, ncolb))
+ info = 8;
+ if (lda < MAX(1, ncola))
+ info = 10;
+ if (k < 0)
+ info = 5;
+ if (m < 0)
+ info = 4;
+ if (transb < 0)
+ info = 2;
+ if (transa < 0)
+ info = 3;
+ if (uplo < 0)
+ info = 1;
+ }
+
+ if (info >= 0) {
+ BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME));
+ return;
+ }
+
+#endif
+ int buffer_size;
+ blasint i, j;
+
+#ifdef SMP
+ int nthreads;
+#endif
+
+
+#ifdef SMP
+ static int (*gemv_thread[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *,
+ BLASLONG, IFLOAT *, BLASLONG, FLOAT,
+ FLOAT *, BLASLONG, int) = {
+ sbgemv_thread_n, sbgemv_thread_t,
+ };
+#endif
+ int (*gemv[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *, BLASLONG,
+ IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = {
+ SBGEMV_N, SBGEMV_T,};
+
+
+ if (m == 0)
+ return;
+
+ IDEBUG_START;
+
+ const blasint incb = ((transb & 1) == 0) ? 1 : ldb;
+
+ if (uplo == 1) {
+ for (i = 0; i < m; i++) {
+ j = m - i;
+
+ aa = a + i;
+ bb = b + i * ldb;
+ if (transa & 1) {
+ aa = a + lda * i;
+ }
+ if (transb & 1)
+ bb = b + i;
+ cc = c + i * ldc + i;
+
+#if 0
+ if (beta != ONE)
+ SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);
+
+ if (alpha == ZERO)
+ continue;
+#endif
+
+ IDEBUG_START;
+
+ buffer_size = j + k + 128 / sizeof(FLOAT);
+#ifdef WINDOWS_ABI
+ buffer_size += 160 / sizeof(FLOAT);
+#endif
+ // for alignment
+ buffer_size = (buffer_size + 3) & ~3;
+ STACK_ALLOC(buffer_size, IFLOAT, buffer);
+
+#ifdef SMP
+
+ if (1L * j * k < 2304L * GEMM_MULTITHREAD_THRESHOLD)
+ nthreads = 1;
+ else
+ nthreads = num_cpu_avail(2);
+
+ if (nthreads == 1) {
+#endif
+
+ if (!(transa & 1))
+ (gemv[(int)transa]) (j, k, alpha, aa, lda,
+ bb, incb, beta, cc, 1);
+ else
+ (gemv[(int)transa]) (k, j, alpha, aa, lda,
+ bb, incb, beta, cc, 1);
+
+#ifdef SMP
+ } else {
+ if (!(transa & 1))
+ (gemv_thread[(int)transa]) (j, k, alpha, aa,
+ lda, bb, incb, beta, cc,
+ 1, nthreads);
+ else
+ (gemv_thread[(int)transa]) (k, j, alpha, aa,
+ lda, bb, incb, beta, cc,
+ 1, nthreads);
+
+ }
+#endif
+
+ STACK_FREE(buffer);
+ }
+ } else {
+
+ for (i = 0; i < m; i++) {
+ j = i + 1;
+
+ bb = b + i * ldb;
+ if (transb & 1) {
+ bb = b + i;
+ }
+ cc = c + i * ldc;
+
+#if 0
+ if (beta != ONE)
+ SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);
+
+ if (alpha == ZERO)
+ continue;
+#endif
+ IDEBUG_START;
+
+ buffer_size = j + k + 128 / sizeof(FLOAT);
+#ifdef WINDOWS_ABI
+ buffer_size += 160 / sizeof(FLOAT);
+#endif
+ // for alignment
+ buffer_size = (buffer_size + 3) & ~3;
+ STACK_ALLOC(buffer_size, IFLOAT, buffer);
+
+#ifdef SMP
+
+ if (1L * j * k < 2304L * GEMM_MULTITHREAD_THRESHOLD)
+ nthreads = 1;
+ else
+ nthreads = num_cpu_avail(2);
+
+ if (nthreads == 1) {
+#endif
+
+ if (!(transa & 1))
+ (gemv[(int)transa]) (j, k, alpha, a, lda, bb,
+ incb, beta, cc, 1);
+ else
+ (gemv[(int)transa]) (k, j, alpha, a, lda, bb,
+ incb, beta, cc, 1);
+
+#ifdef SMP
+ } else {
+ if (!(transa & 1))
+ (gemv_thread[(int)transa]) (j, k, alpha, a, lda,
+ bb, incb, beta, cc, 1,
+ nthreads);
+ else
+ (gemv_thread[(int)transa]) (k, j, alpha, a, lda,
+ bb, incb, beta, cc, 1,
+ nthreads);
+ }
+#endif
+
+ STACK_FREE(buffer);
+ }
+ }
+
+ IDEBUG_END;
+
+ return;
+}
--
2.41.0
From 9a4c2d61a345866e4540f9d6da87eb881419b411 Mon Sep 17 00:00:00 2001
From: Honza Horak <hhorak@redhat.com>
Date: Fri, 9 Feb 2024 09:54:52 +0100
Subject: [PATCH 3/6] fix type conversion warnings
upstream commit:
commit fb99fc2e6e4ec8ecdcfffe1ca1aeb787464d2825
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Date: Wed Feb 7 13:42:08 2024 +0100
fix type conversion warnings
---
test/compare_sgemm_sbgemm.c | 18 ++++++++++++++----
1 file changed, 14 insertions(+), 4 deletions(-)
diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c
index cf808b56d..4afa8bf93 100644
--- a/test/compare_sgemm_sbgemm.c
+++ b/test/compare_sgemm_sbgemm.c
@@ -81,6 +81,16 @@ float16to32 (bfloat16_bits f16)
return f32.v;
}
+float
+float32to16 (float32_bits f32)
+{
+ bfloat16_bits f16;
+ f16.bits.s = f32.bits.s;
+ f16.bits.e = f32.bits.e;
+ f16.bits.m = (uint32_t) f32.bits.m >> 16;
+ return f32.v;
+}
+
int
main (int argc, char *argv[])
{
@@ -108,16 +118,16 @@ main (int argc, char *argv[])
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
C[j * k + i] = 0;
- AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
- BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
+ AA[j * k + i].v = float32to16( A[j * k + i] );
+ BB[j * k + i].v = float32to16( B[j * k + i] );
CC[j * k + i] = 0;
DD[j * k + i] = 0;
}
}
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
&m, B, &k, &beta, C, &m);
- SBGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
- &m, BB, &k, &beta, CC, &m);
+ SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA,
+ &m, (bfloat16*)BB, &k, &beta, CC, &m);
for (i = 0; i < n; i++)
for (j = 0; j < m; j++)
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
--
2.41.0
From 5593507ddbd5d35d088cd4db6285de6b9d84a405 Mon Sep 17 00:00:00 2001
From: Honza Horak <hhorak@redhat.com>
Date: Fri, 9 Feb 2024 09:56:11 +0100
Subject: [PATCH 4/6] fix prototype for c/zaxpby
Upstream commit:
commit b3fa16345d83b723b8984b78dc6a2bb5d9f3d479
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Date: Thu Feb 8 13:15:34 2024 +0100
fix prototype for c/zaxpby
---
common_interface.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/common_interface.h b/common_interface.h
index 318827920..1f6cb5f6d 100644
--- a/common_interface.h
+++ b/common_interface.h
@@ -764,8 +764,8 @@ xdouble BLASFUNC(qlamc3)(xdouble *, xdouble *);
void BLASFUNC(saxpby) (blasint *, float *, float *, blasint *, float *, float *, blasint *);
void BLASFUNC(daxpby) (blasint *, double *, double *, blasint *, double *, double *, blasint *);
-void BLASFUNC(caxpby) (blasint *, float *, float *, blasint *, float *, float *, blasint *);
-void BLASFUNC(zaxpby) (blasint *, double *, double *, blasint *, double *, double *, blasint *);
+void BLASFUNC(caxpby) (blasint *, void *, float *, blasint *, void *, float *, blasint *);
+void BLASFUNC(zaxpby) (blasint *, void *, double *, blasint *, void *, double *, blasint *);
void BLASFUNC(somatcopy) (char *, char *, blasint *, blasint *, float *, float *, blasint *, float *, blasint *);
void BLASFUNC(domatcopy) (char *, char *, blasint *, blasint *, double *, double *, blasint *, double *, blasint *);
--
2.41.0
From 42b30ed2c54034b2b1dbb15bb9e3e705e704b6a9 Mon Sep 17 00:00:00 2001
From: Honza Horak <hhorak@redhat.com>
Date: Fri, 9 Feb 2024 09:56:46 +0100
Subject: [PATCH 5/6] fix incompatible pointer types
Upstream commit:
commit 500ac4de5e20596d5cd797d745db97dd0a62ff86
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Date: Thu Feb 8 13:18:34 2024 +0100
fix incompatible pointer types
---
interface/zaxpby.c | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/interface/zaxpby.c b/interface/zaxpby.c
index 3a4db7403..e5065270d 100644
--- a/interface/zaxpby.c
+++ b/interface/zaxpby.c
@@ -39,12 +39,14 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef CBLAS
-void NAME(blasint *N, FLOAT *ALPHA, FLOAT *x, blasint *INCX, FLOAT *BETA, FLOAT *y, blasint *INCY)
+void NAME(blasint *N, void *VALPHA, FLOAT *x, blasint *INCX, void *VBETA, FLOAT *y, blasint *INCY)
{
blasint n = *N;
blasint incx = *INCX;
blasint incy = *INCY;
+ FLOAT* ALPHA = (FLOAT*) VALPHA;
+ FLOAT* BETA = (FLOAT*) VBETA;
#else
--
2.41.0
From 1c525b6e704523912a04fbd026300a2ff95341f3 Mon Sep 17 00:00:00 2001
From: Honza Horak <hhorak@redhat.com>
Date: Fri, 9 Feb 2024 15:29:17 +0100
Subject: [PATCH 6/6] fix sbgemm bfloat16 conversion errors introduced in PR
4488
Upstream commit:
commit e9f480111e1d5b6f69c8053f79375b0a4242712f
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Date: Wed Feb 7 19:57:18 2024 +0100
fix sbgemm bfloat16 conversion errors introduced in PR 4488
---
test/compare_sgemm_sbgemm.c | 18 ++++++------------
1 file changed, 6 insertions(+), 12 deletions(-)
diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c
index 4afa8bf93..bc74233ab 100644
--- a/test/compare_sgemm_sbgemm.c
+++ b/test/compare_sgemm_sbgemm.c
@@ -81,16 +81,6 @@ float16to32 (bfloat16_bits f16)
return f32.v;
}
-float
-float32to16 (float32_bits f32)
-{
- bfloat16_bits f16;
- f16.bits.s = f32.bits.s;
- f16.bits.e = f32.bits.e;
- f16.bits.m = (uint32_t) f32.bits.m >> 16;
- return f32.v;
-}
-
int
main (int argc, char *argv[])
{
@@ -110,6 +100,8 @@ main (int argc, char *argv[])
float C[m * n];
bfloat16_bits AA[m * k], BB[k * n];
float DD[m * n], CC[m * n];
+ bfloat16 atmp,btmp;
+ blasint one=1;
for (j = 0; j < m; j++)
{
@@ -118,8 +110,10 @@ main (int argc, char *argv[])
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
C[j * k + i] = 0;
- AA[j * k + i].v = float32to16( A[j * k + i] );
- BB[j * k + i].v = float32to16( B[j * k + i] );
+ sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
+ sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
+ AA[j * k + i].v = atmp;
+ BB[j * k + i].v = btmp;
CC[j * k + i] = 0;
DD[j * k + i] = 0;
}
--
2.41.0