Skip to content
Merged
177 changes: 142 additions & 35 deletions ffi/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func validateCallbackSignature(typ reflect.Type) {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Uintptr, reflect.Float32, reflect.Float64,
reflect.Pointer, reflect.UnsafePointer, reflect.Bool:
reflect.Pointer, reflect.UnsafePointer, reflect.Bool, reflect.Struct:
// Valid types
default:
panic("ffi: unsupported callback argument type: " + argType.Kind().String())
Expand Down Expand Up @@ -192,6 +192,46 @@ func callbackWrap(a *callbackArgs) {
var intIdx int // Current integer register index (0-5)
stackIdx := numFloatRegs + numIntRegs // Stack arguments start after registers

getFloat := func() float64 {
// Float64 comes from XMM register
if floatIdx < numFloatRegs {
bits := frame[floatIdx]
floatIdx++
return *(*float64)(unsafe.Pointer(&bits))
} else {
bits := frame[stackIdx]
stackIdx++
return *(*float64)(unsafe.Pointer(&bits))
}
}

getInt := func() uintptr {
var value uintptr
if intIdx < numIntRegs {
value = frame[numFloatRegs+intIdx]
intIdx++
} else {
// All register slots are used: value is on the stack
value = frame[stackIdx]
stackIdx++
}
return value
}

// Write only some bytes to a struct to avoid overwrite.
writePartial := func(dest unsafe.Pointer, size uintptr, value uintptr) {
switch {
case size == 1:
*(*uint8)(dest) = uint8(value)
case size == 2:
*(*uint16)(dest) = uint16(value)
case size <= 4:
*(*uint32)(dest) = uint32(value)
default:
*(*uintptr)(dest) = value
}
}

// Build argument slice for reflection Call
args := make([]reflect.Value, numArgs)

Expand All @@ -201,34 +241,10 @@ func callbackWrap(a *callbackArgs) {

switch argType.Kind() {
case reflect.Float32:
// Float32 comes from XMM register (stored as float64)
if floatIdx < numFloatRegs {
// Read as uintptr, reinterpret as float64 bits
bits := frame[floatIdx]
f64 := *(*float64)(unsafe.Pointer(&bits))
val = reflect.ValueOf(float32(f64))
floatIdx++
} else {
// From stack
bits := frame[stackIdx]
f64 := *(*float64)(unsafe.Pointer(&bits))
val = reflect.ValueOf(float32(f64))
stackIdx++
}
val = reflect.ValueOf(float32(getFloat()))

case reflect.Float64:
// Float64 comes from XMM register
if floatIdx < numFloatRegs {
bits := frame[floatIdx]
f64 := *(*float64)(unsafe.Pointer(&bits))
val = reflect.ValueOf(f64)
floatIdx++
} else {
bits := frame[stackIdx]
f64 := *(*float64)(unsafe.Pointer(&bits))
val = reflect.ValueOf(f64)
stackIdx++
}
val = reflect.ValueOf(getFloat())

case reflect.Bool:
// Bool comes from integer register
Expand Down Expand Up @@ -272,16 +288,65 @@ func callbackWrap(a *callbackArgs) {
stackIdx++
}

case reflect.Struct:
sz := argType.Size()
structData := make([]byte, max(sz, 8))
var valPtr unsafe.Pointer
if sz > 0 {
valPtr = unsafe.Pointer(&structData[0])
}

switch {
case sz == 0:
// Zero-size struct: no fields to populate

case sz <= 8:
// Single eightbyte: INTEGER if any member is not float/double, else SSE.
if isStructAllFloats(argType) {
*(*float64)(valPtr) = getFloat()
} else {
writePartial(valPtr, sz, getInt())
}
case sz <= 16:
// Two eightbytes: classify each independently.
// System V ABI §3.2.3: INTEGER wins over SSE within an eightbyte.
if classifyEightbyte(argType, 0, 8) {
*(*float64)(valPtr) = getFloat()
} else {
*(*uintptr)(valPtr) = getInt()
}
remaining := sz - 8
valPtr = unsafe.Add(valPtr, 8)
if classifyEightbyte(argType, 8, sz) {
*(*float64)(valPtr) = getFloat()
} else {
writePartial(valPtr, remaining, getInt())
}
default:
// MEMORY class (> 16 bytes): copy from the stack in 8-byte chunks.
// Per SysV ABI §3.2.3: MEMORY class structs bypass registers entirely.
nChunks := (sz + 7) / 8
for k := range nChunks {
chunkPtr := unsafe.Add(valPtr, k*8)
chunk := frame[stackIdx]
bytesLeft := sz - k*8
if bytesLeft >= 8 {
*(*uintptr)(chunkPtr) = chunk
} else {
writePartial(chunkPtr, bytesLeft, chunk)
}
stackIdx++
}
}
val = reflect.New(argType)
valByteSlice := unsafe.Slice((*byte)(val.UnsafePointer()), sz)
copy(valByteSlice, structData)
val = val.Elem()

default:
// All other integer types (int, uint, int32, etc.)
if intIdx < numIntRegs {
pos := numFloatRegs + intIdx
val = reflect.NewAt(argType, unsafe.Pointer(&frame[pos])).Elem()
intIdx++
} else {
val = reflect.NewAt(argType, unsafe.Pointer(&frame[stackIdx])).Elem()
stackIdx++
}
value := getInt()
val = reflect.NewAt(argType, unsafe.Pointer(&value)).Elem()
}

args[i] = val
Expand Down Expand Up @@ -322,3 +387,45 @@ func callbackWrap(a *callbackArgs) {
//go:linkname _callbackTrampoline github.com/go-webgpu/goffi/ffi.callbackTrampoline
var _callbackTrampoline byte
var trampolineBaseAddr = uintptr(unsafe.Pointer(&_callbackTrampoline))

// isStructAllFloats returns true if every member of a flat struct is float or double.
// Per System V AMD64 ABI §3.2.3: if any member in an eightbyte is INTEGER class,
// the entire eightbyte is classified as INTEGER (INTEGER wins over SSE).
func isStructAllFloats(structType reflect.Type) bool {
if structType.NumField() == 0 {
return false
}

for i := range structType.NumField() {
field := structType.Field(i)
if field.Type.Kind() == reflect.Struct {
if !isStructAllFloats(field.Type) {
return false
}
} else if field.Type.Kind() != reflect.Float32 && field.Type.Kind() != reflect.Float64 {
return false
}
}
return true
}

// classifyEightbyte returns true if all struct fields whose offset falls within
// [startOff, endOff) are SSE types (float or double).
// Returns false if any field in the range is INTEGER class, or if no fields lie in the range.
//
// CAUTION: Does not currently support nested structs.
func classifyEightbyte(structType reflect.Type, startOff, endOff uintptr) bool {
allFloat := true
hasField := false
for i := range structType.NumField() {
field := structType.Field(i)
if field.Offset >= startOff && field.Offset < endOff {
hasField = true
if field.Type.Kind() != reflect.Float32 && field.Type.Kind() != reflect.Float64 {
allFloat = false
break
}
}
}
return hasField && allFloat
}
Loading
Loading