import plotly.graph_objs as go
import plotly.io as pio
from collections import namedtuple, OrderedDict
from ._special_inputs import IdentityMap, Constant, Range
from .trendline_functions import ols, lowess, rolling, expanding, ewm

from _plotly_utils.basevalidators import ColorscaleValidator
from plotly.colors import qualitative, sequential
import math
from packaging import version
import pandas as pd
import numpy as np

from plotly._subplots import (
    make_subplots,
    _set_trace_grid_reference,
    _subplot_type_for_trace_type,
)

pandas_2_2_0 = version.parse(pd.__version__) >= version.parse("2.2.0")

NO_COLOR = "px_no_color_constant"
trendline_functions = dict(
    lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols
)

# Declare all supported attributes, across all plot types
direct_attrables = (
    ["base", "x", "y", "z", "a", "b", "c", "r", "theta", "size", "x_start", "x_end"]
    + ["hover_name", "text", "names", "values", "parents", "wide_cross"]
    + ["ids", "error_x", "error_x_minus", "error_y", "error_y_minus", "error_z"]
    + ["error_z_minus", "lat", "lon", "locations", "animation_group"]
)
array_attrables = ["dimensions", "custom_data", "hover_data", "path", "wide_variable"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
renameable_group_attrables = [
    "color",  # renamed to marker.color or line.color in infer_config
    "symbol",  # renamed to marker.symbol in infer_config
    "line_dash",  # renamed to line.dash in infer_config
    "pattern_shape",  # renamed to marker.pattern.shape in infer_config
]
all_attrables = (
    direct_attrables + array_attrables + group_attrables + renameable_group_attrables
)

cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin]
cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour]


class PxDefaults(object):
    __slots__ = [
        "template",
        "width",
        "height",
        "color_discrete_sequence",
        "color_discrete_map",
        "color_continuous_scale",
        "symbol_sequence",
        "symbol_map",
        "line_dash_sequence",
        "line_dash_map",
        "pattern_shape_sequence",
        "pattern_shape_map",
        "size_max",
        "category_orders",
        "labels",
    ]

    def __init__(self):
        self.reset()

    def reset(self):
        self.template = None
        self.width = None
        self.height = None
        self.color_discrete_sequence = None
        self.color_discrete_map = {}
        self.color_continuous_scale = None
        self.symbol_sequence = None
        self.symbol_map = {}
        self.line_dash_sequence = None
        self.line_dash_map = {}
        self.pattern_shape_sequence = None
        self.pattern_shape_map = {}
        self.size_max = 20
        self.category_orders = {}
        self.labels = {}


defaults = PxDefaults()
del PxDefaults


MAPBOX_TOKEN = None


def set_mapbox_access_token(token):
    """
    Arguments:
        token: A Mapbox token to be used in `plotly.express.scatter_mapbox` and \
        `plotly.express.line_mapbox` figures. See \
        https://docs.mapbox.com/help/how-mapbox-works/access-tokens/ for more details
    """
    global MAPBOX_TOKEN
    MAPBOX_TOKEN = token


def get_trendline_results(fig):
    """
    Extracts fit statistics for trendlines (when applied to figures generated with
    the `trendline` argument set to `"ols"`).

    Arguments:
        fig: the output of a `plotly.express` charting call
    Returns:
        A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
        results objects, along with columns identifying the subset of the data the
        trendline was fit on.
    """
    return fig._px_trendlines


Mapping = namedtuple(
    "Mapping",
    [
        "show_in_trace_name",
        "grouper",
        "val_map",
        "sequence",
        "updater",
        "variable",
        "facet",
    ],
)
TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"])


def get_label(args, column):
    try:
        return args["labels"][column]
    except Exception:
        return column


def invert_label(args, column):
    """Invert mapping.
    Find key corresponding to value column in dict args["labels"].
    Returns `column` if the value does not exist.
    """
    reversed_labels = {value: key for (key, value) in args["labels"].items()}
    try:
        return reversed_labels[column]
    except Exception:
        return column


def _is_continuous(df, col_name):
    return df[col_name].dtype.kind in "ifc"


def get_decorated_label(args, column, role):
    original_label = label = get_label(args, column)
    if "histfunc" in args and (
        (role == "z")
        or (role == "x" and "orientation" in args and args["orientation"] == "h")
        or (role == "y" and "orientation" in args and args["orientation"] == "v")
    ):
        histfunc = args["histfunc"] or "count"
        if histfunc != "count":
            label = "%s of %s" % (histfunc, label)
        else:
            label = "count"

        if "histnorm" in args and args["histnorm"] is not None:
            if label == "count":
                label = args["histnorm"]
            else:
                histnorm = args["histnorm"]
                if histfunc == "sum":
                    if histnorm == "probability":
                        label = "%s of %s" % ("fraction", label)
                    elif histnorm == "percent":
                        label = "%s of %s" % (histnorm, label)
                    else:
                        label = "%s weighted by %s" % (histnorm, original_label)
                elif histnorm == "probability":
                    label = "%s of sum of %s" % ("fraction", label)
                elif histnorm == "percent":
                    label = "%s of sum of %s" % ("percent", label)
                else:
                    label = "%s of %s" % (histnorm, label)

        if "barnorm" in args and args["barnorm"] is not None:
            label = "%s (normalized as %s)" % (label, args["barnorm"])

    return label


def make_mapping(args, variable):
    if variable == "line_group" or variable == "animation_frame":
        return Mapping(
            show_in_trace_name=False,
            grouper=args[variable],
            val_map={},
            sequence=[""],
            variable=variable,
            updater=(lambda trace, v: v),
            facet=None,
        )
    if variable == "facet_row" or variable == "facet_col":
        letter = "x" if variable == "facet_col" else "y"
        return Mapping(
            show_in_trace_name=False,
            variable=letter,
            grouper=args[variable],
            val_map={},
            sequence=[i for i in range(1, 1000)],
            updater=(lambda trace, v: v),
            facet="row" if variable == "facet_row" else "col",
        )
    (parent, variable, *other_variables) = variable.split(".")
    vprefix = variable
    arg_name = variable
    if variable == "color":
        vprefix = "color_discrete"
    if variable == "dash":
        arg_name = "line_dash"
        vprefix = "line_dash"
    if variable in ["pattern", "shape"]:
        arg_name = "pattern_shape"
        vprefix = "pattern_shape"
    if args[vprefix + "_map"] == "identity":
        val_map = IdentityMap()
    else:
        val_map = args[vprefix + "_map"].copy()
    return Mapping(
        show_in_trace_name=True,
        variable=variable,
        grouper=args[arg_name],
        val_map=val_map,
        sequence=args[vprefix + "_sequence"],
        updater=lambda trace, v: trace.update(
            {parent: {".".join([variable] + other_variables): v}}
        ),
        facet=None,
    )


def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
    """Populates a dict with arguments to update trace

    Parameters
    ----------
    args : dict
        args to be used for the trace
    trace_spec : NamedTuple
        which kind of trace to be used (has constructor, marginal etc.
        attributes)
    trace_data : pandas DataFrame
        data
    mapping_labels : dict
        to be used for hovertemplate
    sizeref : float
        marker sizeref

    Returns
    -------
    trace_patch : dict
        dict to be used to update trace
    fit_results : dict
        fit information to be used for trendlines
    """
    if "line_close" in args and args["line_close"]:
        trace_data = pd.concat([trace_data, trace_data.iloc[:1]])
    trace_patch = trace_spec.trace_patch.copy() or {}
    fit_results = None
    hover_header = ""
    for attr_name in trace_spec.attrs:
        attr_value = args[attr_name]
        attr_label = get_decorated_label(args, attr_value, attr_name)
        if attr_name == "dimensions":
            dims = [
                (name, column)
                for (name, column) in trace_data.items()
                if ((not attr_value) or (name in attr_value))
                and (
                    trace_spec.constructor != go.Parcoords
                    or _is_continuous(args["data_frame"], name)
                )
                and (
                    trace_spec.constructor != go.Parcats
                    or (attr_value is not None and name in attr_value)
                    or len(args["data_frame"][name].unique())
                    <= args["dimensions_max_cardinality"]
                )
            ]
            trace_patch["dimensions"] = [
                dict(label=get_label(args, name), values=column)
                for (name, column) in dims
            ]
            if trace_spec.constructor == go.Splom:
                for d in trace_patch["dimensions"]:
                    d["axis"] = dict(matches=True)
                mapping_labels["%{xaxis.title.text}"] = "%{x}"
                mapping_labels["%{yaxis.title.text}"] = "%{y}"

        elif attr_value is not None:
            if attr_name == "size":
                if "marker" not in trace_patch:
                    trace_patch["marker"] = dict()
                trace_patch["marker"]["size"] = trace_data[attr_value]
                trace_patch["marker"]["sizemode"] = "area"
                trace_patch["marker"]["sizeref"] = sizeref
                mapping_labels[attr_label] = "%{marker.size}"
            elif attr_name == "marginal_x":
                if trace_spec.constructor == go.Histogram:
                    mapping_labels["count"] = "%{y}"
            elif attr_name == "marginal_y":
                if trace_spec.constructor == go.Histogram:
                    mapping_labels["count"] = "%{x}"
            elif attr_name == "trendline":
                if (
                    args["x"]
                    and args["y"]
                    and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
                ):
                    # sorting is bad but trace_specs with "trendline" have no other attrs
                    sorted_trace_data = trace_data.sort_values(by=args["x"])
                    y = sorted_trace_data[args["y"]].values
                    x = sorted_trace_data[args["x"]].values

                    if x.dtype.type == np.datetime64:
                        # convert to unix epoch seconds
                        x = x.astype(np.int64) / 10**9
                    elif x.dtype.type == np.object_:
                        try:
                            x = x.astype(np.float64)
                        except ValueError:
                            raise ValueError(
                                "Could not convert value of 'x' ('%s') into a numeric type. "
                                "If 'x' contains stringified dates, please convert to a datetime column."
                                % args["x"]
                            )
                    if y.dtype.type == np.object_:
                        try:
                            y = y.astype(np.float64)
                        except ValueError:
                            raise ValueError(
                                "Could not convert value of 'y' into a numeric type."
                            )

                    # preserve original values of "x" in case they're dates
                    # otherwise numpy/pandas can mess with the timezones
                    # NB this means trendline functions must output one-to-one with the input series
                    # i.e. we can't do resampling, because then the X values might not line up!
                    non_missing = np.logical_not(
                        np.logical_or(np.isnan(y), np.isnan(x))
                    )
                    trace_patch["x"] = sorted_trace_data[args["x"]][non_missing]
                    trendline_function = trendline_functions[attr_value]
                    y_out, hover_header, fit_results = trendline_function(
                        args["trendline_options"],
                        sorted_trace_data[args["x"]],
                        x,
                        y,
                        args["x"],
                        args["y"],
                        non_missing,
                    )
                    assert len(y_out) == len(
                        trace_patch["x"]
                    ), "missing-data-handling failure in trendline code"
                    trace_patch["y"] = y_out
                    mapping_labels[get_label(args, args["x"])] = "%{x}"
                    mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>"
            elif attr_name.startswith("error"):
                error_xy = attr_name[:7]
                arr = "arrayminus" if attr_name.endswith("minus") else "array"
                if error_xy not in trace_patch:
                    trace_patch[error_xy] = {}
                trace_patch[error_xy][arr] = trace_data[attr_value]
            elif attr_name == "custom_data":
                if len(attr_value) > 0:
                    # here we store a data frame in customdata, and it's serialized
                    # as a list of row lists, which is what we want
                    trace_patch["customdata"] = trace_data[attr_value]
            elif attr_name == "hover_name":
                if trace_spec.constructor not in [
                    go.Histogram,
                    go.Histogram2d,
                    go.Histogram2dContour,
                ]:
                    trace_patch["hovertext"] = trace_data[attr_value]
                    if hover_header == "":
                        hover_header = "<b>%{hovertext}</b><br><br>"
            elif attr_name == "hover_data":
                if trace_spec.constructor not in [
                    go.Histogram,
                    go.Histogram2d,
                    go.Histogram2dContour,
                ]:
                    hover_is_dict = isinstance(attr_value, dict)
                    customdata_cols = args.get("custom_data") or []
                    for col in attr_value:
                        if hover_is_dict and not attr_value[col]:
                            continue
                        if col in [
                            args.get("x"),
                            args.get("y"),
                            args.get("z"),
                            args.get("base"),
                        ]:
                            continue
                        try:
                            position = args["custom_data"].index(col)
                        except (ValueError, AttributeError, KeyError):
                            position = len(customdata_cols)
                            customdata_cols.append(col)
                        attr_label_col = get_decorated_label(args, col, None)
                        mapping_labels[attr_label_col] = "%%{customdata[%d]}" % (
                            position
                        )

                    if len(customdata_cols) > 0:
                        # here we store a data frame in customdata, and it's serialized
                        # as a list of row lists, which is what we want
                        trace_patch["customdata"] = trace_data[customdata_cols]
            elif attr_name == "color":
                if trace_spec.constructor in [
                    go.Choropleth,
                    go.Choroplethmap,
                    go.Choroplethmapbox,
                ]:
                    trace_patch["z"] = trace_data[attr_value]
                    trace_patch["coloraxis"] = "coloraxis1"
                    mapping_labels[attr_label] = "%{z}"
                elif trace_spec.constructor in [
                    go.Sunburst,
                    go.Treemap,
                    go.Icicle,
                    go.Pie,
                    go.Funnelarea,
                ]:
                    if "marker" not in trace_patch:
                        trace_patch["marker"] = dict()

                    if args.get("color_is_continuous"):
                        trace_patch["marker"]["colors"] = trace_data[attr_value]
                        trace_patch["marker"]["coloraxis"] = "coloraxis1"
                        mapping_labels[attr_label] = "%{color}"
                    else:
                        trace_patch["marker"]["colors"] = []
                        if args["color_discrete_map"] is not None:
                            mapping = args["color_discrete_map"].copy()
                        else:
                            mapping = {}
                        for cat in trace_data[attr_value]:
                            if mapping.get(cat) is None:
                                mapping[cat] = args["color_discrete_sequence"][
                                    len(mapping) % len(args["color_discrete_sequence"])
                                ]
                            trace_patch["marker"]["colors"].append(mapping[cat])
                else:
                    colorable = "marker"
                    if trace_spec.constructor in [go.Parcats, go.Parcoords]:
                        colorable = "line"
                    if colorable not in trace_patch:
                        trace_patch[colorable] = dict()
                    trace_patch[colorable]["color"] = trace_data[attr_value]
                    trace_patch[colorable]["coloraxis"] = "coloraxis1"
                    mapping_labels[attr_label] = "%%{%s.color}" % colorable
            elif attr_name == "animation_group":
                trace_patch["ids"] = trace_data[attr_value]
            elif attr_name == "locations":
                trace_patch[attr_name] = trace_data[attr_value]
                mapping_labels[attr_label] = "%{location}"
            elif attr_name == "values":
                trace_patch[attr_name] = trace_data[attr_value]
                _label = "value" if attr_label == "values" else attr_label
                mapping_labels[_label] = "%{value}"
            elif attr_name == "parents":
                trace_patch[attr_name] = trace_data[attr_value]
                _label = "parent" if attr_label == "parents" else attr_label
                mapping_labels[_label] = "%{parent}"
            elif attr_name == "ids":
                trace_patch[attr_name] = trace_data[attr_value]
                _label = "id" if attr_label == "ids" else attr_label
                mapping_labels[_label] = "%{id}"
            elif attr_name == "names":
                if trace_spec.constructor in [
                    go.Sunburst,
                    go.Treemap,
                    go.Icicle,
                    go.Pie,
                    go.Funnelarea,
                ]:
                    trace_patch["labels"] = trace_data[attr_value]
                    _label = "label" if attr_label == "names" else attr_label
                    mapping_labels[_label] = "%{label}"
                else:
                    trace_patch[attr_name] = trace_data[attr_value]
            else:
                trace_patch[attr_name] = trace_data[attr_value]
                mapping_labels[attr_label] = "%%{%s}" % attr_name
        elif (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or (
            trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour]
            and attr_name == "z"
        ):
            # ensure that stuff like "count" gets into the hoverlabel
            mapping_labels[attr_label] = "%%{%s}" % attr_name
    if trace_spec.constructor not in [go.Parcoords, go.Parcats]:
        # Modify mapping_labels according to hover_data keys
        # if hover_data is a dict
        mapping_labels_copy = OrderedDict(mapping_labels)
        if args["hover_data"] and isinstance(args["hover_data"], dict):
            for k, v in mapping_labels.items():
                # We need to invert the mapping here
                k_args = invert_label(args, k)
                if k_args in args["hover_data"]:
                    formatter = args["hover_data"][k_args][0]
                    if formatter:
                        if isinstance(formatter, str):
                            mapping_labels_copy[k] = v.replace("}", "%s}" % formatter)
                    else:
                        _ = mapping_labels_copy.pop(k)
        hover_lines = [k + "=" + v for k, v in mapping_labels_copy.items()]
        trace_patch["hovertemplate"] = hover_header + "<br>".join(hover_lines)
        trace_patch["hovertemplate"] += "<extra></extra>"
    return trace_patch, fit_results


def configure_axes(args, constructor, fig, orders):
    configurators = {
        go.Scatter3d: configure_3d_axes,
        go.Scatterternary: configure_ternary_axes,
        go.Scatterpolar: configure_polar_axes,
        go.Scatterpolargl: configure_polar_axes,
        go.Barpolar: configure_polar_axes,
        go.Scattermap: configure_map,
        go.Choroplethmap: configure_map,
        go.Densitymap: configure_map,
        go.Scattermapbox: configure_mapbox,
        go.Choroplethmapbox: configure_mapbox,
        go.Densitymapbox: configure_mapbox,
        go.Scattergeo: configure_geo,
        go.Choropleth: configure_geo,
    }
    for c in cartesians:
        configurators[c] = configure_cartesian_axes
    if constructor in configurators:
        configurators[constructor](args, fig, orders)


def set_cartesian_axis_opts(args, axis, letter, orders):
    log_key = "log_" + letter
    range_key = "range_" + letter
    if log_key in args and args[log_key]:
        axis["type"] = "log"
        if range_key in args and args[range_key]:
            axis["range"] = [math.log(r, 10) for r in args[range_key]]
    elif range_key in args and args[range_key]:
        axis["range"] = args[range_key]

    if args[letter] in orders:
        axis["categoryorder"] = "array"
        axis["categoryarray"] = (
            orders[args[letter]]
            if isinstance(axis, go.layout.XAxis)
            else list(reversed(orders[args[letter]]))  # top down for Y axis
        )


def configure_cartesian_marginal_axes(args, fig, orders):
    if "histogram" in [args["marginal_x"], args["marginal_y"]]:
        fig.layout["barmode"] = "overlay"

    nrows = len(fig._grid_ref)
    ncols = len(fig._grid_ref[0])

    # Set y-axis titles and axis options in the left-most column
    for yaxis in fig.select_yaxes(col=1):
        set_cartesian_axis_opts(args, yaxis, "y", orders)

    # Set x-axis titles and axis options in the bottom-most row
    for xaxis in fig.select_xaxes(row=1):
        set_cartesian_axis_opts(args, xaxis, "x", orders)

    # Configure axis ticks on marginal subplots
    if args["marginal_x"]:
        fig.update_yaxes(
            showticklabels=False, showline=False, ticks="", range=None, row=nrows
        )
        if args["template"].layout.yaxis.showgrid is None:
            fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
        if args["template"].layout.xaxis.showgrid is None:
            fig.update_xaxes(showgrid=True, row=nrows)

    if args["marginal_y"]:
        fig.update_xaxes(
            showticklabels=False, showline=False, ticks="", range=None, col=ncols
        )
        if args["template"].layout.xaxis.showgrid is None:
            fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
        if args["template"].layout.yaxis.showgrid is None:
            fig.update_yaxes(showgrid=True, col=ncols)

    # Add axis titles to non-marginal subplots
    y_title = get_decorated_label(args, args["y"], "y")
    if args["marginal_x"]:
        fig.update_yaxes(title_text=y_title, row=1, col=1)
    else:
        for row in range(1, nrows + 1):
            fig.update_yaxes(title_text=y_title, row=row, col=1)

    x_title = get_decorated_label(args, args["x"], "x")
    if args["marginal_y"]:
        fig.update_xaxes(title_text=x_title, row=1, col=1)
    else:
        for col in range(1, ncols + 1):
            fig.update_xaxes(title_text=x_title, row=1, col=col)

    # Configure axis type across all x-axes
    if "log_x" in args and args["log_x"]:
        fig.update_xaxes(type="log")

    # Configure axis type across all y-axes
    if "log_y" in args and args["log_y"]:
        fig.update_yaxes(type="log")

    # Configure matching and axis type for marginal y-axes
    matches_y = "y" + str(ncols + 1)
    if args["marginal_x"]:
        for row in range(2, nrows + 1, 2):
            fig.update_yaxes(matches=matches_y, type=None, row=row)

    if args["marginal_y"]:
        for col in range(2, ncols + 1, 2):
            fig.update_xaxes(matches="x2", type=None, col=col)


def configure_cartesian_axes(args, fig, orders):
    if ("marginal_x" in args and args["marginal_x"]) or (
        "marginal_y" in args and args["marginal_y"]
    ):
        configure_cartesian_marginal_axes(args, fig, orders)
        return

    # Set y-axis titles and axis options in the left-most column
    y_title = get_decorated_label(args, args["y"], "y")
    for yaxis in fig.select_yaxes(col=1):
        yaxis.update(title_text=y_title)
        set_cartesian_axis_opts(args, yaxis, "y", orders)

    # Set x-axis titles and axis options in the bottom-most row
    x_title = get_decorated_label(args, args["x"], "x")
    for xaxis in fig.select_xaxes(row=1):
        if "is_timeline" not in args:
            xaxis.update(title_text=x_title)
        set_cartesian_axis_opts(args, xaxis, "x", orders)

    # Configure axis type across all x-axes
    if "log_x" in args and args["log_x"]:
        fig.update_xaxes(type="log")

    # Configure axis type across all y-axes
    if "log_y" in args and args["log_y"]:
        fig.update_yaxes(type="log")

    if "is_timeline" in args:
        fig.update_xaxes(type="date")

    if "ecdfmode" in args:
        if args["orientation"] == "v":
            fig.update_yaxes(rangemode="tozero")
        else:
            fig.update_xaxes(rangemode="tozero")


def configure_ternary_axes(args, fig, orders):
    fig.update_ternaries(
        aaxis=dict(title_text=get_label(args, args["a"])),
        baxis=dict(title_text=get_label(args, args["b"])),
        caxis=dict(title_text=get_label(args, args["c"])),
    )


def configure_polar_axes(args, fig, orders):
    patch = dict(
        angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]),
        radialaxis=dict(),
    )

    for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]:
        if args[var] in orders:
            patch[axis]["categoryorder"] = "array"
            patch[axis]["categoryarray"] = orders[args[var]]

    radialaxis = patch["radialaxis"]
    if args["log_r"]:
        radialaxis["type"] = "log"
        if args["range_r"]:
            radialaxis["range"] = [math.log(x, 10) for x in args["range_r"]]
    else:
        if args["range_r"]:
            radialaxis["range"] = args["range_r"]

    if args["range_theta"]:
        patch["sector"] = args["range_theta"]
    fig.update_polars(patch)


def configure_3d_axes(args, fig, orders):
    patch = dict(
        xaxis=dict(title_text=get_label(args, args["x"])),
        yaxis=dict(title_text=get_label(args, args["y"])),
        zaxis=dict(title_text=get_label(args, args["z"])),
    )

    for letter in ["x", "y", "z"]:
        axis = patch[letter + "axis"]
        if args["log_" + letter]:
            axis["type"] = "log"
            if args["range_" + letter]:
                axis["range"] = [math.log(x, 10) for x in args["range_" + letter]]
        else:
            if args["range_" + letter]:
                axis["range"] = args["range_" + letter]
        if args[letter] in orders:
            axis["categoryorder"] = "array"
            axis["categoryarray"] = orders[args[letter]]
    fig.update_scenes(patch)


def configure_mapbox(args, fig, orders):
    center = args["center"]
    if not center and "lat" in args and "lon" in args:
        center = dict(
            lat=args["data_frame"][args["lat"]].mean(),
            lon=args["data_frame"][args["lon"]].mean(),
        )
    fig.update_mapboxes(
        accesstoken=MAPBOX_TOKEN,
        center=center,
        zoom=args["zoom"],
        style=args["mapbox_style"],
    )


def configure_map(args, fig, orders):
    center = args["center"]
    if not center and "lat" in args and "lon" in args:
        center = dict(
            lat=args["data_frame"][args["lat"]].mean(),
            lon=args["data_frame"][args["lon"]].mean(),
        )
    fig.update_maps(
        center=center,
        zoom=args["zoom"],
        style=args["map_style"],
    )


def configure_geo(args, fig, orders):
    fig.update_geos(
        center=args["center"],
        scope=args["scope"],
        fitbounds=args["fitbounds"],
        visible=args["basemap_visible"],
        projection=dict(type=args["projection"]),
    )


def configure_animation_controls(args, constructor, fig):
    def frame_args(duration):
        return {
            "frame": {"duration": duration, "redraw": constructor != go.Scatter},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
        }

    if "animation_frame" in args and args["animation_frame"] and len(fig.frames) > 1:
        fig.layout.updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(500)],
                        "label": "&#9654;",
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(0)],
                        "label": "&#9724;",
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "showactive": False,
                "type": "buttons",
                "x": 0.1,
                "xanchor": "right",
                "y": 0,
                "yanchor": "top",
            }
        ]
        fig.layout.sliders = [
            {
                "active": 0,
                "yanchor": "top",
                "xanchor": "left",
                "currentvalue": {
                    "prefix": get_label(args, args["animation_frame"]) + "="
                },
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(0)],
                        "label": f.name,
                        "method": "animate",
                    }
                    for f in fig.frames
                ],
            }
        ]


def make_trace_spec(args, constructor, attrs, trace_patch):
    if constructor in [go.Scatter, go.Scatterpolar]:
        if "render_mode" in args and (
            args["render_mode"] == "webgl"
            or (
                args["render_mode"] == "auto"
                and len(args["data_frame"]) > 1000
                and args.get("line_shape") != "spline"
                and args["animation_frame"] is None
            )
        ):
            if constructor == go.Scatter:
                constructor = go.Scattergl
                if "orientation" in trace_patch:
                    del trace_patch["orientation"]
            else:
                constructor = go.Scatterpolargl
    # Create base trace specification
    result = [TraceSpec(constructor, attrs, trace_patch, None)]

    # Add marginal trace specifications
    for letter in ["x", "y"]:
        if "marginal_" + letter in args and args["marginal_" + letter]:
            trace_spec = None
            axis_map = dict(
                xaxis="x1" if letter == "x" else "x2",
                yaxis="y1" if letter == "y" else "y2",
            )
            if args["marginal_" + letter] == "histogram":
                trace_spec = TraceSpec(
                    constructor=go.Histogram,
                    attrs=[letter, "marginal_" + letter],
                    trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map),
                    marginal=letter,
                )
            elif args["marginal_" + letter] == "violin":
                trace_spec = TraceSpec(
                    constructor=go.Violin,
                    attrs=[letter, "hover_name", "hover_data"],
                    trace_patch=dict(scalegroup=letter),
                    marginal=letter,
                )
            elif args["marginal_" + letter] == "box":
                trace_spec = TraceSpec(
                    constructor=go.Box,
                    attrs=[letter, "hover_name", "hover_data"],
                    trace_patch=dict(notched=True),
                    marginal=letter,
                )
            elif args["marginal_" + letter] == "rug":
                symbols = {"x": "line-ns-open", "y": "line-ew-open"}
                trace_spec = TraceSpec(
                    constructor=go.Box,
                    attrs=[letter, "hover_name", "hover_data"],
                    trace_patch=dict(
                        fillcolor="rgba(255,255,255,0)",
                        line={"color": "rgba(255,255,255,0)"},
                        boxpoints="all",
                        jitter=0,
                        hoveron="points",
                        marker={"symbol": symbols[letter]},
                    ),
                    marginal=letter,
                )
            if "color" in attrs or "color" not in args:
                if "marker" not in trace_spec.trace_patch:
                    trace_spec.trace_patch["marker"] = dict()
                first_default_color = args["color_continuous_scale"][0]
                trace_spec.trace_patch["marker"]["color"] = first_default_color
            result.append(trace_spec)

    # Add trendline trace specifications
    if args.get("trendline") and args.get("trendline_scope", "trace") == "trace":
        result.append(make_trendline_spec(args, constructor))
    return result


def make_trendline_spec(args, constructor):
    trace_spec = TraceSpec(
        constructor=go.Scattergl
        if constructor == go.Scattergl  # could be contour
        else go.Scatter,
        attrs=["trendline"],
        trace_patch=dict(mode="lines"),
        marginal=None,
    )
    if args["trendline_color_override"]:
        trace_spec.trace_patch["line"] = dict(color=args["trendline_color_override"])
    return trace_spec


def one_group(x):
    return ""


def apply_default_cascade(args):
    # first we apply px.defaults to unspecified args

    for param in defaults.__slots__:
        if param in args and args[param] is None:
            args[param] = getattr(defaults, param)

    # load the default template if set, otherwise "plotly"
    if args["template"] is None:
        if pio.templates.default is not None:
            args["template"] = pio.templates.default
        else:
            args["template"] = "plotly"

    try:
        # retrieve the actual template if we were given a name
        args["template"] = pio.templates[args["template"]]
    except Exception:
        # otherwise try to build a real template
        args["template"] = go.layout.Template(args["template"])

    # if colors not set explicitly or in px.defaults, defer to a template
    # if the template doesn't have one, we set some final fallback defaults
    if "color_continuous_scale" in args:
        if (
            args["color_continuous_scale"] is None
            and args["template"].layout.colorscale.sequential
        ):
            args["color_continuous_scale"] = [
                x[1] for x in args["template"].layout.colorscale.sequential
            ]
        if args["color_continuous_scale"] is None:
            args["color_continuous_scale"] = sequential.Viridis

    if "color_discrete_sequence" in args:
        if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
            args["color_discrete_sequence"] = args["template"].layout.colorway
        if args["color_discrete_sequence"] is None:
            args["color_discrete_sequence"] = qualitative.D3

    # if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
    # see if we can defer to template. If not, set reasonable defaults
    if "symbol_sequence" in args:
        if args["symbol_sequence"] is None and args["template"].data.scatter:
            args["symbol_sequence"] = [
                scatter.marker.symbol for scatter in args["template"].data.scatter
            ]
        if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
            args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]

    if "line_dash_sequence" in args:
        if args["line_dash_sequence"] is None and args["template"].data.scatter:
            args["line_dash_sequence"] = [
                scatter.line.dash for scatter in args["template"].data.scatter
            ]
        if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
            args["line_dash_sequence"] = [
                "solid",
                "dot",
                "dash",
                "longdash",
                "dashdot",
                "longdashdot",
            ]

    if "pattern_shape_sequence" in args:
        if args["pattern_shape_sequence"] is None and args["template"].data.bar:
            args["pattern_shape_sequence"] = [
                bar.marker.pattern.shape for bar in args["template"].data.bar
            ]
        if not args["pattern_shape_sequence"] or not any(
            args["pattern_shape_sequence"]
        ):
            args["pattern_shape_sequence"] = ["", "/", "\\", "x", "+", "."]


def _check_name_not_reserved(field_name, reserved_names):
    if field_name not in reserved_names:
        return field_name
    else:
        raise NameError(
            "A name conflict was encountered for argument '%s'. "
            "A column or index with name '%s' is ambiguous." % (field_name, field_name)
        )


def _get_reserved_col_names(args):
    """
    This function builds a list of columns of the data_frame argument used
    as arguments, either as str/int arguments or given as columns
    (pandas series type).
    """
    df = args["data_frame"]
    reserved_names = set()
    for field in args:
        if field not in all_attrables:
            continue
        names = args[field] if field in array_attrables else [args[field]]
        if names is None:
            continue
        for arg in names:
            if arg is None:
                continue
            elif isinstance(arg, str):  # no need to add ints since kw arg are not ints
                reserved_names.add(arg)
            elif isinstance(arg, pd.Series):
                arg_name = arg.name
                if arg_name and hasattr(df, arg_name):
                    in_df = arg is df[arg_name]
                    if in_df:
                        reserved_names.add(arg_name)
            elif arg is df.index and arg.name is not None:
                reserved_names.add(arg.name)

    return reserved_names


def _is_col_list(columns, arg):
    """Returns True if arg looks like it's a list of columns or references to columns
    in df_input, and False otherwise (in which case it's assumed to be a single column
    or reference to a column).
    """
    if arg is None or isinstance(arg, str) or isinstance(arg, int):
        return False
    if isinstance(arg, pd.MultiIndex):
        return False  # just to keep existing behaviour for now
    try:
        iter(arg)
    except TypeError:
        return False  # not iterable
    for c in arg:
        if isinstance(c, str) or isinstance(c, int):
            if columns is None or c not in columns:
                return False
        else:
            try:
                iter(c)
            except TypeError:
                return False  # not iterable
    return True


def _isinstance_listlike(x):
    """Returns True if x is an iterable which can be transformed into a pandas Series,
    False for the other types of possible values of a `hover_data` dict.
    A tuple of length 2 is a special case corresponding to a (format, data) tuple.
    """
    if (
        isinstance(x, str)
        or (isinstance(x, tuple) and len(x) == 2)
        or isinstance(x, bool)
        or x is None
    ):
        return False
    else:
        return True


def _escape_col_name(columns, col_name, extra):
    while columns is not None and (col_name in columns or col_name in extra):
        col_name = "_" + col_name
    return col_name


def to_unindexed_series(x, name=None):
    """
    assuming x is list-like or even an existing pd.Series, return a new pd.Series with
    no index, without extracting the data from an existing Series via numpy, which
    seems to mangle datetime columns. Stripping the index from existing pd.Series is
    required to get things to match up right in the new DataFrame we're building
    """
    return pd.Series(x, name=name).reset_index(drop=True)


def process_args_into_dataframe(args, wide_mode, var_name, value_name):
    """
    After this function runs, the `all_attrables` keys of `args` all contain only
    references to columns of `df_output`. This function handles the extraction of data
    from `args["attrable"]` and column-name-generation as appropriate, and adds the
    data to `df_output` and then replaces `args["attrable"]` with the appropriate
    reference.
    """

    df_input = args["data_frame"]
    df_provided = df_input is not None

    # we use a dict instead of a dataframe directly so that it doesn't cause
    # PerformanceWarning by pandas by repeatedly setting the columns.
    # a dict is used instead of a list as the columns needs to be overwritten.
    df_output = {}
    constants = {}
    ranges = []
    wide_id_vars = set()
    reserved_names = _get_reserved_col_names(args) if df_provided else set()

    # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords
    if "dimensions" in args and args["dimensions"] is None:
        if not df_provided:
            raise ValueError(
                "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument."
            )
        else:
            df_output = {col: series for col, series in df_input.items()}

    # hover_data is a dict
    hover_data_is_dict = (
        "hover_data" in args
        and args["hover_data"]
        and isinstance(args["hover_data"], dict)
    )
    # If dict, convert all values of hover_data to tuples to simplify processing
    if hover_data_is_dict:
        for k in args["hover_data"]:
            if _isinstance_listlike(args["hover_data"][k]):
                args["hover_data"][k] = (True, args["hover_data"][k])
            if not isinstance(args["hover_data"][k], tuple):
                args["hover_data"][k] = (args["hover_data"][k], None)
            if df_provided and args["hover_data"][k][1] is not None and k in df_input:
                raise ValueError(
                    "Ambiguous input: values for '%s' appear both in hover_data and data_frame"
                    % k
                )
    # Loop over possible arguments
    for field_name in all_attrables:
        # Massaging variables
        argument_list = (
            [args.get(field_name)]
            if field_name not in array_attrables
            else args.get(field_name)
        )
        # argument not specified, continue
        if argument_list is None or argument_list is [None]:
            continue
        # Argument name: field_name if the argument is not a list
        # Else we give names like ["hover_data_0, hover_data_1"] etc.
        field_list = (
            [field_name]
            if field_name not in array_attrables
            else [field_name + "_" + str(i) for i in range(len(argument_list))]
        )
        # argument_list and field_list ready, iterate over them
        # Core of the loop starts here
        for i, (argument, field) in enumerate(zip(argument_list, field_list)):
            length = len(df_output[next(iter(df_output))]) if len(df_output) else 0
            if argument is None:
                continue
            col_name = None
            # Case of multiindex
            if isinstance(argument, pd.MultiIndex):
                raise TypeError(
                    "Argument '%s' is a pandas MultiIndex. "
                    "pandas MultiIndex is not supported by plotly express "
                    "at the moment." % field
                )
            # ----------------- argument is a special value ----------------------
            if isinstance(argument, Constant) or isinstance(argument, Range):
                col_name = _check_name_not_reserved(
                    str(argument.label) if argument.label is not None else field,
                    reserved_names,
                )
                if isinstance(argument, Constant):
                    constants[col_name] = argument.value
                else:
                    ranges.append(col_name)
            # ----------------- argument is likely a col name ----------------------
            elif isinstance(argument, str) or not hasattr(argument, "__len__"):
                if (
                    field_name == "hover_data"
                    and hover_data_is_dict
                    and args["hover_data"][str(argument)][1] is not None
                ):
                    # hover_data has onboard data
                    # previously-checked to have no name-conflict with data_frame
                    col_name = str(argument)
                    real_argument = args["hover_data"][col_name][1]

                    if length and len(real_argument) != length:
                        raise ValueError(
                            "All arguments should have the same length. "
                            "The length of hover_data key `%s` is %d, whereas the "
                            "length of previously-processed arguments %s is %d"
                            % (
                                argument,
                                len(real_argument),
                                str(list(df_output.keys())),
                                length,
                            )
                        )
                    df_output[col_name] = to_unindexed_series(real_argument, col_name)
                elif not df_provided:
                    raise ValueError(
                        "String or int arguments are only possible when a "
                        "DataFrame or an array is provided in the `data_frame` "
                        "argument. No DataFrame was provided, but argument "
                        "'%s' is of type str or int." % field
                    )
                # Check validity of column name
                elif argument not in df_input.columns:
                    if wide_mode and argument in (value_name, var_name):
                        continue
                    else:
                        err_msg = (
                            "Value of '%s' is not the name of a column in 'data_frame'. "
                            "Expected one of %s but received: %s"
                            % (field, str(list(df_input.columns)), argument)
                        )
                        if argument == "index":
                            err_msg += "\n To use the index, pass it in directly as `df.index`."
                        raise ValueError(err_msg)
                elif length and len(df_input[argument]) != length:
                    raise ValueError(
                        "All arguments should have the same length. "
                        "The length of column argument `df[%s]` is %d, whereas the "
                        "length of  previously-processed arguments %s is %d"
                        % (
                            field,
                            len(df_input[argument]),
                            str(list(df_output.keys())),
                            length,
                        )
                    )
                else:
                    col_name = str(argument)
                    df_output[col_name] = to_unindexed_series(
                        df_input[argument], col_name
                    )
            # ----------------- argument is likely a column / array / list.... -------
            else:
                if df_provided and hasattr(argument, "name"):
                    if argument is df_input.index:
                        if argument.name is None or argument.name in df_input:
                            col_name = "index"
                        else:
                            col_name = argument.name
                        col_name = _escape_col_name(
                            df_input, col_name, [var_name, value_name]
                        )
                    else:
                        if (
                            argument.name is not None
                            and argument.name in df_input
                            and argument is df_input[argument.name]
                        ):
                            col_name = argument.name
                if col_name is None:  # numpy array, list...
                    col_name = _check_name_not_reserved(field, reserved_names)

                if length and len(argument) != length:
                    raise ValueError(
                        "All arguments should have the same length. "
                        "The length of argument `%s` is %d, whereas the "
                        "length of  previously-processed arguments %s is %d"
                        % (field, len(argument), str(list(df_output.keys())), length)
                    )
                df_output[str(col_name)] = to_unindexed_series(argument, str(col_name))

            # Finally, update argument with column name now that column exists
            assert col_name is not None, (
                "Data-frame processing failure, likely due to a internal bug. "
                "Please report this to "
                "https://github.com/plotly/plotly.py/issues/new and we will try to "
                "replicate and fix it."
            )
            if field_name not in array_attrables:
                args[field_name] = str(col_name)
            elif isinstance(args[field_name], dict):
                pass
            else:
                args[field_name][i] = str(col_name)
            if field_name != "wide_variable":
                wide_id_vars.add(str(col_name))

    length = len(df_output[next(iter(df_output))]) if len(df_output) else 0
    df_output.update(
        {col_name: to_unindexed_series(range(length), col_name) for col_name in ranges}
    )
    df_output.update(
        {
            # constant is single value. repeat by len to avoid creating NaN on concating
            col_name: to_unindexed_series([constants[col_name]] * length, col_name)
            for col_name in constants
        }
    )

    df_output = pd.DataFrame(df_output)
    return df_output, wide_id_vars


def build_dataframe(args, constructor):
    """
    Constructs a dataframe and modifies `args` in-place.

    The argument values in `args` can be either strings corresponding to
    existing columns of a dataframe, or data arrays (lists, numpy arrays,
    pandas columns, series).

    Parameters
    ----------
    args : OrderedDict
        arguments passed to the px function and subsequently modified
    constructor : graph_object trace class
        the trace type selected for this figure
    """

    # make copies of all the fields via dict() and list()
    for field in args:
        if field in array_attrables and args[field] is not None:
            if isinstance(args[field], dict):
                args[field] = dict(args[field])
            elif field in ["custom_data", "hover_data"] and isinstance(
                args[field], str
            ):
                args[field] = [args[field]]
            else:
                args[field] = list(args[field])

    # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
    df_provided = args["data_frame"] is not None
    needs_interchanging = False
    if df_provided and not isinstance(args["data_frame"], pd.DataFrame):
        if hasattr(args["data_frame"], "__dataframe__") and version.parse(
            pd.__version__
        ) >= version.parse("2.0.2"):
            import pandas.api.interchange

            df_not_pandas = args["data_frame"]
            args["data_frame"] = df_not_pandas.__dataframe__()
            # According interchange protocol: `def column_names(self) -> Iterable[str]:`
            # so this function can return for example a generator.
            # The easiest way is to convert `columns` to `pandas.Index` so that the
            # type is similar to the types in other code branches.
            columns = pd.Index(args["data_frame"].column_names())
            needs_interchanging = True
        elif hasattr(args["data_frame"], "to_pandas"):
            args["data_frame"] = args["data_frame"].to_pandas()
            columns = args["data_frame"].columns
        elif hasattr(args["data_frame"], "toPandas"):
            args["data_frame"] = args["data_frame"].toPandas()
            columns = args["data_frame"].columns
        elif hasattr(args["data_frame"], "to_pandas_df"):
            args["data_frame"] = args["data_frame"].to_pandas_df()
            columns = args["data_frame"].columns
        else:
            args["data_frame"] = pd.DataFrame(args["data_frame"])
            columns = args["data_frame"].columns
    elif df_provided:
        columns = args["data_frame"].columns
    else:
        columns = None

    df_input = args["data_frame"]

    # now we handle special cases like wide-mode or x-xor-y specification
    # by rearranging args to tee things up for process_args_into_dataframe to work
    no_x = args.get("x") is None
    no_y = args.get("y") is None
    wide_x = False if no_x else _is_col_list(columns, args["x"])
    wide_y = False if no_y else _is_col_list(columns, args["y"])

    wide_mode = False
    var_name = None  # will likely be "variable" in wide_mode
    wide_cross_name = None  # will likely be "index" in wide_mode
    value_name = None  # will likely be "value" in wide_mode
    hist2d_types = [go.Histogram2d, go.Histogram2dContour]
    hist1d_orientation = constructor == go.Histogram or "ecdfmode" in args
    if constructor in cartesians:
        if wide_x and wide_y:
            raise ValueError(
                "Cannot accept list of column references or list of columns for both `x` and `y`."
            )
        if df_provided and no_x and no_y:
            wide_mode = True
            if isinstance(columns, pd.MultiIndex):
                raise TypeError(
                    "Data frame columns is a pandas MultiIndex. "
                    "pandas MultiIndex is not supported by plotly express "
                    "at the moment."
                )
            args["wide_variable"] = list(columns)
            if isinstance(columns, pd.Index):
                var_name = columns.name
            else:
                var_name = None
            if var_name in [None, "value", "index"] or var_name in columns:
                var_name = "variable"
            if constructor == go.Funnel:
                wide_orientation = args.get("orientation") or "h"
            else:
                wide_orientation = args.get("orientation") or "v"
            args["orientation"] = wide_orientation
            args["wide_cross"] = None
        elif wide_x != wide_y:
            wide_mode = True
            args["wide_variable"] = args["y"] if wide_y else args["x"]
            if df_provided and args["wide_variable"] is columns:
                var_name = columns.name
            if isinstance(args["wide_variable"], pd.Index):
                args["wide_variable"] = list(args["wide_variable"])
            if var_name in [None, "value", "index"] or (
                df_provided and var_name in columns
            ):
                var_name = "variable"
            if hist1d_orientation:
                wide_orientation = "v" if wide_x else "h"
            else:
                wide_orientation = "v" if wide_y else "h"
            args["y" if wide_y else "x"] = None
            args["wide_cross"] = None
            if not no_x and not no_y:
                wide_cross_name = "__x__" if wide_y else "__y__"

    if wide_mode:
        value_name = _escape_col_name(columns, "value", [])
        var_name = _escape_col_name(columns, var_name, [])

    if needs_interchanging:
        try:
            if wide_mode or not hasattr(args["data_frame"], "select_columns_by_name"):
                args["data_frame"] = pd.api.interchange.from_dataframe(
                    args["data_frame"]
                )
            else:
                # Save precious resources by only interchanging columns that are
                # actually going to be plotted.
                necessary_columns = {
                    i for i in args.values() if isinstance(i, str) and i in columns
                }
                for field in args:
                    if args[field] is not None and field in array_attrables:
                        necessary_columns.update(i for i in args[field] if i in columns)
                columns = list(necessary_columns)
                args["data_frame"] = pd.api.interchange.from_dataframe(
                    args["data_frame"].select_columns_by_name(columns)
                )
        except (ImportError, NotImplementedError) as exc:
            # temporary workaround; developers of third-party libraries themselves
            # should try a different implementation, if available. For example:
            # def __dataframe__(self, ...):
            #   if not some_condition:
            #     self.to_pandas(...)
            if hasattr(df_not_pandas, "toPandas"):
                args["data_frame"] = df_not_pandas.toPandas()
            elif hasattr(df_not_pandas, "to_pandas_df"):
                args["data_frame"] = df_not_pandas.to_pandas_df()
            elif hasattr(df_not_pandas, "to_pandas"):
                args["data_frame"] = df_not_pandas.to_pandas()
            else:
                raise exc

    df_input = args["data_frame"]

    missing_bar_dim = None
    if (
        constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types
        and not hist1d_orientation
    ):
        if not wide_mode and (no_x != no_y):
            for ax in ["x", "y"]:
                if args.get(ax) is None:
                    args[ax] = df_input.index if df_provided else Range()
                    if constructor == go.Bar:
                        missing_bar_dim = ax
                    else:
                        if args["orientation"] is None:
                            args["orientation"] = "v" if ax == "x" else "h"
        if wide_mode and wide_cross_name is None:
            if no_x != no_y and args["orientation"] is None:
                args["orientation"] = "v" if no_x else "h"
            if df_provided:
                if isinstance(df_input.index, pd.MultiIndex):
                    raise TypeError(
                        "Data frame index is a pandas MultiIndex. "
                        "pandas MultiIndex is not supported by plotly express "
                        "at the moment."
                    )
                args["wide_cross"] = df_input.index
            else:
                args["wide_cross"] = Range(
                    label=_escape_col_name(df_input, "index", [var_name, value_name])
                )

    no_color = False
    if type(args.get("color")) == str and args["color"] == NO_COLOR:
        no_color = True
        args["color"] = None
    # now that things have been prepped, we do the systematic rewriting of `args`

    df_output, wide_id_vars = process_args_into_dataframe(
        args, wide_mode, var_name, value_name
    )

    # now that `df_output` exists and `args` contains only references, we complete
    # the special-case and wide-mode handling by further rewriting args and/or mutating
    # df_output

    count_name = _escape_col_name(df_output, "count", [var_name, value_name])
    if not wide_mode and missing_bar_dim and constructor == go.Bar:
        # now that we've populated df_output, we check to see if the non-missing
        # dimension is categorical: if so, then setting the missing dimension to a
        # constant 1 is a less-insane thing to do than setting it to the index by
        # default and we let the normal auto-orientation-code do its thing later
        other_dim = "x" if missing_bar_dim == "y" else "y"
        if not _is_continuous(df_output, args[other_dim]):
            args[missing_bar_dim] = count_name
            df_output[count_name] = 1
        else:
            # on the other hand, if the non-missing dimension is continuous, then we
            # can use this information to override the normal auto-orientation code
            if args["orientation"] is None:
                args["orientation"] = "v" if missing_bar_dim == "x" else "h"

    if constructor in hist2d_types:
        del args["orientation"]

    if wide_mode:
        # at this point, `df_output` is semi-long/semi-wide, but we know which columns
        # are which, so we melt it and reassign `args` to refer to the newly-tidy
        # columns, keeping track of various names and manglings set up above
        wide_value_vars = [c for c in args["wide_variable"] if c not in wide_id_vars]
        del args["wide_variable"]
        if wide_cross_name == "__x__":
            wide_cross_name = args["x"]
        elif wide_cross_name == "__y__":
            wide_cross_name = args["y"]
        else:
            wide_cross_name = args["wide_cross"]
        del args["wide_cross"]
        dtype = None
        for v in wide_value_vars:
            v_dtype = df_output[v].dtype.kind
            v_dtype = "number" if v_dtype in ["i", "f", "u"] else v_dtype
            if dtype is None:
                dtype = v_dtype
            elif dtype != v_dtype:
                raise ValueError(
                    "Plotly Express cannot process wide-form data with columns of different type."
                )
        df_output = df_output.melt(
            id_vars=wide_id_vars,
            value_vars=wide_value_vars,
            var_name=var_name,
            value_name=value_name,
        )
        assert len(df_output.columns) == len(set(df_output.columns)), (
            "Wide-mode name-inference failure, likely due to a internal bug. "
            "Please report this to "
            "https://github.com/plotly/plotly.py/issues/new and we will try to "
            "replicate and fix it."
        )
        df_output[var_name] = df_output[var_name].astype(str)
        orient_v = wide_orientation == "v"

        if hist1d_orientation:
            args["x" if orient_v else "y"] = value_name
            args["y" if orient_v else "x"] = wide_cross_name
            args["color"] = args["color"] or var_name
        elif constructor in [go.Scatter, go.Funnel] + hist2d_types:
            args["x" if orient_v else "y"] = wide_cross_name
            args["y" if orient_v else "x"] = value_name
            if constructor != go.Histogram2d:
                args["color"] = args["color"] or var_name
            if "line_group" in args:
                args["line_group"] = args["line_group"] or var_name
        elif constructor == go.Bar:
            if _is_continuous(df_output, value_name):
                args["x" if orient_v else "y"] = wide_cross_name
                args["y" if orient_v else "x"] = value_name
                args["color"] = args["color"] or var_name
            else:
                args["x" if orient_v else "y"] = value_name
                args["y" if orient_v else "x"] = count_name
                df_output[count_name] = 1
                args["color"] = args["color"] or var_name
        elif constructor in [go.Violin, go.Box]:
            args["x" if orient_v else "y"] = wide_cross_name or var_name
            args["y" if orient_v else "x"] = value_name

    if hist1d_orientation and constructor == go.Scatter:
        if args["x"] is not None and args["y"] is not None:
            args["histfunc"] = "sum"
        elif args["x"] is None:
            args["histfunc"] = None
            args["orientation"] = "h"
            args["x"] = count_name
            df_output[count_name] = 1
        else:
            args["histfunc"] = None
            args["orientation"] = "v"
            args["y"] = count_name
            df_output[count_name] = 1

    if no_color:
        args["color"] = None
    args["data_frame"] = df_output
    return args


def _check_dataframe_all_leaves(df):
    df_sorted = df.sort_values(by=list(df.columns))
    null_mask = df_sorted.isnull()
    df_sorted = df_sorted.astype(str)
    null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
    for null_row_index in null_indices:
        row = null_mask.iloc[null_row_index]
        i = np.nonzero(row.values)[0][0]
        if not row[i:].all():
            raise ValueError(
                "None entries cannot have not-None children",
                df_sorted.iloc[null_row_index],
            )
    df_sorted[null_mask] = ""
    row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1))
    for i, row in enumerate(row_strings[:-1]):
        if row_strings[i + 1] in row and (i + 1) in null_indices:
            raise ValueError(
                "Non-leaves rows are not permitted in the dataframe \n",
                df_sorted.iloc[i + 1],
                "is not a leaf.",
            )


def process_dataframe_hierarchy(args):
    """
    Build dataframe for sunburst, treemap, or icicle when the path argument is provided.
    """
    df = args["data_frame"]
    path = args["path"][::-1]
    _check_dataframe_all_leaves(df[path[::-1]])
    discrete_color = False

    new_path = []
    for col_name in path:
        new_col_name = col_name + "_path_copy"
        new_path.append(new_col_name)
        df[new_col_name] = df[col_name]
    path = new_path
    # ------------ Define aggregation functions --------------------------------

    def aggfunc_discrete(x):
        uniques = x.unique()
        if len(uniques) == 1:
            return uniques[0]
        else:
            return "(?)"

    agg_f = {}
    aggfunc_color = None
    if args["values"]:
        try:
            df[args["values"]] = pd.to_numeric(df[args["values"]])
        except ValueError:
            raise ValueError(
                "Column `%s` of `df` could not be converted to a numerical data type."
                % args["values"]
            )

        if args["color"]:
            if args["color"] == args["values"]:
                new_value_col_name = args["values"] + "_sum"
                df[new_value_col_name] = df[args["values"]]
                args["values"] = new_value_col_name
        count_colname = args["values"]
    else:
        # we need a count column for the first groupby and the weighted mean of color
        # trick to be sure the col name is unused: take the sum of existing names
        count_colname = (
            "count"
            if "count" not in df.columns
            else "".join([str(el) for el in list(df.columns)])
        )
        # we can modify df because it's a copy of the px argument
        df[count_colname] = 1
        args["values"] = count_colname
    agg_f[count_colname] = "sum"

    if args["color"]:
        if not _is_continuous(df, args["color"]):
            aggfunc_color = aggfunc_discrete
            discrete_color = True
        else:

            def aggfunc_continuous(x):
                return np.average(x, weights=df.loc[x.index, count_colname])

            aggfunc_color = aggfunc_continuous
        agg_f[args["color"]] = aggfunc_color

    #  Other columns (for color, hover_data, custom_data etc.)
    cols = list(set(df.columns).difference(path))
    for col in cols:  # for hover_data, custom_data etc.
        if col not in agg_f:
            agg_f[col] = aggfunc_discrete
    # Avoid collisions with reserved names - columns in the path have been copied already
    cols = list(set(cols) - set(["labels", "parent", "id"]))
    # ----------------------------------------------------------------------------
    df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols)
    #  Set column type here (useful for continuous vs discrete colorscale)
    for col in cols:
        df_all_trees[col] = df_all_trees[col].astype(df[col].dtype)
    for i, level in enumerate(path):
        df_tree = pd.DataFrame(columns=df_all_trees.columns)
        dfg = df.groupby(path[i:]).agg(agg_f)
        dfg = dfg.reset_index()
        # Path label massaging
        df_tree["labels"] = dfg[level].copy().astype(str)
        df_tree["parent"] = ""
        df_tree["id"] = dfg[level].copy().astype(str)
        if i < len(path) - 1:
            j = i + 1
            while j < len(path):
                df_tree["parent"] = (
                    dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"]
                )
                df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"]
                j += 1

        df_tree["parent"] = df_tree["parent"].str.rstrip("/")
        if cols:
            df_tree[cols] = dfg[cols]
        df_all_trees = pd.concat([df_all_trees, df_tree], ignore_index=True)

    # we want to make sure than (?) is the first color of the sequence
    if args["color"] and discrete_color:
        sort_col_name = "sort_color_if_discrete_color"
        while sort_col_name in df_all_trees.columns:
            sort_col_name += "0"
        df_all_trees[sort_col_name] = df[args["color"]].astype(str)
        df_all_trees = df_all_trees.sort_values(by=sort_col_name)

    # Now modify arguments
    args["data_frame"] = df_all_trees
    args["path"] = None
    args["ids"] = "id"
    args["names"] = "labels"
    args["parents"] = "parent"
    if args["color"]:
        if not args["hover_data"]:
            args["hover_data"] = [args["color"]]
        elif isinstance(args["hover_data"], dict):
            if not args["hover_data"].get(args["color"]):
                args["hover_data"][args["color"]] = (True, None)
        else:
            args["hover_data"].append(args["color"])
    return args


def process_dataframe_timeline(args):
    """
    Massage input for bar traces for px.timeline()
    """
    args["is_timeline"] = True
    if args["x_start"] is None or args["x_end"] is None:
        raise ValueError("Both x_start and x_end are required")

    try:
        x_start = pd.to_datetime(args["data_frame"][args["x_start"]])
        x_end = pd.to_datetime(args["data_frame"][args["x_end"]])
    except (ValueError, TypeError):
        raise TypeError(
            "Both x_start and x_end must refer to data convertible to datetimes."
        )

    # note that we are not adding any columns to the data frame here, so no risk of overwrite
    args["data_frame"][args["x_end"]] = (x_end - x_start).astype(
        "timedelta64[ns]"
    ) / np.timedelta64(1, "ms")
    args["x"] = args["x_end"]
    del args["x_end"]
    args["base"] = args["x_start"]
    del args["x_start"]
    return args


def process_dataframe_pie(args, trace_patch):
    names = args.get("names")
    if names is None:
        return args, trace_patch
    order_in = args["category_orders"].get(names, {}).copy()
    if not order_in:
        return args, trace_patch
    df = args["data_frame"]
    trace_patch["sort"] = False
    trace_patch["direction"] = "clockwise"
    uniques = list(df[names].unique())
    order = [x for x in OrderedDict.fromkeys(list(order_in) + uniques) if x in uniques]
    args["data_frame"] = df.set_index(names).loc[order].reset_index()
    return args, trace_patch


def infer_config(args, constructor, trace_patch, layout_patch):
    attrs = [k for k in direct_attrables + array_attrables if k in args]
    grouped_attrs = []

    # Compute sizeref
    sizeref = 0
    if "size" in args and args["size"]:
        sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2

    # Compute color attributes and grouping attributes
    if "color" in args:
        if "color_continuous_scale" in args:
            if "color_discrete_sequence" not in args:
                attrs.append("color")
            else:
                if args["color"] and _is_continuous(args["data_frame"], args["color"]):
                    attrs.append("color")
                    args["color_is_continuous"] = True
                elif constructor in [go.Sunburst, go.Treemap, go.Icicle]:
                    attrs.append("color")
                    args["color_is_continuous"] = False
                else:
                    grouped_attrs.append("marker.color")
        elif "line_group" in args or constructor == go.Histogram2dContour:
            grouped_attrs.append("line.color")
        elif constructor in [go.Pie, go.Funnelarea]:
            attrs.append("color")
            if args["color"]:
                if args["hover_data"] is None:
                    args["hover_data"] = []
                args["hover_data"].append(args["color"])
        else:
            grouped_attrs.append("marker.color")

        show_colorbar = bool(
            "color" in attrs
            and args["color"]
            and constructor not in [go.Pie, go.Funnelarea]
            and (
                constructor not in [go.Treemap, go.Sunburst, go.Icicle]
                or args.get("color_is_continuous")
            )
        )
    else:
        show_colorbar = False

    if "line_dash" in args:
        grouped_attrs.append("line.dash")

    if "symbol" in args:
        grouped_attrs.append("marker.symbol")

    if "pattern_shape" in args:
        if constructor in [go.Scatter]:
            grouped_attrs.append("fillpattern.shape")
        else:
            grouped_attrs.append("marker.pattern.shape")

    if "orientation" in args:
        has_x = args["x"] is not None
        has_y = args["y"] is not None
        if args["orientation"] is None:
            if constructor in [go.Histogram, go.Scatter]:
                if has_y and not has_x:
                    args["orientation"] = "h"
            elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]:
                if has_x and not has_y:
                    args["orientation"] = "h"

        if args["orientation"] is None and has_x and has_y:
            x_is_continuous = _is_continuous(args["data_frame"], args["x"])
            y_is_continuous = _is_continuous(args["data_frame"], args["y"])
            if x_is_continuous and not y_is_continuous:
                args["orientation"] = "h"
            if y_is_continuous and not x_is_continuous:
                args["orientation"] = "v"

        if args["orientation"] is None:
            args["orientation"] = "v"

        if constructor == go.Histogram:
            if has_x and has_y and args["histfunc"] is None:
                args["histfunc"] = trace_patch["histfunc"] = "sum"

            orientation = args["orientation"]
            nbins = args["nbins"]
            trace_patch["nbinsx"] = nbins if orientation == "v" else None
            trace_patch["nbinsy"] = None if orientation == "v" else nbins
            trace_patch["bingroup"] = "x" if orientation == "v" else "y"
        trace_patch["orientation"] = args["orientation"]

        if constructor in [go.Violin, go.Box]:
            mode = "boxmode" if constructor == go.Box else "violinmode"
            if layout_patch[mode] is None and args["color"] is not None:
                if args["y"] == args["color"] and args["orientation"] == "h":
                    layout_patch[mode] = "overlay"
                elif args["x"] == args["color"] and args["orientation"] == "v":
                    layout_patch[mode] = "overlay"
            if layout_patch[mode] is None:
                layout_patch[mode] = "group"

    if (
        constructor == go.Histogram2d
        and args["z"] is not None
        and args["histfunc"] is None
    ):
        args["histfunc"] = trace_patch["histfunc"] = "sum"

    if args.get("text_auto", False) is not False:
        if constructor in [go.Histogram2d, go.Histogram2dContour]:
            letter = "z"
        elif constructor == go.Bar:
            letter = "y" if args["orientation"] == "v" else "x"
        else:
            letter = "value"
        if args["text_auto"] is True:
            trace_patch["texttemplate"] = "%{" + letter + "}"
        else:
            trace_patch["texttemplate"] = "%{" + letter + ":" + args["text_auto"] + "}"

    if constructor in [go.Histogram2d, go.Densitymap, go.Densitymapbox]:
        show_colorbar = True
        trace_patch["coloraxis"] = "coloraxis1"

    if "opacity" in args:
        if args["opacity"] is None:
            if "barmode" in args and args["barmode"] == "overlay":
                trace_patch["marker"] = dict(opacity=0.5)
        elif constructor in [
            go.Densitymap,
            go.Densitymapbox,
            go.Pie,
            go.Funnel,
            go.Funnelarea,
        ]:
            trace_patch["opacity"] = args["opacity"]
        else:
            trace_patch["marker"] = dict(opacity=args["opacity"])
    if (
        "line_group" in args or "line_dash" in args
    ):  # px.line, px.line_*, px.area, px.ecdf
        modes = set()
        if args.get("lines", True):
            modes.add("lines")
        if args.get("text") or args.get("symbol") or args.get("markers"):
            modes.add("markers")
        if args.get("text"):
            modes.add("text")
        if len(modes) == 0:
            modes.add("lines")
        trace_patch["mode"] = "+".join(sorted(modes))
    elif constructor != go.Splom and (
        "symbol" in args or constructor in [go.Scattermap, go.Scattermapbox]
    ):
        trace_patch["mode"] = "markers" + ("+text" if args["text"] else "")

    if "line_shape" in args:
        trace_patch["line"] = dict(shape=args["line_shape"])
    elif "ecdfmode" in args:
        trace_patch["line"] = dict(
            shape="vh" if args["ecdfmode"] == "reversed" else "hv"
        )

    if "geojson" in args:
        trace_patch["featureidkey"] = args["featureidkey"]
        trace_patch["geojson"] = (
            args["geojson"]
            if not hasattr(args["geojson"], "__geo_interface__")  # for geopandas
            else args["geojson"].__geo_interface__
        )

    # Compute marginal attribute: copy to appropriate marginal_*
    if "marginal" in args:
        position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
        other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y"
        args[position] = args["marginal"]
        args[other_position] = None

    # Ignore facet rows and columns when data frame is empty so as to prevent nrows/ncols equaling 0
    if len(args["data_frame"]) == 0:
        args["facet_row"] = args["facet_col"] = None

    # If both marginals and faceting are specified, faceting wins
    if args.get("facet_col") is not None and args.get("marginal_y") is not None:
        args["marginal_y"] = None

    if args.get("facet_row") is not None and args.get("marginal_x") is not None:
        args["marginal_x"] = None

    # facet_col_wrap only works if no marginals or row faceting is used
    if (
        args.get("marginal_x") is not None
        or args.get("marginal_y") is not None
        or args.get("facet_row") is not None
    ):
        args["facet_col_wrap"] = 0

    if "trendline" in args and args["trendline"] is not None:
        if args["trendline"] not in trendline_functions:
            raise ValueError(
                "Value '%s' for `trendline` must be one of %s"
                % (args["trendline"], trendline_functions.keys())
            )

    if "trendline_options" in args and args["trendline_options"] is None:
        args["trendline_options"] = dict()

    if "ecdfnorm" in args:
        if args.get("ecdfnorm", None) not in [None, "percent", "probability"]:
            raise ValueError(
                "`ecdfnorm` must be one of None, 'percent' or 'probability'. "
                + "'%s' was provided." % args["ecdfnorm"]
            )
        args["histnorm"] = args["ecdfnorm"]

    # Compute applicable grouping attributes
    for k in group_attrables:
        if k in args:
            grouped_attrs.append(k)

    # Create grouped mappings
    grouped_mappings = [make_mapping(args, a) for a in grouped_attrs]

    # Create trace specs
    trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
    return trace_specs, grouped_mappings, sizeref, show_colorbar


def get_groups_and_orders(args, grouper):
    """
    `orders` is the user-supplied ordering with the remaining data-frame-supplied
    ordering appended if the column is used for grouping. It includes anything the user
    gave, for any variable, including values not present in the dataset. It's a dict
    where the keys are e.g. "x" or "color"

    `groups` is the dicts of groups, ordered by the order above. Its keys are
    tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
    of a single dimension-group
    """
    orders = {} if "category_orders" not in args else args["category_orders"].copy()

    # figure out orders and what the single group name would be if there were one
    single_group_name = []
    unique_cache = dict()
    for col in grouper:
        if col == one_group:
            single_group_name.append("")
        else:
            if col not in unique_cache:
                unique_cache[col] = list(args["data_frame"][col].unique())
            uniques = unique_cache[col]
            if len(uniques) == 1:
                single_group_name.append(uniques[0])
            if col not in orders:
                orders[col] = uniques
            else:
                orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
    df = args["data_frame"]
    if len(single_group_name) == len(grouper):
        # we have a single group, so we can skip all group-by operations!
        groups = {tuple(single_group_name): df}
    else:
        required_grouper = [g for g in grouper if g != one_group]
        grouped = df.groupby(
            required_grouper, sort=False, observed=True
        )  # skip one_group groupers
        group_indices = grouped.indices
        sorted_group_names = [
            g if len(required_grouper) != 1 else (g,) for g in group_indices
        ]

        for i, col in reversed(list(enumerate(required_grouper))):
            sorted_group_names = sorted(
                sorted_group_names,
                key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
            )

        # calculate the full group_names by inserting "" in the tuple index for one_group groups
        full_sorted_group_names = [list(t) for t in sorted_group_names]
        for i, col in enumerate(grouper):
            if col == one_group:
                for g in full_sorted_group_names:
                    g.insert(i, "")
        full_sorted_group_names = [tuple(g) for g in full_sorted_group_names]

        groups = {}
        for sf, s in zip(full_sorted_group_names, sorted_group_names):
            if len(s) > 1:
                groups[sf] = grouped.get_group(s)
            else:
                if pandas_2_2_0:
                    groups[sf] = grouped.get_group((s[0],))
                else:
                    groups[sf] = grouped.get_group(s[0])
    return groups, orders


def make_figure(args, constructor, trace_patch=None, layout_patch=None):
    trace_patch = trace_patch or {}
    layout_patch = layout_patch or {}
    apply_default_cascade(args)

    args = build_dataframe(args, constructor)
    if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None:
        args = process_dataframe_hierarchy(args)
    if constructor in [go.Pie]:
        args, trace_patch = process_dataframe_pie(args, trace_patch)
    if constructor == "timeline":
        constructor = go.Bar
        args = process_dataframe_timeline(args)

    trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
        args, constructor, trace_patch, layout_patch
    )
    grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
    groups, orders = get_groups_and_orders(args, grouper)

    col_labels = []
    row_labels = []
    nrows = ncols = 1
    for m in grouped_mappings:
        if m.grouper not in orders:
            m.val_map[""] = m.sequence[0]
        else:
            sorted_values = orders[m.grouper]
            if m.facet == "col":
                prefix = get_label(args, args["facet_col"]) + "="
                col_labels = [prefix + str(s) for s in sorted_values]
                ncols = len(col_labels)
            if m.facet == "row":
                prefix = get_label(args, args["facet_row"]) + "="
                row_labels = [prefix + str(s) for s in sorted_values]
                nrows = len(row_labels)
            for val in sorted_values:
                if val not in m.val_map:  # always False if it's an IdentityMap
                    m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]

    subplot_type = _subplot_type_for_trace_type(constructor().type)

    trace_names_by_frame = {}
    frames = OrderedDict()
    trendline_rows = []
    trace_name_labels = None
    facet_col_wrap = args.get("facet_col_wrap", 0)
    for group_name, group in groups.items():
        mapping_labels = OrderedDict()
        trace_name_labels = OrderedDict()
        frame_name = ""
        for col, val, m in zip(grouper, group_name, grouped_mappings):
            if col != one_group:
                key = get_label(args, col)
                if not isinstance(m.val_map, IdentityMap):
                    mapping_labels[key] = str(val)
                    if m.show_in_trace_name:
                        trace_name_labels[key] = str(val)
                if m.variable == "animation_frame":
                    frame_name = val
        trace_name = ", ".join(trace_name_labels.values())
        if frame_name not in trace_names_by_frame:
            trace_names_by_frame[frame_name] = set()
        trace_names = trace_names_by_frame[frame_name]

        for trace_spec in trace_specs:
            # Create the trace
            trace = trace_spec.constructor(name=trace_name)
            if trace_spec.constructor not in [
                go.Parcats,
                go.Parcoords,
                go.Choropleth,
                go.Choroplethmap,
                go.Choroplethmapbox,
                go.Densitymap,
                go.Densitymapbox,
                go.Histogram2d,
                go.Sunburst,
                go.Treemap,
                go.Icicle,
            ]:
                trace.update(
                    legendgroup=trace_name,
                    showlegend=(trace_name != "" and trace_name not in trace_names),
                )
            if trace_spec.constructor in [go.Bar, go.Violin, go.Box, go.Histogram]:
                trace.update(alignmentgroup=True, offsetgroup=trace_name)
            trace_names.add(trace_name)

            # Init subplot row/col
            trace._subplot_row = 1
            trace._subplot_col = 1

            for i, m in enumerate(grouped_mappings):
                val = group_name[i]
                try:
                    m.updater(trace, m.val_map[val])  # covers most cases
                except ValueError:
                    # this catches some odd cases like marginals
                    if (
                        trace_spec != trace_specs[0]
                        and (
                            trace_spec.constructor in [go.Violin, go.Box]
                            and m.variable in ["symbol", "pattern", "dash"]
                        )
                        or (
                            trace_spec.constructor in [go.Histogram]
                            and m.variable in ["symbol", "dash"]
                        )
                    ):
                        pass
                    elif (
                        trace_spec != trace_specs[0]
                        and trace_spec.constructor in [go.Histogram]
                        and m.variable == "color"
                    ):
                        trace.update(marker=dict(color=m.val_map[val]))
                    elif (
                        trace_spec.constructor
                        in [go.Choropleth, go.Choroplethmap, go.Choroplethmapbox]
                        and m.variable == "color"
                    ):
                        trace.update(
                            z=[1] * len(group),
                            colorscale=[m.val_map[val]] * 2,
                            showscale=False,
                            showlegend=True,
                        )
                    else:
                        raise

                # Find row for trace, handling facet_row and marginal_x
                if m.facet == "row":
                    row = m.val_map[val]
                else:
                    if (
                        args.get("marginal_x") is not None  # there is a marginal
                        and trace_spec.marginal != "x"  # and we're not it
                    ):
                        row = 2
                    else:
                        row = 1

                # Find col for trace, handling facet_col and marginal_y
                if m.facet == "col":
                    col = m.val_map[val]
                    if facet_col_wrap:  # assumes no facet_row, no marginals
                        row = 1 + ((col - 1) // facet_col_wrap)
                        col = 1 + ((col - 1) % facet_col_wrap)
                else:
                    if trace_spec.marginal == "y":
                        col = 2
                    else:
                        col = 1

                if row > 1:
                    trace._subplot_row = row

                if col > 1:
                    trace._subplot_col = col
            if (
                trace_specs[0].constructor == go.Histogram2dContour
                and trace_spec.constructor == go.Box
                and trace.line.color
            ):
                trace.update(marker=dict(color=trace.line.color))

            if "ecdfmode" in args:
                base = args["x"] if args["orientation"] == "v" else args["y"]
                var = args["x"] if args["orientation"] == "h" else args["y"]
                ascending = args.get("ecdfmode", "standard") != "reversed"
                group = group.sort_values(by=base, ascending=ascending)
                group_sum = group[var].sum()  # compute here before next line mutates
                group[var] = group[var].cumsum()
                if not ascending:
                    group = group.sort_values(by=base, ascending=True)

                if args.get("ecdfmode", "standard") == "complementary":
                    group[var] = group_sum - group[var]

                if args["ecdfnorm"] == "probability":
                    group[var] = group[var] / group_sum
                elif args["ecdfnorm"] == "percent":
                    group[var] = 100.0 * group[var] / group_sum

            patch, fit_results = make_trace_kwargs(
                args, trace_spec, group, mapping_labels.copy(), sizeref
            )
            trace.update(patch)
            if fit_results is not None:
                trendline_rows.append(mapping_labels.copy())
                trendline_rows[-1]["px_fit_results"] = fit_results
            if frame_name not in frames:
                frames[frame_name] = dict(data=[], name=frame_name)
            frames[frame_name]["data"].append(trace)
    frame_list = [f for f in frames.values()]
    if len(frame_list) > 1:
        frame_list = sorted(
            frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"])
        )

    if show_colorbar:
        colorvar = (
            "z"
            if constructor in [go.Histogram2d, go.Densitymap, go.Densitymapbox]
            else "color"
        )
        range_color = args["range_color"] or [None, None]

        colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
        layout_patch["coloraxis1"] = dict(
            colorscale=colorscale_validator.validate_coerce(
                args["color_continuous_scale"]
            ),
            cmid=args["color_continuous_midpoint"],
            cmin=range_color[0],
            cmax=range_color[1],
            colorbar=dict(
                title_text=get_decorated_label(args, args[colorvar], colorvar)
            ),
        )
    for v in ["height", "width"]:
        if args[v]:
            layout_patch[v] = args[v]
    layout_patch["legend"] = dict(tracegroupgap=0)
    if trace_name_labels:
        layout_patch["legend"]["title_text"] = ", ".join(trace_name_labels)
    if args["title"]:
        layout_patch["title_text"] = args["title"]
    elif args["template"].layout.margin.t is None:
        layout_patch["margin"] = {"t": 60}
    if (
        "size" in args
        and args["size"]
        and args["template"].layout.legend.itemsizing is None
    ):
        layout_patch["legend"]["itemsizing"] = "constant"

    if facet_col_wrap:
        nrows = math.ceil(ncols / facet_col_wrap)
        ncols = min(ncols, facet_col_wrap)

    if args.get("marginal_x") is not None:
        nrows += 1

    if args.get("marginal_y") is not None:
        ncols += 1

    fig = init_figure(
        args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
    )

    # Position traces in subplots
    for frame in frame_list:
        for trace in frame["data"]:
            if isinstance(trace, go.Splom):
                # Special case that is not compatible with make_subplots
                continue

            _set_trace_grid_reference(
                trace,
                fig.layout,
                fig._grid_ref,
                nrows - trace._subplot_row + 1,
                trace._subplot_col,
            )

    # Add traces, layout and frames to figure
    fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
    fig.update_layout(layout_patch)
    if "template" in args and args["template"] is not None:
        fig.update_layout(template=args["template"], overwrite=True)
    for f in frame_list:
        f["name"] = str(f["name"])
    fig.frames = frame_list if len(frames) > 1 else []

    if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":
        trendline_spec = make_trendline_spec(args, constructor)
        trendline_trace = trendline_spec.constructor(
            name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False
        )
        if "line" not in trendline_spec.trace_patch:  # no color override
            for m in grouped_mappings:
                if m.variable == "color":
                    next_color = m.sequence[len(m.val_map) % len(m.sequence)]
                    trendline_spec.trace_patch["line"] = dict(color=next_color)
        patch, fit_results = make_trace_kwargs(
            args, trendline_spec, args["data_frame"], {}, sizeref
        )
        trendline_trace.update(patch)
        fig.add_trace(
            trendline_trace, row="all", col="all", exclude_empty_subplots=True
        )
        fig.update_traces(selector=-1, showlegend=True)
        if fit_results is not None:
            trendline_rows.append(dict(px_fit_results=fit_results))

    fig._px_trendlines = pd.DataFrame(trendline_rows)

    configure_axes(args, constructor, fig, orders)
    configure_animation_controls(args, constructor, fig)
    return fig


def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
    # Build subplot specs
    specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)]

    # Default row/column widths uniform
    column_widths = [1.0] * ncols
    row_heights = [1.0] * nrows
    facet_col_wrap = args.get("facet_col_wrap", 0)

    # Build column_widths/row_heights
    if subplot_type == "xy":
        if args.get("marginal_x") is not None:
            if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
                main_size = 0.74
            else:
                main_size = 0.84

            row_heights = [main_size] * (nrows - 1) + [1 - main_size]
            vertical_spacing = 0.01
        elif facet_col_wrap:
            vertical_spacing = args.get("facet_row_spacing") or 0.07
        else:
            vertical_spacing = args.get("facet_row_spacing") or 0.03

        if args.get("marginal_y") is not None:
            if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
                main_size = 0.74
            else:
                main_size = 0.84

            column_widths = [main_size] * (ncols - 1) + [1 - main_size]
            horizontal_spacing = 0.005
        else:
            horizontal_spacing = args.get("facet_col_spacing") or 0.02
    else:
        # Other subplot types:
        #   'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None
        #
        # We can customize subplot spacing per type once we enable faceting
        # for all plot types
        if facet_col_wrap:
            vertical_spacing = args.get("facet_row_spacing") or 0.07
        else:
            vertical_spacing = args.get("facet_row_spacing") or 0.03
        horizontal_spacing = args.get("facet_col_spacing") or 0.02

    if facet_col_wrap:
        subplot_labels = [None] * nrows * ncols
        while len(col_labels) < nrows * ncols:
            col_labels.append(None)
        for i in range(nrows):
            for j in range(ncols):
                subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j]

    def _spacing_error_translator(e, direction, facet_arg):
        """
        Translates the spacing errors thrown by the underlying make_subplots
        routine into one that describes an argument adjustable through px.
        """
        if ("%s spacing" % (direction,)) in e.args[0]:
            e.args = (
                e.args[0]
                + """
Use the {facet_arg} argument to adjust this spacing.""".format(
                    facet_arg=facet_arg
                ),
            )
            raise e

    # Create figure with subplots
    try:
        fig = make_subplots(
            rows=nrows,
            cols=ncols,
            specs=specs,
            shared_xaxes="all",
            shared_yaxes="all",
            row_titles=[] if facet_col_wrap else list(reversed(row_labels)),
            column_titles=[] if facet_col_wrap else col_labels,
            subplot_titles=subplot_labels if facet_col_wrap else [],
            horizontal_spacing=horizontal_spacing,
            vertical_spacing=vertical_spacing,
            row_heights=row_heights,
            column_widths=column_widths,
            start_cell="bottom-left",
        )
    except ValueError as e:
        _spacing_error_translator(e, "Horizontal", "facet_col_spacing")
        _spacing_error_translator(e, "Vertical", "facet_row_spacing")

    # Remove explicit font size of row/col titles so template can take over
    for annot in fig.layout.annotations:
        annot.update(font=None)

    return fig
