diff --git a/R/layers2traces.R b/R/layers2traces.R
index 0387d34c5c..f85b1d8f3e 100644
--- a/R/layers2traces.R
+++ b/R/layers2traces.R
@@ -613,6 +613,35 @@ to_basic.GeomQuantile <- function(data, prestats_data, layout, params, p, ...){
dat
}
+#' @export
+to_basic.GeomStratum <- function (data, ...) {
+ to_basic.GeomRect(data, ...)
+}
+
+#' @export
+to_basic.GeomAlluvium <- function (data, ...) {
+ prefix_class(transform_alluvium(data), "GeomPolygon")
+}
+
+# transform the alluvium data into the corresponding polygons
+transform_alluvium <- function(data) {
+ data <- data[order(data$x), ]
+
+ if(unique(data$colour) == 0) data$colour <- NULL
+
+ unused_aes <- ! names(data) %in% c("x", "y", "ymin", "ymax")
+
+ row_number <- nrow(data)
+
+ data_rev <- data[rev(seq_len(row_number)), ]
+
+ structure(rbind(
+ cbind(x = data$x, y = data$ymin, data[unused_aes]),
+ cbind(x = data$x[row_number], y = data$ymin[row_number], data[row_number, unused_aes]),
+ cbind(x = data_rev$x, y = data_rev$ymax, data_rev[unused_aes])
+ ), class = class(data))
+}
+
#' @export
to_basic.default <- function(data, prestats_data, layout, params, p, ...) {
data
diff --git a/tests/testthat/_snaps/ggalluvial/alluvium.svg b/tests/testthat/_snaps/ggalluvial/alluvium.svg
new file mode 100644
index 0000000000..cfc05b43fa
--- /dev/null
+++ b/tests/testthat/_snaps/ggalluvial/alluvium.svg
@@ -0,0 +1 @@
+
diff --git a/tests/testthat/_snaps/ggalluvial/stratum-alluvium.svg b/tests/testthat/_snaps/ggalluvial/stratum-alluvium.svg
new file mode 100644
index 0000000000..f50c0f8437
--- /dev/null
+++ b/tests/testthat/_snaps/ggalluvial/stratum-alluvium.svg
@@ -0,0 +1 @@
+
diff --git a/tests/testthat/_snaps/ggalluvial/stratum.svg b/tests/testthat/_snaps/ggalluvial/stratum.svg
new file mode 100644
index 0000000000..e70d5d2762
--- /dev/null
+++ b/tests/testthat/_snaps/ggalluvial/stratum.svg
@@ -0,0 +1 @@
+
diff --git a/tests/testthat/test-ggalluvial.R b/tests/testthat/test-ggalluvial.R
new file mode 100644
index 0000000000..a0a87d74a9
--- /dev/null
+++ b/tests/testthat/test-ggalluvial.R
@@ -0,0 +1,46 @@
+library(ggalluvial)
+
+test_that("using both of `geom_alluvium` and `geom_stratum` gives the correct output", {
+ p <- ggplot(as.data.frame(Titanic),
+ aes(y = Freq,
+ axis1 = Survived, axis2 = Sex, axis3 = Class)) +
+ geom_alluvium(aes(fill = Class),
+ width = 0, knot.pos = 0, reverse = FALSE) +
+ guides(fill = "none") +
+ geom_stratum(width = 1/8, reverse = FALSE) +
+ geom_text(stat = "stratum", aes(label = after_stat(stratum)),
+ reverse = FALSE) +
+ scale_x_continuous(breaks = 1:3, labels = c("Survived", "Sex", "Class")) +
+ coord_flip() +
+ ggtitle("Titanic survival by class and sex")
+ # write_plotly_svg(p, "tests/testthat/_snaps/ggalluvial/stratum-alluvium.svg")
+ expect_doppelganger(ggplotly(p), "stratum-alluvium")
+})
+
+test_that("using `geom_stratum` gives the correct output", {
+ p <- ggplot(as.data.frame(Titanic),
+ aes(y = Freq,
+ axis1 = Survived, axis2 = Sex, axis3 = Class)) +
+ geom_stratum(width = 1/8, reverse = FALSE) +
+ geom_text(stat = "stratum", aes(label = after_stat(stratum)),
+ reverse = FALSE) +
+ scale_x_continuous(breaks = 1:3, labels = c("Survived", "Sex", "Class")) +
+ coord_flip() +
+ ggtitle("Titanic survival by class and sex")
+ #write_plotly_svg(p, "tests/testthat/_snaps/ggalluvial/stratum.svg")
+ expect_doppelganger(ggplotly(p), "stratum")
+})
+
+test_that("using `geom_alluvium` gives the correct output", {
+ p <- ggplot(as.data.frame(Titanic),
+ aes(y = Freq,
+ axis1 = Survived, axis2 = Sex, axis3 = Class)) +
+ geom_alluvium(aes(fill = Class),
+ width = 0, knot.pos = 0, reverse = FALSE) +
+ guides(fill = "none") +
+ scale_x_continuous(breaks = 1:3, labels = c("Survived", "Sex", "Class")) +
+ coord_flip() +
+ ggtitle("Titanic survival by class and sex")
+ #write_plotly_svg(p, "tests/testthat/_snaps/ggalluvial/alluvium.svg")
+ expect_doppelganger(ggplotly(p), "alluvium")
+})
\ No newline at end of file