/******************************************************************************
** Copyright (c) 2015-2019, Intel Corporation                                **
** All rights reserved.                                                      **
**                                                                           **
** Redistribution and use in source and binary forms, with or without        **
** modification, are permitted provided that the following conditions        **
** are met:                                                                  **
** 1. Redistributions of source code must retain the above copyright         **
**    notice, this list of conditions and the following disclaimer.          **
** 2. Redistributions in binary form must reproduce the above copyright      **
**    notice, this list of conditions and the following disclaimer in the    **
**    documentation and/or other materials provided with the distribution.   **
** 3. Neither the name of the copyright holder nor the names of its          **
**    contributors may be used to endorse or promote products derived        **
**    from this software without specific prior written permission.          **
**                                                                           **
** THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS       **
** "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT         **
** LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR     **
** A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT      **
** HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,    **
** SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED  **
** TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR    **
** PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF    **
** LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING      **
** NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS        **
** SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.              **
******************************************************************************/
/* Alexander Heinecke (Intel Corp.)
******************************************************************************/
#include "generator_common.h"
#include "generator_gemm_common.h"
#include "generator_gemm_sse3_avx_avx2_avx512.h"
#include "generator_gemm_avx512_fsdbcst.h"
#include "generator_gemm_noarch.h"
#include "libxsmm_main.h"

#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <stdio.h>
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(pop)
#endif

/* @TODO change int based architecture value */
LIBXSMM_API
void libxsmm_generator_gemm_kernel( libxsmm_generated_code*        io_generated_code,
                                    const libxsmm_gemm_descriptor* i_xgemm_desc,
                                    const char*                    i_arch ) {
  /* apply the alignment override */
  libxsmm_gemm_descriptor l_xgemm_desc_mod = *i_xgemm_desc;
  unsigned int l_vector_length = 1;

  /* add instruction set mismatch check to code, header */
  libxsmm_generator_isa_check_header( io_generated_code, i_arch );

  /* determining vector length depending on architecture and precision */
  /* @TODO fix me */
  if ( (strcmp(i_arch, "wsm") == 0) && LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 2;
  } else if ( (strcmp(i_arch, "wsm") == 0) && LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 4;
  } else if ( (strcmp(i_arch, "snb") == 0) && LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 4;
  } else if ( (strcmp(i_arch, "snb") == 0) && LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 8;
  } else if ( (strcmp(i_arch, "hsw") == 0) && LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 4;
  } else if ( (strcmp(i_arch, "hsw") == 0) && LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 8;
  } else if ( (strcmp(i_arch, "knl") == 0) && LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 8;
  } else if ( (strcmp(i_arch, "knl") == 0) && LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
  } else if ( (strcmp(i_arch, "knl") == 0) && LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 2 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
  } else if ( (strcmp(i_arch, "knm") == 0) && LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 8;
  } else if ( (strcmp(i_arch, "knm") == 0) && LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
  } else if ( (strcmp(i_arch, "knm") == 0) && LIBXSMM_GEMM_PRECISION_I16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 8 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
  } else if ( (strcmp(i_arch, "knm") == 0) && LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 2 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
    /* @TODO for now we enforce M==16 for BF16 */
    if ( LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_OUT( l_xgemm_desc_mod.datatype ) ) {
      if ( l_xgemm_desc_mod.m % 16 != 0 ) {
        LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
        return;
      }
    }
  } else if ( (strcmp(i_arch, "skx") == 0) && LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 8;
  } else if ( (strcmp(i_arch, "skx") == 0) && LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
  } else if ( (strcmp(i_arch, "skx") == 0) && LIBXSMM_GEMM_PRECISION_I16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 2 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
  } else if ( (strcmp(i_arch, "skx") == 0) && LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 2 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
    /* @TODO for now we enforce M==16 for BF16 */
    if ( LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_OUT( l_xgemm_desc_mod.datatype ) ) {
      if ( l_xgemm_desc_mod.m % 16 != 0 ) {
        LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
        return;
      }
    }
  } else if ( (strcmp(i_arch, "clx") == 0) && LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 8;
  } else if ( (strcmp(i_arch, "clx") == 0) && LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
  } else if ( (strcmp(i_arch, "clx") == 0) && LIBXSMM_GEMM_PRECISION_I16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 2 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
  } else if ( (strcmp(i_arch, "clx") == 0) && LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 2 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
    /* @TODO for now we enforce M==16 for BF16 */
    if ( LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_OUT( l_xgemm_desc_mod.datatype ) ) {
      if ( l_xgemm_desc_mod.m % 16 != 0 ) {
        LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
        return;
      }
    }
  } else if ( (strcmp(i_arch, "cpx") == 0) && LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 8;
  } else if ( (strcmp(i_arch, "cpx") == 0) && LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
  } else if ( (strcmp(i_arch, "cpx") == 0) && LIBXSMM_GEMM_PRECISION_I16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 2 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
  } else if ( (strcmp(i_arch, "cpx") == 0) && LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) {
    l_vector_length = 16;
    /* some checks as we cannot mask everything */
    if ( (l_xgemm_desc_mod.k % 2 != 0) ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
      return;
    }
    l_xgemm_desc_mod.k = l_xgemm_desc_mod.k/2;
    l_xgemm_desc_mod.ldb = l_xgemm_desc_mod.ldb/2;
    /* @TODO for now we enforce M==16 for BF16 */
    if ( LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_OUT( l_xgemm_desc_mod.datatype ) ) {
      if ( l_xgemm_desc_mod.m % 16 != 0 ) {
        LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
        return;
      }
    }
  } else if ( (strcmp(i_arch, "noarch") == 0) ) {
    /* Nothing to do */
  } else {
    LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH_PREC );
    return;
  }

  /* check LDA */
  if ( l_xgemm_desc_mod.lda < l_xgemm_desc_mod.m ) {
    LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_LDA );
    return;
  }

  /* check LDB */
  if ( (l_xgemm_desc_mod.flags & LIBXSMM_GEMM_FLAG_TRANS_B) > 0 ) {
    if ( l_xgemm_desc_mod.ldb < l_xgemm_desc_mod.n ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_LDB_TRANS );
      return;
    }
  } else {
    if ( l_xgemm_desc_mod.ldb < l_xgemm_desc_mod.k ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_LDB );
      return;
    }
  }

  /* check LDC */
  if ( l_xgemm_desc_mod.ldc < l_xgemm_desc_mod.m ) {
    LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_LDC );
    return;
  }

  /* check for trans B cases which are not supported in the generator */
  if ( (l_xgemm_desc_mod.flags & LIBXSMM_GEMM_FLAG_TRANS_B) > 0 ) {
    if ( (LIBXSMM_GEMM_PRECISION_I16  == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype )) ||
         (LIBXSMM_GEMM_PRECISION_I8   == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype )) ||
         (LIBXSMM_GEMM_PRECISION_BF16 == LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ))    ) {
      LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_TRANS_B );
      return;
    } else {
      /* we are fine, we have transpose support */
    }
  }

  /* check if alignment is not possible */
  if ( 0 != (l_xgemm_desc_mod.lda % l_vector_length) ) {
    l_xgemm_desc_mod.flags &= ~LIBXSMM_GEMM_FLAG_ALIGN_A;
  }
  if ( 0 != (l_xgemm_desc_mod.ldc % l_vector_length) ) {
    l_xgemm_desc_mod.flags &= ~LIBXSMM_GEMM_FLAG_ALIGN_C;
  }

  if ( (strcmp(i_arch, "wsm") == 0) ||
       (strcmp(i_arch, "snb") == 0) ||
       (strcmp(i_arch, "hsw") == 0) ) {
    /* call actual kernel generation with revised parameters */
    libxsmm_generator_gemm_sse3_avx_avx2_avx512_kernel(io_generated_code, &l_xgemm_desc_mod, i_arch );
  } else if ( (strcmp(i_arch, "knl") == 0) ||
              (strcmp(i_arch, "knm") == 0) ) {
    /* call actual kernel generation with revised parameters */
    libxsmm_generator_gemm_avx512_kernel_fsdbcst(io_generated_code, &l_xgemm_desc_mod, i_arch );
  } else if ( (strcmp(i_arch, "skx") == 0) ||
              (strcmp(i_arch, "clx") == 0) ||
              (strcmp(i_arch, "cpx") == 0) ) {
    /* call actual kernel generation with revised parameters */
    if ( ( LIBXSMM_GEMM_PRECISION_I16  != LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype )   &&
           LIBXSMM_GEMM_PRECISION_BF16 != LIBXSMM_GETENUM_INP( l_xgemm_desc_mod.datatype ) ) &&
         ( (l_vector_length == 16  && (l_xgemm_desc_mod.m == 32 || l_xgemm_desc_mod.m == 48 || l_xgemm_desc_mod.m == 64)) ||
           (l_vector_length == 8   && (l_xgemm_desc_mod.m == 16 || l_xgemm_desc_mod.m == 24 || l_xgemm_desc_mod.m == 32))    ) ) {
      libxsmm_generator_gemm_sse3_avx_avx2_avx512_kernel(io_generated_code, &l_xgemm_desc_mod, i_arch );
    } else {
      libxsmm_generator_gemm_avx512_kernel_fsdbcst(io_generated_code, &l_xgemm_desc_mod, i_arch );
    }
  } else if ( (strcmp(i_arch, "noarch") == 0) ) {
    /* call actual kernel generation with revised parameters */
    libxsmm_generator_gemm_noarch_kernel(io_generated_code, &l_xgemm_desc_mod, i_arch );
  } else {
    LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_ARCH );
    return;
  }

  /* add instruction set mismatch check to code, footer */
  libxsmm_generator_isa_check_footer( io_generated_code, i_arch );

  /* add flop counter for debug compilation */
  libxsmm_generator_gemm_add_flop_counter( io_generated_code, i_xgemm_desc );
}


LIBXSMM_API
void libxsmm_generator_gemm_inlineasm( const char*                    i_file_out,
                                       const char*                    i_routine_name,
                                       const libxsmm_gemm_descriptor* i_xgemm_desc,
                                       const char*                    i_arch ) {
  /* init generated code object */
  libxsmm_generated_code l_generated_code;
  l_generated_code.generated_code = NULL;
  l_generated_code.buffer_size = 0;
  l_generated_code.code_size = 0;
  l_generated_code.code_type = 0;
  l_generated_code.last_error = 0;

  /* add signature to code string */
  libxsmm_mmfunction_signature( &l_generated_code, i_routine_name, i_xgemm_desc );

  /* generate the actual kernel code for current description depending on the architecture */
  libxsmm_generator_gemm_kernel( &l_generated_code, i_xgemm_desc, i_arch );

  /* close current function */
  libxsmm_close_function( &l_generated_code );

  /* check for errors during code generation */
  if ( l_generated_code.last_error != 0 ) {
    LIBXSMM_HANDLE_ERROR_VERBOSE( &l_generated_code, l_generated_code.last_error );
    return;
  }

  /* append code to source file */
  {
    FILE *const l_file_handle = fopen( i_file_out, "a" );
    if ( l_file_handle != NULL ) {
      assert(l_generated_code.generated_code != NULL);
      fputs( (const char*)l_generated_code.generated_code, l_file_handle );
      fclose( l_file_handle );
    } else {
      fprintf(stderr, "LIBXSMM ERROR libxsmm_generator_gemm_inlineasm could not write to into destination source file\n");
      return;
    }
  }

  /* free code memory */
  free( l_generated_code.generated_code );
}


LIBXSMM_API
void libxsmm_generator_gemm_directasm(const char*                     i_file_out,
                                       const char*                     i_routine_name,
                                       const libxsmm_gemm_descriptor* i_xgemm_desc,
                                       const char*                     i_arch ) {
  /* init generated code object */
  libxsmm_generated_code l_generated_code;
  l_generated_code.generated_code = NULL;
  l_generated_code.buffer_size = 0;
  l_generated_code.code_size = 0;
  l_generated_code.code_type = 1;
  l_generated_code.last_error = 0;

  /* check if we are not noarch */
  if ( strcmp( i_arch, "noarch" ) == 0 ) {
    fprintf(stderr, "LIBXSMM ERROR, libxsmm_generator_gemm_direct: we cannot create ASM when noarch is specified!\n");
    return;
  }

  /* add signature to code string */
  libxsmm_mmfunction_signature( &l_generated_code, i_routine_name, i_xgemm_desc );

  /* generate the actual kernel code for current description depending on the architecture */
  libxsmm_generator_gemm_kernel( &l_generated_code, i_xgemm_desc, i_arch );

  /* check for errors during code generation */
  if ( l_generated_code.last_error != 0 ) {
    LIBXSMM_HANDLE_ERROR_VERBOSE( &l_generated_code, l_generated_code.last_error );
    return;
  }

  /* append code to source file */
  {
    FILE *const l_file_handle = fopen( i_file_out, "w" );
    if ( l_file_handle != NULL ) {
      assert(l_generated_code.generated_code != NULL);
      fputs( (const char*)l_generated_code.generated_code, l_file_handle );
      fclose( l_file_handle );
    } else {
      fprintf(stderr, "LIBXSMM ERROR, libxsmm_generator_gemm_direct: could not write to into destination source file!\n");
      return;
    }
  }

  /* free code memory */
  free( l_generated_code.generated_code );
}

