From b848853a7d58ceccde6e24b649a47bac2b3247b5 Mon Sep 17 00:00:00 2001 From: jeff1010322 Date: Fri, 1 Nov 2024 14:29:05 -0400 Subject: [PATCH] fix: func reflect impl for named functions, updated tests Co-authored-by: Ethan Lewis --- src/reflect/all_test.go | 4 +++ src/reflect/type.go | 57 ++++++++++++++++++++++++++++++--------- src/reflect/type_test.go | 2 ++ src/reflect/value_test.go | 53 ++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 12 deletions(-) diff --git a/src/reflect/all_test.go b/src/reflect/all_test.go index 436bc00341..da30d314a5 100644 --- a/src/reflect/all_test.go +++ b/src/reflect/all_test.go @@ -3312,6 +3312,8 @@ func TestMethodPkgPath(t *testing.T) { } } +*/ + func TestVariadicType(t *testing.T) { // Test example from Type documentation. var f func(x int, y ...float64) @@ -3335,6 +3337,8 @@ func TestVariadicType(t *testing.T) { t.Error(s) } +/* + type inner struct { x int } diff --git a/src/reflect/type.go b/src/reflect/type.go index 315dbe51e7..97f43e6efb 100644 --- a/src/reflect/type.go +++ b/src/reflect/type.go @@ -608,14 +608,22 @@ func (t *rawType) String() string { s += " }" return s case Func: + isVariadic := t.IsVariadic() f := "func(" for i := 0; i < t.NumIn(); i++ { if i > 0 { f += ", " } - f += t.In(i).String() + + input := t.In(i).String() + if isVariadic && i == t.NumIn()-1 { + f += "..." + input = input[2:] + } + f += input } + f += ") " var rets string @@ -1069,14 +1077,24 @@ func (t *rawType) ConvertibleTo(u Type) bool { } func (t *rawType) IsVariadic() bool { - // need to test if bool mapped to int set by compiler + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + t = named.elem + } + if t.Kind() != Func { panic("reflect: IsVariadic of non-func type") } + return (*funcType)(unsafe.Pointer(t)).variadic } func (t *rawType) NumIn() int { + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + return int((*funcType)(unsafe.Pointer(named.elem)).inCount) + } + if t.Kind() != Func { panic("reflect: NumIn of non-func type") } @@ -1084,6 +1102,11 @@ func (t *rawType) NumIn() int { } func (t *rawType) NumOut() int { + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + return int((*funcType)(unsafe.Pointer(named.elem)).outCount) + } + if t.Kind() != Func { panic("reflect: NumOut of non-func type") } @@ -1163,33 +1186,43 @@ func addChecked(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer { } func (t *rawType) In(i int) Type { + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + t = named.elem + } + if t.Kind() != Func { panic(errTypeField) } - descriptor := (*funcType)(unsafe.Pointer(t.underlying())) - if uint(i) >= uint(descriptor.inCount) { + fType := (*funcType)(unsafe.Pointer(t)) + if uint(i) >= uint(fType.inCount) { panic("reflect: field index out of range") } - pointer := (unsafe.Add(unsafe.Pointer(&descriptor.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) - return (*rawType)(*(**rawType)(pointer)) + pointer := (unsafe.Add(unsafe.Pointer(&fType.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) + return (*(**rawType)(pointer)) } func (t *rawType) Out(i int) Type { + if t.isNamed() { + named := (*namedType)(unsafe.Pointer(t)) + t = named.elem + } + if t.Kind() != Func { panic(errTypeField) } - descriptor := (*funcType)(unsafe.Pointer(t.underlying())) - if uint(i) >= uint(descriptor.outCount) { + fType := (*funcType)(unsafe.Pointer(t)) + + if uint(i) >= uint(fType.outCount) { panic("reflect: field index out of range") } // Shift the index by the number of input parameters. - i = i + int((*funcType)(unsafe.Pointer(t)).inCount) - - pointer := (unsafe.Add(unsafe.Pointer(&descriptor.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) - return (*rawType)(*(**rawType)(pointer)) + i = i + int(fType.inCount) + pointer := (unsafe.Add(unsafe.Pointer(&fType.fields[0]), uintptr(i)*unsafe.Sizeof(unsafe.Pointer(nil)))) + return (*(**rawType)(pointer)) } // OverflowComplex reports whether the complex128 x cannot be represented by type t. diff --git a/src/reflect/type_test.go b/src/reflect/type_test.go index 75784f9666..dd54b0175c 100644 --- a/src/reflect/type_test.go +++ b/src/reflect/type_test.go @@ -13,6 +13,7 @@ func TestTypeFor(t *testing.T) { type ( mystring string myiface interface{} + myfunc func() ) testcases := []struct { @@ -25,6 +26,7 @@ func TestTypeFor(t *testing.T) { {new(mystring), reflect.TypeFor[mystring]()}, {new(any), reflect.TypeFor[any]()}, {new(myiface), reflect.TypeFor[myiface]()}, + {new(myfunc), reflect.TypeFor[myfunc]()}, } for _, tc := range testcases { want := reflect.ValueOf(tc.wantFrom).Elem().Type() diff --git a/src/reflect/value_test.go b/src/reflect/value_test.go index 508b358ad9..2ef4f0f35d 100644 --- a/src/reflect/value_test.go +++ b/src/reflect/value_test.go @@ -487,6 +487,59 @@ func TestTinyStruct(t *testing.T) { } } +func TestTinyFunc(t *testing.T) { + type barStruct struct { + QuxString string + BazInt int + } + + type foobar func(bar barStruct, x int, v ...string) string + + var fb foobar + + reffb := TypeOf(fb) + + numIn := reffb.NumIn() + if want := 3; numIn != want { + t.Errorf("NumIn=%v, want %v", numIn, want) + } + + numOut := reffb.NumOut() + if want := 1; numOut != want { + t.Errorf("NumOut=%v, want %v", numOut, want) + } + + in0 := reffb.In(0) + if want := TypeOf(barStruct{}); in0 != want { + t.Errorf("In(0)=%v, want %v", in0, want) + } + + in1 := reffb.In(1) + if want := TypeOf(0); in1 != want { + t.Errorf("In(1)=%v, want %v", in1, want) + } + + in2 := reffb.In(2) + if want := TypeOf([]string{}); in2 != want { + t.Errorf("In(2)=%v, want %v", in2, want) + } + + out0 := reffb.Out(0) + if want := TypeOf(""); out0 != want { + t.Errorf("Out(0)=%v, want %v", out0, want) + } + + isVariadic := reffb.IsVariadic() + if want := true; isVariadic != want { + t.Errorf("IsVariadic=%v, want %v", isVariadic, want) + } + + if got, want := reffb.String(), "reflect_test.foobar"; got != want { + t.Errorf("Value.String()=%v, want %v", got, want) + } + +} + func TestTinyZero(t *testing.T) { s := "hello, world" sptr := &s