/*
 * Copyright (c) 2021, Jeffrey Lee
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met: 
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
#ifndef SOA4_H
#define SOA4_H

#include <arm_neon.h>

namespace soa4 {

class vec1f;

inline vec1f recp(vec1f f,int accuracy = 0);

class vec1f
{
private:
	float32x4_t v;
public:
	inline vec1f() {}

	inline vec1f(float f)
	{
		v = vdupq_n_f32(f);
	}

	inline vec1f(float32x4_t f)
	{
		v = f;
	}

	inline vec1f(int32x4_t i)
	{
		v = vcvtq_f32_s32(i);
	}

	inline operator float32x4_t() const
	{
		return v;
	}

	static inline vec1f vld1(const float *f)
	{
		return vld1q_dup_f32(f);
	}

	inline float operator[](const int i) { return vgetq_lane_f32(v,i); }
	inline void set_elem(const int i,float f) { v = vsetq_lane_f32(f,v,i); }

	/* Vector math */

	inline vec1f operator +(const vec1f &b) const
	{
		return vaddq_f32(*this,b);
	}

	inline vec1f operator -(const vec1f &b) const
	{
		return vsubq_f32(*this,b);
	}

	inline vec1f operator *(const vec1f &b) const
	{
		return vmulq_f32(*this,b);
	}

	inline vec1f operator /(const vec1f &b) const
	{
		return vmulq_f32(*this,recp(b));
	}

	inline vec1f& operator +=(const vec1f &b)
	{
		*this = vaddq_f32(*this,b);
		return *this;
	}

	inline vec1f& operator -=(const vec1f &b)
	{
		*this = vsubq_f32(*this,b);
		return *this;
	}

	inline vec1f& operator *=(const vec1f &b)
	{
		*this = vmulq_f32(*this,b);
		return *this;
	}

	inline vec1f& operator /=(const vec1f &b)
	{
		*this = (*this)/b;
		return *this;
	}

	/* Scalar math */

	inline vec1f operator *(const float b) const
	{
		return vmulq_n_f32(*this,b);
	}

	inline vec1f& operator *=(const float b)
	{
		*this = vmulq_n_f32(*this,b);
		return *this;
	}

	/* Other operators */

	inline vec1f operator-() const
	{
		return vnegq_f32(*this);
	}

	inline uint32x4_t operator<(const vec1f &b) const
	{
		return vcltq_f32(*this,b);
	}

	inline uint32x4_t operator>(const vec1f &b) const
	{
		return vcgtq_f32(*this,b);
	}

	inline uint32x4_t operator<=(const vec1f &b) const
	{
		return vcleq_f32(*this,b);
	}

	inline uint32x4_t operator>=(const vec1f &b) const
	{
		return vcgeq_f32(*this,b);
	}
};

class vec1i
{
private:
	int32x4_t v;
public:
	inline vec1i() {}

	inline vec1i(int i)
	{
		v = vdupq_n_s32(i);
	}

	inline vec1i(float32x4_t f)
	{
		v = vcvtq_s32_f32(f);
	}

	inline vec1i(int32x4_t i)
	{
		v = i;
	}

	inline operator int32x4_t() const
	{
		return v;
	}

	inline int operator[](const int i) { return vgetq_lane_s32(v,i); }

	/* Vector math */

	inline vec1i operator +(const vec1i &b) const
	{
		return vaddq_s32(*this,b);
	}

	inline vec1i operator -(const vec1i &b) const
	{
		return vsubq_s32(*this,b);
	}

	inline vec1i operator *(const vec1i &b) const
	{
		return vmulq_s32(*this,b);
	}

	inline vec1i& operator +=(const vec1i &b)
	{
		*this = vaddq_s32(*this,b);
		return *this;
	}

	inline vec1i& operator -=(const vec1i &b)
	{
		*this = vsubq_s32(*this,b);
		return *this;
	}

	inline vec1i& operator *=(const vec1i &b)
	{
		*this = vmulq_s32(*this,b);
		return *this;
	}

	/* Scalar math */

	inline vec1i operator *(const int b) const
	{
		return vmulq_n_s32(*this,b);
	}

	inline vec1i& operator *=(const int b)
	{
		*this = vmulq_n_s32(*this,b);
		return *this;
	}

	/* Other operators */

	inline vec1i operator-() const
	{
		return vnegq_s32(*this);
	}
};

typedef vector<vec1f,2> vec2f;
typedef vector<vec1f,3> vec3f;
typedef vector<vec1f,4> vec4f;
typedef vector<vec1i,2> vec2i;
typedef vector<vec1i,3> vec3i;
typedef vector<vec1i,4> vec4i;

const int scalar_elem = 4;

inline vec1f abs(vec1f f)
{
	return vabsq_f32(f);
}

inline vec1f vec1f_spread(float f,float step)
{
	float32x4_t val;
	for(int i=0;i<scalar_elem;i++)
	{
		val = vsetq_lane_f32(f,val,i);
		f += step;
	}
	return val;
}

inline vec1f inversesqrt(vec1f f, int accuracy = 0)
{
	float32x4_t recp = vrsqrteq_f32(f);
	for(int i=0;i<accuracy;i++)
	{
		float32x4_t est = f*recp;
		float32x4_t step = vrsqrtsq_f32(est,recp);
		recp = vmulq_f32(recp,step);
	}
	return recp;
}

inline vec1f SQRT(vec1f f)
{
	return f*inversesqrt(f);
}

inline vec1f recp(vec1f f, int accuracy)
{
	float32x4_t r = vrecpeq_f32(f);
	for(int i=0;i<accuracy;i++)
	{
		float32x4_t step = vrecpsq_f32(r,f);
		r = vmulq_f32(r,step);
	}
	return r;
}

inline vec1f floor(vec1f f)
{
	/* Yuck - VCVT rounds to zero, but floor() needs to round to -inf */
	int32x4_t i = vcvtq_s32_f32(f);
	i = vbslq_s32(vcltq_f32(f,vdupq_n_f32(0)),vsubq_s32(i,vdupq_n_s32(1)),i);
	return vcvtq_f32_s32(i);
}

inline vec1f floor_tozero(vec1f f)
{
	int32x4_t i = vcvtq_s32_f32(f);
	return vcvtq_f32_s32(i);
}

inline vec1f fract(vec1f f)
{
	return f-floor(f);
}

inline vec1f length(vec1f f)
{
	return f;
}

inline vec1f dot(vec1f a,vec1f b)
{
	return a*b;
}

/* min/max */
inline vec1f min(vec1f a,vec1f b)
{
	return vminq_f32(a,b);
}

inline vec1f max(vec1f a,vec1f b)
{
	return vmaxq_f32(a,b);
}

inline vec1i min(vec1i a,vec1i b)
{
	return vminq_s32(a,b);
}

inline vec1i max(vec1i a,vec1i b)
{
	return vmaxq_s32(a,b);
}

inline vec1f select(uint32x4_t mask,vec1f a,vec1f b)
{
	return vbslq_f32(mask,a,b);
}

inline vec1i revelem(vec1i i)
{
	int32x2_t low = vget_low_s32(i);
	int32x2_t high = vget_high_s32(i);
	/* Swap words in doublewords */
	low = vext_s32(low,low,1);
	high = vext_s32(high,high,1);
	/* Swap the two doublewords */
	return vcombine_s32(high,low);
}

inline vec1i rgb(vec3f col)
{
	vec1i col3;
	for(int i=2;i>=0;i--)
	{
		vec1i col2 = vcvtq_n_s32_f32(col[i],8);
		col2 = min(col2,255);
		col3 = vsliq_n_s32(col3, col2,(2-i)<<3);
	}
	return col3;
}

inline vec1i rgb_fast(vec3f col)
{
	vec1i col3;
	for(int i=2;i>=0;i--)
	{
		vec1i col2 = vcvtq_n_s32_f32(col[i],8);
		col3 = vsliq_n_s32(col3, col2,(2-i)<<3);
	}
	return col3;
}

inline vec1f mla(vec1f a,vec1f b,vec1f c)
{
	return vmlaq_f32(a,b,c);
}

inline vec1f mla(vec1f a,vec1f b,float c)
{
	return vmlaq_n_f32(a,b,c);
}

inline vec1i mla(vec1i a,vec1i b,vec1i c)
{
	return vmlaq_s32(a,b,c);
}

inline vec1i mla(vec1i a,vec1i b,int c)
{
	return vmlaq_n_s32(a,b,c);
}

inline vec1f mls(vec1f a,vec1f b,vec1f c)
{
	return vmlsq_f32(a,b,c);
}

inline vec1f mls(vec1f a,vec1f b,float c)
{
	return vmlsq_n_f32(a,b,c);
}

inline vec1i mls(vec1i a,vec1i b,vec1i c)
{
	return vmlsq_s32(a,b,c);
}

inline vec1i mls(vec1i a,vec1i b,int c)
{
	return vmlsq_n_s32(a,b,c);
}

} /* namespace soa4 */

#endif
