#include "pyramidavs.h"
#include <math.h>

class FastBlur : public GenericVideoFilter {
private:
	PyramidAVS* pyramid;
	PyramidAVS* transpose;
//	PyramidAVS* _pyramid2;
	std::vector<Pyramid*> benchmarked;
	float blur_x;
	float blur_y;
	float blur_x_c;
	float blur_y_c;
	int iterations;
	bool dither;
	bool gamma;
	float SigmaToBoxRadius(double s, int iterations);
	const char* blur_func;
	void SetupBlurs(double _blur_x, double _blur_y);
	std::pair<float, float> FuncBlur(int n, IScriptEnvironment* env);
	bool has_at_least_v8; // v8 interface frameprop copy support

//	PVideoFrame test;

public:
	FastBlur(PClip _child, double _xblur, double _yblur, const char* _blur_func, int iterations, bool _dither, bool _gamma, float _threads, IScriptEnvironment* env);
	~FastBlur();
	PVideoFrame __stdcall GetFrame(int n, IScriptEnvironment* env);
	int __stdcall SetCacheHints(int cachehints, int frame_range) { return cachehints == CACHE_GET_MTMODE ? MT_MULTI_INSTANCE : 0; }
};

float FastBlur::SigmaToBoxRadius(double s, int iterations) {
	double q = s * s*1.0 / iterations;
	int l = (int)floor((sqrt(12 * q + 1) - 1) * 0.5);
	float a = (float)((2 * l + 1) * (l * (l + 1) - 3 * q) / (6 * (q - (l + 1) * (l + 1))));
	return l + a;
}

void FastBlur::SetupBlurs(double _blur_x, double _blur_y) {
	double _blur_x_c, _blur_y_c;
//	if (_blur_y == -1) _blur_y = _blur_x;

	if (vi.IsYUV()) { //!vi.IsRGB() && vi.NumComponents() > 1) {
		_blur_x_c = _blur_x / (1 << vi.GetPlaneWidthSubsampling(PLANAR_U));
		_blur_y_c = _blur_y / (1 << vi.GetPlaneHeightSubsampling(PLANAR_U));
	} else {
		_blur_x_c = _blur_x;
		_blur_y_c = _blur_y;
	}

	blur_x = SigmaToBoxRadius(_blur_x, iterations);
	blur_y = SigmaToBoxRadius(_blur_y, iterations);
	blur_x_c = SigmaToBoxRadius(_blur_x_c, iterations);
	blur_y_c = SigmaToBoxRadius(_blur_y_c, iterations);
}

std::pair<float, float> FastBlur::FuncBlur(int n, IScriptEnvironment* env) {
	auto result = env->Invoke(blur_func, AVSValue(n));

	if (!result.IsArray()) result = AVSValue(result, 1);

	for (int i = 0; i < result.ArraySize(); ++i) {
		if (!result[i].IsFloat()) throw("Blur function didn't return a valid result");
	}

	return std::pair(result[0].AsFloatf(), result[std::min(1, result.ArraySize() - 1)].AsFloatf());
}

FastBlur::FastBlur(PClip _child, double _blur_x, double _blur_y, const char* _blur_func, int _iterations, bool _dither, bool _gamma, float _threads, IScriptEnvironment* env) : GenericVideoFilter(_child), blur_func(_blur_func), iterations(_iterations), dither(_dither), gamma(_gamma) {
	has_at_least_v8 = true;
	try { env->CheckVersion(8); } catch (const AvisynthError&) { has_at_least_v8 = false; }

	int threads;

	if (_threads < 0) {
		threads = (std::min)((int)std::thread::hardware_concurrency(), (std::max)(2, (int)floor(std::thread::hardware_concurrency() * 0.5 - 1)));
	} else if (_threads == 0) {
		threads = 0;
	} else if (_threads < 1) {
		threads = (std::max)(1, (int)floor(std::thread::hardware_concurrency() * _threads));
	} else {
		threads = (std::min)((int)floor(_threads), (int)std::thread::hardware_concurrency());
	}

	if (blur_func) {
		try {
			FastBlur::FuncBlur(0, env);
		} catch (AvisynthError& e) {
			env->ThrowError("FastBlur: %s", e.msg);
		} catch (const char* e) {
			env->ThrowError("FastBlur: %s", e);
		} catch (IScriptEnvironment::NotFound) {
			env->ThrowError("FastBlur: Function '%s' not found", blur_func);
		}
	} else {
		if (_blur_x < 0 || _blur_y < 0) env->ThrowError("FastBlur: Blur radii must be non-negative");

		SetupBlurs(_blur_x, _blur_y);
	}

	try {
		pyramid = new PyramidAVS(child, 1, dither, NULL, false, threads);
		transpose = new PyramidAVS(child, 1, dither, NULL, true, threads);
	} catch (const char *e) {
		env->ThrowError("FastBlur: %s", e);
	}
}

FastBlur::~FastBlur() {
	delete pyramid;
	delete transpose;
}

PVideoFrame __stdcall FastBlur::GetFrame(int n, IScriptEnvironment* env) {
	int i, p;
	float bx, by;

	PVideoFrame src = child->GetFrame(n, env);
	PVideoFrame dst = src->IsWritable() ? src : has_at_least_v8 ? env->NewVideoFrameP(vi, &src) : env->NewVideoFrame(vi); // frame property support

	double x_time = 0;
	double y_time = 0;

	if (blur_func) {
		float _blur_x, _blur_y;
		try {
			std::tie(_blur_x, _blur_y) = FuncBlur(n, env);
			_blur_x = std::max(0.0f, _blur_x);
			_blur_y = std::max(0.0f, _blur_y);
		} catch (...) {
			_blur_x = _blur_y = 0;
		}

		SetupBlurs(_blur_x, _blur_y);
	}
	
	for (p = 0; p < (int)pyramid->planes.size(); p++) {
		uint8_t* dst_p = dst->GetWritePtr(pyramid->planes[p].plane_id);
		uint8_t* src_p = (uint8_t*)src->GetReadPtr(pyramid->planes[p].plane_id);
		int dst_pitch = dst->GetPitch(pyramid->planes[p].plane_id);
		int src_pitch = src->GetPitch(pyramid->planes[p].plane_id);

		bool _gamma = (gamma && pyramid->planes[p].gamma);

		if (pyramid->planes[p].plane_id == PLANAR_U || pyramid->planes[p].plane_id == PLANAR_V) {
			bx = blur_x_c;
			by = blur_y_c;
		} else {
			bx = blur_x;
			by = blur_y;
		}

		Pyramid* a = pyramid->planes[p].pyramid;
		Pyramid* b = transpose->planes[p].pyramid;

		double div = 1.0;
		double m = ((double)bx * 2 + 1) * ((double)by * 2 + 1);

		pyramid->Copy(p, src_p, src_pitch, _gamma);

		for (i = 0; i < iterations; i++) {
			a->BlurXTranspose(bx, b);
			b->BlurXTranspose(by, a);

			div *= m;
			if (div > 1000000) {
				a->Multiply(0, float(1.0 / div));
				div = 1.0;
			}
		}

		if (div > 1) a->Multiply(0, float(1.0 / div));

		pyramid->Out(p, dst_p, dst_pitch, _gamma, (vi.BitsPerComponent() == 32) | true); // clamping required for float output to avoid accumulation errors resulting in negative values
	}

	return dst;
}

AVSValue __cdecl Create_FastBlur(AVSValue args, void* user_data, IScriptEnvironment* env) {
	PClip clip = args[0].AsClip();
	bool yuy2 = clip->GetVideoInfo().IsYUY2();

	if (yuy2) clip = env->Invoke("ConverttoYV16", clip).AsClip();

	AVSValue out;

	if (user_data) {
		out = new FastBlur(
			clip,
			-1,
			-1,
			args[1].AsString(),
			args[2].AsInt(3),
			args[3].AsBool(false),
			args[4].AsBool(true),
			args[5].AsFloatf(0),
			env
		);
	} else {
		out = new FastBlur(
			clip,
			args[1].AsFloat(),
			args[2].AsFloat(args[1].AsFloatf()),
			nullptr,
			args[3].AsInt(3),
			args[4].AsBool(false),
			args[5].AsBool(true),
			args[6].AsFloatf(-1),
			env
		);
	}

	if (yuy2) out = env->Invoke("ConverttoYUY2", out.AsClip()).AsClip();

	return out;
}

const AVS_Linkage* AVS_linkage = 0;

extern "C" __declspec(dllexport) const char* __stdcall AvisynthPluginInit3(IScriptEnvironment* env, const AVS_Linkage* const vectors) {
	AVS_linkage = vectors;

	env->AddFunction("FastBlur", "cf[y_blur]f[iterations]i[dither]b[gamma]b[threads]f", Create_FastBlur, 0);
	env->AddFunction("FastBlur", "cs[iterations]i[dither]b[gamma]b[threads]f", Create_FastBlur, (void*)1);

	return "FastBlur 0.4";
}
