//====== Copyright © 1996-2007, Valve Corporation, All rights reserved. =======
//
//=============================================================================

// STATIC: "CONVERT_TO_SRGB"				"0..1"	[ps20b][= g_pHardwareConfig->NeedsShaderSRGBConversion()] [PC]
// STATIC: "CONVERT_TO_SRGB"				"0..0"	[= 0] [XBOX]
// STATIC: "BLUR"							"0..1"
// STATIC: "FADEOUTONSILHOUETTE"			"0..1"
// STATIC: "CUBEMAP"						"0..1"
// STATIC: "REFRACTTINTTEXTURE"				"0..1"
// STATIC: "MASKED"							"0..1"
// STATIC: "COLORMODULATE"					"0..1"
// STATIC: "SECONDARY_NORMAL"				"0..1"
// STATIC: "NORMAL_DECODE_MODE"				"0..0"	[XBOX]
// STATIC: "NORMAL_DECODE_MODE"				"0..0"	[PC]
// STATIC: "SHADER_SRGB_READ"				"0..1"	[ps20b]

// DYNAMIC: "PIXELFOGTYPE"					"0..1"
// DYNAMIC: "WRITE_DEPTH_TO_DESTALPHA"		"0..1"	[ps20b] [PC]
// DYNAMIC: "WRITE_DEPTH_TO_DESTALPHA"		"0..0"	[ps20b] [XBOX]

// SKIP: $MASKED && $BLUR

#if defined( SHADER_MODEL_PS_2_0 )
#	define WRITE_DEPTH_TO_DESTALPHA 0
#endif

#include "common_ps_fxc.h"
#include "shader_constant_register_map.h"

sampler NormalSampler2				: register( s1 );
sampler RefractSampler				: register( s2 );
sampler NormalSampler				: register( s3 );
#if CUBEMAP
sampler EnvmapSampler				: register( s4 );
#endif
#if REFRACTTINTTEXTURE
sampler RefractTintSampler			: register( s5 );
#endif

#if NORMAL_DECODE_MODE == NORM_DECODE_ATI2N_ALPHA
sampler AlphaMapSampler				: register( s6 );	// alpha
sampler AlphaMapSampler2			: register( s7 );
#else
#define AlphaMapSampler2			NormalSampler
#define AlphaMapSampler				NormalSampler2
#endif

const float3 g_EnvmapTint			: register( c0 );
const float3 g_RefractTint			: register( c1 );
const float3 g_EnvmapContrast		: register( c2 );
const float3 g_EnvmapSaturation		: register( c3 );
const float4  g_c5					: register( c5 );
#define g_RefractScale g_c5.x
#define g_flTime g_c5.w

const float4 g_FogParams				: register( PSREG_FOG_PARAMS );
const float4 g_EyePos_SpecExponent		: register( PSREG_EYEPOS_SPEC_EXPONENT );

static const int g_BlurCount = BLUR;
static const float g_BlurFraction = 1.0f / 512.0f;
static const float g_HalfBlurFraction = 0.5 * g_BlurFraction;
static const float4 g_BlurFractionVec = float4( g_BlurFraction, g_HalfBlurFraction, 
                                               -g_BlurFraction,-g_HalfBlurFraction );

struct PS_INPUT
{
	float4 vBumpTexCoord			: TEXCOORD0;			// NormalMap1 in xy, NormalMap2 in wz
	float3 vTangentVertToEyeVector	: TEXCOORD1;
	float3 vWorldNormal				: TEXCOORD2; 
	float3 vWorldTangent			: TEXCOORD3; 
	float3 vWorldBinormal			: TEXCOORD4; 
	float3 vRefractXYW				: TEXCOORD5; 
	float3 vWorldViewVector			: TEXCOORD6;
#if COLORMODULATE
	float4 ColorModulate            : COLOR0;
#endif

	float4 worldPos_projPosZ		: TEXCOORD7;		// Necessary for pixel fog
	float4 fogFactorW				: COLOR1;	
};

float4 main( PS_INPUT i ) : COLOR
{
	float3 result;
	
	float pixelFogFactor = CalcPixelFogFactor( PIXELFOGTYPE, g_FogParams, g_EyePos_SpecExponent.z, i.worldPos_projPosZ.z, i.worldPos_projPosZ.w );

#if FADEOUTONSILHOUETTE
	//float blend = -i.projNormal.z;
	float blend = saturate( dot( -i.vWorldViewVector.xyz, i.vWorldNormal.xyz ) );
	blend = blend * blend * blend;
#else
	float blend = 1.0f;
#endif

	// Decompress normal
	float4 vNormal = DecompressNormal( NormalSampler, i.vBumpTexCoord.xy, NORMAL_DECODE_MODE, AlphaMapSampler );

#if SECONDARY_NORMAL
	float3 vNormal2 = DecompressNormal( NormalSampler2, i.vBumpTexCoord.wz, NORMAL_DECODE_MODE, AlphaMapSampler2 );
	vNormal.xyz = normalize( vNormal.xyz + vNormal2.xyz );
#endif

#if REFRACTTINTTEXTURE
	float3 refractTintColor = 2.0 * g_RefractTint * tex2D( RefractTintSampler, i.vBumpTexCoord.xy );
#else
	float3 refractTintColor = g_RefractTint;
#endif

#if COLORMODULATE
	refractTintColor *= i.ColorModulate.rgb;
#endif

	// Perform division by W only once
	float ooW = 1.0f / i.vRefractXYW.z;
	
	// Compute coordinates for sampling refraction
	float2 vRefractTexCoordNoWarp = i.vRefractXYW.xy * ooW;
	float2 vRefractTexCoord = vNormal.xy;
	float scale = vNormal.a * g_RefractScale;
#if COLORMODULATE
	scale *= i.ColorModulate.a;
#endif
	vRefractTexCoord *= scale;
	vRefractTexCoord += vRefractTexCoordNoWarp;

#if (BLUR==1)  // use polyphase magic to convert 9 lookups into 4

	//  basic principle behind this transformation:
	//  [ A  B  C ]
	//  [ D  E  F ]
	//  [ G  H  I ]
	//  use bilinear filtering hardware to weight upper 2x2 samples evenly (0.25* [A + B + D + E]).
	//  scale the upper 2x2 by 4/9 (total area of kernel occupied)
	//  use bilinear filtering hardware to weight right 1x2 samples evenly (0.5*[C + F])
	//  scale right 1x2 by 2/9
	//  use bilinear filtering hardware to weight lower 2x1 samples evenly (0.5*[G + H])
	//  scale bottom 2x1 by 2/9
	//  fetch last sample (I) and scale by 1/9.
	
	float2 upper_2x2_loc = vRefractTexCoord.xy - float2(g_HalfBlurFraction, g_HalfBlurFraction);
	float2 right_1x2_loc = vRefractTexCoord.xy + float2(g_BlurFraction, -g_HalfBlurFraction);
	float2 lower_2x1_loc = vRefractTexCoord.xy + float2(-g_HalfBlurFraction, g_BlurFraction);
	float2 singleton_loc = vRefractTexCoord.xy + float2(g_BlurFraction, g_BlurFraction);
	result  = tex2D(RefractSampler, upper_2x2_loc) * 0.4444444;
	result += tex2D(RefractSampler, right_1x2_loc) * 0.2222222;
	result += tex2D(RefractSampler, lower_2x1_loc) * 0.2222222;
	result += tex2D(RefractSampler, singleton_loc) * 0.1111111;

	#if ( SHADER_SRGB_READ == 1 )
	{
		// Just do this once rather than after every blur step, which is wrong, but much more efficient
		result = GammaToLinear( result );
	}
	#endif

	float3 unblurredColor = tex2D(RefractSampler, vRefractTexCoordNoWarp.xy);
	#if ( SHADER_SRGB_READ == 1 )
	{
		unblurredColor = GammaToLinear( unblurredColor );
	}
	#endif

	result = lerp(unblurredColor, result * refractTintColor, blend);

#elif (BLUR>0)  // iteratively step through render target
	int x, y;

	result = float3( 0.0f, 0.0f, 0.0f );
	for( x = -g_BlurCount; x <= g_BlurCount; x++ )
	{
		for( y = -g_BlurCount; y <= g_BlurCount; y++ )
		{
			result += tex2D( RefractSampler, vRefractTexCoord.xy + float2( g_BlurFraction * x, g_BlurFraction * y ) );
		}
	}

	int width = g_BlurCount * 2 + 1;
	result *= 1.0f / ( width * width );

	#if ( SHADER_SRGB_READ == 1 )
	{
		// Just do this once rather than after every blur step, which is wrong, but much more efficient
		result = GammaToLinear( result );
	}
	#endif

	// result is the blurred one now. . .now lerp.
	float3 unblurredColor = tex2D( RefractSampler, vRefractTexCoordNoWarp.xy );
	#if ( SHADER_SRGB_READ == 1 )
	{
		unblurredColor = GammaToLinear( unblurredColor );
	}
	#endif

	result = lerp( unblurredColor, result * refractTintColor, blend );
#else
#	if MASKED
	float4 fMaskedResult = tex2D( RefractSampler, vRefractTexCoord.xy );
	#if ( SHADER_SRGB_READ == 1 )
	{
		fMaskedResult = GammaToLinear( fMaskedResult );
	}
	#endif

	return FinalOutput( fMaskedResult, pixelFogFactor, PIXELFOGTYPE, TONEMAP_SCALE_NONE );
#	else
	float3 colorWarp = tex2D( RefractSampler, vRefractTexCoord.xy );
	#if ( SHADER_SRGB_READ == 1 )
	{
		colorWarp = GammaToLinear( colorWarp );
	}
	#endif
	float3 colorNoWarp = tex2D( RefractSampler, vRefractTexCoordNoWarp.xy );
	#if ( SHADER_SRGB_READ == 1 )
	{
		colorNoWarp = GammaToLinear( colorNoWarp );
	}
	#endif

	colorWarp *= refractTintColor;
	result = lerp( colorNoWarp, colorWarp, blend );
#	endif
#endif

#if CUBEMAP
	float specularFactor = vNormal.a;

	float3 worldSpaceNormal = Vec3TangentToWorld( vNormal.xyz, i.vWorldNormal, i.vWorldTangent, i.vWorldBinormal );

	float3 reflectVect = CalcReflectionVectorUnnormalized( worldSpaceNormal, i.vTangentVertToEyeVector );
	float3 specularLighting = texCUBE( EnvmapSampler, reflectVect );
	specularLighting *= specularFactor;
	specularLighting *= g_EnvmapTint;
	float3 specularLightingSquared = specularLighting * specularLighting;
	specularLighting = lerp( specularLighting, specularLightingSquared, g_EnvmapContrast );
	float3 greyScale = dot( specularLighting, float3( 0.299f, 0.587f, 0.114f ) );
	specularLighting = lerp( greyScale, specularLighting, g_EnvmapSaturation );
	result += specularLighting;
#endif

#if COLORMODULATE
	float resultAlpha = i.ColorModulate.a * vNormal.a;
#else
	float resultAlpha = vNormal.a;
#endif

	return FinalOutput( float4( result, resultAlpha ), pixelFogFactor, PIXELFOGTYPE, TONEMAP_SCALE_NONE, (WRITE_DEPTH_TO_DESTALPHA != 0), i.worldPos_projPosZ.w );
}