@@ -539,6 +539,226 @@ func TestDisconnect(t *testing.T) {
539
539
}
540
540
}
541
541
542
+ type mockKeyingTransport struct {
543
+ packetConn
544
+ kexInitAllowed chan struct {}
545
+ kexInitSent chan struct {}
546
+ }
547
+
548
+ func (n * mockKeyingTransport ) prepareKeyChange (* algorithms , * kexResult ) error {
549
+ return nil
550
+ }
551
+
552
+ func (n * mockKeyingTransport ) writePacket (packet []byte ) error {
553
+ if packet [0 ] == msgKexInit {
554
+ <- n .kexInitAllowed
555
+ n .kexInitSent <- struct {}{}
556
+ }
557
+ return n .packetConn .writePacket (packet )
558
+ }
559
+
560
+ func (n * mockKeyingTransport ) readPacket () ([]byte , error ) {
561
+ return n .packetConn .readPacket ()
562
+ }
563
+
564
+ func (n * mockKeyingTransport ) setStrictMode () error { return nil }
565
+
566
+ func (n * mockKeyingTransport ) setInitialKEXDone () {}
567
+
568
+ func TestHandshakePendingPacketsWait (t * testing.T ) {
569
+ a , b := memPipe ()
570
+
571
+ trS := & mockKeyingTransport {
572
+ packetConn : a ,
573
+ kexInitAllowed : make (chan struct {}, 2 ),
574
+ kexInitSent : make (chan struct {}, 2 ),
575
+ }
576
+ // Allow the first KEX.
577
+ trS .kexInitAllowed <- struct {}{}
578
+
579
+ trC := & mockKeyingTransport {
580
+ packetConn : b ,
581
+ kexInitAllowed : make (chan struct {}, 2 ),
582
+ kexInitSent : make (chan struct {}, 2 ),
583
+ }
584
+ // Allow the first KEX.
585
+ trC .kexInitAllowed <- struct {}{}
586
+
587
+ clientConf := & ClientConfig {
588
+ HostKeyCallback : InsecureIgnoreHostKey (),
589
+ }
590
+ clientConf .SetDefaults ()
591
+
592
+ v := []byte ("version" )
593
+ client := newClientTransport (trC , v , v , clientConf , "addr" , nil )
594
+
595
+ serverConf := & ServerConfig {}
596
+ serverConf .AddHostKey (testSigners ["ecdsa" ])
597
+ serverConf .AddHostKey (testSigners ["rsa" ])
598
+ serverConf .SetDefaults ()
599
+ server := newServerTransport (trS , v , v , serverConf )
600
+
601
+ if err := server .waitSession (); err != nil {
602
+ t .Fatalf ("server.waitSession: %v" , err )
603
+ }
604
+ if err := client .waitSession (); err != nil {
605
+ t .Fatalf ("client.waitSession: %v" , err )
606
+ }
607
+
608
+ <- trC .kexInitSent
609
+ <- trS .kexInitSent
610
+
611
+ // Allow and request new KEX server side.
612
+ trS .kexInitAllowed <- struct {}{}
613
+ server .requestKeyExchange ()
614
+ // Wait until the KEX init is sent.
615
+ <- trS .kexInitSent
616
+ // The client is not allowed to respond to the KEX, so writes will be
617
+ // blocked on the server side once the packets queue is full.
618
+ for i := 0 ; i < maxPendingPackets ; i ++ {
619
+ p := []byte {msgRequestSuccess , byte (i )}
620
+ if err := server .writePacket (p ); err != nil {
621
+ t .Errorf ("unexpected write error: %v" , err )
622
+ }
623
+ }
624
+ // The packets queue is now full, the next write will block.
625
+ server .mu .Lock ()
626
+ if len (server .pendingPackets ) != maxPendingPackets {
627
+ t .Errorf ("unexpected pending packets size; got: %d, want: %d" , len (server .pendingPackets ), maxPendingPackets )
628
+ }
629
+ server .mu .Unlock ()
630
+
631
+ writeDone := make (chan struct {})
632
+ go func () {
633
+ defer close (writeDone )
634
+
635
+ p := []byte {msgRequestSuccess , byte (65 )}
636
+ // This write will block until KEX completes.
637
+ err := server .writePacket (p )
638
+ if err != nil {
639
+ t .Errorf ("unexpected write error: %v" , err )
640
+ }
641
+ }()
642
+
643
+ // Consume packets on the client side
644
+ readDone := make (chan bool )
645
+ go func () {
646
+ defer close (readDone )
647
+
648
+ for {
649
+ if _ , err := client .readPacket (); err != nil {
650
+ if err != io .EOF {
651
+ t .Errorf ("unexpected read error: %v" , err )
652
+ }
653
+ break
654
+ }
655
+ }
656
+ }()
657
+
658
+ // Allow the client to reply to the KEX and so unblock the write goroutine.
659
+ trC .kexInitAllowed <- struct {}{}
660
+ <- trC .kexInitSent
661
+ <- writeDone
662
+ // Close the client to unblock the read goroutine.
663
+ client .Close ()
664
+ <- readDone
665
+ server .Close ()
666
+ }
667
+
668
+ func TestHandshakePendingPacketsError (t * testing.T ) {
669
+ a , b := memPipe ()
670
+
671
+ trS := & mockKeyingTransport {
672
+ packetConn : a ,
673
+ kexInitAllowed : make (chan struct {}, 2 ),
674
+ kexInitSent : make (chan struct {}, 2 ),
675
+ }
676
+ // Allow the first KEX.
677
+ trS .kexInitAllowed <- struct {}{}
678
+
679
+ trC := & mockKeyingTransport {
680
+ packetConn : b ,
681
+ kexInitAllowed : make (chan struct {}, 2 ),
682
+ kexInitSent : make (chan struct {}, 2 ),
683
+ }
684
+ // Allow the first KEX.
685
+ trC .kexInitAllowed <- struct {}{}
686
+
687
+ clientConf := & ClientConfig {
688
+ HostKeyCallback : InsecureIgnoreHostKey (),
689
+ }
690
+ clientConf .SetDefaults ()
691
+
692
+ v := []byte ("version" )
693
+ client := newClientTransport (trC , v , v , clientConf , "addr" , nil )
694
+
695
+ serverConf := & ServerConfig {}
696
+ serverConf .AddHostKey (testSigners ["ecdsa" ])
697
+ serverConf .AddHostKey (testSigners ["rsa" ])
698
+ serverConf .SetDefaults ()
699
+ server := newServerTransport (trS , v , v , serverConf )
700
+
701
+ if err := server .waitSession (); err != nil {
702
+ t .Fatalf ("server.waitSession: %v" , err )
703
+ }
704
+ if err := client .waitSession (); err != nil {
705
+ t .Fatalf ("client.waitSession: %v" , err )
706
+ }
707
+
708
+ <- trC .kexInitSent
709
+ <- trS .kexInitSent
710
+
711
+ // Allow and request new KEX server side.
712
+ trS .kexInitAllowed <- struct {}{}
713
+ server .requestKeyExchange ()
714
+ // Wait until the KEX init is sent.
715
+ <- trS .kexInitSent
716
+ // The client is not allowed to respond to the KEX, so writes will be
717
+ // blocked on the server side once the packets queue is full.
718
+ for i := 0 ; i < maxPendingPackets ; i ++ {
719
+ p := []byte {msgRequestSuccess , byte (i )}
720
+ if err := server .writePacket (p ); err != nil {
721
+ t .Errorf ("unexpected write error: %v" , err )
722
+ }
723
+ }
724
+ // The packets queue is now full, the next write will block.
725
+ writeDone := make (chan struct {})
726
+ go func () {
727
+ defer close (writeDone )
728
+
729
+ p := []byte {msgRequestSuccess , byte (65 )}
730
+ // This write will block until KEX completes.
731
+ err := server .writePacket (p )
732
+ if err != io .EOF {
733
+ t .Errorf ("unexpected write error: %v" , err )
734
+ }
735
+ }()
736
+
737
+ // Consume packets on the client side
738
+ readDone := make (chan bool )
739
+ go func () {
740
+ defer close (readDone )
741
+
742
+ for {
743
+ if _ , err := client .readPacket (); err != nil {
744
+ if err != io .EOF {
745
+ t .Errorf ("unexpected read error: %v" , err )
746
+ }
747
+ break
748
+ }
749
+ }
750
+ }()
751
+
752
+ // Close the server to unblock the write after an error
753
+ server .Close ()
754
+ <- writeDone
755
+ // Unblock the pending write and close the client to unblock the read
756
+ // goroutine.
757
+ trC .kexInitAllowed <- struct {}{}
758
+ client .Close ()
759
+ <- readDone
760
+ }
761
+
542
762
func TestHandshakeRekeyDefault (t * testing.T ) {
543
763
clientConf := & ClientConfig {
544
764
Config : Config {
0 commit comments