import { max } from 'd3-array';
import { sankey, sankeyLinkHorizontal } from 'd3-sankey';
import { scaleLinear } from 'd3-scale';
import clone from 'lodash.clonedeep';
import uniqWith from 'lodash.uniqwith';
import * as React from 'react';
import styled from 'styled-components';
import { modularScale } from '../../helpers/modular-scale';

const nodeWidth = 12;

const getProperty = (property) => (d) => d[property];
const getId = getProperty('id');
const getName = getProperty('name');

const Label = styled('text')`
  font-family: var(--roboto-condensed);
  font-size: ${modularScale(-1)};
  text-transform: uppercase;
  letter-spacing: 0.01em;
  line-height: 1;
`;

const AxisLabel = styled(Label)`
  text-transform: none;
`;

class FixedHorizontalNodesSankey extends React.Component {
  state = {
    nodes: [],
    links: [],
    x: scaleLinear(),
    margin: {
      top: 24,
      right: 100,
      bottom: 64,
      left: 64,
    },
  };

  static getDerivedStateFromProps(nextProps, { x, margin }) {
    if (nextProps.width === 0) return null;

    margin = {
      top: nextProps.height * 0.2,
      bottom: nextProps.height * 0.2,
      left: 60,
      right: nextProps.width * 0.4,
    };

    // Create clones because d3.sankey doesn't clone arrays before operations
    let stops = clone(nextProps.stops);
    let legs = clone(nextProps.legs);

    let destinationId = legs.find((l) => l.targetType === 'destination').target;
    let destination = stops.find((s) => getId(s) === destinationId);

    let destinationStops = [];

    legs.forEach((leg) => {
      if (leg.targetType === 'destination') {
        destinationStops.push({
          ...destination,
          id: `${getId(destination)}:${leg.travel_time}`,
        });

        leg.target = `${getId(destination)}:${leg.travel_time}`;
      }
    });

    destinationStops = uniqWith(destinationStops, (da, db) => da.id === db.id);

    stops = [
      ...stops.filter((stop) => getId(stop) !== destinationId),
      ...destinationStops,
    ];

    x.domain([0, max(legs, (d) => d.travel_time)])
      .range([0, nextProps.width - nodeWidth - margin.right - margin.left])
      .nice();

    let generator = sankey(x, 'travel_time')
      .nodeId(getId)
      .nodeWidth(nodeWidth)
      .nodePadding(nodeWidth)
      .size([
        nextProps.width - margin.right - margin.left - nodeWidth,
        nextProps.height - margin.top - margin.bottom,
      ]);

    let graph = generator({
      nodes: stops,
      links: legs.map((leg) => ({
        ...leg,
        value: leg.ratio,
      })),
    });

    // CUSTOM LAYOUT ALGORITHM
    graph.nodes.forEach((node) => {
      let type = node.targetLinks.length
        ? node.targetLinks[0].targetType
        : node.sourceLinks[0].sourceType;

      if (type === 'origin') {
        node.x0 = x(0);
        node.x1 = node.x0 + nodeWidth;
      }

      if (type === 'via' || type === 'destination') {
        node.x0 = x(node.targetLinks[0].travel_time);
        node.x1 = node.x0 + nodeWidth;
      }

      if (type === 'via') {
        node.y0 = node.targetLinks[0].y0 - node.targetLinks[0].width / 2;
        node.y1 = node.y0 + node.targetLinks[0].width;
      }

      node.type = type;
    });

    generator.update(graph);

    return {
      ...graph,
      margin,
      x,
    };
  }

  render() {
    const { width, height } = this.props;
    const { margin, nodes, links, x } = this.state;
    const linkGenerator = sankeyLinkHorizontal();

    return (
      <>
        <g className="y-axis" transform={`translate(0,0)`}>
          <Label
            x={margin.left - 30}
            y={height * 0.5}
            textAnchor={'middle'}
            transform={`rotate(90,${margin.left - 30},${height * 0.5})`}
            fontWeight={600}
          >
            % reizigers
          </Label>
          <line
            x1={margin.left - 10}
            x2={margin.left - 10}
            y1={margin.top}
            y2={height - margin.bottom}
            stroke={'black'}
            strokeWidth={1}
          />
        </g>
        <g
          className="x-axis"
          transform={`translate(${margin.left},${height - margin.bottom + 15})`}
        >
          <Label
            x={(width - margin.right - margin.left) / 2}
            y={50}
            textAnchor={'middle'}
            fontWeight={600}
          >
            Reistijd
          </Label>
          {x.ticks().map((tick, i) => (
            <g key={`tick-${tick}`}>
              <line
                x1={x(x.ticks()[0])}
                x2={x(x.ticks()[x.ticks().length - 1])}
                y1={0}
                y2={0}
                stroke={'black'}
                strokeWidth={1}
              />
              <line
                x1={x(tick)}
                x2={x(tick)}
                y1={0}
                y2={5}
                stroke={'black'}
                strokeWidth={1}
              />
              <AxisLabel
                x={x(tick)}
                y={8}
                dy={15}
                fill={'black'}
                textAnchor={x.ticks().length - 1 === i ? 'end' : 'middle'}
              >
                {tick}m
              </AxisLabel>
            </g>
          ))}
        </g>
        <g
          className="travelers"
          transform={`translate(${margin.left},${margin.top})`}
        >
          {links.map((link) =>
            link.target.type === 'destination' ? (
              <Label
                key={`label-${getId(link.source)}-${getId(link.target)}`}
                y={link.y1}
                fill={'#000'}
              >
                <tspan
                  x={x(link.travel_time)}
                  dx={nodeWidth + 20}
                  fontSize={'1.5em'}
                  fontWeight={600}
                  fill={'var(--red)'}
                >
                  {Math.round(link.ratio * 100)}%
                </tspan>
                <tspan x={x(link.travel_time)} dx={nodeWidth + 20} dy={18}>
                  reizigers
                </tspan>
              </Label>
            ) : null
          )}
        </g>
        <g
          className="sankey"
          transform={`translate(${margin.left},${margin.top})`}
        >
          <g className="links">
            {links.map((link) => (
              <path
                key={`${getId(link.source)}-${getId(link.target)}`}
                d={linkGenerator(link)}
                fill={'none'}
                stroke={'var(--red)'}
                strokeWidth={link.width}
                strokeOpacity={0.75}
              />
            ))}
          </g>
          <g className="labels">
            {nodes.map((node) => (
              <Label
                key={`label-${getId(node)}`}
                x={
                  node.x1 + 6 >
                  (width - margin.left - margin.right - nodeWidth) * 0.45
                    ? node.x0 - 6
                    : node.x1 + 6
                }
                y={Math.round((node.y1 + node.y0) / 2)}
                dy={5}
                fill={'white'}
                textAnchor={
                  node.x1 + 6 >
                  (width - margin.left - margin.right - nodeWidth) * 0.45
                    ? 'end'
                    : 'start'
                }
              >
                {getName(node)}
              </Label>
            ))}
          </g>
          <g className="nodes">
            {nodes.map((node) => (
              <rect
                key={getId(node)}
                x={node.x0}
                y={node.y0}
                height={node.y1 - node.y0}
                width={node.x1 - node.x0}
                fill={'white'}
              />
            ))}
          </g>
        </g>
      </>
    );
  }
}

export default FixedHorizontalNodesSankey;
