/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.analytics.movement;

import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import org.apache.cassandra.analytics.DataGenerationUtils;
import org.apache.cassandra.analytics.ResiliencyTestBase;
import org.apache.cassandra.analytics.TestConsistencyLevel;
import org.apache.cassandra.analytics.TestUninterruptibles;
import org.apache.cassandra.distributed.api.ICluster;
import org.apache.cassandra.distributed.api.IInstance;
import org.apache.cassandra.sidecar.testing.QualifiedName;
import org.apache.cassandra.spark.bulkwriter.WriterOptions;
import org.apache.cassandra.testing.ClusterBuilderConfiguration;
import org.apache.cassandra.testing.utils.ClusterUtils;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.params.provider.Arguments;

abstract class NodeMovementTestBase
extends ResiliencyTestBase {
    public static final int SINGLE_DC_MOVING_NODE_IDX = 5;
    public static final int MULTI_DC_MOVING_NODE_IDX = 3;
    IInstance movingNode;
    Dataset<Row> df;
    Map<? extends IInstance, Set<String>> expectedInstanceData;

    NodeMovementTestBase() {
    }

    protected void runMovingNodeTest(TestConsistencyLevel cl) {
        QualifiedName table = NodeMovementTestBase.uniqueTestTableFullName("spark_test", cl.readCL, cl.writeCL);
        this.bulkWriterDataFrameWriter(this.df, table).option(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name()).save();
        this.validateData(table, cl.readCL, 1000);
        this.validateNodeSpecificData(table, this.expectedInstanceData, false);
    }

    @Override
    protected void beforeTestStart() {
        super.beforeTestStart();
        SparkSession spark = this.getOrCreateSparkSession();
        this.df = DataGenerationUtils.generateCourseData(spark, 1000);
        this.expectedInstanceData = this.generateExpectedInstanceData((ICluster<? extends IInstance>)this.cluster, Collections.singletonList(this.movingNode), 1000);
    }

    protected void afterClusterProvisioned() {
        ClusterBuilderConfiguration configuration = this.testClusterConfiguration();
        int movingNodeIndex = configuration.dcCount > 1 ? 3 : 5;
        this.movingNode = this.cluster.get(movingNodeIndex);
        IInstance seed = this.cluster.get(1);
        new Thread(() -> {
            long moveTarget = NodeMovementTestBase.calculateMoveTargetToken((ICluster<? extends IInstance>)this.cluster, configuration.dcCount);
            this.movingNode.nodetoolResult(new String[]{"move", "--", Long.toString(moveTarget)}).asserts().success();
        }).start();
        TestUninterruptibles.awaitUninterruptiblyOrThrow(this.transitioningStateStart(), 2L, TimeUnit.MINUTES);
        this.cluster.awaitRingState(seed, this.movingNode, "Moving");
    }

    protected abstract CountDownLatch transitioningStateStart();

    protected void completeTransitionAndValidateWrites(CountDownLatch transitionalStateEnd, Stream<Arguments> testInputs, boolean expectFailure) {
        transitionalStateEnd.countDown();
        Assertions.assertThat((Object)this.movingNode).isNotNull();
        if (!expectFailure) {
            this.cluster.awaitRingState(this.cluster.get(1), this.movingNode, "Normal");
        }
        testInputs.forEach(arguments -> {
            TestConsistencyLevel cl = (TestConsistencyLevel)arguments.get()[0];
            QualifiedName tableName = NodeMovementTestBase.uniqueTestTableFullName("spark_test", cl.readCL, cl.writeCL);
            this.validateData(tableName, cl.readCL, 1000);
            this.validateNodeSpecificData(tableName, this.expectedInstanceData, false);
        });
        if (expectFailure) {
            String initialToken = this.movingNode.config().getString("initial_token");
            Optional<ClusterUtils.RingInstanceDetails> movingInstance = ClusterUtils.ring((IInstance)this.cluster.get(1)).stream().filter(i -> i.getAddress().equals(this.movingNode.broadcastAddress().getAddress().getHostAddress())).findFirst();
            Assertions.assertThat(movingInstance).isPresent();
            String state = movingInstance.get().getState();
            Assertions.assertThat((state.equals("Moving") || state.equals("Normal") && movingInstance.get().getToken().equals(initialToken) ? 1 : 0) != 0).isTrue();
        }
    }

    static long calculateMoveTargetToken(ICluster<? extends IInstance> cluster, int dcCount) {
        IInstance seed = cluster.get(1);
        int nextIndex = dcCount > 1 ? 3 : 2;
        long t2 = Long.parseLong(seed.config().getString("initial_token"));
        long t3 = Long.parseLong(cluster.get(nextIndex).config().getString("initial_token"));
        return t2 + (t3 - t2) / 2L;
    }
}

