/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

// To be linked into merged libraries.

#define _GNU_SOURCE

#include <assert.h>
#include <ctype.h>
#include <stdlib.h>
#include <string.h>

#include <jni.h>

#include "jni_lib_merge.h"
// We define JNI_OnLoad_Weak directly, so this macro doesn't hurt anything,
// but undef it anyway just to avoid possible future confusion.
#undef JNI_OnLoad
#define PRINT(...) JNI_MERGE_PRINT(__VA_ARGS__)

// Class with mapping that will be generated by buck or manually edited by
// users.
static const char* invoke_class_name =
    "com/facebook/react/soloader/OpenSourceMergedSoMapping";

// Stub pre-merged library to ensure that our custom section gets created
// when we merge a group of native libraries with no JNI_OnLoad. We don't
// want this to be instrumented by ASAN to avoid a redzone around it (which
// would add bogus bytes to the pre_merge_jni_libraries section). We also want
// the linker to retain it (since the section will be traversed at runtime).
static struct pre_merge_jni_library pmjl_stub
#if __has_attribute(__retain__)
    __attribute__((__retain__))
#endif
    __attribute__((__section__("pre_merge_jni_libraries")))
    __attribute__((no_sanitize("address"))) = {
        .name = "jni_lib_merge-stub",
        .onload_func = NULL,
};

// References to custom section bounds, filled in by linker.
// Marking these hidden here takes precedence over the global visibility
// of the generated symbols.  This ensures that they are hidden in
// the shared object, so they can't be accidentally referenced by
// another merged library.
__attribute__((__visibility__("hidden"))) extern struct pre_merge_jni_library
    __start_pre_merge_jni_libraries;
__attribute__((__visibility__("hidden"))) extern struct pre_merge_jni_library
    __stop_pre_merge_jni_libraries;

// Returns a malloc'ed string.
static char* method_name_for_invoke(const char* soname) {
  char* name = strdup(soname);
  if (!name) {
    // Can't log here, since we don't know if we're depending on Android
    // logging.
    assert(!"Failed to strdup soname.");
    abort();
  }
  for (char* c = name; *c != '\0'; ++c) {
    if (!isalnum(*c) && *c != '_') {
      *c = '_';
    }
  }
  return name;
}

// Replacement for weak JNI_OnLoad.
jint JNI_OnLoad_Weak(JavaVM* vm, void* reserved) {
#ifdef INVOKE_ALL_JNI_ONLOAD
  (void)invoke_class_name;
  (void)method_name_for_invoke;
  PRINT("Entering merged library JNI_OnLoad. Invoking all.\n");
  JNIEnv* env;
  if ((*vm)->GetEnv(vm, (void**)&env, JNI_VERSION_1_2) != JNI_OK) {
    return JNI_ERR;
  }

  struct pre_merge_jni_library* start = &__start_pre_merge_jni_libraries;
  struct pre_merge_jni_library* stop = &__stop_pre_merge_jni_libraries;
  for (struct pre_merge_jni_library* pmjl = start; pmjl != stop; pmjl++) {
    if (pmjl == &pmjl_stub) {
      continue;
    }
    if (pmjl->onload_func(env, NULL /* unused */) < 0) {
      return JNI_ERR;
    }
  }
  return JNI_VERSION_1_6;
#else
  (void)reserved;

  PRINT("Entering merged library JNI_OnLoad.\n");

  // Get the JNI Env so we can register the original JNI_OnLoad methods.
  JNIEnv* env;
  if ((*vm)->GetEnv(vm, (void**)&env, JNI_VERSION_1_2) != JNI_OK) {
    return JNI_ERR;
  }

  // Find the class we need to register with.
  jclass invoke_class = (*env)->FindClass(env, invoke_class_name);
  if (invoke_class == NULL) {
    return JNI_ERR;
  }

  // Construct the argument to RegisterNatives with proper sanitized names
  // and function pointers.
  struct pre_merge_jni_library* start = &__start_pre_merge_jni_libraries;
  struct pre_merge_jni_library* stop = &__stop_pre_merge_jni_libraries;
  size_t num_merged_libraries = stop - start;
  PRINT(
      "Preparing %zu pre-merged libs (including stub)\n", num_merged_libraries);
  JNINativeMethod* natives = calloc(num_merged_libraries, sizeof(*natives));
  assert(natives != NULL);
  if (natives == NULL) {
    abort();
  }
  JNINativeMethod* cur_native = natives;
  struct pre_merge_jni_library* pmjl = start;
  for (size_t i = 0; i < num_merged_libraries; i++, pmjl++) {
    if (pmjl == &pmjl_stub) {
      continue;
    }
    char* name = method_name_for_invoke(pmjl->name);
    PRINT(
        "Preparing to register %s.  onload_func: %p\n",
        name,
        pmjl->onload_func);

    cur_native->name = name;
    cur_native->signature = "()I";
    cur_native->fnPtr = pmjl->onload_func;
    cur_native++;
  }

  size_t num_actual_methods = cur_native - natives;
  PRINT("About to register %zu actual methods.\n", num_actual_methods);
  jint ret =
      (*env)->RegisterNatives(env, invoke_class, natives, num_actual_methods);

  for (size_t i = 0; i < num_actual_methods; i++) {
    free((void*)natives[i].name);
  }
  free(natives);

  if (ret < 0) {
    // Exception already thrown.
    return JNI_ERR;
  }

  return JNI_VERSION_1_6;
#endif // INVOKE_ALL_JNI_ONLOAD
}
