diff --git a/intcode.go b/intcode.go index 390f5d0..1efba95 100644 --- a/intcode.go +++ b/intcode.go @@ -1,6 +1,7 @@ package aoc2019 import ( + "context" "log" "reflect" "strconv" @@ -91,24 +92,45 @@ func parseIntcode(code string) ([]int64, error) { return out, nil } +type intcodeParams struct { + // Intcode program to execute + Code []int64 + // Context to execute the program in (program might hang on input if context is closed during input directive) + Context context.Context + // Channel / Callback to query on input directive + In interface{} + // Channel to use for output directive + Out chan int64 +} + func executeIntcode(code []int64, in interface{}, out chan int64) ([]int64, error) { + return executeIntcodeWithParams(intcodeParams{ + Code: code, + Context: context.Background(), + In: in, + Out: out, + }) +} + +func executeIntcodeWithParams(params intcodeParams) ([]int64, error) { var ( + code = params.Code inCB func() (int64, error) pos int64 relativeBase int64 ) - if out != nil { - defer close(out) + if params.Out != nil { + defer close(params.Out) } - switch in.(type) { + switch params.In.(type) { case nil: inCB = func() (int64, error) { return 0, errors.New("No input available") } case chan int64: - inCB = func() (int64, error) { return <-(in.(chan int64)), nil } + inCB = func() (int64, error) { return <-(params.In.(chan int64)), nil } case func() (int64, error): - inCB = in.(func() (int64, error)) + inCB = params.In.(func() (int64, error)) default: return nil, errors.New("Unsupported input type") } @@ -169,6 +191,10 @@ func executeIntcode(code []int64, in interface{}, out chan int64) ([]int64, erro return nil, errors.Errorf("Code position out of bounds: %d (len=%d)", pos, len(code)) } + if err := params.Context.Err(); err != nil { + return nil, errors.Wrap(err, "Context closed") + } + // Position is expected to be an OpCode op := parseOpCode(code[pos]) @@ -195,7 +221,7 @@ func executeIntcode(code []int64, in interface{}, out chan int64) ([]int64, erro pos += 2 case opCodeTypeOutput: // p1 => out - out <- getParamValue(1, op) + params.Out <- getParamValue(1, op) pos += 2 case opCodeTypeJumpIfTrue: // p1 != 0 => jmp