/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts

import { DataType } from '../../../../wasm-common';
import { LOG_DEBUG } from '../../../log';
import { TensorView } from '../../../tensor-view';
import { ShapeUtil } from '../../../util';
import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types';
import {
  createTensorShapeVariables,
  inputVariable,
  outputVariable,
  ShaderHelper,
  tensorTypeToWsglStorageType,
  UniformsArrayType,
} from '../common';
import { ConvTransposeAttributes } from '../conv-transpose';

const createConvTranspose2DOpProgramShaderSource = (
  shaderHelper: ShaderHelper,
  inputs: readonly TensorView[],
  outputShape: readonly number[],
  hasBias: boolean,
  is1DimensionDispatch: boolean,
  isVec4 = false,
  dataType: string,
  uniforms: UniformsArrayType,
  isChannelsLast = false,
): string => {
  const rowDim = isChannelsLast ? 1 : 2;
  const colDim = isChannelsLast ? 2 : 3;
  const channelDim = isChannelsLast ? 3 : 1;
  const workPerThread = isVec4 ? 2 : 1;

  let declareFunctions = `
  fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) {
    result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value);
  }`;
  if (hasBias) {
    declareFunctions += `
    fn getBiasByOutputCoords(coords : vec4<u32>) -> ${isVec4 ? `vec4<${dataType}>` : dataType} {
      return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
    }`;
  }
  const components = isVec4 ? 4 : 1;
  const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
  const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components);
  const inputVariables = [dy, w];
  if (hasBias) {
    inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components));
  }
  const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);

  const codeSnippet4 = `{
        let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1];
        let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1];
        let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread};
        let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4;

        let dyCorner = vec2<i32>(i32(r), i32(c)) - vec2<i32>(uniforms.pads);

        // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
        // ? = to be determined. : = across all values in that axis.
        var dotProd: array<vec4<${dataType}>, ${workPerThread}>;
        for (var i = 0; i < ${workPerThread}; i++) {
          dotProd[i] = vec4<${dataType}>(0.0);
        }
        for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) {
          var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x);
          let wRPerm = uniforms.filter_dims[0] - 1 - wR;
          if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) ||
              fract(dyR) > 0.0 || wRPerm < 0) {
            continue;
          }
          let idyR: u32 = u32(dyR);

          for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) {
            let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
            let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
            let wCPerm = uniforms.filter_dims[1] - 1 - wC;
            if (wCPerm < 0) {
              continue;
            }
            var bDyCVal = true;
            var bDyCVal2 = true;
            if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) ||
                fract(dyC) > 0.0) {
              bDyCVal = false;
            }
            if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) ||
                fract(dyC2) > 0.0) {
              bDyCVal2 = false;
            }

            let idyC: u32 = u32(dyC);
            let idyC2: u32 = u32(dyC2);
            if (bDyCVal && bDyCVal2) {
              let d2Length = uniforms.Dy_shape[3];
              for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) {
                let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
                let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
                let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
                let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

                var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
                let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
                                      dot(xValue, wValue1),
                                      dot(xValue, wValue2),
                                      dot(xValue, wValue3));
                dotProd[0] = dotProd[0] + tmpval;

                xValue =  ${dy.get('batch', 'idyR', 'idyC2', 'd2')};

                dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0),
                                                    dot(xValue, wValue1),
                                                    dot(xValue, wValue2),
                                                    dot(xValue, wValue3));
              }
            } else if (bDyCVal) {
              let d2Length = uniforms.Dy_shape[${channelDim}];
              for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
                let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
                let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
                let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
                let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

                var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
                let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
                                      dot(xValue, wValue1),
                                      dot(xValue, wValue2),
                                      dot(xValue, wValue3));
                dotProd[0] = dotProd[0] + tmpval;
              }
            } else if (bDyCVal2) {
              let d2Length = uniforms.Dy_shape[3];
              for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
                let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
                let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
                let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
                let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

                var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};
                let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
                                      dot(xValue, wValue1),
                                      dot(xValue, wValue2),
                                      dot(xValue, wValue3));
                dotProd[1] = dotProd[1] + tmpval;
              }
            }
          }
        }

        for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) {
          let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`};
          ${output.set('batch', 'r', 'c + i', 'd1', 'value')};
        }
      }`;
  const codeSnippet = `
          let outputIndices = ${output.offsetToIndices('global_idx')};
          let batch = ${output.indicesGet('outputIndices', 0)};
          let d1 = ${output.indicesGet('outputIndices', channelDim)};
          let r = ${output.indicesGet('outputIndices', rowDim)};
          let c = ${output.indicesGet('outputIndices', colDim)};
          let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads;
          let dyRCorner = dyCorner.x;
          let dyCCorner = dyCorner.y;
          let groupId = d1 / uniforms.output_channels_per_group;
          let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
          // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
          // ? = to be determined. : = across all values in that axis.
          var dotProd = ${dataType}(0.0);
          for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
            if (wR % uniforms.dilations.x != 0) {
              continue;
            }
            let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]);
            let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
            if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 ||
                wRPerm < 0) {
              continue;
            }
            let idyR: u32 = u32(dyR);

            for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
              if (wC % uniforms.dilations.y != 0) {
                continue;
              }
              let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
              let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
              if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) ||
                  fract(dyC) > 0.0 || wCPerm < 0) {
                continue;
              }
              let idyC: u32 = u32(dyC);
              var inputChannel = groupId * uniforms.input_channels_per_group;
              for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) {
                let xValue = ${
                  isChannelsLast
                    ? dy.get('batch', 'idyR', 'idyC', 'inputChannel')
                    : dy.get('batch', 'inputChannel', 'idyR', 'idyC')
                };
                let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')};
                dotProd = dotProd + xValue * wValue;
                inputChannel = inputChannel + 1;
              }
            }
          }
          let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`};
          ${output.setByOffset('global_idx', 'value')};
        `;

  return `
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
  ${declareFunctions}

    ${shaderHelper.mainStart()}
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')};
  ${isVec4 ? codeSnippet4 : codeSnippet}}`;
};

export const createConvTranspose2DProgramInfo = (
  inputs: readonly TensorView[],
  attributes: ConvTransposeAttributes,
  squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
  const hasBias = inputs.length > 2;
  // const isChannelsLast = attributes.format === 'NHWC';
  const outputShape = attributes.outputShape;
  const outputSize = ShapeUtil.size(outputShape);

  // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
  // TODO Enable isVec4 for performance
  // Disabled due to weight matrix layout issue
  // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0;
  const dispatch = [Math.ceil(outputSize / 64), 1, 1];
  LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);

  const isChannelsLast = attributes.format === 'NHWC';
  const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
  const strides = [attributes.strides[0], attributes.strides[1]];
  const filterDims = [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
  const dilations = [attributes.dilations[0], attributes.dilations[1]];
  const effectiveFilterDims = [
    filterDims[0] +
      (attributes.dilations[0] <= 1
        ? 0
        : (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)),
    filterDims[1] +
      (attributes.dilations[1] <= 1
        ? 0
        : (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)),
  ];
  const pads = [
    effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
    effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2,
  ];

  const isVec4 = false;
  const group = attributes.group;
  const wShape = inputs[1].dims;
  const inputChannelsPerGroup = wShape[0] / group;
  const outputChannelsPerGroup = wShape[1];

  const programUniforms: ProgramUniform[] = [
    { type: DataType.uint32, data: outputSize },
    { type: DataType.uint32, data: strides },
    { type: DataType.uint32, data: filterDims },
    { type: DataType.uint32, data: dilations },
    { type: DataType.uint32, data: effectiveFilterDims },
    { type: DataType.int32, data: pads },
    { type: DataType.uint32, data: inputChannelsPerGroup },
    { type: DataType.uint32, data: outputChannelsPerGroup },
    ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims),
  ];
  if (hasBias) {
    programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
    inputDependencies.push('rank');
  }
  programUniforms.push(...createTensorShapeVariables(outputShape));

  const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1;
  const getShaderSource = (shaderHelper: ShaderHelper) => {
    const uniforms: UniformsArrayType = [
      { name: 'output_size', type: 'u32' },
      { name: 'strides', type: 'u32', length: strides.length },
      { name: 'filter_dims', type: 'u32', length: filterDims.length },
      { name: 'dilations', type: 'u32', length: filterDims.length },
      { name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length },
      { name: 'pads', type: 'i32', length: pads.length },
      { name: 'input_channels_per_group', type: 'u32' },
      { name: 'output_channels_per_group', type: 'u32' },
    ];
    const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
    return `${createConvTranspose2DOpProgramShaderSource(
      shaderHelper,
      inputs,
      outputShape,
      hasBias,
      is1DimensionDispatch,
      isVec4,
      dataType,
      uniforms,
      isChannelsLast,
    )}`;
  };
  return {
    name: 'ConvTranspose2D',
    shaderCache: { hint: `${attributes.cacheKey};`, inputDependencies },
    getRunData: () => ({
      dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] },
      outputs: [
        {
          dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
          dataType: inputs[0].dataType,
        },
      ],
      programUniforms,
    }),
    getShaderSource,
  };
};
