diff --git a/tests/test_cli.py b/tests/test_cli.py index 1d5882fe5ce..f688fcaf0fd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -491,6 +491,14 @@ def test_import_tmuxinator(cli_args, inputs, tmpdir, monkeypatch): ['\n', 'y\n', './la.yaml\n', 'y\n'], ), (['freeze'], ['\n', 'y\n', './exists.yaml\n', './la.yaml\n', 'y\n']), # Exists + ( # Create a new one + ['freeze', 'mysession', '--force'], + ['\n', 'y\n', './la.yaml\n', 'y\n'] + ), + ( # Imply current session if not entered + ['freeze', '--force'], + ['\n', 'y\n', './la.yaml\n', 'y\n'], + ), ], ) def test_freeze(server, cli_args, inputs, tmpdir, monkeypatch): @@ -508,6 +516,34 @@ def test_freeze(server, cli_args, inputs, tmpdir, monkeypatch): assert tmpdir.join('la.yaml').check() +@pytest.mark.parametrize( + "cli_args,inputs", + [ + ( # Overwrite + ['freeze', 'mysession', '--force'], + ['\n', 'y\n', './exists.yaml\n', 'y\n'], + ), + ( # Imply current session if not entered + ['freeze', '--force'], + ['\n', 'y\n', './exists.yaml\n', 'y\n'] + ), + ], +) +def test_freeze_overwrite(server, cli_args, inputs, tmpdir, monkeypatch): + monkeypatch.setenv('HOME', str(tmpdir)) + tmpdir.join('exists.yaml').ensure() + + server.new_session(session_name='mysession') + + with tmpdir.as_cwd(): + runner = CliRunner() + # Use tmux server (socket name) used in the test + cli_args = cli_args + ['-L', server.socket_name] + out = runner.invoke(cli.cli, cli_args, input=''.join(inputs)) + print(out.output) + assert tmpdir.join('exists.yaml').check() + + def test_get_abs_path(tmpdir): expect = str(tmpdir) with tmpdir.as_cwd(): diff --git a/tmuxp/cli.py b/tmuxp/cli.py index 35374493568..0af95f25fd9 100644 --- a/tmuxp/cli.py +++ b/tmuxp/cli.py @@ -659,7 +659,8 @@ def startup(config_dir): @click.argument('session_name', nargs=1, required=False) @click.option('-S', 'socket_path', help='pass-through for tmux -S') @click.option('-L', 'socket_name', help='pass-through for tmux -L') -def command_freeze(session_name, socket_name, socket_path): +@click.option('--force', 'force', help='overwrite the config file', is_flag=True) +def command_freeze(session_name, socket_name, socket_path, force): """Snapshot a session into a config. If SESSION_NAME is provided, snapshot that session. Otherwise, use the @@ -716,7 +717,7 @@ def command_freeze(session_name, socket_name, socket_path): dest_prompt = click.prompt( 'Save to: %s' % save_to, value_proc=get_abs_path, default=save_to ) - if os.path.exists(dest_prompt): + if not force and os.path.exists(dest_prompt): print('%s exists. Pick a new filename.' % dest_prompt) continue