/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.planner.runtime.batch.sql

import org.apache.flink.api.common.typeinfo.BasicTypeInfo.LONG_TYPE_INFO
import org.apache.flink.api.common.typeinfo.LocalTimeTypeInfo
import org.apache.flink.api.java.typeutils.{RowTypeInfo, TypeExtractor}
import org.apache.flink.configuration.Configuration
import org.apache.flink.table.api.DataTypes
import org.apache.flink.table.catalog.DataTypeFactory
import org.apache.flink.table.data.StringData
import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.planner.expressions.utils.{Func1, Func18, RichFunc2}
import org.apache.flink.table.planner.runtime.utils.BatchTestBase
import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.JavaTableFunc0
import org.apache.flink.table.planner.runtime.utils.TestData._
import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils.{MyPojo, MyPojoFunc}
import org.apache.flink.table.planner.utils._
import org.apache.flink.table.planner.utils.DateTimeTestUtil._
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.inference.{InputTypeStrategies, TypeInference, TypeStrategies}
import org.apache.flink.types.Row

import org.junit.jupiter.api.{BeforeEach, Test}

class CorrelateITCase extends BatchTestBase {

  @BeforeEach
  override def before(): Unit = {
    super.before()
    registerCollection("inputT", TableFunctionITCase.testData, type3, "a, b, c")
    registerCollection("inputTWithNull", TableFunctionITCase.testDataWithNull, type3, "a, b, c")
    registerCollection("SmallTable3", smallData3, type3, "a, b, c")
  }

  @Test
  def testTableFunction(): Unit = {
    tEnv.createTemporarySystemFunction("func", new TableFunc1)
    checkResult(
      "select c, s from inputT, LATERAL TABLE(func(c)) as T(s)",
      Seq(
        row("Jack#22", "Jack"),
        row("Jack#22", "22"),
        row("John#19", "John"),
        row("John#19", "19"),
        row("Anna#44", "Anna"),
        row("Anna#44", "44")
      )
    )
  }

  @Test
  def testLeftOuterJoin(): Unit = {
    tEnv.createTemporarySystemFunction("func", new TableFunc2)
    checkResult(
      "select c, s, l from inputT LEFT JOIN LATERAL TABLE(func(c)) as T(s, l) ON TRUE",
      Seq(
        row("Jack#22", "Jack", 4),
        row("Jack#22", "22", 2),
        row("John#19", "John", 4),
        row("John#19", "19", 2),
        row("Anna#44", "Anna", 4),
        row("Anna#44", "44", 2),
        row("nosharp", null, null)
      )
    )
  }

  @Test
  def testWithFilter(): Unit = {
    tEnv.createTemporarySystemFunction("func", new TableFunc0)
    checkResult(
      "select c, name, age from inputT, LATERAL TABLE(func(c)) as T(name, age) WHERE T.age > 20",
      Seq(row("Jack#22", "Jack", 22), row("Anna#44", "Anna", 44)))
  }

  @Test
  def testHierarchyType(): Unit = {
    tEnv.createTemporarySystemFunction("func", new HierarchyTableFunction)
    checkResult(
      "select c, name, adult, len from inputT, LATERAL TABLE(func(c)) as T(name, adult, len)",
      Seq(
        row("Jack#22", "Jack", true, 22),
        row("John#19", "John", false, 19),
        row("Anna#44", "Anna", true, 44))
    )
  }

  /** T(name, age) must have the right order with TypeInfo of PojoUser. */
  @Test
  def testPojoType(): Unit = {
    tEnv.createTemporarySystemFunction("func", new PojoTableFunc)
    checkResult(
      "select c, name, age from inputT, LATERAL TABLE(func(c)) as T(name, age) WHERE T.age > 20",
      Seq(row("Jack#22", "Jack", 22), row("Anna#44", "Anna", 44)))
  }

  @Test
  def testUserDefinedTableFunctionWithScalarFunction(): Unit = {
    tEnv.createTemporarySystemFunction("func", new TableFunc1)
    checkResult(
      "select c, s from inputT, LATERAL TABLE(func(SUBSTRING(c, 2))) as T(s)",
      Seq(
        row("Jack#22", "ack"),
        row("Jack#22", "22"),
        row("John#19", "ohn"),
        row("John#19", "19"),
        row("Anna#44", "nna"),
        row("Anna#44", "44")
      )
    )
  }

  @Test
  def testUserDefinedTableFunctionWithScalarFunctionInCondition(): Unit = {
    tEnv.createTemporarySystemFunction("func", new TableFunc0)
    tEnv.createTemporarySystemFunction("func18", Func18)
    tEnv.createTemporarySystemFunction("func1", Func1)
    checkResult(
      "select c, name, age from inputT, LATERAL TABLE(func(c)) as T(name, age) " +
        "where func18(name, 'J') and func1(a) < 3 and func1(age) > 20",
      Seq(
        row("Jack#22", "Jack", 22)
      )
    )
  }

  @Test
  def testLongAndTemporalTypes(): Unit = {
    registerCollection(
      "myT",
      Seq(row(localDate("1990-10-14"), 1000L, localDateTime("1990-10-14 12:10:10"))),
      new RowTypeInfo(
        LocalTimeTypeInfo.LOCAL_DATE,
        LONG_TYPE_INFO,
        LocalTimeTypeInfo.LOCAL_DATE_TIME),
      "x, y, z"
    )
    tEnv.createTemporarySystemFunction("func", new JavaTableFunc0)
    checkResult(
      "select s from myT, LATERAL TABLE(func(x, y, z)) as T(s)",
      Seq(
        row(1000L),
        row(655906210000L),
        row(7591L)
      ))
  }

  @Test
  def testUserDefinedTableFunctionWithParameter(): Unit = {
    tEnv.createTemporarySystemFunction("func", new RichTableFunc1)
    val conf = new Configuration()
    conf.setString("word_separator", "#")
    env.getConfig.setGlobalJobParameters(conf)
    checkResult(
      "select a, s from inputT, LATERAL TABLE(func(c)) as T(s)",
      Seq(row(1, "Jack"), row(1, "22"), row(2, "John"), row(2, "19"), row(3, "Anna"), row(3, "44")))
  }

  @Test
  def testUserDefinedTableFunctionWithScalarFunctionWithParameters(): Unit = {
    tEnv.createTemporarySystemFunction("func", new RichTableFunc1)
    tEnv.createTemporarySystemFunction("func2", new RichFunc2)
    val conf = new Configuration()
    conf.setString("word_separator", "#")
    conf.setString("string.value", "test")
    env.getConfig.setGlobalJobParameters(conf)
    checkResult(
      "select a, s from SmallTable3, LATERAL TABLE(func(func2(c))) as T(s)",
      Seq(
        row(1, "Hi"),
        row(1, "test"),
        row(2, "Hello"),
        row(2, "test"),
        row(3, "Hello world"),
        row(3, "test"))
    )
  }

  @Test
  def testTableFunctionConstructorWithParams(): Unit = {
    tEnv.createTemporarySystemFunction("func30", new TableFunc3(null))
    tEnv.createTemporarySystemFunction("func31", new TableFunc3("OneConf_"))
    tEnv.createTemporarySystemFunction("func32", new TableFunc3("TwoConf_"))
    checkResult(
      "select c, d, f, h, e, g, i from inputT, " +
        "LATERAL TABLE(func30(c)) as T0(d, e), " +
        "LATERAL TABLE(func31(c)) as T1(f, g)," +
        "LATERAL TABLE(func32(c)) as T2(h, i)",
      Seq(
        row("Anna#44", "Anna", "OneConf_Anna", "TwoConf_Anna", 44, 44, 44),
        row("Jack#22", "Jack", "OneConf_Jack", "TwoConf_Jack", 22, 22, 22),
        row("John#19", "John", "OneConf_John", "TwoConf_John", 19, 19, 19)
      )
    )
  }

  @Test
  def testTableFunctionWithVariableArguments(): Unit = {
    tEnv.createTemporarySystemFunction("func", new VarArgsFunc0)
    checkResult(
      "select c, d from inputT, LATERAL TABLE(func('1', '2', c)) as T0(d) where c = 'Jack#22'",
      Seq(
        row("Jack#22", 1),
        row("Jack#22", 2),
        row("Jack#22", "Jack#22")
      )
    )
  }

  @Test
  def testPojoField(): Unit = {
    val data = Seq(row(new MyPojo(5, 105)), row(new MyPojo(6, 11)), row(new MyPojo(7, 12)))
    registerCollection(
      "MyTable",
      data,
      new RowTypeInfo(TypeExtractor.createTypeInfo(classOf[MyPojo])),
      "a")

    // 1. external type for udtf parameter
    tEnv.createTemporarySystemFunction("pojoTFunc", new MyPojoTableFunc)
    checkResult(
      "select s from MyTable, LATERAL TABLE(pojoTFunc(a)) as T(s)",
      Seq(row(105), row(11), row(12)))

    // 2. external type return in udtf
    tEnv.createTemporarySystemFunction("pojoFunc", MyPojoFunc)
    tEnv.createTemporarySystemFunction("toPojoTFunc", new MyToPojoTableFunc)
    checkResult(
      "select b from MyTable, LATERAL TABLE(toPojoTFunc(pojoFunc(a))) as T(b, c)",
      Seq(row(105), row(11), row(12)))
  }

  @Test
  def testTableFunctionWithFinishMethod(): Unit = {
    registerTemporarySystemFunction("udtfWithFinish", classOf[RichTableFuncWithFinish])
    checkResult(
      "select s from inputT, LATERAL TABLE(udtfWithFinish(c)) as T(s)",
      Seq(row("Jack#22"), row("John#19"), row("Anna#44"), row("nosharp"))
    )
  }

// TODO support dynamic type
//  @Test
//  def testDynamicTypeWithSQL(): Unit = {
//    registerFunction("funcDyna0", new UDTFWithDynamicType0)
//    registerFunction("funcDyna1", new UDTFWithDynamicType0)
//    checkResult(
//      "SELECT c,name,len0,len1,name1,len10 FROM inputT JOIN " +
//        "LATERAL TABLE(funcDyna0(c, 'string,int,int')) AS T1(name,len0,len1) ON TRUE JOIN " +
//        "LATERAL TABLE(funcDyna1(c, 'string,int')) AS T2(name1,len10) ON TRUE " +
//        "where c = 'Anna#44'",
//      Seq(
//        row("Anna#44,44,2,2,44,2"),
//        row("Anna#44,44,2,2,Anna,4"),
//        row("Anna#44,Anna,4,4,44,2"),
//        row("Anna#44,Anna,4,4,Anna,4")
//      ))
//  }
//
//  @Test
//  def testDynamicTypeWithSQLAndVariableArgs(): Unit = {
//    registerFunction("funcDyna0", new UDTFWithDynamicTypeAndVariableArgs)
//    registerFunction("funcDyna1", new UDTFWithDynamicTypeAndVariableArgs)
//    checkResult(
//      "SELECT c,name,len0,len1,name1,len10 FROM inputT JOIN " +
//        "LATERAL TABLE(funcDyna0(c, 'string,int,int', 'a', 'b', 'c')) " +
//        "AS T1(name,len0,len1) ON TRUE JOIN " +
//        "LATERAL TABLE(funcDyna1(c, 'string,int', 'a', 'b', 'c')) AS T2(name1,len10) ON TRUE " +
//        "where c = 'Anna#44'",
//      Seq(
//        row("Anna#44,44,2,2,44,2"),
//        row("Anna#44,44,2,2,Anna,4"),
//        row("Anna#44,Anna,4,4,44,2"),
//        row("Anna#44,Anna,4,4,Anna,4")
//      ))
//  }
//
//  @Test
//  def testDynamicTypeWithSQLAndVariableArgsWithMultiEval(): Unit = {
//    val funcDyna0 = new UDTFWithDynamicTypeAndVariableArgs
//    registerFunction("funcDyna0", funcDyna0)
//    checkResult(
//      "SELECT a, b, c, d, e FROM inputT JOIN " +
//        "LATERAL TABLE(funcDyna0(a)) AS T1(d, e) ON TRUE " +
//        "where c = 'Anna#44'",
//      Seq(
//        row("3,2,Anna#44,3,3"),
//        row("3,2,Anna#44,3,3")
//      ))
//  }
}

@SerialVersionUID(1L)
object StringUdFunc extends ScalarFunction {
  def eval(s: String): String = s
}

object TableFunctionITCase {
  lazy val testData = Seq(
    row(1, 1L, "Jack#22"),
    row(2, 2L, "John#19"),
    row(3, 2L, "Anna#44"),
    row(4, 3L, "nosharp")
  )

  lazy val testDataWithNull = Seq(
    row(1, 1L, "Jack#22"),
    row(2, 2L, null),
    row(3, 2L, ""),
    row(4, 3L, "nosharp")
  )
}

@SerialVersionUID(1L)
class MyPojoTableFunc extends TableFunction[Int] {
  def eval(s: MyPojo): Unit = collect(s.f2)

  override def getTypeInference(typeFactory: DataTypeFactory): TypeInference = {
    TypeInference.newBuilder
      .inputTypeStrategy(
        InputTypeStrategies.sequence(
          InputTypeStrategies.explicit(
            DataTypes.STRUCTURED(
              classOf[MyPojo],
              DataTypes.FIELD("f1", DataTypes.INT()),
              DataTypes.FIELD("f2", DataTypes.INT())))))
      .outputTypeStrategy(TypeStrategies.explicit(DataTypes.INT()))
      .build
  }
}

@SerialVersionUID(1L)
class MyToPojoTableFunc extends TableFunction[MyPojo] {
  def eval(s: Int): Unit = collect(new MyPojo(s, s))

  override def getTypeInference(typeFactory: DataTypeFactory): TypeInference = {
    TypeInference.newBuilder
      .typedArguments(DataTypes.INT())
      .outputTypeStrategy(
        TypeStrategies.explicit(
          DataTypes.STRUCTURED(
            classOf[MyPojo],
            DataTypes.FIELD("f1", DataTypes.INT()),
            DataTypes.FIELD("f2", DataTypes.INT()))))
      .build
  }
}

@SerialVersionUID(1L)
class GenericTableFunc[T](t: DataType) extends TableFunction[T] {
  def eval(s: Int): Unit = {
    if (t == DataTypes.STRING) {
      collect(s.toString.asInstanceOf[T])
    } else if (t == DataTypes.INT) {
      collect(s.asInstanceOf[T])
    } else {
      throw new RuntimeException
    }
  }

  override def getTypeInference(typeFactory: DataTypeFactory): TypeInference = {
    TypeInference.newBuilder
      .typedArguments(DataTypes.INT())
      .outputTypeStrategy(TypeStrategies.explicit(t))
      .build
  }
}

@SerialVersionUID(1L)
class BinaryStringTableFunc extends TableFunction[Row] {
  def eval(s: StringData, cons: StringData): Unit = collect(Row.of(s, cons))

  override def getTypeInference(typeFactory: DataTypeFactory): TypeInference = {
    TypeInference.newBuilder
      .typedArguments(
        DataTypes.STRING().bridgedTo(classOf[StringData]),
        DataTypes.STRING().bridgedTo(classOf[StringData]))
      .outputTypeStrategy(
        TypeStrategies.explicit(
          DataTypes.ROW(
            DataTypes.STRING().bridgedTo(classOf[StringData]),
            DataTypes.STRING().bridgedTo(classOf[StringData]))))
      .build
  }
}
