//    MCL: MiMiC Communication Library
//    Copyright (C) 2015-2025  The MiMiC Authors (see CONTRIBUTORS file for details).
//
//    This file is part of MCL.
//
//    MCL is free software: you can redistribute it and/or modify
//    it under the terms of the GNU Lesser General Public License as
//    published by the Free Software Foundation, either version 3 of
//    the License, or (at your option) any later version.
//
//    MCL is distributed in the hope that it will be useful, but
//    WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

#include "endpoint.h"

#include <filesystem>
#include <string>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "data_types.h"
#include "error_codes.h"
#include "env_controls.h"
#include "test_macros.h"
#include "transfer_logger.h"

using namespace mcl;
using ::testing::_;
using ::testing::Exactly;
using ::testing::DoAll;
using ::testing::Return;
using ::testing::Invoke;

class MockTransport : public Transport {
  public:
    MockTransport() : Transport() {}
    ~MockTransport() override = default;

    MOCK_METHOD(int,  Initialize,       (void *args), (override));
    MOCK_METHOD(int,  Finalize,         (), (override));
    MOCK_METHOD(int,  Abort,            (int error_code), (override));
    MOCK_METHOD(int,  Send,             (void* data, int length, int data_type, int tag, int destination), (override));
    MOCK_METHOD(int,  Receive,          (void* data, int length, int data_type, int tag, int source), (override));
    MOCK_METHOD(int,  ProbeSize,        (int data_type, int source), (override));

    MOCK_METHOD(bool, is_server,        (), (const));
    MOCK_METHOD(bool, is_client,        (), (const));
    MOCK_METHOD(int,  num_programs,     (), (const));
    MOCK_METHOD(int,  program_id,       (), (const));
};

template<int value>
int ReceiveValue(void* data, int count, int data_type, int tag, int source) {
    *static_cast<int*>(data) = value;
    return 0;
}

int ReceiveSource(void* data, int count, int data_type, int tag, int source) {
    *static_cast<int*>(data) = source;
    return 0;
}


class ServerTest : public ::testing::Test {
  protected:
    void SetUp() override {
        is_server_ = true;
        protocol_ = std::make_shared<MockTransport>();
        EXPECT_CALL(*protocol_, program_id()).Times(Exactly(1)).WillRepeatedly(Return(0));
        endpoint_ = new Endpoint(protocol_);
        ASSERT_EQ(endpoint_->id(), 0);
    }

    void TearDown() override { delete endpoint_; }
     
    void ExpectEndpointTypeInquiry(std::shared_ptr<MockTransport>& protocol) {
        EXPECT_CALL(*protocol_, is_server()).Times(Exactly(1)).WillOnce(Return(is_server_));
        EXPECT_CALL(*protocol_, is_client()).Times(Exactly(1)).WillOnce(Return(!is_server_));
    }

    std::shared_ptr<MockTransport> protocol_;
    Endpoint *endpoint_;
    bool is_server_ = true;
    int num_programs_= 5;
};

TEST_F(ServerTest, Initialize) {
    ExpectEndpointTypeInquiry(protocol_);
    EXPECT_CALL(*protocol_, num_programs()).Times(Exactly(1)).WillRepeatedly(Return(num_programs_));
    for (int i = 1; i < num_programs_; ++i) {
        EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, i)).Times(Exactly(1)).WillOnce(ReceiveSource);
    }
    ASSERT_EQ(endpoint_->Initialize(), kSuccess);
    int size = static_cast<int>(endpoint_->location_ids().size());
    ASSERT_EQ(size, num_programs_-1);
    for (int i = 0; i < size; ++i)
        ASSERT_EQ(i+1, endpoint_->location_ids()[i]);
}

TEST_F(ServerTest, InitializeFailedReceive) {
    EXPECT_CALL(*protocol_, is_server()).Times(Exactly(1)).WillOnce(Return(true));
    EXPECT_CALL(*protocol_, num_programs()).WillOnce(Return(2));
    EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, 1)).Times(Exactly(1)).WillOnce(Return(1));
    ASSERT_EQ(endpoint_->Initialize(), 1);
}

TEST_F(ServerTest, InitializeFailedClientId) {
    EXPECT_CALL(*protocol_, is_server()).Times(Exactly(1)).WillOnce(Return(true));
    EXPECT_CALL(*protocol_, num_programs()).WillOnce(Return(2));
    EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, 1)).Times(Exactly(1)).WillOnce(DoAll(ReceiveValue<5>, Return(0)));
    ASSERT_EQ(endpoint_->Initialize(), kErrEndpointNotFound);
}

TEST_F(ServerTest, Finalize) {
    ExpectEndpointTypeInquiry(protocol_);
    EXPECT_CALL(*protocol_, num_programs()).Times(Exactly(1)).WillRepeatedly(Return(num_programs_));
    for (int i = 1; i < num_programs_; ++i) {
        EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, i)).Times(Exactly(1)).WillOnce(ReceiveSource);
    }
    ASSERT_EQ(endpoint_->Initialize(), kSuccess);
    ASSERT_EQ(endpoint_->Finalize(), 0);
    ASSERT_EQ(endpoint_->location_ids().size(), 0);
}


class EndpointTest : public ServerTest {
  protected:
    void SetUp() override {
        ServerTest::SetUp();
        ExpectEndpointTypeInquiry(protocol_);
        EXPECT_CALL(*protocol_, num_programs()).Times(Exactly(1)).WillRepeatedly(Return(num_programs_));
        for (int i = 1; i < num_programs_; ++i) {
            EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, i)).Times(Exactly(1)).WillOnce(ReceiveSource);
        }
        endpoint_->Initialize();
    }

    int fail_client_ = 5;
};

TEST_F(EndpointTest, Send) {
    int data;
    for (int i = 1; i < num_programs_; ++i) {
        EXPECT_CALL(*protocol_, Send(_, 1, kTypeInt, _, i)).Times(Exactly(1));
        endpoint_->Send(&data, 1, kTypeInt, 0, i);
    }
}

TEST_F(EndpointTest, SendFail) {
    int data;
    ASSERT_EQ(endpoint_->Send(&data, 1, kTypeInt, 0, fail_client_), kErrEndpointNotFound);
}

TEST_F(EndpointTest, Receive) {
    for (int i = 1; i < num_programs_; ++i) {
        int data = 0;
        EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, i)).Times(Exactly(1)).WillOnce(ReceiveSource);
        endpoint_->Receive(&data, 1, kTypeInt, 0, i);
        ASSERT_TRUE(data == i);
    }
}

TEST_F(EndpointTest, ReceiveFail) {
    int data;
    ASSERT_EQ(endpoint_->Receive(&data, 1, kTypeInt, 0, fail_client_), kErrEndpointNotFound);
}

TEST_F(EndpointTest, ProbeSize) {
    int probed_size = 10;
    for (int i = 1; i < num_programs_; ++i) {
        EXPECT_CALL(*protocol_, ProbeSize(kTypeInt, i)).Times(Exactly(1)).WillOnce(Return(probed_size));
        ASSERT_EQ(endpoint_->ProbeSize(kTypeInt, i), probed_size);
    }
}

TEST_F(EndpointTest, FailedProbeSize) {
    ASSERT_EQ(endpoint_->ProbeSize(kTypeInt, fail_client_), -1);
}


class ClientTest : public ServerTest {
  protected:
    void SetUp() override {
        is_server_ = false;
        protocol_ = std::make_shared<MockTransport>();
        EXPECT_CALL(*protocol_, program_id()).Times(Exactly(1)).WillRepeatedly(Return(client_id_));
        endpoint_ = new Endpoint(protocol_);
        ASSERT_EQ(endpoint_->id(), client_id_);
    }

    int client_id_ = 1;
};

TEST_F(ClientTest, Initialize) {
    ExpectEndpointTypeInquiry(protocol_);
    EXPECT_CALL(*protocol_, Send(_, 1, kTypeInt, _, 0)).Times(Exactly(1)).WillOnce(Return(0));
    ASSERT_EQ(endpoint_->Initialize(), kSuccess);
    ASSERT_EQ(endpoint_->location_ids().size(), 1u);
    ASSERT_EQ(endpoint_->location_ids()[0], 0);
}

TEST_F(ClientTest, Finalize) {
    ExpectEndpointTypeInquiry(protocol_);
    EXPECT_CALL(*protocol_, Send(_, 1, kTypeInt, _, 0)).Times(Exactly(1)).WillOnce(Return(0));
    ASSERT_EQ(endpoint_->Initialize(), kSuccess);
    ASSERT_EQ(endpoint_->Finalize(), kSuccess);
}

TEST_F(ClientTest, FailedSend) {
    ExpectEndpointTypeInquiry(protocol_);
    EXPECT_CALL(*protocol_, Send(_, 1, kTypeInt, _, 0)).Times(Exactly(1)).WillOnce(Return(1));
    ASSERT_EQ(endpoint_->Initialize(), 1);
}

TEST(RecordTest, Record) {
    SET_ENV_CONTROL(env_record_communication, "ON");
    std::shared_ptr<MockTransport> protocol = std::make_shared<MockTransport>();
    int client_id = 4;
    EXPECT_CALL(*protocol, program_id()).Times(Exactly(1)).WillRepeatedly(Return(client_id));
    Endpoint *endpoint = new Endpoint(protocol);
    ASSERT_EQ(endpoint->id(), client_id);
    
    EXPECT_CALL(*protocol, is_server()).Times(Exactly(1)).WillOnce(Return(false));
    EXPECT_CALL(*protocol, is_client()).Times(Exactly(1)).WillOnce(Return(true));
    EXPECT_CALL(*protocol, Send(_, 1, kTypeInt, _, 0)).Times(Exactly(1)).WillOnce(Return(0));
    ASSERT_EQ(endpoint->Initialize(), kSuccess);

    int size = 5;
    EXPECT_CALL(*protocol, ProbeSize(kTypeInt, 0)).Times(Exactly(1)).WillOnce(Return(size));
    ASSERT_EQ(endpoint->ProbeSize(kTypeInt, 0), size);
    ASSERT_EQ(endpoint->Finalize(), kSuccess);
    delete endpoint;

    std::string file_name = "MCL_LOG_" + std::to_string(client_id);
    TransferLogger *log = new TransferLogger(file_name, READ);

    int value;
    bool value_b;
    ASSERT_EQ(log->Read(&value_b, 1, kTypeChar, kSkipSendTag, kSkipSendTag), kSuccess);
    ASSERT_EQ(value_b, false);
    ASSERT_EQ(log->Read(&value, 1, kTypeInt, 410, 0), kSuccess);
    ASSERT_EQ(value, client_id);
    ASSERT_EQ(log->Read(&value, 1, kTypeInt, kProbeTag, 0), kSuccess);
    ASSERT_EQ(value, size);
    delete log;
    std::remove(file_name.c_str());
}

int main(int argc, char** argv) {
    ::testing::InitGoogleMock(&argc, argv);
    auto result = RUN_ALL_TESTS();
    return result;
}
