aboutsummaryrefslogtreecommitdiff
path: root/libgomp/testsuite/libgomp.oacc-fortran/gemm.f90
blob: de78148c7b36ca74a2d81313c46a3bd8a2fcbb11 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
! Exercise three levels of parallelism using SGEMM from BLAS.

! { dg-do run }

! Explicitly set vector_length to 128 using a vector_length clause.
subroutine openacc_sgemm_128 (m, n, k, alpha, a, b, beta, c)
  integer :: m, n, k
  real :: alpha, beta
  real :: a(k,*), b(k,*), c(m,*)

  integer :: i, j, l
  real :: temp

  !$acc parallel loop copy(c(1:m,1:n)) copyin(a(1:k,1:m),b(1:k,1:n)) vector_length (128) firstprivate (temp)
  do j = 1, n
     !$acc loop
     do i = 1, m
        temp = 0.0
        !$acc loop reduction(+:temp)
        do l = 1, k
           temp = temp + a(l,i)*b(l,j)
        end do
        if(beta == 0.0) then
           c(i,j) = alpha*temp
        else
           c(i,j) = alpha*temp + beta*c(i,j)
        end if
     end do
  end do
end subroutine openacc_sgemm_128

subroutine host_sgemm (m, n, k, alpha, a, b, beta, c)
  integer :: m, n, k
  real :: alpha, beta
  real :: a(k,*), b(k,*), c(m,*)

  integer :: i, j, l
  real :: temp

  do j = 1, n
     do i = 1, m
        temp = 0.0
        do l = 1, k
           temp = temp + a(l,i)*b(l,j)
        end do
        if(beta == 0.0) then
           c(i,j) = alpha*temp
        else
           c(i,j) = alpha*temp + beta*c(i,j)
        end if
     end do
  end do
end subroutine host_sgemm

program main
  integer, parameter :: M = 100, N = 50, K = 2000
  real :: a(K, M), b(K, N), c(M, N), d (M, N), e (M, N)
  real alpha, beta
  integer i, j

  a(:,:) = 1.0
  b(:,:) = 0.25

  c(:,:) = 0.0
  d(:,:) = 0.0
  e(:,:) = 0.0

  alpha = 1.05
  beta = 1.25

  call openacc_sgemm_128 (M, N, K, alpha, a, b, beta, d)
  call host_sgemm (M, N, K, alpha, a, b, beta, e)

  do i = 1, m
     do j = 1, n
        if (d(i,j) /= e(i,j)) call abort
     end do
  end do
end program main