Skip to content

Commit f264623

Browse files
authored
Merge pull request #43 from manta1130/feature/twosat
Implement twosat
2 parents 5f6bdac + d6a667c commit f264623

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ pub use string::{
3232
lcp_array, lcp_array_arbitrary, suffix_array, suffix_array_arbitrary, suffix_array_manual,
3333
z_algorithm, z_algorithm_arbitrary,
3434
};
35+
pub use twosat::TwoSat;

src/twosat.rs

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,114 @@
1+
use crate::internal_scc;
12

3+
pub struct TwoSat {
4+
n: usize,
5+
scc: internal_scc::SccGraph,
6+
answer: Vec<bool>,
7+
}
8+
impl TwoSat {
9+
pub fn new(n: usize) -> Self {
10+
TwoSat {
11+
n,
12+
answer: vec![false; n],
13+
scc: internal_scc::SccGraph::new(2 * n),
14+
}
15+
}
16+
pub fn add_clause(&mut self, i: usize, f: bool, j: usize, g: bool) {
17+
assert!(i < self.n && j < self.n);
18+
self.scc.add_edge(2 * i + !f as usize, 2 * j + g as usize);
19+
self.scc.add_edge(2 * j + !g as usize, 2 * i + f as usize);
20+
}
21+
pub fn satisfiable(&mut self) -> bool {
22+
let id = self.scc.scc_ids().1;
23+
for i in 0..self.n {
24+
if id[2 * i] == id[2 * i + 1] {
25+
return false;
26+
}
27+
self.answer[i] = id[2 * i] < id[2 * i + 1];
28+
}
29+
true
30+
}
31+
pub fn answer(&self) -> &[bool] {
32+
&self.answer
33+
}
34+
}
35+
36+
#[cfg(test)]
37+
mod tests {
38+
#![allow(clippy::many_single_char_names)]
39+
use super::*;
40+
#[test]
41+
fn solve_alpc_h_sample1() {
42+
// https://atcoder.jp/contests/practice2/tasks/practice2_h
43+
44+
let (n, d) = (3, 2);
45+
let x = [1, 2, 0i32];
46+
let y = [4, 5, 6];
47+
48+
let mut t = TwoSat::new(n);
49+
50+
for i in 0..n {
51+
for j in i + 1..n {
52+
if (x[i] - x[j]).abs() < d {
53+
t.add_clause(i, false, j, false);
54+
}
55+
if (x[i] - y[j]).abs() < d {
56+
t.add_clause(i, false, j, true);
57+
}
58+
if (y[i] - x[j]).abs() < d {
59+
t.add_clause(i, true, j, false);
60+
}
61+
if (y[i] - y[j]).abs() < d {
62+
t.add_clause(i, true, j, true);
63+
}
64+
}
65+
}
66+
assert!(t.satisfiable());
67+
let answer = t.answer();
68+
let mut res = vec![];
69+
for (i, &v) in answer.iter().enumerate() {
70+
if v {
71+
res.push(x[i])
72+
} else {
73+
res.push(y[i]);
74+
}
75+
}
76+
77+
//Check the min distance between flags
78+
res.sort();
79+
let mut min_distance = i32::max_value();
80+
for i in 1..res.len() {
81+
min_distance = std::cmp::min(min_distance, res[i] - res[i - 1]);
82+
}
83+
assert!(min_distance >= d);
84+
}
85+
86+
#[test]
87+
fn solve_alpc_h_sample2() {
88+
// https://atcoder.jp/contests/practice2/tasks/practice2_h
89+
90+
let (n, d) = (3, 3);
91+
let x = [1, 2, 0i32];
92+
let y = [4, 5, 6];
93+
94+
let mut t = TwoSat::new(n);
95+
96+
for i in 0..n {
97+
for j in i + 1..n {
98+
if (x[i] - x[j]).abs() < d {
99+
t.add_clause(i, false, j, false);
100+
}
101+
if (x[i] - y[j]).abs() < d {
102+
t.add_clause(i, false, j, true);
103+
}
104+
if (y[i] - x[j]).abs() < d {
105+
t.add_clause(i, true, j, false);
106+
}
107+
if (y[i] - y[j]).abs() < d {
108+
t.add_clause(i, true, j, true);
109+
}
110+
}
111+
}
112+
assert!(!t.satisfiable());
113+
}
114+
}

0 commit comments

Comments
 (0)