Skip to content

[mypy] Fixes typing errors in other/dpll #5759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 3, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions other/davisb_putnamb_logemannb_loveland.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from __future__ import annotations

import random
from typing import Iterable


class Clause:
Expand All @@ -27,12 +28,12 @@ class Clause:
True
"""

def __init__(self, literals: list[int]) -> None:
def __init__(self, literals: list[str]) -> None:
"""
Represent the literals and an assignment in a clause."
"""
# Assign all literals to None initially
self.literals = {literal: None for literal in literals}
self.literals: dict[str, bool | None] = {literal: None for literal in literals}

def __str__(self) -> str:
"""
Expand All @@ -52,7 +53,7 @@ def __len__(self) -> int:
"""
return len(self.literals)

def assign(self, model: dict[str, bool]) -> None:
def assign(self, model: dict[str, bool | None]) -> None:
"""
Assign values to literals of the clause as given by model.
"""
Expand All @@ -68,7 +69,7 @@ def assign(self, model: dict[str, bool]) -> None:
value = not value
self.literals[literal] = value

def evaluate(self, model: dict[str, bool]) -> bool:
def evaluate(self, model: dict[str, bool | None]) -> bool | None:
"""
Evaluates the clause with the assignments in model.
This has the following steps:
Expand Down Expand Up @@ -97,7 +98,7 @@ class Formula:
{{A1, A2, A3'}, {A5', A2', A1}} is ((A1 v A2 v A3') and (A5' v A2' v A1))
"""

def __init__(self, clauses: list[Clause]) -> None:
def __init__(self, clauses: Iterable[Clause]) -> None:
"""
Represent the number of clauses and the clauses themselves.
"""
Expand Down Expand Up @@ -139,14 +140,14 @@ def generate_formula() -> Formula:
"""
Randomly generate a formula.
"""
clauses = set()
clauses: set[Clause] = set()
no_of_clauses = random.randint(1, 10)
while len(clauses) < no_of_clauses:
clauses.add(generate_clause())
return Formula(set(clauses))
return Formula(clauses)


def generate_parameters(formula: Formula) -> (list[Clause], list[str]):
def generate_parameters(formula: Formula) -> tuple[list[Clause], list[str]]:
"""
Return the clauses and symbols from a formula.
A symbol is the uncomplemented form of a literal.
Expand All @@ -173,8 +174,8 @@ def generate_parameters(formula: Formula) -> (list[Clause], list[str]):


def find_pure_symbols(
clauses: list[Clause], symbols: list[str], model: dict[str, bool]
) -> (list[str], dict[str, bool]):
clauses: list[Clause], symbols: list[str], model: dict[str, bool | None]
) -> tuple[list[str], dict[str, bool | None]]:
"""
Return pure symbols and their values to satisfy clause.
Pure symbols are symbols in a formula that exist only
Expand All @@ -198,11 +199,11 @@ def find_pure_symbols(
{'A1': True, 'A2': False, 'A3': True, 'A5': False}
"""
pure_symbols = []
assignment = dict()
assignment: dict[str, bool | None] = dict()
literals = []

for clause in clauses:
if clause.evaluate(model) is True:
if clause.evaluate(model):
continue
for literal in clause.literals:
literals.append(literal)
Expand All @@ -225,8 +226,8 @@ def find_pure_symbols(


def find_unit_clauses(
clauses: list[Clause], model: dict[str, bool]
) -> (list[str], dict[str, bool]):
clauses: list[Clause], model: dict[str, bool | None]
) -> tuple[list[str], dict[str, bool | None]]:
"""
Returns the unit symbols and their values to satisfy clause.
Unit symbols are symbols in a formula that are:
Expand Down Expand Up @@ -263,7 +264,7 @@ def find_unit_clauses(
Ncount += 1
if Fcount == len(clause) - 1 and Ncount == 1:
unit_symbols.append(sym)
assignment = dict()
assignment: dict[str, bool | None] = dict()
for i in unit_symbols:
symbol = i[:2]
assignment[symbol] = len(i) == 2
Expand All @@ -273,8 +274,8 @@ def find_unit_clauses(


def dpll_algorithm(
clauses: list[Clause], symbols: list[str], model: dict[str, bool]
) -> (bool, dict[str, bool]):
clauses: list[Clause], symbols: list[str], model: dict[str, bool | None]
) -> tuple[bool | None, dict[str, bool | None] | None]:
"""
Returns the model if the formula is satisfiable, else None
This has the following steps:
Expand Down