aboutsummaryrefslogtreecommitdiff
path: root/libgomp/testsuite/libgomp.oacc-fortran/gemm-2.f90
blob: fe108732a5fd5fe8f82dbae684b173311410b532 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
! Exercise three levels of parallelism using SGEMM from BLAS.

! { dg-do run }
! { dg-additional-options "-fopenacc-dim=::128" }

! Implicitly set vector_length to 128 using -fopenacc-dim.
subroutine openacc_sgemm (m, n, k, alpha, a, b, beta, c)
  integer :: m, n, k
  real :: alpha, beta
  real :: a(k,*), b(k,*), c(m,*)

  integer :: i, j, l
  real :: temp

  !$acc parallel loop copy(c(1:m,1:n)) copyin(a(1:k,1:m),b(1:k,1:n)) firstprivate (temp)
  do j = 1, n
     !$acc loop
     do i = 1, m
        temp = 0.0
        !$acc loop reduction(+:temp)
        do l = 1, k
           temp = temp + a(l,i)*b(l,j)
        end do
        if(beta == 0.0) then
           c(i,j) = alpha*temp
        else
           c(i,j) = alpha*temp + beta*c(i,j)
        end if
     end do
  end do
end subroutine openacc_sgemm

subroutine host_sgemm (m, n, k, alpha, a, b, beta, c)
  integer :: m, n, k
  real :: alpha, beta
  real :: a(k,*), b(k,*), c(m,*)

  integer :: i, j, l
  real :: temp

  do j = 1, n
     do i = 1, m
        temp = 0.0
        do l = 1, k
           temp = temp + a(l,i)*b(l,j)
        end do
        if(beta == 0.0) then
           c(i,j) = alpha*temp
        else
           c(i,j) = alpha*temp + beta*c(i,j)
        end if
     end do
  end do
end subroutine host_sgemm

program main
  integer, parameter :: M = 100, N = 50, K = 2000
  real :: a(K, M), b(K, N), c(M, N), d (M, N), e (M, N)
  real alpha, beta
  integer i, j

  a(:,:) = 1.0
  b(:,:) = 0.25

  c(:,:) = 0.0
  d(:,:) = 0.0
  e(:,:) = 0.0

  alpha = 1.05
  beta = 1.25

  call openacc_sgemm (M, N, K, alpha, a, b, beta, c)
  call host_sgemm (M, N, K, alpha, a, b, beta, e)

  do i = 1, m
     do j = 1, n
        if (c(i,j) /= e(i,j)) call abort
     end do
  end do
end program main