#include "common.h"
#include "night_vision.h"

struct         v2p
{
	float4	factor	: COLOR0;        // for SM3 - factor.rgb - tonemap-prescaled
	float3	tc0		: TEXCOORD0;
	float3	tc1		: TEXCOORD1;
	float4  sv    : SV_Position;
};
struct        _out
{
	float4	low		: SV_Target0;
	float4	high	: SV_Target1;
};


TextureCube	s_sky0	:register(t0);
TextureCube	s_sky1	:register(t1);

//////////////////////////////////////////////////////////////////////////////////////////
// Pixel
_out main( v2p I )
{

	float3	s0		= s_sky0.Sample( smp_rtlinear, I.tc0 );
	float3	s1		= s_sky1.Sample( smp_rtlinear, I.tc1 );

	float3			sky = lerp (s0,s1,I.factor.w);

	sky *= I.factor.rgb; 

	// final tone-mapping
	_out			o;
	tonemap(o.low, o.high, sky, 1 )	; //factor contains tonemapping
	
			float2 texturecoord = I.sv.xy / screen_res.xy;
			float2 texturecoord_2 = I.sv.xy / screen_res.xy;
			float lua_param_nvg_generation = floor(shader_param_8.x);             // 1, 2, or 3
			float lua_param_nvg_num_tubes = frac(shader_param_8.x) * 10.0f;          // 1, 2, 4, 1.1, or 1.2
			float lua_param_nvg_gain_current = floor(shader_param_8.y) / 10.0f;
			float lua_param_vignette_current = floor(shader_param_8.z) / 100.0f;
	
			if (lua_param_nvg_generation > 0 && ((compute_lens_mask(aspect_ratio_correction(texturecoord), lua_param_nvg_num_tubes) == 1.0f || compute_lens_mask(aspect_ratio_correction(texturecoord_2), lua_param_nvg_num_tubes) == 1.0f)))
			{
			o.low.r = dot(o.low.rgb * 5.0f, luma_conversion_coeff);
			o.high.r = dot(o.high.rgb * 5.0f, luma_conversion_coeff);
			
			o.low.gb = 0.0f;
			o.high.gb = 0.0f;
			
			o.low *= lua_param_nvg_gain_current * sky_brightness_factor;
			o.high *= lua_param_nvg_gain_current * sky_brightness_factor;
			
			o.high = clamp(o.high,0.0,1.0);
			o.low= clamp(o.low,0.0,1.0);
			
			float vignette = calc_vignette(lua_param_nvg_num_tubes, texturecoord, lua_param_vignette_current);
			float vignette_2 = calc_vignette(lua_param_nvg_num_tubes, texturecoord_2, lua_param_vignette_current);
			o.low *= (vignette * vignette_2);
			o.high *= (vignette * vignette_2);
			}
	return        	o;
}