/*
 * Copyright (c) 2021, Jeffrey Lee
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met: 
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
#ifndef SOA2_H
#define SOA2_H

#include <arm_neon.h>

namespace soa2 {

class vec1f;

inline vec1f recp(vec1f f,int accuracy = 0);

class vec1f
{
private:
	float32x2_t v;
public:
	inline vec1f() {}

	inline vec1f(float f)
	{
		v = vdup_n_f32(f);
	}

	inline vec1f(float32x2_t f)
	{
		v = f;
	}

	inline vec1f(int32x2_t i)
	{
		v = vcvt_f32_s32(i);
	}

	inline operator float32x2_t() const
	{
		return v;
	}

	static inline vec1f vld1(const float *f)
	{
		return vld1_dup_f32(f);
	}

	inline float operator[](const int i) { return vget_lane_f32(v,i); }
	inline void set_elem(const int i,float f) { v = vset_lane_f32(f,v,i); }

	/* Vector math */

	inline vec1f operator +(const vec1f &b) const
	{
		return vadd_f32(*this,b);
	}

	inline vec1f operator -(const vec1f &b) const
	{
		return vsub_f32(*this,b);
	}

	inline vec1f operator *(const vec1f &b) const
	{
		return vmul_f32(*this,b);
	}

	inline vec1f operator /(const vec1f &b) const
	{
		return vmul_f32(*this,recp(b));
	}

	inline vec1f& operator +=(const vec1f &b)
	{
		*this = vadd_f32(*this,b);
		return *this;
	}

	inline vec1f& operator -=(const vec1f &b)
	{
		*this = vsub_f32(*this,b);
		return *this;
	}

	inline vec1f& operator *=(const vec1f &b)
	{
		*this = vmul_f32(*this,b);
		return *this;
	}

	inline vec1f& operator /=(const vec1f &b)
	{
		*this = (*this)/b;
		return *this;
	}

	/* Scalar math */

	inline vec1f operator *(const float b) const
	{
		return vmul_n_f32(*this,b);
	}

	inline vec1f& operator *=(const float b)
	{
		*this = vmul_n_f32(*this,b);
		return *this;
	}

	/* Other operators */

	inline vec1f operator-() const
	{
		return vneg_f32(*this);
	}

	inline uint32x2_t operator<(const vec1f &b) const
	{
		return vclt_f32(*this,b);
	}

	inline uint32x2_t operator>(const vec1f &b) const
	{
		return vcgt_f32(*this,b);
	}

	inline uint32x2_t operator<=(const vec1f &b) const
	{
		return vcle_f32(*this,b);
	}

	inline uint32x2_t operator>=(const vec1f &b) const
	{
		return vcge_f32(*this,b);
	}
};

class vec1i
{
private:
	int32x2_t v;
public:
	inline vec1i() {}

	inline vec1i(int i)
	{
		v = vdup_n_s32(i);
	}

	inline vec1i(float32x2_t f)
	{
		v = vcvt_s32_f32(f);
	}

	inline vec1i(int32x2_t i)
	{
		v = i;
	}

	inline operator int32x2_t() const
	{
		return v;
	}

	inline int operator[](const int i) { return vget_lane_s32(v,i); }

	/* Vector math */

	inline vec1i operator +(const vec1i &b) const
	{
		return vadd_s32(*this,b);
	}

	inline vec1i operator -(const vec1i &b) const
	{
		return vsub_s32(*this,b);
	}

	inline vec1i operator *(const vec1i &b) const
	{
		return vmul_s32(*this,b);
	}

	inline vec1i& operator +=(const vec1i &b)
	{
		*this = vadd_s32(*this,b);
		return *this;
	}

	inline vec1i& operator -=(const vec1i &b)
	{
		*this = vsub_s32(*this,b);
		return *this;
	}

	inline vec1i& operator *=(const vec1i &b)
	{
		*this = vmul_s32(*this,b);
		return *this;
	}

	/* Scalar math */

	inline vec1i operator *(const int b) const
	{
		return vmul_n_s32(*this,b);
	}

	inline vec1i& operator *=(const int b)
	{
		*this = vmul_n_s32(*this,b);
		return *this;
	}

	/* Other operators */

	inline vec1i operator-() const
	{
		return vneg_s32(*this);
	}
};

typedef vector<vec1f,2> vec2f;
typedef vector<vec1f,3> vec3f;
typedef vector<vec1f,4> vec4f;
typedef vector<vec1i,2> vec2i;
typedef vector<vec1i,3> vec3i;
typedef vector<vec1i,4> vec4i;

const int scalar_elem = 2;

inline vec1f abs(vec1f f)
{
	return vabs_f32(f);
}

inline vec1f vec1f_spread(float f,float step)
{
	float32x2_t val;
	for(int i=0;i<scalar_elem;i++)
	{
		val = vset_lane_f32(f,val,i);
		f += step;
	}
	return val;
}

inline vec1f inversesqrt(vec1f f,int accuracy = 0)
{
	float32x2_t recp = vrsqrte_f32(f);
	for(int i=0;i<accuracy;i++)
	{
		float32x2_t est = f*recp;
		float32x2_t step = vrsqrts_f32(est,recp);
		recp = vmul_f32(recp,step);
	}
	return recp;
}

inline vec1f SQRT(vec1f f)
{
	return f*inversesqrt(f);
}

inline vec1f recp(vec1f f,int accuracy)
{
	float32x2_t r = vrecpe_f32(f);
	for(int i=0;i<accuracy;i++)
	{
		float32x2_t step = vrecps_f32(r,f);
		r = vmul_f32(r,step);
	}
	return r;
}

inline vec1f floor(vec1f f)
{
	/* Yuck - VCVT rounds to zero, but floor() needs to round to -inf */
	int32x2_t i = vcvt_s32_f32(f);
	i = vbsl_s32(vclt_f32(f,vdup_n_f32(0)),vsub_s32(i,vdup_n_s32(1)),i);
	return vcvt_f32_s32(i);
}

inline vec1f floor_tozero(vec1f f)
{
	int32x2_t i = vcvt_s32_f32(f);
	return vcvt_f32_s32(i);
}

inline vec1f fract(vec1f f)
{
	return f-floor(f);
}

inline vec1f length(vec1f f)
{
	return f;
}

inline vec1f dot(vec1f a,vec1f b)
{
	return a*b;
}

/* min/max */
inline vec1f min(vec1f a,vec1f b)
{
	return vmin_f32(a,b);
}

inline vec1f max(vec1f a,vec1f b)
{
	return vmax_f32(a,b);
}

inline vec1i min(vec1i a,vec1i b)
{
	return vmin_s32(a,b);
}

inline vec1i max(vec1i a,vec1i b)
{
	return vmax_s32(a,b);
}

inline vec1f select(uint32x2_t mask,vec1f a,vec1f b)
{
	return vbsl_f32(mask,a,b);
}

inline vec1i revelem(vec1i i)
{
	return vext_s32(i,i,1);
}

inline vec1i rgb(vec3f col)
{
	vec1i col3;
	for(int i=2;i>=0;i--)
	{
		vec1i col2 = vcvt_n_s32_f32(col[i],8);
		col2 = min(col2,255);
		col3 = vsli_n_s32(col3, col2,(2-i)<<3);
	}
	return col3;
}

inline vec1i rgb_fast(vec3f col)
{
	vec1i col3;
	for(int i=2;i>=0;i--)
	{
		vec1i col2 = vcvt_n_s32_f32(col[i],8);
		col3 = vsli_n_s32(col3, col2,(2-i)<<3);
	}
	return col3;
}

inline vec1f mla(vec1f a,vec1f b,vec1f c)
{
	return vmla_f32(a,b,c);
}

inline vec1f mla(vec1f a,vec1f b,float c)
{
	return vmla_n_f32(a,b,c);
}

inline vec1i mla(vec1i a,vec1i b,vec1i c)
{
	return vmla_s32(a,b,c);
}

inline vec1i mla(vec1i a,vec1i b,int c)
{
	return vmla_n_s32(a,b,c);
}

inline vec1f mls(vec1f a,vec1f b,vec1f c)
{
	return vmls_f32(a,b,c);
}

inline vec1f mls(vec1f a,vec1f b,float c)
{
	return vmls_n_f32(a,b,c);
}

inline vec1i mls(vec1i a,vec1i b,vec1i c)
{
	return vmls_s32(a,b,c);
}

inline vec1i mls(vec1i a,vec1i b,int c)
{
	return vmls_n_s32(a,b,c);
}

} /* namespace soa2 */

#endif
