diff --git a/src/libfuturize/fixes/__init__.py b/src/libfuturize/fixes/__init__.py index 7de304da..0b562501 100644 --- a/src/libfuturize/fixes/__init__.py +++ b/src/libfuturize/fixes/__init__.py @@ -50,7 +50,7 @@ 'lib2to3.fixes.fix_getcwdu', # 'lib2to3.fixes.fix_imports', # called by libfuturize.fixes.fix_future_standard_library # 'lib2to3.fixes.fix_imports2', # we don't handle this yet (dbm) - 'lib2to3.fixes.fix_input', + # 'lib2to3.fixes.fix_input', # Called conditionally by libfuturize.fixes.fix_input 'lib2to3.fixes.fix_itertools', 'lib2to3.fixes.fix_itertools_imports', 'lib2to3.fixes.fix_filter', @@ -86,6 +86,7 @@ 'libfuturize.fixes.fix_future_builtins', 'libfuturize.fixes.fix_future_standard_library', 'libfuturize.fixes.fix_future_standard_library_urllib', + 'libfuturize.fixes.fix_input', 'libfuturize.fixes.fix_metaclass', 'libpasteurize.fixes.fix_newstyle', 'libfuturize.fixes.fix_object', diff --git a/src/libfuturize/fixes/fix_input.py b/src/libfuturize/fixes/fix_input.py new file mode 100644 index 00000000..8a43882e --- /dev/null +++ b/src/libfuturize/fixes/fix_input.py @@ -0,0 +1,32 @@ +""" +Fixer for input. + +Does a check for `from builtins import input` before running the lib2to3 fixer. +The fixer will not run when the input is already present. + + +this: + a = input() +becomes: + from builtins import input + a = eval(input()) + +and this: + from builtins import input + a = input() +becomes (no change): + from builtins import input + a = input() +""" + +import lib2to3.fixes.fix_input +from lib2to3.fixer_util import does_tree_import + + +class FixInput(lib2to3.fixes.fix_input.FixInput): + def transform(self, node, results): + + if does_tree_import('builtins', 'input', node): + return + + return super(FixInput, self).transform(node, results) diff --git a/tests/test_future/test_futurize.py b/tests/test_future/test_futurize.py index f2201141..0d7c42de 100644 --- a/tests/test_future/test_futurize.py +++ b/tests/test_future/test_futurize.py @@ -436,6 +436,27 @@ def test_import_builtins(self): """ self.convert_check(before, after, ignore_imports=False, run=False) + def test_input_without_import(self): + before = """ + a = input() + """ + after = """ + from builtins import input + a = eval(input()) + """ + self.convert_check(before, after, ignore_imports=False, run=False) + + def test_input_with_import(self): + before = """ + from builtins import input + a = input() + """ + after = """ + from builtins import input + a = input() + """ + self.convert_check(before, after, ignore_imports=False, run=False) + def test_xrange(self): """ The ``from builtins import range`` line was being added to the