import {
  ShaderMaterial,
  Color,
  DataTexture,
  Math as _Math,
  RGBFormat,
  FloatType,
} from "three";
import { extend } from "react-three-fiber";

class CustomMaterial extends ShaderMaterial {
  constructor() {
    super({
      vertexShader: `uniform float shift;
      varying vec2 vUv;
      void main() {
        vec3 pos = position;
        pos.x = pos.x + ((sin(uv.y * 3.1415926535897932384626433832795) * shift * 2.0) * 0.125);
        vUv = uv;
        gl_Position = projectionMatrix * modelViewMatrix * vec4(pos,1.);
      }`,
      fragmentShader: `uniform sampler2D tex;
      uniform sampler2D prevTex;
      uniform sampler2D texDisp;
      uniform float hasTexture;
      uniform float shift;
      uniform vec3 color;
      uniform float opacity;
      uniform float seed;
      uniform float col_s;
      uniform float transition;

      varying vec2 vUv;

      float rand(vec2 co) {
        return fract(sin(dot(co.xy, vec2(12.9898, 78.233))) * 43758.5453);
      }

      void main() {
        vec2 p = vUv;

        float distortion_x = fract(sin(seed*480.3311)*787.3402);
        float distortion_y = fract(sin(seed*770.5555)*244.4653);
        float seed_x = fract(sin(seed*155.8068)*306.9539) * 0.6 - 0.3;
        float seed_y = fract(sin(seed*22.7735)*295.1813) * 0.6 - 0.3;
        float angle = fract(sin(seed*410.7249)*58.7764) * 2. * 3.1415926538 - 3.1415926538;

        float xs = floor(gl_FragCoord.x / 0.5);
        float ys = floor(gl_FragCoord.y / 0.5);
        vec4 normal = texture2D(texDisp, p*seed*seed);

        if (shift > 0.2) {
          if (p.y < distortion_x+col_s && p.y > distortion_x-col_s * seed) {
            if (seed_x > 0.) {
              p.y = 1. - (p.y + distortion_y);
            } else {
              p.y = distortion_y;
            }
          }
          if (p.x < distortion_y + col_s && p.x > distortion_y-col_s * seed) {
            if (seed_y > 0.) {
              p.x = distortion_x;
            } else {
              p.x = 1. - (p.x + distortion_x);
            }
          }
          p.x += normal.x * seed_x * (seed/5.);
          p.y += normal.y * seed_y * (seed/5.);
        }

        vec2 offset = shift * shift * vec2( cos(angle), sin(angle)) * 0.1;

        vec4 cr = transition * texture2D(tex, p + offset) + (1. - transition) * texture2D(prevTex, p + offset);
        vec4 cga = transition * texture2D(tex, p) + (1. - transition) * texture2D(prevTex, p);
        vec4 cb = transition * texture2D(tex, p - offset) + (1. - transition) * texture2D(prevTex, p);

        vec4 snow = 10. * shift * vec4(rand(vec2(xs, ys * 50.)) * 0.05);

        if (hasTexture == 1.0) gl_FragColor = vec4(cr.r, cga.g, cb.b, cga.a) + snow;
        else gl_FragColor = vec4(color, opacity);
      }`,
      uniforms: {
        tex: { value: null },
        prevTex: { value: null },
        hasTexture: { value: 0 },
        shift: { value: 0 },
        opacity: { value: 1 },
        color: { value: new Color("white") },
        seed: { value: 0.02 },
        col_s: { value: 0.05 },
        transition: { value: 1.0 }
      },
    });

    this.uniforms["texDisp"] = {
      value: this.generateHeightMap(),
    };
  }

  generateHeightMap(dt_size = 16) {
    const data_arr = new Float32Array(dt_size * dt_size * 3);
    const length = dt_size * dt_size;

    for (let i = 0; i < length; i++) {
      const val = _Math.randFloat(0, 1);
      data_arr[i * 3 + 0] = val;
      data_arr[i * 3 + 1] = val;
      data_arr[i * 3 + 2] = val;
    }

    const texture = new DataTexture(
      data_arr,
      dt_size,
      dt_size,
      RGBFormat,
      FloatType
    );
    texture.needsUpdate = true;
    return texture;
  }

  set shift(value) {
    this.uniforms.shift.value = value;
  }

  get shift() {
    return this.uniforms.shift.value;
  }

  set seed(value) {
    this.uniforms.seed.value = value;
  }

  set col_s(value) {
    this.uniforms.col_s.value = value;
  }

  set transition(value) {
    this.uniforms.transition.value = value;
  }

  set map(value) {
    this.uniforms.hasTexture.value = !!value;
    this.uniforms.tex.value = value;
  }

  get map() {
    return this.uniforms.tex.value;
  }

  set prevMap(value) {
    this.uniforms.prevTex.value = value;
  }

  get prevMap() {
    return this.uniforms.prevTex.value;
  }

  get color() {
    return this.uniforms.color.value;
  }

  get opacity() {
    return this.uniforms.opacity.value;
  }

  set opacity(value) {
    if (this.uniforms) this.uniforms.opacity.value = value;
  }
}

extend({ CustomMaterial });
