/*
 * Copyright 2022-2024 Frederico de Oliveira Linhares
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "graphics_pipeline_3d.hpp"

#include <array>
#include <stdexcept>

#include "../int/core.hpp"
#include "skeletal_mesh_vertex.hpp"
#include "uniform_data_object.hpp"

namespace
{

void
load_pipeline(void *obj)
{
  auto self = static_cast<BluCat::GRA::GraphicsPipeline3DSkeletal*>(obj);

  VkPipelineShaderStageCreateInfo vert_shader_stage_info = {};
  vert_shader_stage_info.sType =
      VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  vert_shader_stage_info.pNext = nullptr;
  vert_shader_stage_info.flags = 0;
  vert_shader_stage_info.stage = VK_SHADER_STAGE_VERTEX_BIT;
  vert_shader_stage_info.module =
    BluCat::INT::core.vk_device_with_swapchain->vert3d_skeletal_shader_module;
  vert_shader_stage_info.pName = "main";
  vert_shader_stage_info.pSpecializationInfo = nullptr;

  VkPipelineShaderStageCreateInfo frag_shader_stage_info = {};
  frag_shader_stage_info.sType =
      VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  frag_shader_stage_info.pNext = nullptr;
  frag_shader_stage_info.flags = 0;
  frag_shader_stage_info.stage = VK_SHADER_STAGE_FRAGMENT_BIT;
  frag_shader_stage_info.module =
    BluCat::INT::core.vk_device_with_swapchain->frag3d_shader_module;
  frag_shader_stage_info.pName = "main";
  frag_shader_stage_info.pSpecializationInfo = nullptr;

  VkPipelineShaderStageCreateInfo shader_stages[] = {
    vert_shader_stage_info,
    frag_shader_stage_info
  };

  VkVertexInputBindingDescription vertex_input_binding{};
  vertex_input_binding.binding = 0;
  vertex_input_binding.stride = sizeof(BluCat::GRA::SkeletalMeshVertex);
  vertex_input_binding.inputRate = VK_VERTEX_INPUT_RATE_VERTEX;

  std::array<VkVertexInputAttributeDescription, 5> vertex_attribute{};
  // Position.
  vertex_attribute[0].location = 0;
  vertex_attribute[0].binding = 0;
  vertex_attribute[0].format = VK_FORMAT_R32G32B32_SFLOAT;
  vertex_attribute[0].offset = offsetof(BluCat::GRA::SkeletalMeshVertex, position);
  // Normal.
  vertex_attribute[1].location = 1;
  vertex_attribute[1].binding = 0;
  vertex_attribute[1].format = VK_FORMAT_R32G32B32_SFLOAT;
  vertex_attribute[1].offset = offsetof(BluCat::GRA::SkeletalMeshVertex, normal);
  // Texture coordinate.
  vertex_attribute[2].location = 2;
  vertex_attribute[2].binding = 0;
  vertex_attribute[2].format = VK_FORMAT_R32G32_SFLOAT;
  vertex_attribute[2].offset = offsetof(BluCat::GRA::SkeletalMeshVertex, texture_coord);
  // Bones ids.
  vertex_attribute[3].location = 3;
  vertex_attribute[3].binding = 0;
  vertex_attribute[3].format = VK_FORMAT_R32G32B32A32_SINT;
  vertex_attribute[3].offset = offsetof(BluCat::GRA::SkeletalMeshVertex, bone_ids);
  // Bones weights.
  vertex_attribute[4].location = 4;
  vertex_attribute[4].binding = 0;
  vertex_attribute[4].format = VK_FORMAT_R32G32B32A32_SFLOAT;
  vertex_attribute[4].offset = offsetof(BluCat::GRA::SkeletalMeshVertex, bone_weights);

  VkPipelineVertexInputStateCreateInfo vertex_input_info = {};
  vertex_input_info.sType =
      VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO;
  vertex_input_info.pNext = nullptr;
  vertex_input_info.flags = 0;
  vertex_input_info.vertexBindingDescriptionCount = 1;
  vertex_input_info.pVertexBindingDescriptions = &vertex_input_binding;
  vertex_input_info.vertexAttributeDescriptionCount =
      static_cast<uint32_t>(vertex_attribute.size());
  vertex_input_info.pVertexAttributeDescriptions = vertex_attribute.data();

  VkPipelineInputAssemblyStateCreateInfo input_assembly = {};
  input_assembly.sType =
      VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO;
  input_assembly.pNext = nullptr;
  input_assembly.flags = 0;
  input_assembly.topology = VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
  input_assembly.primitiveRestartEnable = VK_FALSE;

  VkViewport viewport = {};
  viewport.x = 0;
  viewport.y = 0;
  viewport.width = BluCat::INT::core.display_width;
  viewport.height = BluCat::INT::core.display_height;
  viewport.minDepth = 0.0f;
  viewport.maxDepth = 1.0f;

  VkRect2D scissor = {};
  scissor.offset = {0, 0};
  scissor.extent = {BluCat::INT::core.display_width, BluCat::INT::core.display_height};

  VkPipelineViewportStateCreateInfo viewport_state = {};
  viewport_state.sType =
      VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO;
  viewport_state.pNext = nullptr;
  viewport_state.flags = 0;
  viewport_state.viewportCount = 1;
  viewport_state.pViewports = &viewport;
  viewport_state.scissorCount = 1;
  viewport_state.pScissors = &scissor;

  VkPipelineRasterizationStateCreateInfo rasterizer = {};
  rasterizer.sType =
      VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO;
  rasterizer.pNext = nullptr;
  rasterizer.flags = 0;
  rasterizer.depthClampEnable = VK_FALSE;
  rasterizer.rasterizerDiscardEnable = VK_FALSE;
  rasterizer.polygonMode = VK_POLYGON_MODE_FILL;
  rasterizer.cullMode = VK_CULL_MODE_NONE;
  rasterizer.frontFace = VK_FRONT_FACE_COUNTER_CLOCKWISE;
  rasterizer.depthBiasEnable = VK_FALSE;
  rasterizer.depthBiasConstantFactor = 0.0f;
  rasterizer.depthBiasClamp = 0.0f;
  rasterizer.depthBiasSlopeFactor = 0.0f;
  rasterizer.lineWidth = 1.0f;

  VkPipelineMultisampleStateCreateInfo multisampling = {};
  multisampling.sType =
      VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO;
  multisampling.rasterizationSamples = VK_SAMPLE_COUNT_1_BIT;
  multisampling.sampleShadingEnable = VK_FALSE;
  multisampling.minSampleShading = 1.0f;
  multisampling.pSampleMask = nullptr;
  multisampling.alphaToCoverageEnable = VK_FALSE;
  multisampling.alphaToOneEnable = VK_FALSE;

  VkPipelineDepthStencilStateCreateInfo depth_stencil = {};
  depth_stencil.sType =
    VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO;
  depth_stencil.depthTestEnable = VK_TRUE;
  depth_stencil.depthWriteEnable = VK_TRUE;
  depth_stencil.depthCompareOp = VK_COMPARE_OP_LESS;
  depth_stencil.depthBoundsTestEnable = VK_FALSE;
  depth_stencil.minDepthBounds = 0.0f;
  depth_stencil.maxDepthBounds = 1.0f;
  depth_stencil.stencilTestEnable = VK_FALSE;
  depth_stencil.front = {};
  depth_stencil.back = {};

  VkPipelineColorBlendAttachmentState color_blend_attachment = {};
  color_blend_attachment.blendEnable = VK_FALSE;
  color_blend_attachment.srcColorBlendFactor = VK_BLEND_FACTOR_ONE;
  color_blend_attachment.dstColorBlendFactor = VK_BLEND_FACTOR_ZERO;
  color_blend_attachment.colorBlendOp = VK_BLEND_OP_ADD;
  color_blend_attachment.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
  color_blend_attachment.dstAlphaBlendFactor = VK_BLEND_FACTOR_ZERO;
  color_blend_attachment.alphaBlendOp = VK_BLEND_OP_ADD;
  color_blend_attachment.colorWriteMask =
      VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT |
      VK_COLOR_COMPONENT_B_BIT | VK_COLOR_COMPONENT_A_BIT;

  VkPipelineColorBlendStateCreateInfo color_blending = {};
  color_blending.sType =
      VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO;
  color_blending.pNext = nullptr;
  color_blending.flags = 0;
  color_blending.logicOpEnable = VK_FALSE;
  color_blending.logicOp = VK_LOGIC_OP_COPY;
  color_blending.attachmentCount = 1;
  color_blending.pAttachments = &color_blend_attachment;
  color_blending.blendConstants[0] = 0.0f;
  color_blending.blendConstants[1] = 0.0f;
  color_blending.blendConstants[2] = 0.0f;
  color_blending.blendConstants[3] = 0.0f;

  VkDynamicState dynamic_states[] = {
    VK_DYNAMIC_STATE_VIEWPORT,
    VK_DYNAMIC_STATE_LINE_WIDTH
  };

  VkPipelineDynamicStateCreateInfo dynamic_state_info = {};
  dynamic_state_info.sType =
      VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO;
  dynamic_state_info.dynamicStateCount = 2;
  dynamic_state_info.pDynamicStates = dynamic_states;

  VkGraphicsPipelineCreateInfo pipeline_info{};
  pipeline_info.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO;
  pipeline_info.pNext = nullptr;
  pipeline_info.flags = 0;
  pipeline_info.stageCount = 2;
  pipeline_info.pStages = shader_stages;
  pipeline_info.pVertexInputState = &vertex_input_info;
  pipeline_info.pInputAssemblyState = &input_assembly;
  pipeline_info.pTessellationState = nullptr;
  pipeline_info.pViewportState = &viewport_state;
  pipeline_info.pRasterizationState = &rasterizer;
  pipeline_info.pMultisampleState = &multisampling;
  pipeline_info.pDepthStencilState = &depth_stencil;
  pipeline_info.pColorBlendState = &color_blending;
  pipeline_info.pDynamicState = &dynamic_state_info;
  pipeline_info.layout = BluCat::INT::core.vk_graphics_pipeline_3d_layout->pipeline;
  pipeline_info.renderPass = BluCat::INT::core.vk_render_pass->pipeline_3d;
  pipeline_info.subpass = 0;
  pipeline_info.basePipelineHandle = VK_NULL_HANDLE;
  pipeline_info.basePipelineIndex = -1;

  if(vkCreateGraphicsPipelines(
       BluCat::INT::core.vk_device_with_swapchain->device, VK_NULL_HANDLE, 1,
       &pipeline_info, nullptr, &self->graphic_pipeline)
     != VK_SUCCESS)
    throw CommandError{"Failed to create graphics pipeline."};
}

void
unload_pipeline(void *obj)
{
  auto self = static_cast<BluCat::GRA::GraphicsPipeline3DSkeletal*>(obj);

  vkDestroyPipeline(
    BluCat::INT::core.vk_device_with_swapchain->device, self->graphic_pipeline,
		nullptr);
}

const CommandChain loader{
  {&load_pipeline, &unload_pipeline}
};

}

namespace BluCat::GRA
{

GraphicsPipeline3DSkeletal::GraphicsPipeline3DSkeletal()
{
  loader.execute(this);
}

GraphicsPipeline3DSkeletal::~GraphicsPipeline3DSkeletal()
{
  loader.revert(this);
}

void
GraphicsPipeline3DSkeletal::draw(
  std::shared_ptr<View3D> view, const VkCommandBuffer draw_command_buffer,
  const size_t current_frame, const uint32_t image_index)
{
	vkCmdBindPipeline(
		draw_command_buffer, VK_PIPELINE_BIND_POINT_GRAPHICS,
		this->graphic_pipeline);

	// Draw models
	for(auto& [skeletal_mesh, instances]:
				INT::core.vk_renderer->skeletal_models_to_draw[current_frame])
	{
		VkBuffer vertex_buffers[]{skeletal_mesh->vertex_buffer->buffer};
		VkDeviceSize offsets[]{0};

		vkCmdBindVertexBuffers(
			draw_command_buffer, 0, 1, vertex_buffers, offsets);
		vkCmdBindIndexBuffer(
			draw_command_buffer, skeletal_mesh->index_buffer->buffer, 0,
			VK_INDEX_TYPE_UINT32);

		for(auto &instance: instances)
		{ // Object matrix.
			glm::mat4 translation_matrix{1.0f};
			translation_matrix = glm::translate(
				translation_matrix, *instance->position);
			glm::mat4 rotation_matrix{glm::toMat4(*instance->orientation)};

			std::array<VkDescriptorSet, 4> vk_descriptor_sets{
				INT::core.vk_light->descriptor_sets_world[image_index],
				view->descriptor_sets_3d[image_index],
				instance->descriptor_sets[image_index],
				instance->texture->descriptor_sets[image_index]};

			vkCmdBindDescriptorSets(
				draw_command_buffer, VK_PIPELINE_BIND_POINT_GRAPHICS,
				INT::core.vk_graphics_pipeline_3d_layout->pipeline, 0,
				vk_descriptor_sets.size(), vk_descriptor_sets.data(), 0, nullptr);

			vkCmdDrawIndexed(
				draw_command_buffer, skeletal_mesh->index_count, 1, 0, 0, 0);

			BluCat::GRA::UDOSkeletalModel udo_skeletal_model{};
			instance->tick(INT::core.delta_time);
			udo_skeletal_model.base_matrix = translation_matrix * rotation_matrix;
			std::copy(instance->bone_transforms.begin(),
								instance->bone_transforms.end(),
								udo_skeletal_model.bone_matrices);
			instance->uniform_buffers[image_index].copy_data(&udo_skeletal_model);
		}
  }

}

}