
package org.botdesigner.blueprint.ui

import androidx.compose.foundation.background
import androidx.compose.foundation.gestures.awaitEachGesture
import androidx.compose.foundation.gestures.detectTapGestures
import androidx.compose.foundation.gestures.detectTransformGestures
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.offset
import androidx.compose.foundation.layout.requiredSize
import androidx.compose.material3.MaterialTheme
import androidx.compose.runtime.Composable
import androidx.compose.runtime.CompositionLocalProvider
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
import androidx.compose.runtime.key
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberUpdatedState
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.CacheDrawScope
import androidx.compose.ui.draw.DrawResult
import androidx.compose.ui.draw.drawWithCache
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.geometry.Size
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.Path
import androidx.compose.ui.graphics.Shape
import androidx.compose.ui.graphics.drawscope.Stroke
import androidx.compose.ui.graphics.drawscope.translate
import androidx.compose.ui.graphics.graphicsLayer
import androidx.compose.ui.input.pointer.PointerEventPass
import androidx.compose.ui.input.pointer.PointerEventType
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.platform.LocalDensity
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.unit.Density
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import androidx.compose.ui.util.fastForEach
import androidx.compose.ui.zIndex
import org.botdesigner.blueprint.BlueprintManager
import org.botdesigner.blueprint.ConnectorLinesMap
import org.botdesigner.blueprint.PinLinesMap
import org.botdesigner.blueprint.components.BlueprintNodeStateHolder
import org.botdesigner.blueprint.components.Connector
import org.botdesigner.blueprint.components.Pin
import org.botdesigner.blueprint.scrollDeltaY
import org.botdesigner.blueprint.ui.components.LocalDeviceDensity
import kotlin.math.max
import kotlin.math.min
import kotlin.math.roundToInt

// Field size in pixels on ALL devices.
// So node coordinates are equal on all devices despite different density.
// ~15 bits - MAX available value
// + workaround for https://github.com/JetBrains/compose-multiplatform/issues/2807
internal const val FIELD_SIZE_PX = 32700f

@Composable
fun BlueprintWidget(
    manager: BlueprintManager,
    modifier: Modifier = Modifier,
    nodeModifier: (BlueprintNodeStateHolder<*>) -> Modifier = { Modifier },
    nodeShape: Shape = MaterialTheme.shapes.medium,
    backgroundColor: Color = Color(0xff1c1c1c),
    gridColor: Color = Color(0xff353535),
    borderColor: Color = Color.Red,
    cellSize: Dp = 50.dp,
    fieldSize: Dp = 30000.dp,
    gridVisible: Boolean = true,
    borderVisible: Boolean = true
) {

    val halfFieldSize = fieldSize / 2

    val topDensity = LocalDensity.current

    val densityScale = remember(halfFieldSize, topDensity) {
        topDensity.run {
            halfFieldSize.toPx() / FIELD_SIZE_PX * 2
        }
    }

    CompositionLocalProvider(
        LocalDeviceDensity provides topDensity,
        LocalDensity provides Density(
            density = topDensity.density / densityScale,
            fontScale = topDensity.fontScale,
        ),
    ) {

        val density = LocalDensity.current

        val halfFieldSizePx = remember(density, halfFieldSize) {
            density.run { halfFieldSize.toPx() }
        }

        val focusManager = LocalFocusManager.current

        LaunchedEffect(densityScale) {
            manager.onDensityScaleChanged(densityScale)
        }

        CompositionLocalProvider(
            LocalBlueprintSize provides halfFieldSizePx
        ) {

            val cachedConnectorsPath = remember {
                Path()
            }

            val pinLinesPathPool = remember {
                mutableMapOf<Color,Path>()
            }

            val pointerModifier = if (manager.isDragEnabled || manager.isScaleEnabled)
                Modifier
                    .pointerInput(0){
                        detectTapGestures {
                            focusManager.clearFocus(force = true)
                        }
                    }
                    .pointerInput(manager, halfFieldSizePx) {
                        detectTransformGestures { _, pan, zoom, _ ->
                            focusManager.clearFocus()
                            manager.onDrag(pan / manager.scale, halfFieldSizePx * 2)
                            manager.onScale(zoom)
                        }
                    }
                    .pointerInput(densityScale) {
                        awaitEachGesture {
                            val event = awaitPointerEvent(PointerEventPass.Main)

                            if (event.type == PointerEventType.Scroll) {
                                val change = event.scrollDeltaY() / 25 * densityScale

                                manager.onScale(
                                    (manager.scale - change) / manager.scale,
                                )
                            }
                        }
                    }
            else Modifier

            Box(
                modifier = modifier
                    .background(backgroundColor)
                    .then(pointerModifier)
                    .requiredSize(halfFieldSize * 2)
                    .graphicsLayer {
                        scaleX = manager.scale
                        scaleY = manager.scale
                        translationX = manager.translation.x * manager.scale
                        translationY = manager.translation.y * manager.scale
                    }.drawWithCache {
                        // separate drawWithCache because of different caching cycles
                        drawField(
                            gridVisible = gridVisible,
                            borderVisible = borderVisible,
                            cellSize = cellSize,
                            halfFieldSizePx = halfFieldSizePx,
                            gridColor = gridColor,
                            borderColor = borderColor
                        )
                    }.drawWithCache {
                        drawPinLines(
                            pathPool = pinLinesPathPool,
                            manager = manager,
                            halfFieldSizePx = halfFieldSizePx
                        )
                    }.drawWithCache {
                        drawConnectorLines(
                            cachedPath = cachedConnectorsPath,
                            manager = manager,
                            halfFieldSizePx = halfFieldSizePx
                        )
                    }
            ) {

                val windowSize by rememberUpdatedState(windowSize())

                val iconSize = density.run { 50.dp.toPx() }
                val maxWidth = density.run { 450.dp.toPx() }


                // REQUIRED!!!! workaround for map load crash
                if (manager.components.isEmpty())
                    return@Box

                manager.components.forEach { (_, v) ->
                    key(v.id.id) {
                        val isInsideViewport by remember(density, v) {
                            derivedStateOf {
                                v.isInsideViewport(
                                    viewportSize = windowSize / manager.scale,
                                    maxWidth = maxWidth,
                                    iconSize = iconSize,
                                    offset = manager.translation
                                )
                            }
                        }

                        if (isInsideViewport) {
                            v.Draw(
                                shape = nodeShape,
                                fieldSize = halfFieldSizePx,
                                modifier = Modifier
                                    .zIndex((v.lastChange.value % Int.MAX_VALUE).toFloat())
                                    .offset(halfFieldSize, halfFieldSize)
                                        then nodeModifier(v)
                            )
                        }
                    }
                }
            }
        }
    }
}
//
//suspend fun AwaitPointerEventScope.awaitFirstDownWithShift(
//    requireUnconsumed: Boolean = true,
//    pass: PointerEventPass = PointerEventPass.Main,
//): PointerInputChange {
//    var event: PointerEvent
//    do {
//        event = awaitPointerEvent(pass)
//    } while (!event.isPrimaryChangedDown(requireUnconsumed) || !event.keyboardModifiers.isShiftPressed)
//    return event.changes[0]
//}

fun BlueprintNodeStateHolder<*>.isInsideViewport(
    viewportSize : Size,
    maxWidth : Float,
    iconSize : Float,
    offset: Offset
) : Boolean {

    val pos = position.value

    val maxPins = pins.count(Pin<*>::isOut).let { x ->
        maxOf(x, pins.size - x)
    }

    val maxCons = connectors.count(Connector::isOut).let { x ->
        maxOf(x, connectors.size - x)
    }

    val approximateHeight = (maxPins + maxCons + 2) * iconSize


    return !(pos.x + maxWidth < -offset.x - viewportSize.width / 2 ||
            pos.x - iconSize > -offset.x + viewportSize.width / 2 ||
            pos.y + approximateHeight < -offset.y - viewportSize.height / 2 ||
            pos.y - iconSize > -offset.y + viewportSize.height / 2)
}

private fun CacheDrawScope.drawConnectorLines(
    cachedPath: Path,
    manager: BlueprintManager,
    halfFieldSizePx: Float
) : DrawResult {

    val connectorsPath = getConnectorLinesPath(cachedPath, manager.connectorLines)
    val connectorsStroke = Stroke(3f * density)

    return onDrawBehind {
        translate(halfFieldSizePx, halfFieldSizePx) {

            drawPath(
                color = Color.White,
                path = connectorsPath,
                style = connectorsStroke,
            )
        }
    }
}

private fun CacheDrawScope.drawPinLines(
    pathPool : MutableMap<Color, Path>,
    manager: BlueprintManager,
    halfFieldSizePx: Float
) : DrawResult {

    val pinLinesPath = getPinLinesPaths(pathPool, manager.pinLines)
    val pinStroke = Stroke(1f * density)

    return onDrawBehind {
        translate(halfFieldSizePx, halfFieldSizePx) {
            pinLinesPath.fastForEach {
                drawPath(
                    color = it.first,
                    path = it.second,
                    style = pinStroke,
                )
            }
        }
    }
}

private fun CacheDrawScope.drawField(
    gridVisible: Boolean,
    borderVisible: Boolean,
    cellSize: Dp,
    halfFieldSizePx : Float,
    gridColor : Color = Color.DarkGray.copy(alpha = .5f),
    borderColor : Color = Color.Red
) : DrawResult {

    val gridPaths = if (gridVisible)
        gridPaths(fieldSizePx = halfFieldSizePx, cellSize = cellSize)
    else null

    val gridStrokes = Triple(
        Stroke(15 * density),
        Stroke(5 * density),
        Stroke(1 * density),
    )

    return onDrawBehind {
        translate(halfFieldSizePx, halfFieldSizePx) {

            if (borderVisible) {
                drawRect(
                    borderColor,
                    Offset(-halfFieldSizePx, -halfFieldSizePx),
                    Size(halfFieldSizePx * 2, halfFieldSizePx * 2),
                    style = Stroke(width = 10.dp.toPx())
                )
            }
            gridPaths?.let {
                drawPath(it.first, gridColor, style = gridStrokes.first)
                drawPath(it.second, gridColor, style = gridStrokes.second)
                drawPath(it.third,gridColor, style = gridStrokes.third)
            }
        }
    }
}

private fun Density.getPinLinesPaths(
    pathPool : MutableMap<Color, Path>,
    pinLines: PinLinesMap
) : List<Pair<Color, Path>> {


    return pinLines.entries
        .groupBy { it.key.first.color }
        .map { (color, entry) ->
            color to pathPool.getOrPut(color) { Path() }.apply {
                rewind()
                entry.fastForEach { (k, v) ->
                    val (a, b) = if (k.first.isOut)
                        v.second.value to v.first.value
                    else
                        v.first.value to v.second.value

                    moveTo(a.x, a.y)


                    cubicTo(
                        x1 = if (a.x < b.x)
                            min(a.x + 100.dp.toPx(), b.x)
                        else a.x + 100.dp.toPx(), y1 = a.y,
                        x2 = if (a.x < b.x)
                            max(b.x - 100.dp.toPx(), a.x)
                        else b.x - 100.dp.toPx(), y2 = b.y,
                        x3 = b.x, y3 = b.y
                    )
                }
            }
        }
}


private fun Density.getConnectorLinesPath(
    cachedPath : Path,
    connectorLines: ConnectorLinesMap
) : Path = cachedPath.apply {
    rewind()
    connectorLines.forEach { (k, v) ->

        val (b, a) = if (k.first.isOut)
            v.second.value to v.first.value
        else
            v.first.value to v.second.value

        moveTo(a.x, a.y)

        cubicTo(
            x1 = if (a.x < b.x)
                min(a.x + 100.dp.toPx(), b.x)
            else a.x + 100.dp.toPx(), y1 = a.y,
            x2 = if (a.x < b.x)
                max(b.x - 100.dp.toPx(), a.x)
            else b.x - 100.dp.toPx(), y2 = b.y,
            x3 = b.x, y3 = b.y
        )
    }
}

private fun Density.gridPaths(
    fieldSizePx : Float,
    cellSize: Dp,
) : Triple<Path,Path,Path> {

    var n = 0

    val a = Path()
    val b = Path()
    val c = Path()

    fun path() = when {
        n % 50 == 0 -> a
        n % 10 == 0 -> b
        else -> c
    }

    for (i in (-fieldSizePx).roundToInt()..(fieldSizePx).roundToInt() step cellSize.roundToPx()) {
        n++
        val p = path()

        p.moveTo(-fieldSizePx, i.toFloat())
        p.lineTo(fieldSizePx, i.toFloat())
    }
    n = 0
    for (i in (-fieldSizePx).roundToInt()..(fieldSizePx).roundToInt() step cellSize.roundToPx()) {
        n++
        val p = path()
        p.moveTo(i.toFloat(), -fieldSizePx)
        p.lineTo(i.toFloat(), fieldSizePx)
    }
    return Triple(a, b, c)
}
