/*******************************************************************************
* Copyright 2021-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*   Content : Intel(R) oneAPI Math Kernel Library (oneMKL) Sparse BLAS C OpenMP
*             offload example for mkl_sparse_sp2m and mkl_sparse_d_export_csr.
*
********************************************************************************
*
* Consider the matrix A (see 'Sparse Storage Formats for Sparse BLAS Level 2
* and Level 3 in the  Intel oneMKL Reference Manual')
*
*                 |   1       -1      0   -3     0   |
*                 |  -2        5      0    0     0   |
*   A    =        |   0        0      4    6     4   |,
*                 |  -4        0      2    7     0   |
*                 |   0        8      0    0    -5   |
*
* and the matrix B
*
*                 |  10     11      0     0     0   |
*                 |   0      0     12    13     0   |
*   B    =        |  14      0      0     0    15   |,
*                 |   0     16     17     0     0   |
*                 |   0      0      0    18    19   |
*
*  The matrices A and B are represented in a zero-based compressed sparse row (CSR)
*  storage scheme with three arrays (see 'Sparse Matrix Storage Schemes' in the
*  Intel oneMKL Reference Manual) as follows:
*
*         values_A   =  ( 1 -1 -3 -2  5  4  6  4 -4  2  7  8 -5 )
*         columns_A  =  ( 0  1  3  0  1  2  3  4  0  2  3  1  4 )
*         rowIndex_A =  ( 0        3     5        8       11    13 )
*
*         values_B   = ( 10  11  12  13  15  14  16  17  18  19 )
*         columns_B  = (  0   1   2   3   0   4   1   2   3   4 )
*         rowIndex_B = (  0       2       4       6       8      10 )
*
*  The test computes the following operations :
*
*       C = A*B using mkl_sparse_sp2m omp offload
*       where A and B are general sparse matrices of MxK and KxN.
*
********************************************************************************
*/
#include <assert.h>
#include <omp.h>
#include <stdio.h>
#include <stdlib.h>

#include "common_for_sparse_examples.h"
#include "mkl.h"
#include "mkl_omp_offload.h"

#define CALL_AND_CHECK_IE_STATUS(function, error_message)         \
    do {                                                          \
        ie_status = function;                                     \
        if (ie_status != SPARSE_STATUS_SUCCESS) {                 \
            printf(error_message ": %d\n", ie_status);            \
            free_allocated_memories(pointer_array, num_pointers); \
            return ie_status;                                     \
        }                                                         \
    } while (0)

int main()
{
//*******************************************************************************
//     Declaration and initialization of parameters for sparse representation of
//     the matrix A in the CSR format:
//     Assume A is M-by-K matrix and B is K-by-N matrix.
//*******************************************************************************
#define M 5      // nRows of A & C
#define K 5      // nCols of A & nRows of B.
#define N 5      // nCols of B & C.
#define NNZ_A 13 // NNZ of A.
#define NNZ_B 10 // NNZ of B.
#define ALIGN 64

    //*******************************************************************************
    //    Declaration of local variables:
    //*******************************************************************************
    double *values_A = NULL, *values_B = NULL, *values_C = NULL, *values_C_gpu = NULL;
    MKL_INT *columns_A = NULL, *columns_B = NULL, *columns_C = NULL, *columns_C_gpu = NULL;
    MKL_INT *rowIndex_A = NULL, *rowIndex_B = NULL, *rowStart_C = NULL, *rowEnd_C = NULL;
    MKL_INT *rowStart_C_gpu = NULL, *rowEnd_C_gpu = NULL;

    MKL_INT nrows_C, ncols_C, nrows_C_gpu, ncols_C_gpu, i, j, nnz;

    sparse_index_base_t indexing = SPARSE_INDEX_BASE_ZERO;
    sparse_index_base_t indexing_C, indexing_C_gpu;
    struct matrix_descr mat_descr_A;
    struct matrix_descr mat_descr_B;
    sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
    sparse_operation_t opB = SPARSE_OPERATION_NON_TRANSPOSE;

    sparse_matrix_t csrA = NULL, csrB = NULL, csrC = NULL;

    values_A   = (double *)mkl_malloc(sizeof(double) * NNZ_A, ALIGN);
    columns_A  = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * NNZ_A, ALIGN);
    rowIndex_A = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * (M + 1), ALIGN);

    values_B   = (double *)mkl_malloc(sizeof(double) * NNZ_B, ALIGN);
    columns_B  = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * NNZ_B, ALIGN);
    rowIndex_B = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * (K + 1), ALIGN);

    // setting matrix_descr as general matrix.
    mat_descr_A.type = SPARSE_MATRIX_TYPE_GENERAL;
    mat_descr_B.type = SPARSE_MATRIX_TYPE_GENERAL;

    const int num_pointers = 6;
    void *pointer_array[num_pointers];
    pointer_array[0] = values_A;
    pointer_array[1] = columns_A;
    pointer_array[2] = rowIndex_A;
    pointer_array[3] = values_B;
    pointer_array[4] = columns_B;
    pointer_array[5] = rowIndex_B;

    if (!values_A || !columns_A || !rowIndex_A || !values_B || !columns_B || !rowIndex_B) {
        free_allocated_memories(pointer_array, num_pointers);
        return 1;
    }

    //*******************************************************************************
    //    Sparse representation of the matrix A
    //*******************************************************************************
    double init_values_A[NNZ_A] = {1.0, -1.0, -3.0, -2.0, 5.0, 4.0, 6.0,
                                   4.0, -4.0, 2.0,  7.0,  8.0, -5.0};
    MKL_INT init_columns_A[NNZ_A]   = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
    MKL_INT init_row_index_A[M + 1] = {0, 3, 5, 8, 11, 13};

    for (i = 0; i < NNZ_A; i++) {
        values_A[i]  = init_values_A[i];
        columns_A[i] = init_columns_A[i];
    }
    for (i = 0; i < M + 1; i++) {
        rowIndex_A[i] = init_row_index_A[i];
    }

    for (i = 0; i < NNZ_B; i++) {
        values_B[i] = i + 10;
    }
    for (i = 0; i < NNZ_B; i++) {
        columns_B[i] = i % 5;
    }
    rowIndex_B[0] = 0;
    for (i = 1; i < K + 1; i++) {
        rowIndex_B[i] = rowIndex_B[i - 1] + 2;
    }

    columns_B[4] = 0;
    columns_B[5] = 4;

    /* Printing usable data */
    printf("\n\n_______________Example program for MKL_SPARSE_SP2M_________________\n\n");
    printf(" COMPUTE  op(A) * op(B) = C, where matrices are stored in CSR format\n");
    printf("\n MATRIX A:\nrow# : list of (value, column) pairs\n");
    fflush(0);

    print_csr_matrix_double(rowIndex_A, columns_A, values_A, M);
    printf("\n MATRIX B:\nrow# : list of (value, column) pairs\n");
    print_csr_matrix_double(rowIndex_B, columns_B, values_B, K);

    printf("\n EXAMPLE PROGRAM FOR mkl_sparse_sp2m omp_offload \n");
    printf("---------------------------------------------------\n");
    printf("\n");
    printf("   INPUT DATA FOR mkl_sparse_sp2m omp offload    \n");
    printf("   WITH GENERAL SPARSE MATRIX     \n");
    printf("   SPARSE_OPERATION_NON_TRANSPOSE \n");
    printf("   SPARSE_INDEX_BASE_ZERO         \n");
    printf("   SPARSE_STAGE_NNZ_COUNT         \n");
    printf("   SPARSE_STAGE_FINALIZE_MULT     \n");

    sparse_status_t ie_status;

    // Create handle with matrix stored in CSR format
    CALL_AND_CHECK_IE_STATUS(mkl_sparse_d_create_csr(&csrA, indexing, M, K, rowIndex_A,
                                                     rowIndex_A + 1, columns_A, values_A),
                             "Error in mkl_sparse_d_create_csr for csrA");

    CALL_AND_CHECK_IE_STATUS(mkl_sparse_d_create_csr(&csrB, indexing, K, N, rowIndex_B,
                                                     rowIndex_B + 1, columns_B, values_B),
                             "Error in mkl_sparse_d_create_csr for csrB");

    sparse_request_t request = SPARSE_STAGE_NNZ_COUNT;
    CALL_AND_CHECK_IE_STATUS(
            mkl_sparse_sp2m(opA, mat_descr_A, csrA, opB, mat_descr_B, csrB, request, &csrC),
            "Error in mkl_sparse_sp2m with NNZ_COUNT request");

    request = SPARSE_STAGE_FINALIZE_MULT;
    CALL_AND_CHECK_IE_STATUS(
            mkl_sparse_sp2m(opA, mat_descr_A, csrA, opB, mat_descr_B, csrB, request, &csrC),
            "Error in mkl_sparse_sp2m with FINALIZE_MULT request");

    // sort csr matrix of C matrix.
    CALL_AND_CHECK_IE_STATUS(mkl_sparse_order(csrC), "Error in mkl_sparse_order");

    CALL_AND_CHECK_IE_STATUS(mkl_sparse_d_export_csr(csrC, &indexing_C, &nrows_C, &ncols_C, &rowStart_C,
                                                     &rowEnd_C, &columns_C, &values_C),
                             "Error in mkl_sparse_d_export_csr");

    printf("\n RESULTANT MATRIX C:\nrow# : list of (value, column) pairs\n");
    print_csr_matrix_double(rowStart_C, columns_C, values_C, nrows_C);

    MKL_INT start_ind_C = (indexing_C == SPARSE_INDEX_BASE_ZERO) ? 0 : 1;
    printf(" NNZ of C = " INT_PRINT_FORMAT "\n", rowStart_C[nrows_C] - start_ind_C);
    printf("_____________________________________________________________________  \n");
    fflush(0);

    // Release matrix handle and deallocate matrix
    CALL_AND_CHECK_IE_STATUS(mkl_sparse_destroy(csrA), "Error in mkl_sparse_destroy, csrA");
    CALL_AND_CHECK_IE_STATUS(mkl_sparse_destroy(csrB), "Error in mkl_sparse_destroy, csrB");

    const int devNum = 0;

    sparse_matrix_t csrA_gpu = NULL;
    sparse_matrix_t csrB_gpu = NULL;
    sparse_matrix_t csrC_gpu = NULL;

    sparse_status_t status_create_A   = SPARSE_STATUS_SUCCESS;
    sparse_status_t status_create_B   = SPARSE_STATUS_SUCCESS;
    sparse_status_t status_sp2m_nnz   = SPARSE_STATUS_SUCCESS;
    sparse_status_t status_sp2m_final = SPARSE_STATUS_SUCCESS;
    sparse_status_t status_export     = SPARSE_STATUS_SUCCESS;
    sparse_status_t status_destroy_A  = SPARSE_STATUS_SUCCESS;
    sparse_status_t status_destroy_B  = SPARSE_STATUS_SUCCESS;
    sparse_status_t status_destroy_C  = SPARSE_STATUS_SUCCESS;

// call create_csr/sp2m/export/destroy via omp_offload.
#pragma omp target data map(                                                                     \
        to : rowIndex_A[0 : M + 1], columns_A[0 : NNZ_A],                                        \
                                              values_A[0 : NNZ_A],                               \
                                                       rowIndex_B[0 : K + 1],                    \
                                                                  columns_B[0 : NNZ_B],          \
                                                                            values_B[0 : NNZ_B]) \
                                                                          device(devNum)
    {
        printf("Create CSR matrix via omp_offload\n");

#pragma omp target variant dispatch device(devNum) use_device_ptr(rowIndex_A, columns_A, values_A)
        status_create_A = mkl_sparse_d_create_csr(&csrA_gpu, indexing, M, K, rowIndex_A,
                                                  rowIndex_A + 1, columns_A, values_A);

#pragma omp target variant dispatch device(devNum) use_device_ptr(rowIndex_B, columns_B, values_B)
        status_create_B = mkl_sparse_d_create_csr(&csrB_gpu, indexing, K, N, rowIndex_B,
                                                  rowIndex_B + 1, columns_B, values_B);

        printf("Compute mkl_sparse_sp2m via omp_offload: NNZ_COUNT\n");

        request = SPARSE_STAGE_NNZ_COUNT;

#pragma omp target variant dispatch device(devNum) nowait
        status_sp2m_nnz = mkl_sparse_sp2m(opA, mat_descr_A, csrA_gpu, opB, mat_descr_B, csrB_gpu,
                                          request, &csrC_gpu);
// We should handle external dependency here:
// the second-stage should start after the first-stage finishes.
#pragma omp taskwait

        printf("Compute mkl_sparse_sp2m via omp_offload: FINALIZE_MULT\n");

        request = SPARSE_STAGE_FINALIZE_MULT;
#pragma omp target variant dispatch device(devNum) nowait
        status_sp2m_final = mkl_sparse_sp2m(opA, mat_descr_A, csrA_gpu, opB, mat_descr_B, csrB_gpu,
                                            request, &csrC_gpu);
#pragma omp taskwait

        printf("Export mkl_sparse_sp2m resultant matrix via omp_offload\n");

#pragma omp target variant dispatch device(devNum)
        status_export = mkl_sparse_d_export_csr(csrC_gpu, &indexing_C_gpu, &nrows_C_gpu, &ncols_C_gpu,
                                                &rowStart_C_gpu, &rowEnd_C_gpu, &columns_C_gpu,
                                                &values_C_gpu);

        printf("Destroy the CSR matrix via omp_offload\n");

#pragma omp target variant dispatch device(devNum)
        status_destroy_A = mkl_sparse_destroy(csrA_gpu);

#pragma omp target variant dispatch device(devNum)
        status_destroy_B = mkl_sparse_destroy(csrB_gpu);
    }

    int flps_per_value    = 10;
    int validation_status = 0;

    int status_offload = status_create_A | status_create_B | status_sp2m_nnz | status_sp2m_final |
                         status_export | status_destroy_A | status_destroy_B;
    if (status_offload != 0) {
        printf("\tERROR: status_create_A = %d, status_create_B = %d, status_sp2m_nnz = %d, "
               "status_sp2m_final = %d, status_export = %d, status_destroy_A = %d, "
               "status_destroy_B = %d\n",
               status_create_A, status_create_B, status_sp2m_nnz, status_sp2m_final, status_export,
               status_destroy_A, status_destroy_B);
        goto cleanup;
    }

    printf("\n RESULTANT MATRIX C from offload:\nrow# : list of (value, column) pairs\n");
    print_csr_matrix_double(rowStart_C_gpu, columns_C_gpu, values_C_gpu, nrows_C_gpu);

    start_ind_C = (indexing_C_gpu == SPARSE_INDEX_BASE_ZERO) ? 0 : 1;
    nnz = rowStart_C_gpu[nrows_C_gpu] - start_ind_C;
    printf(" NNZ of C gpu = " INT_PRINT_FORMAT "\n", nnz);
    printf("_____________________________________________________________________  \n");
    fflush(0);

    validation_status = validation_result_double(values_C, values_C_gpu, nnz, flps_per_value);

    CALL_AND_CHECK_IE_STATUS(mkl_sparse_destroy(csrC), "Error in mkl_sparse_destroy, csrC");

#pragma omp target variant dispatch device(devNum)
    status_destroy_C = mkl_sparse_destroy(csrC_gpu);
    if (status_destroy_C != SPARSE_STATUS_SUCCESS) {
        printf(" Error in mkl_sparse_destroy offload: status_destroy_C = %d\n", status_destroy_C);
    }

cleanup:
    free_allocated_memories(pointer_array, num_pointers);

    const int status_all = validation_status | status_offload | status_destroy_C;
    printf("Test %s\n", status_all == 0 ? "PASSED" : "FAILED");
    fflush(stdout);

    return status_all;
}
