/* Copyright (C) 2024 Wildfire Games.
 * This file is part of 0 A.D.
 *
 * 0 A.D. is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 2 of the License, or
 * (at your option) any later version.
 *
 * 0 A.D. is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with 0 A.D.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "precompiled.h"

#include "renderer/DebugRenderer.h"

#include "graphics/Camera.h"
#include "graphics/Color.h"
#include "graphics/ShaderManager.h"
#include "graphics/ShaderProgram.h"
#include "lib/hash.h"
#include "maths/BoundingBoxAligned.h"
#include "maths/Brush.h"
#include "maths/Matrix3D.h"
#include "maths/Vector3D.h"
#include "ps/CStrInternStatic.h"
#include "renderer/backend/IDeviceCommandContext.h"
#include "renderer/Renderer.h"
#include "renderer/SceneRenderer.h"

#include <cmath>

void CDebugRenderer::Initialize()
{
	const std::array<Renderer::Backend::SVertexAttributeFormat, 1> attributes{{
		{Renderer::Backend::VertexAttributeStream::POSITION,
			Renderer::Backend::Format::R32G32B32_SFLOAT, 0, sizeof(float) * 3,
			Renderer::Backend::VertexAttributeRate::PER_VERTEX, 0}
	}};
	m_VertexInputLayout = g_Renderer.GetVertexInputLayout(attributes);
}

void CDebugRenderer::DrawLine(
	const CVector3D& from, const CVector3D& to, const CColor& color,
	const float width, const bool depthTestEnabled)
{
	if (from == to)
		return;

	DrawLine({from, to}, color, width, depthTestEnabled);
}

void CDebugRenderer::DrawLine(
	const std::vector<CVector3D>& line, const CColor& color,
	const float width, const bool depthTestEnabled)
{
	if (line.size() <= 1)
		return;

	Renderer::Backend::IDeviceCommandContext* deviceCommandContext =
		g_Renderer.GetDeviceCommandContext();

	CShaderTechniquePtr debugLineTech =
		GetShaderTechnique(str_debug_line, color, depthTestEnabled);
	deviceCommandContext->SetGraphicsPipelineState(
		debugLineTech->GetGraphicsPipelineState());
	deviceCommandContext->BeginPass();

	const CCamera& viewCamera = g_Renderer.GetSceneRenderer().GetViewCamera();

	Renderer::Backend::IShaderProgram* debugLineShader = debugLineTech->GetShader();
	const CMatrix3D transform = viewCamera.GetViewProjection();
	deviceCommandContext->SetUniform(
		debugLineShader->GetBindingSlot(str_transform), transform.AsFloatArray());
	deviceCommandContext->SetUniform(
		debugLineShader->GetBindingSlot(str_color), color.AsFloatArray());

	const CVector3D cameraIn = viewCamera.GetOrientation().GetIn();

	std::vector<float> vertices;
	vertices.reserve(line.size() * 6 * 3);
#define ADD(position) \
	vertices.emplace_back((position).X); \
	vertices.emplace_back((position).Y); \
	vertices.emplace_back((position).Z);

	for (size_t idx = 1; idx < line.size(); ++idx)
	{
		const CVector3D from = line[idx - 1];
		const CVector3D to = line[idx];
		const CVector3D direction = (to - from).Normalized();
		const CVector3D view = direction.Dot(cameraIn) > 0.9f ?
			CVector3D(0.0f, 1.0f, 0.0f) :
			cameraIn;
		const CVector3D offset = view.Cross(direction).Normalized() * width;

		ADD(from + offset)
		ADD(to - offset)
		ADD(to + offset)
		ADD(from + offset)
		ADD(from - offset)
		ADD(to - offset)
	}

#undef ADD

	deviceCommandContext->SetVertexInputLayout(m_VertexInputLayout);
	deviceCommandContext->SetVertexBufferData(
		0, vertices.data(), vertices.size() * sizeof(vertices[0]));

	deviceCommandContext->Draw(0, vertices.size() / 3);

	deviceCommandContext->EndPass();
}

void CDebugRenderer::DrawCircle(const CVector3D& origin, const float radius, const CColor& color)
{
	CShaderTechniquePtr debugCircleTech =
		GetShaderTechnique(str_debug_line, color);

	Renderer::Backend::IDeviceCommandContext* deviceCommandContext =
		g_Renderer.GetDeviceCommandContext();

	deviceCommandContext->SetGraphicsPipelineState(
		debugCircleTech->GetGraphicsPipelineState());
	deviceCommandContext->BeginPass();

	const CCamera& camera = g_Renderer.GetSceneRenderer().GetViewCamera();

	Renderer::Backend::IShaderProgram* debugCircleShader = debugCircleTech->GetShader();

	const CMatrix3D transform = camera.GetViewProjection();
	deviceCommandContext->SetUniform(
		debugCircleShader->GetBindingSlot(str_transform), transform.AsFloatArray());
	deviceCommandContext->SetUniform(
		debugCircleShader->GetBindingSlot(str_color), color.AsFloatArray());

	const CVector3D cameraUp = camera.GetOrientation().GetUp();
	const CVector3D cameraLeft = camera.GetOrientation().GetLeft();

	std::vector<float> vertices;
#define ADD(position) \
	vertices.emplace_back((position).X); \
	vertices.emplace_back((position).Y); \
	vertices.emplace_back((position).Z);

	constexpr size_t segments = 16;
	for (size_t idx = 0; idx <= segments; ++idx)
	{
		const float angle = M_PI * 2.0f * idx / segments;
		const CVector3D offset = cameraUp * sin(angle) - cameraLeft * cos(angle);
		const float nextAngle = M_PI * 2.0f * (idx + 1) / segments;
		const CVector3D nextOffset = cameraUp * sin(nextAngle) - cameraLeft * cos(nextAngle);
		ADD(origin)
		ADD(origin + offset * radius)
		ADD(origin + nextOffset * radius)
	}

#undef ADD

	deviceCommandContext->SetVertexInputLayout(m_VertexInputLayout);
	deviceCommandContext->SetVertexBufferData(
		0, vertices.data(), vertices.size() * sizeof(vertices[0]));

	deviceCommandContext->Draw(0, vertices.size() / 3);

	deviceCommandContext->EndPass();
}

void CDebugRenderer::DrawCameraFrustum(const CCamera& camera, const CColor& color, int intermediates, bool wireframe)
{
	CCamera::Quad nearPoints{camera.GetViewQuad(camera.GetNearPlane())};
	for (CVector3D& point : nearPoints)
		point = camera.m_Orientation.Transform(point);
	CCamera::Quad farPoints{camera.GetViewQuad(camera.GetFarPlane())};
	for (CVector3D& point : farPoints)
		point = camera.m_Orientation.Transform(point);

	CShaderTechniquePtr overlayTech =
		GetShaderTechnique(str_debug_line, color, true, wireframe);

	Renderer::Backend::IDeviceCommandContext* deviceCommandContext =
		g_Renderer.GetDeviceCommandContext();
	deviceCommandContext->SetGraphicsPipelineState(
		overlayTech->GetGraphicsPipelineState());
	deviceCommandContext->BeginPass();

	Renderer::Backend::IShaderProgram* overlayShader = overlayTech->GetShader();

	const CMatrix3D transform = g_Renderer.GetSceneRenderer().GetViewCamera().GetViewProjection();
	deviceCommandContext->SetUniform(
		overlayShader->GetBindingSlot(str_transform), transform.AsFloatArray());
	deviceCommandContext->SetUniform(
		overlayShader->GetBindingSlot(str_color), color.AsFloatArray());

	std::vector<float> vertices;
#define ADD(position) \
	vertices.emplace_back((position).X); \
	vertices.emplace_back((position).Y); \
	vertices.emplace_back((position).Z);

	// Near plane.
	ADD(nearPoints[0]);
	ADD(nearPoints[1]);
	ADD(nearPoints[2]);
	ADD(nearPoints[0]);
	ADD(nearPoints[2]);
	ADD(nearPoints[3]);

	// Far plane.
	ADD(farPoints[0]);
	ADD(farPoints[1]);
	ADD(farPoints[2]);
	ADD(farPoints[0]);
	ADD(farPoints[2]);
	ADD(farPoints[3]);

	// Intermediate planes.
	CVector3D intermediatePoints[4];
	for (int i = 0; i < intermediates; ++i)
	{
		const float t = (i + 1.0f) / (intermediates + 1.0f);

		for (int j = 0; j < 4; ++j)
			intermediatePoints[j] = nearPoints[j] * t + farPoints[j] * (1.0f - t);

		ADD(intermediatePoints[0]);
		ADD(intermediatePoints[1]);
		ADD(intermediatePoints[2]);
		ADD(intermediatePoints[0]);
		ADD(intermediatePoints[2]);
		ADD(intermediatePoints[3]);
	}

	deviceCommandContext->SetVertexInputLayout(m_VertexInputLayout);
	deviceCommandContext->SetVertexBufferData(
		0, vertices.data(), vertices.size() * sizeof(vertices[0]));

	deviceCommandContext->Draw(0, vertices.size() / 3);

	vertices.clear();

	// Connection lines.
	for (int i = 0; i < 4; ++i)
	{
		const int nextI = (i + 1) % 4;
		ADD(nearPoints[i]);
		ADD(farPoints[nextI]);
		ADD(farPoints[i]);
		ADD(nearPoints[i]);
		ADD(nearPoints[nextI]);
		ADD(farPoints[nextI]);
	}

	deviceCommandContext->SetVertexInputLayout(m_VertexInputLayout);
	deviceCommandContext->SetVertexBufferData(
		0, vertices.data(), vertices.size() * sizeof(vertices[0]));

	deviceCommandContext->Draw(0, vertices.size() / 3);
#undef ADD

	deviceCommandContext->EndPass();
}

void CDebugRenderer::DrawBoundingBox(
	const CBoundingBoxAligned& boundingBox, const CColor& color,
	bool wireframe)
{
	DrawBoundingBox(
		boundingBox, color,
		g_Renderer.GetSceneRenderer().GetViewCamera().GetViewProjection(), wireframe);
}

void CDebugRenderer::DrawBoundingBox(
	const CBoundingBoxAligned& boundingBox, const CColor& color,
	const CMatrix3D& transform, bool wireframe)
{
	CShaderTechniquePtr shaderTech =
		GetShaderTechnique(str_debug_line, color, true, wireframe);

	Renderer::Backend::IDeviceCommandContext* deviceCommandContext =
		g_Renderer.GetDeviceCommandContext();
	deviceCommandContext->SetGraphicsPipelineState(
		shaderTech->GetGraphicsPipelineState());
	deviceCommandContext->BeginPass();

	Renderer::Backend::IShaderProgram* shader = shaderTech->GetShader();

	deviceCommandContext->SetUniform(
		shader->GetBindingSlot(str_transform), transform.AsFloatArray());
	deviceCommandContext->SetUniform(
		shader->GetBindingSlot(str_color), color.AsFloatArray());

	std::vector<float> data;

#define ADD_FACE(x, y, z) \
	ADD_PT(0, 0, x, y, z); ADD_PT(1, 0, x, y, z); ADD_PT(1, 1, x, y, z); \
	ADD_PT(1, 1, x, y, z); ADD_PT(0, 1, x, y, z); ADD_PT(0, 0, x, y, z);
#define ADD_PT(u_, v_, x, y, z) \
	STMT(int u = u_; int v = v_; \
		data.push_back(boundingBox[x].X); \
		data.push_back(boundingBox[y].Y); \
		data.push_back(boundingBox[z].Z); \
	)

	ADD_FACE(u, v, 0);
	ADD_FACE(0, u, v);
	ADD_FACE(u, 0, 1-v);
	ADD_FACE(u, 1-v, 1);
	ADD_FACE(1, u, 1-v);
	ADD_FACE(u, 1, v);

#undef ADD_FACE

	deviceCommandContext->SetVertexInputLayout(m_VertexInputLayout);
	deviceCommandContext->SetVertexBufferData(
		0, data.data(), data.size() * sizeof(data[0]));

	deviceCommandContext->Draw(0, 6 * 6);

	deviceCommandContext->EndPass();
}

void CDebugRenderer::DrawBrush(const CBrush& brush, const CColor& color, bool wireframe)
{
	CShaderTechniquePtr shaderTech =
		GetShaderTechnique(str_debug_line, color, true, wireframe);

	Renderer::Backend::IDeviceCommandContext* deviceCommandContext =
		g_Renderer.GetDeviceCommandContext();
	deviceCommandContext->SetGraphicsPipelineState(
		shaderTech->GetGraphicsPipelineState());
	deviceCommandContext->BeginPass();

	Renderer::Backend::IShaderProgram* shader = shaderTech->GetShader();

	const CMatrix3D transform = g_Renderer.GetSceneRenderer().GetViewCamera().GetViewProjection();
	deviceCommandContext->SetUniform(
		shader->GetBindingSlot(str_transform), transform.AsFloatArray());
	deviceCommandContext->SetUniform(
		shader->GetBindingSlot(str_color), color.AsFloatArray());

	std::vector<float> data;

	std::vector<std::vector<size_t>> faces;
	brush.GetFaces(faces);

#define ADD_VERT(a) \
	STMT( \
		data.push_back(brush.GetVertices()[faces[i][a]].X); \
		data.push_back(brush.GetVertices()[faces[i][a]].Y); \
		data.push_back(brush.GetVertices()[faces[i][a]].Z); \
	)

	for (size_t i = 0; i < faces.size(); ++i)
	{
		// Triangulate into (0,1,2), (0,2,3), ...
		for (size_t j = 1; j < faces[i].size() - 2; ++j)
		{
			ADD_VERT(0);
			ADD_VERT(j);
			ADD_VERT(j+1);
		}
	}

#undef ADD_VERT

	deviceCommandContext->SetVertexInputLayout(m_VertexInputLayout);
	deviceCommandContext->SetVertexBufferData(
		0, data.data(), data.size() * sizeof(data[0]));

	deviceCommandContext->Draw(0, data.size() / 5);

	deviceCommandContext->EndPass();
}

size_t CDebugRenderer::ShaderTechniqueKeyHash::operator()(
	const ShaderTechniqueKey& key) const
{
	size_t seed = 0;
	hash_combine(seed, key.name.GetHash());
	hash_combine(seed, key.transparent);
	hash_combine(seed, key.depthTestEnabled);
	hash_combine(seed, key.wireframe);
	return seed;
}

bool CDebugRenderer::ShaderTechniqueKeyEqual::operator()(
	const ShaderTechniqueKey& lhs, const ShaderTechniqueKey& rhs) const
{
	return
		lhs.name == rhs.name && lhs.transparent == rhs.transparent &&
		lhs.depthTestEnabled == rhs.depthTestEnabled &&
		lhs.wireframe == rhs.wireframe;
}

const CShaderTechniquePtr& CDebugRenderer::GetShaderTechnique(
	const CStrIntern name, const CColor& color, const bool depthTestEnabled,
	const bool wireframe)
{
	const ShaderTechniqueKey key{
		name, color.a != 1.0f, depthTestEnabled, wireframe};
	CShaderTechniquePtr& shaderTechnique = m_ShaderTechniqueMapping[key];
	if (shaderTechnique)
		return shaderTechnique;

	shaderTechnique = g_Renderer.GetShaderManager().LoadEffect(
		name, {},
		[key](Renderer::Backend::SGraphicsPipelineStateDesc& pipelineStateDesc)
		{
			pipelineStateDesc.depthStencilState.depthTestEnabled = key.depthTestEnabled;
			if (key.transparent)
			{
				pipelineStateDesc.blendState.enabled = true;
				pipelineStateDesc.blendState.srcColorBlendFactor = pipelineStateDesc.blendState.srcAlphaBlendFactor =
					Renderer::Backend::BlendFactor::SRC_ALPHA;
				pipelineStateDesc.blendState.dstColorBlendFactor = pipelineStateDesc.blendState.dstAlphaBlendFactor =
					Renderer::Backend::BlendFactor::ONE_MINUS_SRC_ALPHA;
				pipelineStateDesc.blendState.colorBlendOp = pipelineStateDesc.blendState.alphaBlendOp =
					Renderer::Backend::BlendOp::ADD;
			}
			else
				pipelineStateDesc.blendState.enabled = false;
			if (key.wireframe)
				pipelineStateDesc.rasterizationState.polygonMode = Renderer::Backend::PolygonMode::LINE;
			pipelineStateDesc.rasterizationState.cullMode = Renderer::Backend::CullMode::NONE;
		});
	return shaderTechnique;
}
