diff --git a/plotly_express/_chart_types.py b/plotly_express/_chart_types.py index c6f2fc0..c44eb09 100644 --- a/plotly_express/_chart_types.py +++ b/plotly_express/_chart_types.py @@ -4,7 +4,7 @@ def scatter( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -56,7 +56,7 @@ def scatter( def density_contour( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -166,7 +166,7 @@ def density_heatmap( def line( - data_frame, + data_frame=None, x=None, y=None, line_group=None, @@ -210,7 +210,7 @@ def line( def area( - data_frame, + data_frame=None, x=None, y=None, line_group=None, @@ -254,7 +254,7 @@ def area( def bar( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -303,7 +303,7 @@ def bar( def histogram( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -360,7 +360,7 @@ def histogram( def violin( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -410,7 +410,7 @@ def violin( def box( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -504,7 +504,7 @@ def strip( def scatter_3d( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -554,7 +554,7 @@ def scatter_3d( def line_3d( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -599,7 +599,7 @@ def line_3d( def scatter_ternary( - data_frame, + data_frame=None, a=None, b=None, c=None, @@ -637,7 +637,7 @@ def scatter_ternary( def line_ternary( - data_frame, + data_frame=None, a=None, b=None, c=None, @@ -671,7 +671,7 @@ def line_ternary( def scatter_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -714,7 +714,7 @@ def scatter_polar( def line_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -753,7 +753,7 @@ def line_polar( def bar_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -790,7 +790,7 @@ def bar_polar( def choropleth( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -829,7 +829,7 @@ def choropleth( def scatter_geo( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -872,7 +872,7 @@ def scatter_geo( def line_geo( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -913,7 +913,7 @@ def line_geo( def scatter_mapbox( - data_frame, + data_frame=None, lat=None, lon=None, color=None, @@ -948,7 +948,7 @@ def scatter_mapbox( def line_mapbox( - data_frame, + data_frame=None, lat=None, lon=None, color=None, @@ -978,7 +978,7 @@ def line_mapbox( def scatter_matrix( - data_frame, + data_frame=None, dimensions=None, color=None, symbol=None, @@ -1015,7 +1015,7 @@ def scatter_matrix( def parallel_coordinates( - data_frame, + data_frame=None, dimensions=None, color=None, labels={}, @@ -1039,7 +1039,7 @@ def parallel_coordinates( def parallel_categories( - data_frame, + data_frame=None, dimensions=None, color=None, labels={}, diff --git a/plotly_express/_core.py b/plotly_express/_core.py index 339734e..a26094d 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -659,6 +659,28 @@ def apply_default_cascade(args): if args["color_discrete_sequence"] is None: args["color_discrete_sequence"] = qualitative.Plotly +def has_value(collection, key): + return collection.get(key, None) is not None + +def build_dataframe(args, attrables): + """ + Constructs an implicit dataframe and modifies `args` in-place. + + `attrables` is a list of keys into `args`, all of whose corresponding + values are converted into columns of a dataframe. + + Used to be support calls to plotting function that elide a dataframe argument; + for example `scatter(x=[1,2], y=[3,4])`. + """ + data_frame_columns = {} + for field in attrables: + if not has_value(args, field): + continue + data_frame_columns[field] = args[field] + # This sets the label of an attribute to be the name of the attribute. + args[field] = field + args["data_frame"] = pandas.DataFrame(data_frame_columns) + return args def infer_config(args, constructor, trace_patch): attrables = ( @@ -669,11 +691,12 @@ def infer_config(args, constructor, trace_patch): ) array_attrables = ["dimensions", "hover_data"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] - + all_attrables = attrables + group_attrables + ["color"] + if not has_value(args, "data_frame"): + build_dataframe(args, all_attrables) df_columns = args["data_frame"].columns - - for attr in attrables + group_attrables + ["color"]: - if attr in args and args[attr] is not None: + for attr in all_attrables: + if has_value(args, attr): maybe_col_list = [args[attr]] if attr not in array_attrables else args[attr] for maybe_col in maybe_col_list: try: @@ -790,7 +813,6 @@ def get_orderings(args, grouper, grouped): return orders, group_names - def make_figure(args, constructor, trace_patch={}, layout_patch={}): apply_default_cascade(args) trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(