/*******************************************************************************
* Copyright 2021-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*  Content:
*      dgetrf_batch (group API) OpenMP Offload Example
*******************************************************************************/
#include <stdio.h>
#include <omp.h>
#include "mkl.h"
#include "mkl_omp_offload.h"

#define GROUP_COUNT 4

#define MIN_GROUP_SIZE 1
#define MIN_M 1
#define MIN_N 1

#define MAX_GROUP_SIZE 10
#define MAX_M 16
#define MAX_N 16

#ifdef MKL_ILP64
  #define FMT "%4lld"
#else
  #define FMT "%4d"
#endif

int main() {

    // Total number of groups in the matrix batch
    MKL_INT  group_count = GROUP_COUNT;

    // Allocate memory for parameter arrays
    MKL_INT* group_size = (MKL_INT *) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* m          = (MKL_INT *) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* n          = (MKL_INT *) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* lda        = (MKL_INT *) mkl_malloc(group_count * sizeof(MKL_INT), 64);
    if(!group_size || !m || !n || !lda ) {
        printf("\n     ERROR. Failed parameter array memory allocation \n");
        return 1;
    }

    // Generate random matrix dimensions for the matrices in each group
    // m[igrp]          - number of rows in matrices for group igrp (MIN_M <= m[igrp] <= MAX_M)
    // n[igrp]          - number of columns in matrices for group igrp (MIN_N <= n[igrp] <= MAX_N)
    // lda[igrp]        - leading dimension in matrices for group igrp (lda[igrp] >= m[igrp])
    // group_size[igrp] - number of matrices in group igrp (group_size[igrp] >= 0)
    MKL_INT  batch_size = 0;
    printf("\nComputing the LU factorization for the following batch of matrices:\n");
    printf("===================================================================\n");
    for (MKL_INT igrp=0; igrp < group_count; igrp++) {
        m[igrp] = MIN_M + (rand() % MAX_M - MIN_M + 1);
        n[igrp] = MIN_N + (rand() % MAX_N - MIN_N + 1);
        lda[igrp] = m[igrp];
        group_size[igrp] = MIN_GROUP_SIZE + (rand() % MAX_GROUP_SIZE - MIN_GROUP_SIZE + 1);
        batch_size += group_size[igrp];

        printf("  Group="FMT ", m="FMT ", n="FMT ", lda="FMT ", group_size="FMT "\n",
            igrp, m[igrp], n[igrp], lda[igrp], group_size[igrp]);
    }
    printf("  Total number of matrices="FMT "\n", batch_size);
    printf("===================================================================\n");

    // Allocate array of pointers for host and device data
    double**  a         = (double**)  mkl_malloc(batch_size  * sizeof(double*),  64);
    MKL_INT** ipiv      = (MKL_INT**) mkl_malloc(batch_size  * sizeof(MKL_INT*), 64);
    double**  a_dev     = (double**)  mkl_malloc(batch_size  * sizeof(double*),  64);
    MKL_INT** ipiv_dev  = (MKL_INT**) mkl_malloc(batch_size  * sizeof(MKL_INT*), 64);
    MKL_INT*  info      = (MKL_INT *) mkl_malloc(sizeof(MKL_INT) * batch_size, 64);
    if ( !a || !a_dev || !ipiv || !ipiv_dev || !info ) {
        printf("\n ERROR. Failed pointer array memory allocation. \n");
        return 1;
    }

    // Allocate memory and initialize data for each matrix
    MKL_INT* a_size    = (MKL_INT *)mkl_malloc(group_count * sizeof(MKL_INT), 64);
    MKL_INT* ipiv_size = (MKL_INT *)mkl_malloc(group_count * sizeof(MKL_INT), 64);
    for (MKL_INT igrp = 0, idx = 0; igrp < group_count; igrp++) {
        a_size[igrp]    = lda[igrp] * n[igrp];
        ipiv_size[igrp] = (m[igrp] < n[igrp]) ? m[igrp] : n[igrp];
        for (MKL_INT imat = 0; imat < group_size[igrp]; imat++, idx++) {
            // Allocate memory for matrix idx
            a[idx] = (double *) mkl_malloc(sizeof(double) * a_size[igrp], 64);
            if (a[idx] == NULL) {
                printf("Failed to allocate A matrices\n");
                return 1;
            }

            // Allocate pivot array for matrix idx
            ipiv[idx] = (MKL_INT *) mkl_malloc(sizeof(MKL_INT) * ipiv_size[igrp], 64);
            if (ipiv[idx] == NULL) {
                printf("Failed to allocate ipiv arrays\n");
                return 1;
            }

            // Initialize entries of matrix idx
            for (MKL_INT row = 0; row < n[igrp]; row++) {
                for (MKL_INT col = 0; col < m[igrp]; col++) {
                    a[idx][col + row*lda[igrp]] = (double) rand() / (double) RAND_MAX - 0.5;
                }
            }
        }
    }

    // Map each array in A and ipiv to the device and store the corresponding device pointers in a new array
    double* a_ptr;
    MKL_INT* ipiv_ptr;
    for (MKL_INT igrp = 0, idx = 0; igrp < group_count; igrp++) {
        for (MKL_INT imat = 0; imat < group_size[igrp]; imat++, idx++) {
            a_ptr    = a[idx];
            ipiv_ptr = ipiv[idx];
            #pragma omp target enter data map(to:a_ptr[0:a_size[igrp]],ipiv_ptr[0:ipiv_size[igrp]])
            #pragma omp target data use_device_ptr(a_ptr,ipiv_ptr)
            {
                a_dev[idx] = a_ptr;
                ipiv_dev[idx] = ipiv_ptr;
            }
        }
    }

    // Execute getrf_batch on GPU via variant dispatch construct
    #pragma omp target data map(to:a_dev[0:batch_size], ipiv_dev[0:batch_size]) map(from:info[0:batch_size])
    {
      #pragma omp target variant dispatch use_device_ptr(a_dev, ipiv_dev, info)
      {
          dgetrf_batch(m, n, a_dev, lda, ipiv_dev, &group_count, group_size, info);
      }
    }

    // Bring A and ipiv data back to the host
    for (MKL_INT igrp = 0, idx = 0; igrp < group_count; igrp++) {
        for (MKL_INT imat = 0; imat < group_size[igrp]; imat++, idx++) {
            a_ptr    = a[idx];
            ipiv_ptr = ipiv[idx];
            #pragma omp target exit data map(from:a_ptr[0:a_size[igrp]],ipiv_ptr[0:ipiv_size[igrp]])
        }
    }
    printf("\n\nFinished call to dgetrf_batch \n");

    MKL_INT exit_status = 0;
    for (MKL_INT igrp = 0, idx = 0; igrp < group_count; igrp++) {
        for (MKL_INT imat = 0; imat < group_size[igrp]; imat++, idx++) {
            if (info[idx]) {
                printf("dgetrf_batch offload failed: Matrix "FMT " (matrix "FMT " from group "FMT " ) returned with info="FMT"\n",
                    idx, imat, igrp, info[idx]);
                exit_status++;
            }
        }
    }
    if (exit_status) {
        printf("Total number of failures:"FMT "\n", exit_status);
    }

    // Cleanup
    for (MKL_INT idx = 0; idx < batch_size; idx++) {
        mkl_free(a[idx]);
        mkl_free(ipiv[idx]);
    }
    mkl_free(group_size);
    mkl_free(m);
    mkl_free(n);
    mkl_free(lda);
    mkl_free(a);
    mkl_free(ipiv);
    mkl_free(info);
    mkl_free(a_size);
    mkl_free(ipiv_size);

    if (exit_status) {
        printf("\n\n===============================\nExample executed with errors.\n===============================\n\n");
    } else {
        printf("\n\n===============================\nExample executed successfully.\n===============================\n\n");
    }
    return exit_status;

}
